OmniSciDB  94e8789169
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
TableFunctionCompilationContext.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2019 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <llvm/IR/Verifier.h>
20 #include <llvm/Support/raw_os_ostream.h>
21 #include <algorithm>
22 #include <boost/algorithm/string.hpp>
23 
25 
26 extern std::unique_ptr<llvm::Module> g_rt_module;
27 extern std::unique_ptr<llvm::Module> rt_udf_cpu_module;
28 extern std::unique_ptr<llvm::Module> rt_udf_gpu_module;
29 
30 namespace {
31 
32 llvm::Function* generate_entry_point(const CgenState* cgen_state) {
33  auto& ctx = cgen_state->context_;
34  const auto pi8_type = llvm::PointerType::get(get_int_type(8, ctx), 0);
35  const auto ppi8_type = llvm::PointerType::get(pi8_type, 0);
36  const auto pi64_type = llvm::PointerType::get(get_int_type(64, ctx), 0);
37  const auto ppi64_type = llvm::PointerType::get(pi64_type, 0);
38  const auto i32_type = get_int_type(32, ctx);
39 
40  const auto func_type = llvm::FunctionType::get(
41  i32_type, {ppi8_type, pi64_type, ppi64_type, pi64_type}, false);
42 
43  auto func = llvm::Function::Create(func_type,
44  llvm::Function::ExternalLinkage,
45  "call_table_function",
46  cgen_state->module_);
47  auto arg_it = func->arg_begin();
48  const auto input_cols_arg = &*arg_it;
49  input_cols_arg->setName("input_col_buffers");
50  const auto input_row_counts = &*(++arg_it);
51  input_row_counts->setName("input_row_counts");
52  const auto output_buffers = &*(++arg_it);
53  output_buffers->setName("output_buffers");
54  const auto output_row_count = &*(++arg_it);
55  output_row_count->setName("output_row_count");
56  return func;
57 }
58 
60  llvm::LLVMContext& ctx) {
61  if (elem_ti.is_fp()) {
62  switch (elem_ti.get_size()) {
63  case 4:
64  return llvm::Type::getFloatPtrTy(ctx);
65  case 8:
66  return llvm::Type::getDoublePtrTy(ctx);
67  }
68  }
69  if (elem_ti.is_boolean()) {
70  return llvm::Type::getInt8PtrTy(ctx);
71  }
72  CHECK(elem_ti.is_integer());
73  switch (elem_ti.get_size()) {
74  case 1:
75  return llvm::Type::getInt8PtrTy(ctx);
76  case 2:
77  return llvm::Type::getInt16PtrTy(ctx);
78  case 4:
79  return llvm::Type::getInt32PtrTy(ctx);
80  case 8:
81  return llvm::Type::getInt64PtrTy(ctx);
82  }
83  LOG(FATAL) << "get_llvm_type_from_sql_column_type: not implemented for "
84  << ::toString(elem_ti);
85  return nullptr;
86 }
87 
88 llvm::Value* alloc_column(std::string col_name,
89  const SQLTypeInfo& data_target_info,
90  llvm::Value* data_ptr,
91  llvm::Value* data_size,
92  llvm::LLVMContext& ctx,
93  llvm::IRBuilder<>& ir_builder,
94  bool byval) {
95  /*
96  Creates a new Column instance of given element type and initialize
97  its data ptr and sz members. If data ptr or sz are unspecified
98  (have nullptr values) then the corresponding members are
99  initialized with NULL and -1, respectively.
100  */
101  llvm::Type* data_ptr_llvm_type =
102  get_llvm_type_from_sql_column_type(data_target_info, ctx);
103  llvm::StructType* col_struct_type =
104  llvm::StructType::get(ctx,
105  {
106  data_ptr_llvm_type, /* T* ptr */
107  llvm::Type::getInt64Ty(ctx) /* int64_t sz */
108  });
109  auto col = ir_builder.CreateAlloca(col_struct_type);
110  col->setName(col_name);
111  auto col_ptr_ptr = ir_builder.CreateStructGEP(col_struct_type, col, 0);
112  auto col_sz_ptr = ir_builder.CreateStructGEP(col_struct_type, col, 1);
113  col_ptr_ptr->setName(col_name + ".ptr");
114  col_sz_ptr->setName(col_name + ".sz");
115 
116  if (data_ptr != nullptr) {
117  if (data_ptr->getType() == data_ptr_llvm_type->getPointerElementType()) {
118  ir_builder.CreateStore(data_ptr, col_ptr_ptr);
119  } else {
120  auto tmp = ir_builder.CreateBitCast(data_ptr, data_ptr_llvm_type);
121  ir_builder.CreateStore(tmp, col_ptr_ptr);
122  }
123  } else {
124  ir_builder.CreateStore(llvm::Constant::getNullValue(data_ptr_llvm_type), col_ptr_ptr);
125  }
126  if (data_size != nullptr) {
127  auto data_size_type = data_size->getType();
128  if (data_size_type->isPointerTy()) {
129  CHECK(data_size_type->getPointerElementType()->isIntegerTy(64));
130  auto val = ir_builder.CreateLoad(data_size);
131  ir_builder.CreateStore(val, col_sz_ptr);
132  } else {
133  CHECK(data_size_type->isIntegerTy(64));
134  ir_builder.CreateStore(data_size, col_sz_ptr);
135  }
136  } else {
137  auto const_minus1 = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), -1, true);
138  ir_builder.CreateStore(const_minus1, col_sz_ptr);
139  }
140 
141  if (byval) {
142  return ir_builder.CreateLoad(col);
143  } else {
144  auto col_ptr = ir_builder.CreatePointerCast(
145  col_ptr_ptr, llvm::PointerType::get(llvm::Type::getInt8Ty(ctx), 0));
146  col_ptr->setName(col_name + "_ptr");
147  return col_ptr;
148  }
149 }
150 
151 } // namespace
152 
154  : cgen_state_(std::make_unique<CgenState>(/*num_query_infos=*/0,
155  /*contains_left_deep_outer_join=*/false)) {
156  auto cgen_state = cgen_state_.get();
157  CHECK(cgen_state);
158 
159  std::unique_ptr<llvm::Module> module(runtime_module_shallow_copy(cgen_state));
160  cgen_state->module_ = module.get();
161 
163  module_ = std::move(module);
164 }
165 
167  const CompilationOptions& co,
168  Executor* executor) {
169  generateEntryPoint(exe_unit);
172  }
173  finalize(co, executor);
174 }
175 
177  const TableFunctionExecutionUnit& exe_unit) {
179  auto arg_it = entry_point_func_->arg_begin();
180  const auto input_cols_arg = &*arg_it;
181  const auto input_row_counts_arg = &*(++arg_it);
182  const auto output_buffers_arg = &*(++arg_it);
183  const auto output_row_count_ptr = &*(++arg_it);
184 
185  auto cgen_state = cgen_state_.get();
186  CHECK(cgen_state);
187  auto& ctx = cgen_state->context_;
188 
189  const auto bb_entry = llvm::BasicBlock::Create(ctx, ".entry", entry_point_func_, 0);
190  cgen_state->ir_builder_.SetInsertPoint(bb_entry);
191 
192  const auto bb_exit = llvm::BasicBlock::Create(ctx, ".exit", entry_point_func_);
193 
194  const auto func_body_bb = llvm::BasicBlock::Create(
195  ctx, ".func_body", cgen_state->ir_builder_.GetInsertBlock()->getParent());
196  cgen_state->ir_builder_.SetInsertPoint(func_body_bb);
197 
198  auto col_heads = generate_column_heads_load(
199  exe_unit.input_exprs.size(), input_cols_arg, cgen_state->ir_builder_, ctx);
200  CHECK_EQ(exe_unit.input_exprs.size(), col_heads.size());
201 
202  auto row_count_heads = generate_column_heads_load(
203  exe_unit.input_exprs.size(), input_row_counts_arg, cgen_state->ir_builder_, ctx);
204 
205  // The column arguments of C++ UDTFs processed by clang must be
206  // passed by reference, see rbc issue 200.
207  auto pass_column_by_value = exe_unit.table_func.isRuntime();
208  std::vector<llvm::Value*> func_args;
209  for (size_t i = 0; i < exe_unit.input_exprs.size(); i++) {
210  const auto& expr = exe_unit.input_exprs[i];
211  const auto& ti = expr->get_type_info();
212  if (ti.is_fp()) {
213  auto r = cgen_state->ir_builder_.CreateBitCast(
214  col_heads[i], llvm::PointerType::get(get_fp_type(get_bit_width(ti), ctx), 0));
215  func_args.push_back(cgen_state->ir_builder_.CreateLoad(r));
216  } else if (ti.is_integer()) {
217  auto r = cgen_state->ir_builder_.CreateBitCast(
218  col_heads[i], llvm::PointerType::get(get_int_type(get_bit_width(ti), ctx), 0));
219  func_args.push_back(cgen_state->ir_builder_.CreateLoad(r));
220  } else if (ti.is_column()) {
221  auto col = alloc_column(std::string("input_col.") + std::to_string(i),
222  ti.get_elem_type(),
223  col_heads[i],
224  row_count_heads[i],
225  ctx,
226  cgen_state_->ir_builder_,
227  pass_column_by_value);
228  func_args.push_back(col);
229  } else {
230  throw std::runtime_error(
231  "Only integer and floating point columns or scalars are supported as inputs to "
232  "table "
233  "functions, got " +
234  ti.get_type_name());
235  }
236  }
237  std::vector<llvm::Value*> output_col_args;
238  for (size_t i = 0; i < exe_unit.target_exprs.size(); i++) {
239  auto output_load = cgen_state->ir_builder_.CreateLoad(
240  cgen_state->ir_builder_.CreateGEP(output_buffers_arg, cgen_state_->llInt(i)));
241  const auto& expr = exe_unit.target_exprs[i];
242  const auto& ti = expr->get_type_info();
243  CHECK(!ti.is_column()); // UDTF output column type is its data type
244  auto col = alloc_column(std::string("output_col.") + std::to_string(i),
245  ti,
246  output_load,
247  output_row_count_ptr,
248  ctx,
249  cgen_state_->ir_builder_,
250  pass_column_by_value);
251  func_args.push_back(col);
252  }
253  auto func_name = exe_unit.table_func.getName();
254  boost::algorithm::to_lower(func_name);
255  const auto table_func_return =
256  cgen_state->emitExternalCall(func_name, get_int_type(32, ctx), func_args);
257  table_func_return->setName("table_func_ret");
258 
259  // If table_func_return is non-negative then store the value in
260  // output_row_count and return zero. Otherwise, return
261  // table_func_return that negative value contains the error code.
262  const auto bb_exit_0 = llvm::BasicBlock::Create(ctx, ".exit0", entry_point_func_);
263 
264  auto const_zero = llvm::ConstantInt::get(table_func_return->getType(), 0, true);
265  auto is_ok = cgen_state_->ir_builder_.CreateICmpSGE(table_func_return, const_zero);
266  cgen_state_->ir_builder_.CreateCondBr(is_ok, bb_exit_0, bb_exit);
267 
268  cgen_state_->ir_builder_.SetInsertPoint(bb_exit_0);
269  auto r = cgen_state->ir_builder_.CreateIntCast(
270  table_func_return, get_int_type(64, ctx), true);
271  cgen_state->ir_builder_.CreateStore(r, output_row_count_ptr);
272  cgen_state->ir_builder_.CreateRet(const_zero);
273 
274  cgen_state->ir_builder_.SetInsertPoint(bb_exit);
275  cgen_state->ir_builder_.CreateRet(table_func_return);
276 
277  cgen_state->ir_builder_.SetInsertPoint(bb_entry);
278  cgen_state->ir_builder_.CreateBr(func_body_bb);
279 
280  /*
281  std::cout << "=================================" << std::endl;
282  entry_point_func_->print(llvm::outs());
283  std::cout << "=================================" << std::endl;
284  */
285 
287 }
288 
291  std::vector<llvm::Type*> arg_types;
292  arg_types.reserve(entry_point_func_->arg_size());
293  std::for_each(entry_point_func_->arg_begin(),
294  entry_point_func_->arg_end(),
295  [&arg_types](const auto& arg) { arg_types.push_back(arg.getType()); });
296  CHECK_EQ(arg_types.size(), entry_point_func_->arg_size());
297 
298  auto cgen_state = cgen_state_.get();
299  CHECK(cgen_state);
300  auto& ctx = cgen_state->context_;
301 
302  std::vector<llvm::Type*> wrapper_arg_types(arg_types.size() + 1);
303  wrapper_arg_types[0] = llvm::PointerType::get(get_int_type(32, ctx), 0);
304  wrapper_arg_types[1] = arg_types[0];
305 
306  for (size_t i = 1; i < arg_types.size(); ++i) {
307  wrapper_arg_types[i + 1] = arg_types[i];
308  }
309 
310  auto wrapper_ft =
311  llvm::FunctionType::get(llvm::Type::getVoidTy(ctx), wrapper_arg_types, false);
312  kernel_func_ = llvm::Function::Create(wrapper_ft,
313  llvm::Function::ExternalLinkage,
314  "table_func_kernel",
315  cgen_state->module_);
316 
317  auto wrapper_bb_entry = llvm::BasicBlock::Create(ctx, ".entry", kernel_func_, 0);
318  llvm::IRBuilder<> b(ctx);
319  b.SetInsertPoint(wrapper_bb_entry);
320  std::vector<llvm::Value*> loaded_args = {kernel_func_->arg_begin() + 1};
321  for (size_t i = 2; i < wrapper_arg_types.size(); ++i) {
322  loaded_args.push_back(kernel_func_->arg_begin() + i);
323  }
324  auto error_lv = b.CreateCall(entry_point_func_, loaded_args);
325  b.CreateStore(error_lv, kernel_func_->arg_begin());
326  b.CreateRetVoid();
327 }
328 
330  Executor* executor) {
331  /*
332  TODO 1: eliminate need for OverrideFromSrc
333  TODO 2: detect and link only the udf's that are needed
334  */
335  if (co.device_type == ExecutorDeviceType::GPU && rt_udf_gpu_module != nullptr) {
337  *module_,
338  cgen_state_.get(),
339  llvm::Linker::Flags::OverrideFromSrc);
340  }
341  if (co.device_type == ExecutorDeviceType::CPU && rt_udf_cpu_module != nullptr) {
343  *module_,
344  cgen_state_.get(),
345  llvm::Linker::Flags::OverrideFromSrc);
346  }
347 
348  module_.release();
349  // Add code to cache?
350 
351  LOG(IR) << "Table Function Entry Point IR\n"
353 
355  LOG(IR) << "Table Function Kernel IR\n" << serialize_llvm_object(kernel_func_);
356 
357  CHECK(executor);
358  executor->initializeNVPTXBackend();
359  const auto cuda_mgr = executor->catalog_->getDataMgr().getCudaMgr();
360  CHECK(cuda_mgr);
361 
362  CodeGenerator::GPUTarget gpu_target{executor->nvptx_target_machine_.get(),
363  cuda_mgr,
364  executor->blockSize(),
365  cgen_state_.get(),
366  false};
368  kernel_func_,
370  co,
371  gpu_target);
372  } else {
373  auto ee =
375  func_ptr = reinterpret_cast<FuncPtr>(ee->getPointerToFunction(entry_point_func_));
376  own_execution_engine_ = std::move(ee);
377  }
378 
379  LOG(IR) << "End of IR";
380 }
std::string to_lower(const std::string &str)
#define CHECK_EQ(x, y)
Definition: Logger.h:205
std::unique_ptr< llvm::Module > rt_udf_cpu_module
HOST DEVICE int get_size() const
Definition: sqltypes.h:321
std::string toString(const ExtArgumentType &sig_type)
std::unique_ptr< llvm::Module > runtime_module_shallow_copy(CgenState *cgen_state)
std::vector< Analyzer::Expr * > input_exprs
const table_functions::TableFunction table_func
void generateEntryPoint(const TableFunctionExecutionUnit &exe_unit)
#define LOG(tag)
Definition: Logger.h:188
std::unique_ptr< llvm::Module > rt_udf_gpu_module
bool is_fp() const
Definition: sqltypes.h:482
std::shared_ptr< GpuCompilationContext > gpu_code_
llvm::Function * generate_entry_point(const CgenState *cgen_state)
int32_t(*)(const int8_t **input_cols, const int64_t *input_row_count, int64_t **out, int64_t *output_row_count) FuncPtr
std::unique_ptr< llvm::Module > module_
llvm::Type * get_fp_type(const int width, llvm::LLVMContext &context)
static ExecutionEngineWrapper generateNativeCPUCode(llvm::Function *func, const std::unordered_set< llvm::Function * > &live_funcs, const CompilationOptions &co)
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
std::string to_string(char const *&&v)
static std::shared_ptr< GpuCompilationContext > generateNativeGPUCode(llvm::Function *func, llvm::Function *wrapper_func, const std::unordered_set< llvm::Function * > &live_funcs, const CompilationOptions &co, const GPUTarget &gpu_target)
llvm::Module * module_
Definition: CgenState.h:318
void verify_function_ir(const llvm::Function *func)
size_t get_bit_width(const SQLTypeInfo &ti)
llvm::LLVMContext & context_
Definition: CgenState.h:327
bool is_integer() const
Definition: sqltypes.h:480
std::unique_ptr< llvm::Module > g_rt_module
llvm::Value * alloc_column(std::string col_name, const SQLTypeInfo &data_target_info, llvm::Value *data_ptr, llvm::Value *data_size, llvm::LLVMContext &ctx, llvm::IRBuilder<> &ir_builder, bool byval)
bool is_boolean() const
Definition: sqltypes.h:485
static void link_udf_module(const std::unique_ptr< llvm::Module > &udf_module, llvm::Module &module, CgenState *cgen_state, llvm::Linker::Flags flags=llvm::Linker::Flags::None)
ExecutorDeviceType device_type
void finalize(const CompilationOptions &co, Executor *executor)
std::string serialize_llvm_object(const T *llvm_obj)
std::vector< llvm::Value * > generate_column_heads_load(const int num_columns, llvm::Value *byte_stream_arg, llvm::IRBuilder<> &ir_builder, llvm::LLVMContext &ctx)
llvm::Type * get_llvm_type_from_sql_column_type(const SQLTypeInfo elem_ti, llvm::LLVMContext &ctx)
bool g_enable_watchdog false
Definition: Execute.cpp:76
#define CHECK(condition)
Definition: Logger.h:197
std::vector< Analyzer::Expr * > target_exprs
void compile(const TableFunctionExecutionUnit &exe_unit, const CompilationOptions &co, Executor *executor)