OmniSciDB  340b00dbf6
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
TableFunctionCompilationContext.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2019 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <llvm/IR/Verifier.h>
20 #include <llvm/Support/raw_os_ostream.h>
21 #include <algorithm>
22 #include <boost/algorithm/string.hpp>
23 
25 
26 extern std::unique_ptr<llvm::Module> g_rt_module;
27 extern std::unique_ptr<llvm::Module> rt_udf_cpu_module;
28 extern std::unique_ptr<llvm::Module> rt_udf_gpu_module;
29 
30 namespace {
31 
32 llvm::Function* generate_entry_point(const CgenState* cgen_state) {
33  auto& ctx = cgen_state->context_;
34  const auto pi8_type = llvm::PointerType::get(get_int_type(8, ctx), 0);
35  const auto ppi8_type = llvm::PointerType::get(pi8_type, 0);
36  const auto pi64_type = llvm::PointerType::get(get_int_type(64, ctx), 0);
37  const auto ppi64_type = llvm::PointerType::get(pi64_type, 0);
38  const auto i32_type = get_int_type(32, ctx);
39 
40  const auto func_type = llvm::FunctionType::get(
41  i32_type, {ppi8_type, pi64_type, ppi64_type, pi64_type}, false);
42 
43  auto func = llvm::Function::Create(func_type,
44  llvm::Function::ExternalLinkage,
45  "call_table_function",
46  cgen_state->module_);
47  auto arg_it = func->arg_begin();
48  const auto input_cols_arg = &*arg_it;
49  input_cols_arg->setName("input_col_buffers");
50  const auto input_row_count = &*(++arg_it);
51  input_row_count->setName("input_row_count");
52  const auto output_buffers = &*(++arg_it);
53  output_buffers->setName("output_buffers");
54  const auto output_row_count = &*(++arg_it);
55  output_row_count->setName("output_row_count");
56  return func;
57 }
58 
60  llvm::LLVMContext& ctx) {
61  if (elem_ti.is_fp()) {
62  switch (elem_ti.get_size()) {
63  case 4:
64  return llvm::Type::getFloatPtrTy(ctx);
65  case 8:
66  return llvm::Type::getDoublePtrTy(ctx);
67  }
68  }
69  if (elem_ti.is_boolean()) {
70  return llvm::Type::getInt8PtrTy(ctx);
71  }
72  CHECK(elem_ti.is_integer());
73  switch (elem_ti.get_size()) {
74  case 1:
75  return llvm::Type::getInt8PtrTy(ctx);
76  case 2:
77  return llvm::Type::getInt16PtrTy(ctx);
78  case 4:
79  return llvm::Type::getInt32PtrTy(ctx);
80  case 8:
81  return llvm::Type::getInt64PtrTy(ctx);
82  }
83  LOG(FATAL) << "get_llvm_type_from_sql_column_type: not implemented for "
84  << ::toString(elem_ti);
85  return nullptr;
86 }
87 
88 llvm::Value* alloc_column(std::string col_name,
89  const SQLTypeInfo& data_target_info,
90  llvm::Value* data_ptr,
91  llvm::Value* data_size,
92  llvm::LLVMContext& ctx,
93  llvm::IRBuilder<>& ir_builder,
94  bool byval) {
95  /*
96  Creates a new Column instance of given element type and initialize
97  its data ptr and sz members. If data ptr or sz are unspecified
98  (have nullptr values) then the corresponding members are
99  initialized with NULL and -1, respectively.
100  */
101  llvm::Type* data_ptr_llvm_type =
102  get_llvm_type_from_sql_column_type(data_target_info, ctx);
103  llvm::StructType* col_struct_type =
104  llvm::StructType::get(ctx,
105  {
106  data_ptr_llvm_type, /* T* ptr */
107  llvm::Type::getInt64Ty(ctx) /* int64_t sz */
108  });
109  auto col = ir_builder.CreateAlloca(col_struct_type);
110  col->setName(col_name);
111  auto col_ptr_ptr = ir_builder.CreateStructGEP(col_struct_type, col, 0);
112  auto col_sz_ptr = ir_builder.CreateStructGEP(col_struct_type, col, 1);
113  col_ptr_ptr->setName(col_name + ".ptr");
114  col_sz_ptr->setName(col_name + ".sz");
115 
116  if (data_ptr != nullptr) {
117  if (data_ptr->getType() == data_ptr_llvm_type->getPointerElementType()) {
118  ir_builder.CreateStore(data_ptr, col_ptr_ptr);
119  } else {
120  auto tmp = ir_builder.CreateBitCast(data_ptr, data_ptr_llvm_type);
121  ir_builder.CreateStore(tmp, col_ptr_ptr);
122  }
123  } else {
124  ir_builder.CreateStore(llvm::Constant::getNullValue(data_ptr_llvm_type), col_ptr_ptr);
125  }
126  if (data_size != nullptr) {
127  auto data_size_type = data_size->getType();
128  if (data_size_type->isPointerTy()) {
129  CHECK(data_size_type->getPointerElementType()->isIntegerTy(64));
130  auto val = ir_builder.CreateLoad(data_size);
131  ir_builder.CreateStore(val, col_sz_ptr);
132  } else {
133  CHECK(data_size_type->isIntegerTy(64));
134  ir_builder.CreateStore(data_size, col_sz_ptr);
135  }
136  } else {
137  auto const_minus1 = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), -1, true);
138  ir_builder.CreateStore(const_minus1, col_sz_ptr);
139  }
140 
141  if (byval) {
142  return ir_builder.CreateLoad(col);
143  } else {
144  auto col_ptr = ir_builder.CreatePointerCast(
145  col_ptr_ptr, llvm::PointerType::get(llvm::Type::getInt8Ty(ctx), 0));
146  col_ptr->setName(col_name + "_ptr");
147  return col_ptr;
148  }
149 }
150 
151 } // namespace
152 
154  : cgen_state_(std::make_unique<CgenState>(/*num_query_infos=*/0,
155  /*contains_left_deep_outer_join=*/false)) {
156  auto cgen_state = cgen_state_.get();
157  CHECK(cgen_state);
158 
159  std::unique_ptr<llvm::Module> module(runtime_module_shallow_copy(cgen_state));
160  cgen_state->module_ = module.get();
161 
163  module_ = std::move(module);
164 }
165 
167  const CompilationOptions& co,
168  Executor* executor) {
169  generateEntryPoint(exe_unit);
172  }
173  finalize(co, executor);
174 }
175 
177  const TableFunctionExecutionUnit& exe_unit) {
179  auto arg_it = entry_point_func_->arg_begin();
180  const auto input_cols_arg = &*arg_it;
181  const auto input_row_count = &*(++arg_it);
182  const auto output_buffers_arg = &*(++arg_it);
183  const auto output_row_count_ptr = &*(++arg_it);
184 
185  auto cgen_state = cgen_state_.get();
186  CHECK(cgen_state);
187  auto& ctx = cgen_state->context_;
188 
189  const auto bb_entry = llvm::BasicBlock::Create(ctx, ".entry", entry_point_func_, 0);
190  cgen_state->ir_builder_.SetInsertPoint(bb_entry);
191 
192  const auto bb_exit = llvm::BasicBlock::Create(ctx, ".exit", entry_point_func_);
193 
194  const auto func_body_bb = llvm::BasicBlock::Create(
195  ctx, ".func_body", cgen_state->ir_builder_.GetInsertBlock()->getParent());
196  cgen_state->ir_builder_.SetInsertPoint(func_body_bb);
197 
198  auto col_heads = generate_column_heads_load(
199  exe_unit.input_exprs.size(), input_cols_arg, cgen_state->ir_builder_, ctx);
200  CHECK_EQ(exe_unit.input_exprs.size(), col_heads.size());
201 
202  // The column arguments of C++ UDTFs processed by clang must be
203  // passed by reference, see rbc issue 200.
204  auto pass_column_by_value = exe_unit.table_func.isRuntime();
205  std::vector<llvm::Value*> func_args;
206  for (size_t i = 0; i < exe_unit.input_exprs.size(); i++) {
207  const auto& expr = exe_unit.input_exprs[i];
208  const auto& ti = expr->get_type_info();
209  if (ti.is_fp()) {
210  auto r = cgen_state->ir_builder_.CreateBitCast(
211  col_heads[i], llvm::PointerType::get(get_fp_type(get_bit_width(ti), ctx), 0));
212  func_args.push_back(cgen_state->ir_builder_.CreateLoad(r));
213  } else if (ti.is_integer()) {
214  auto r = cgen_state->ir_builder_.CreateBitCast(
215  col_heads[i], llvm::PointerType::get(get_int_type(get_bit_width(ti), ctx), 0));
216  func_args.push_back(cgen_state->ir_builder_.CreateLoad(r));
217  } else if (ti.is_column()) {
218  auto col = alloc_column(std::string("input_col.") + std::to_string(i),
219  ti.get_elem_type(),
220  col_heads[i],
221  input_row_count,
222  ctx,
223  cgen_state_->ir_builder_,
224  pass_column_by_value);
225  func_args.push_back(col);
226  } else {
227  throw std::runtime_error(
228  "Only integer and floating point columns or scalars are supported as inputs to "
229  "table "
230  "functions, got " +
231  ti.get_type_name());
232  }
233  }
234  std::vector<llvm::Value*> output_col_args;
235  for (size_t i = 0; i < exe_unit.target_exprs.size(); i++) {
236  auto output_load = cgen_state->ir_builder_.CreateLoad(
237  cgen_state->ir_builder_.CreateGEP(output_buffers_arg, cgen_state_->llInt(i)));
238  const auto& expr = exe_unit.target_exprs[i];
239  const auto& ti = expr->get_type_info();
240  CHECK(!ti.is_column()); // UDTF output column type is its data type
241  auto col = alloc_column(std::string("output_col.") + std::to_string(i),
242  ti,
243  output_load,
244  output_row_count_ptr,
245  ctx,
246  cgen_state_->ir_builder_,
247  pass_column_by_value);
248  func_args.push_back(col);
249  }
250  auto func_name = exe_unit.table_func.getName();
251  boost::algorithm::to_lower(func_name);
252  const auto table_func_return =
253  cgen_state->emitExternalCall(func_name, get_int_type(32, ctx), func_args);
254  table_func_return->setName("table_func_ret");
255 
256  // If table_func_return is non-negative then store the value in
257  // output_row_count and return zero. Otherwise, return
258  // table_func_return that negative value contains the error code.
259  const auto bb_exit_0 = llvm::BasicBlock::Create(ctx, ".exit0", entry_point_func_);
260 
261  auto const_zero = llvm::ConstantInt::get(table_func_return->getType(), 0, true);
262  auto is_ok = cgen_state_->ir_builder_.CreateICmpSGE(table_func_return, const_zero);
263  cgen_state_->ir_builder_.CreateCondBr(is_ok, bb_exit_0, bb_exit);
264 
265  cgen_state_->ir_builder_.SetInsertPoint(bb_exit_0);
266  auto r = cgen_state->ir_builder_.CreateIntCast(
267  table_func_return, get_int_type(64, ctx), true);
268  cgen_state->ir_builder_.CreateStore(r, output_row_count_ptr);
269  cgen_state->ir_builder_.CreateRet(const_zero);
270 
271  cgen_state->ir_builder_.SetInsertPoint(bb_exit);
272  cgen_state->ir_builder_.CreateRet(table_func_return);
273 
274  cgen_state->ir_builder_.SetInsertPoint(bb_entry);
275  cgen_state->ir_builder_.CreateBr(func_body_bb);
276 
277  /*
278  std::cout << "=================================" << std::endl;
279  entry_point_func_->print(llvm::outs());
280  std::cout << "=================================" << std::endl;
281  */
282 
284 }
285 
288  std::vector<llvm::Type*> arg_types;
289  arg_types.reserve(entry_point_func_->arg_size());
290  std::for_each(entry_point_func_->arg_begin(),
291  entry_point_func_->arg_end(),
292  [&arg_types](const auto& arg) { arg_types.push_back(arg.getType()); });
293  CHECK_EQ(arg_types.size(), entry_point_func_->arg_size());
294 
295  auto cgen_state = cgen_state_.get();
296  CHECK(cgen_state);
297  auto& ctx = cgen_state->context_;
298 
299  std::vector<llvm::Type*> wrapper_arg_types(arg_types.size() + 1);
300  wrapper_arg_types[0] = llvm::PointerType::get(get_int_type(32, ctx), 0);
301  wrapper_arg_types[1] = arg_types[0];
302 
303  for (size_t i = 1; i < arg_types.size(); ++i) {
304  wrapper_arg_types[i + 1] = arg_types[i];
305  }
306 
307  auto wrapper_ft =
308  llvm::FunctionType::get(llvm::Type::getVoidTy(ctx), wrapper_arg_types, false);
309  kernel_func_ = llvm::Function::Create(wrapper_ft,
310  llvm::Function::ExternalLinkage,
311  "table_func_kernel",
312  cgen_state->module_);
313 
314  auto wrapper_bb_entry = llvm::BasicBlock::Create(ctx, ".entry", kernel_func_, 0);
315  llvm::IRBuilder<> b(ctx);
316  b.SetInsertPoint(wrapper_bb_entry);
317  std::vector<llvm::Value*> loaded_args = {kernel_func_->arg_begin() + 1};
318  for (size_t i = 2; i < wrapper_arg_types.size(); ++i) {
319  loaded_args.push_back(kernel_func_->arg_begin() + i);
320  }
321  auto error_lv = b.CreateCall(entry_point_func_, loaded_args);
322  b.CreateStore(error_lv, kernel_func_->arg_begin());
323  b.CreateRetVoid();
324 }
325 
327  Executor* executor) {
328  /*
329  TODO 1: eliminate need for OverrideFromSrc
330  TODO 2: detect and link only the udf's that are needed
331  */
332  if (co.device_type == ExecutorDeviceType::GPU && rt_udf_gpu_module != nullptr) {
334  *module_,
335  cgen_state_.get(),
336  llvm::Linker::Flags::OverrideFromSrc);
337  }
338  if (co.device_type == ExecutorDeviceType::CPU && rt_udf_cpu_module != nullptr) {
340  *module_,
341  cgen_state_.get(),
342  llvm::Linker::Flags::OverrideFromSrc);
343  }
344 
345  module_.release();
346  // Add code to cache?
347 
348  LOG(IR) << "Table Function Entry Point IR\n"
350 
352  LOG(IR) << "Table Function Kernel IR\n" << serialize_llvm_object(kernel_func_);
353 
354  CHECK(executor);
355  executor->initializeNVPTXBackend();
356  const auto cuda_mgr = executor->catalog_->getDataMgr().getCudaMgr();
357  CHECK(cuda_mgr);
358 
359  CodeGenerator::GPUTarget gpu_target{executor->nvptx_target_machine_.get(),
360  cuda_mgr,
361  executor->blockSize(),
362  cgen_state_.get(),
363  false};
365  kernel_func_,
367  co,
368  gpu_target);
369  } else {
370  auto ee =
372  func_ptr = reinterpret_cast<FuncPtr>(ee->getPointerToFunction(entry_point_func_));
373  own_execution_engine_ = std::move(ee);
374  }
375 
376  LOG(IR) << "End of IR";
377 }
std::string to_lower(const std::string &str)
#define CHECK_EQ(x, y)
Definition: Logger.h:205
std::unique_ptr< llvm::Module > rt_udf_cpu_module
HOST DEVICE int get_size() const
Definition: sqltypes.h:340
std::string toString(const ExtArgumentType &sig_type)
std::unique_ptr< llvm::Module > runtime_module_shallow_copy(CgenState *cgen_state)
std::vector< Analyzer::Expr * > input_exprs
const table_functions::TableFunction table_func
void generateEntryPoint(const TableFunctionExecutionUnit &exe_unit)
#define LOG(tag)
Definition: Logger.h:188
std::unique_ptr< llvm::Module > rt_udf_gpu_module
bool is_fp() const
Definition: sqltypes.h:491
std::shared_ptr< GpuCompilationContext > gpu_code_
llvm::Function * generate_entry_point(const CgenState *cgen_state)
int32_t(*)(const int8_t **input_cols, const int64_t *input_row_count, int64_t **out, int64_t *output_row_count) FuncPtr
std::unique_ptr< llvm::Module > module_
llvm::Type * get_fp_type(const int width, llvm::LLVMContext &context)
static ExecutionEngineWrapper generateNativeCPUCode(llvm::Function *func, const std::unordered_set< llvm::Function * > &live_funcs, const CompilationOptions &co)
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
std::string to_string(char const *&&v)
static std::shared_ptr< GpuCompilationContext > generateNativeGPUCode(llvm::Function *func, llvm::Function *wrapper_func, const std::unordered_set< llvm::Function * > &live_funcs, const CompilationOptions &co, const GPUTarget &gpu_target)
llvm::Module * module_
Definition: CgenState.h:318
void verify_function_ir(const llvm::Function *func)
size_t get_bit_width(const SQLTypeInfo &ti)
llvm::LLVMContext & context_
Definition: CgenState.h:327
bool is_integer() const
Definition: sqltypes.h:489
std::unique_ptr< llvm::Module > g_rt_module
llvm::Value * alloc_column(std::string col_name, const SQLTypeInfo &data_target_info, llvm::Value *data_ptr, llvm::Value *data_size, llvm::LLVMContext &ctx, llvm::IRBuilder<> &ir_builder, bool byval)
bool is_boolean() const
Definition: sqltypes.h:494
static void link_udf_module(const std::unique_ptr< llvm::Module > &udf_module, llvm::Module &module, CgenState *cgen_state, llvm::Linker::Flags flags=llvm::Linker::Flags::None)
ExecutorDeviceType device_type
void finalize(const CompilationOptions &co, Executor *executor)
std::string serialize_llvm_object(const T *llvm_obj)
std::vector< llvm::Value * > generate_column_heads_load(const int num_columns, llvm::Value *byte_stream_arg, llvm::IRBuilder<> &ir_builder, llvm::LLVMContext &ctx)
llvm::Type * get_llvm_type_from_sql_column_type(const SQLTypeInfo elem_ti, llvm::LLVMContext &ctx)
bool g_enable_watchdog false
Definition: Execute.cpp:73
#define CHECK(condition)
Definition: Logger.h:197
std::vector< Analyzer::Expr * > target_exprs
void compile(const TableFunctionExecutionUnit &exe_unit, const CompilationOptions &co, Executor *executor)