OmniSciDB  0bd2ec9cf4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
UDFCompiler.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2019 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "UDFCompiler.h"
18 #include <clang/AST/AST.h>
19 #include <clang/AST/ASTConsumer.h>
20 #include <clang/AST/RecursiveASTVisitor.h>
21 #include <clang/Driver/Compilation.h>
22 #include <clang/Driver/Driver.h>
23 #include <clang/Frontend/CompilerInstance.h>
24 #include <clang/Frontend/FrontendActions.h>
25 #include <clang/Frontend/TextDiagnosticPrinter.h>
26 #include <clang/Parse/ParseAST.h>
27 #include <clang/Tooling/CommonOptionsParser.h>
28 #include <clang/Tooling/Tooling.h>
29 #include <llvm/Support/Program.h>
30 #include <llvm/Support/raw_ostream.h>
31 #include <boost/process/search_path.hpp>
32 #include <memory>
33 #include "Execute.h"
34 #include "Shared/Logger.h"
35 
36 using namespace clang;
37 using namespace clang::tooling;
38 
39 static llvm::cl::OptionCategory ToolingSampleCategory("UDF Tooling");
40 
41 namespace {
42 
43 // By implementing RecursiveASTVisitor, we can specify which AST nodes
44 // we're interested in by overriding relevant methods.
45 
46 class FunctionDeclVisitor : public RecursiveASTVisitor<FunctionDeclVisitor> {
47  public:
48  FunctionDeclVisitor(llvm::raw_fd_ostream& ast_file,
49  SourceManager& s_manager,
50  ASTContext& context)
51  : ast_file_(ast_file), source_manager_(s_manager), context_(context) {
52  source_manager_.getDiagnostics().setShowColors();
53  }
54 
55  bool VisitFunctionDecl(FunctionDecl* f) {
56  // Only function definitions (with bodies), not declarations.
57  if (f->hasBody()) {
58  if (getMainFileName() == getFuncDeclFileName(f)) {
59  auto printing_policy = context_.getPrintingPolicy();
60  printing_policy.FullyQualifiedName = 1;
61  printing_policy.UseVoidForZeroParams = 1;
62  printing_policy.PolishForDeclaration = 1;
63  printing_policy.TerseOutput = 1;
64  f->print(ast_file_, printing_policy);
65  ast_file_ << "\n";
66  }
67  }
68 
69  return true;
70  }
71 
72  private:
73  std::string getMainFileName() const {
74  auto f_entry = source_manager_.getFileEntryForID(source_manager_.getMainFileID());
75  return f_entry->getName().str();
76  }
77 
78  std::string getFuncDeclFileName(FunctionDecl* f) const {
79  SourceLocation spell_loc = source_manager_.getSpellingLoc(f->getLocation());
80  PresumedLoc p_loc = source_manager_.getPresumedLoc(spell_loc);
81 
82  return std::string(p_loc.getFilename());
83  }
84 
85  private:
86  llvm::raw_fd_ostream& ast_file_;
87  SourceManager& source_manager_;
88  ASTContext& context_;
89 };
90 
91 // Implementation of the ASTConsumer interface for reading an AST produced
92 // by the Clang parser.
93 class DeclASTConsumer : public ASTConsumer {
94  public:
95  DeclASTConsumer(llvm::raw_fd_ostream& ast_file,
96  SourceManager& s_manager,
97  ASTContext& context)
98  : visitor_(ast_file, s_manager, context) {}
99 
100  // Override the method that gets called for each parsed top-level
101  // declaration.
102  bool HandleTopLevelDecl(DeclGroupRef decl_reference) override {
103  for (DeclGroupRef::iterator b = decl_reference.begin(), e = decl_reference.end();
104  b != e;
105  ++b) {
106  // Traverse the declaration using our AST visitor.
107  visitor_.TraverseDecl(*b);
108  }
109  return true;
110  }
111 
112  private:
114 };
115 
116 // For each source file provided to the tool, a new FrontendAction is created.
117 class HandleDeclAction : public ASTFrontendAction {
118  public:
119  HandleDeclAction(llvm::raw_fd_ostream& ast_file) : ast_file_(ast_file) {}
120 
121  ~HandleDeclAction() override {}
122 
123  std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance& instance,
124  StringRef file) override {
125  return llvm::make_unique<DeclASTConsumer>(
126  ast_file_, instance.getSourceManager(), instance.getASTContext());
127  }
128 
129  private:
130  llvm::raw_fd_ostream& ast_file_;
131 };
132 
133 class ToolFactory : public FrontendActionFactory {
134  public:
135  ToolFactory(llvm::raw_fd_ostream& ast_file) : ast_file_(ast_file) {}
136 
137  clang::FrontendAction* create() override { return new HandleDeclAction(ast_file_); }
138 
139  private:
140  llvm::raw_fd_ostream& ast_file_;
141 };
142 
143 bool on_search_path(const std::string file) {
144  boost::filesystem::path p = boost::process::search_path(file);
145  return boost::filesystem::exists(p);
146 }
147 } // namespace
148 
150  : clang_path(llvm::sys::findProgramByName("clang++").get().c_str())
151  , diag_options(new DiagnosticOptions())
152  , diag_client(new TextDiagnosticPrinter(llvm::errs(), diag_options.get()))
153  , diag_id(new clang::DiagnosticIDs())
154  , diags(diag_id, diag_options.get(), diag_client)
155  , diag_client_owner(diags.takeClient())
156  , the_driver(clang_path.c_str(), llvm::sys::getDefaultTargetTriple(), diags) {}
157 
158 std::string UdfCompiler::removeFileExtension(const std::string& path) {
159  if (path == "." || path == "..") {
160  return path;
161  }
162 
163  size_t pos = path.find_last_of("\\/.");
164  if (pos != std::string::npos && path[pos] == '.') {
165  return path.substr(0, pos);
166  }
167 
168  return path;
169 }
170 
171 std::string UdfCompiler::getFileExt(std::string& s) {
172  size_t i = s.rfind('.', s.length());
173  if (1 != std::string::npos) {
174  return (s.substr(i + 1, s.length() - i));
175  }
176 }
177 
178 void UdfCompiler::replaceExtn(std::string& s, const std::string& new_ext) {
179  std::string::size_type i = s.rfind('.', s.length());
180 
181  if (i != std::string::npos) {
182  s.replace(i + 1, getFileExt(s).length(), new_ext);
183  }
184 }
185 
186 std::string UdfCompiler::genGpuIrFilename(const char* udf_file_name) {
187  std::string gpu_file_name(removeFileExtension(udf_file_name));
188 
189  gpu_file_name += "_gpu.bc";
190  return gpu_file_name;
191 }
192 
193 std::string UdfCompiler::genCpuIrFilename(const char* udf_fileName) {
194  std::string cpu_file_name(removeFileExtension(udf_fileName));
195 
196  cpu_file_name += "_cpu.bc";
197  return cpu_file_name;
198 }
199 
200 int UdfCompiler::compileFromCommandLine(std::vector<const char*>& command_line) {
201  UdfClangDriver compiler_driver;
202  auto the_driver(compiler_driver.getClangDriver());
203 
204  the_driver->CCPrintOptions = 0;
205  std::unique_ptr<driver::Compilation> compilation(
206  the_driver->BuildCompilation(command_line));
207 
208  if (!compilation) {
209  LOG(FATAL) << "failed to build compilation object!\n";
210  }
211 
212  llvm::SmallVector<std::pair<int, const driver::Command*>, 10> failing_commands;
213  int res = the_driver->ExecuteCompilation(*compilation, failing_commands);
214 
215  if (res < 0) {
216  for (const std::pair<int, const driver::Command*>& p : failing_commands) {
217  if (p.first) {
218  the_driver->generateCompilationDiagnostics(*compilation, *p.second);
219  }
220  }
221  }
222 
223  return res;
224 }
225 
226 int UdfCompiler::compileToGpuByteCode(const char* udf_file_name, bool cpu_mode) {
227  auto a_path = llvm::sys::findProgramByName("clang++");
228  auto clang_path = a_path.get();
229 
230  std::string gpu_outName(genGpuIrFilename(udf_file_name));
231 
232  std::vector<const char*> command_line{clang_path.c_str(),
233  "-c",
234  "-O2",
235  "-emit-llvm",
236  "-o",
237  gpu_outName.c_str(),
238  "-std=c++14"};
239 
240  // If we are not compiling for cpu mode, then target the gpu
241  // Otherwise assume we can generic ir that will
242  // be translated to gpu code during target code generation
243  if (!cpu_mode) {
244  command_line.emplace_back("--cuda-gpu-arch=sm_30");
245  command_line.emplace_back("--cuda-device-only");
246  command_line.emplace_back("-xcuda");
247  }
248 
249  command_line.emplace_back(udf_file_name);
250 
251  return compileFromCommandLine(command_line);
252 }
253 
254 int UdfCompiler::compileToCpuByteCode(const char* udf_file_name) {
255  auto a_path = llvm::sys::findProgramByName("clang++");
256  auto clang_path = a_path.get();
257 
258  std::string cpu_outName(genCpuIrFilename(udf_file_name));
259 
260  std::vector<const char*> command_line{clang_path.c_str(),
261  "-c",
262  "-O2",
263  "-emit-llvm",
264  "-o",
265  cpu_outName.c_str(),
266  "-std=c++14",
267  udf_file_name};
268 
269  return compileFromCommandLine(command_line);
270 }
271 
272 int UdfCompiler::parseToAst(const char* file_name) {
273  UdfClangDriver the_driver;
274  std::string resource_path = the_driver.getClangDriver()->ResourceDir;
275  std::string include_option =
276  std::string("-I") + resource_path + std::string("/include");
277 
278  const char arg0[] = "astparser";
279  const char* arg1 = file_name;
280  const char arg2[] = "--";
281  const char* arg3 = include_option.c_str();
282  const char* arg_vector[] = {arg0, arg1, arg2, arg3};
283 
284  int num_args = sizeof(arg_vector) / sizeof(arg_vector[0]);
285  CommonOptionsParser op(num_args, arg_vector, ToolingSampleCategory);
286  ClangTool tool(op.getCompilations(), op.getSourcePathList());
287 
288  std::string out_name(file_name);
289  std::string file_ext("ast");
290  replaceExtn(out_name, file_ext);
291 
292  std::error_code out_error_info;
293  llvm::raw_fd_ostream out_file(
294  llvm::StringRef(out_name), out_error_info, llvm::sys::fs::F_None);
295 
296  auto factory = llvm::make_unique<ToolFactory>(out_file);
297  return tool.run(factory.get());
298 }
299 
300 const std::string& UdfCompiler::getAstFileName() const {
301  return udf_ast_file_name_;
302 }
303 
304 UdfCompiler::UdfCompiler(const std::string& file_name)
305  : udf_file_name_(file_name), udf_ast_file_name_(file_name) {
307 }
308 
310  std::string cpu_ir_file(genCpuIrFilename(udf_file_name_.c_str()));
311 
312  VLOG(1) << "UDFCompiler cpu bc file = " << cpu_ir_file << std::endl;
313 
314  read_udf_cpu_module(cpu_ir_file);
315 }
316 
318  std::string gpu_ir_file(genGpuIrFilename(udf_file_name_.c_str()));
319 
320  VLOG(1) << "UDFCompiler gpu bc file = " << gpu_ir_file << std::endl;
321 
322  read_udf_gpu_module(gpu_ir_file);
323 }
324 
328 }
329 
331  int gpu_compile_result = 1;
332 
333  if (on_search_path("nvcc")) {
334  gpu_compile_result = compileToGpuByteCode(udf_file_name_.c_str(), false);
335  }
336 
337  // If gpu compilation fails but cpu compilation has succeeded, try compiling
338  // for the cpu with the assumption the user does not have the CUDA toolkit
339  // installed
340  if (gpu_compile_result != 0) {
341  gpu_compile_result = compileToGpuByteCode(udf_file_name_.c_str(), true);
342  }
343 
344  return gpu_compile_result;
345 }
346 
348  if (on_search_path("clang++")) {
349  LOG(INFO) << "UDFCompiler filename to compiler: " << udf_file_name_ << std::endl;
350  if (!boost::filesystem::exists(udf_file_name_)) {
351  LOG(FATAL) << "User defined function file " << udf_file_name_ << " does not exist.";
352  return 1;
353  }
354 
355  auto ast_result = parseToAst(udf_file_name_.c_str());
356 
357  if (ast_result == 0) {
358  // Compile udf file to generate cpu and gpu bytecode files
359 
360  int cpu_compile_result = compileToCpuByteCode(udf_file_name_.c_str());
361 #ifdef HAVE_CUDA
362  int gpu_compile_result = 1;
363 #endif
364 
365  if (cpu_compile_result == 0) {
367 #ifdef HAVE_CUDA
368  gpu_compile_result = compileForGpu();
369 
370  if (gpu_compile_result == 0) {
372  } else {
373  LOG(FATAL) << "Unable to compile UDF file for gpu" << std::endl;
374  return 1;
375  }
376 #endif
377  } else {
378  LOG(FATAL) << "Unable to compile UDF file for cpu" << std::endl;
379  return 1;
380  }
381  } else {
382  LOG(FATAL) << "Unable to create AST file for udf compilation" << std::endl;
383  return 1;
384  }
385  } else {
386  LOG(FATAL) << "Unable to compile udfs due to absence of clang++" << std::endl;
387  return 1;
388  }
389 
390  return 0;
391 }
ToolFactory(llvm::raw_fd_ostream &ast_file)
clang::driver::Driver * getClangDriver()
Definition: UDFCompiler.h:39
std::string genCpuIrFilename(const char *udf_file_name)
int compileToCpuByteCode(const char *udf_file_name)
void readCompiledModules()
UdfCompiler(const std::string &)
const std::string & getAstFileName() const
#define LOG(tag)
Definition: Logger.h:188
void readCpuCompiledModule()
void read_udf_cpu_module(const std::string &udf_ir_filename)
static llvm::cl::OptionCategory ToolingSampleCategory("UDF Tooling")
void read_udf_gpu_module(const std::string &udf_ir_filename)
External interface for parsing AST and bitcode files.
std::string genGpuIrFilename(const char *udf_file_name)
std::string removeFileExtension(const std::string &path)
DeclASTConsumer(llvm::raw_fd_ostream &ast_file, SourceManager &s_manager, ASTContext &context)
Definition: UDFCompiler.cpp:95
int compileToGpuByteCode(const char *udf_file_name, bool cpu_mode)
std::unique_ptr< ASTConsumer > CreateASTConsumer(CompilerInstance &instance, StringRef file) override
std::string udf_ast_file_name_
Definition: UDFCompiler.h:74
const int8_t const int64_t const uint64_t const int32_t const int64_t int64_t uint32_t const int64_t int32_t * error_code
int parseToAst(const char *file_name)
bool HandleTopLevelDecl(DeclGroupRef decl_reference) override
std::string udf_file_name_
Definition: UDFCompiler.h:73
int compileUdf()
void replaceExtn(std::string &s, const std::string &new_ext)
std::string getFuncDeclFileName(FunctionDecl *f) const
Definition: UDFCompiler.cpp:78
bool on_search_path(const std::string file)
void readGpuCompiledModule()
int compileFromCommandLine(std::vector< const char * > &command_line)
clang::FrontendAction * create() override
std::string getFileExt(std::string &s)
int compileForGpu()
FunctionDeclVisitor(llvm::raw_fd_ostream &ast_file, SourceManager &s_manager, ASTContext &context)
Definition: UDFCompiler.cpp:48
#define VLOG(n)
Definition: Logger.h:283