OmniSciDB  8a228a1076
TableFunctionCompilationContext.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2019 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <llvm/IR/Verifier.h>
20 #include <llvm/Support/raw_os_ostream.h>
21 #include <algorithm>
22 #include <boost/algorithm/string.hpp>
23 
25 
26 extern std::unique_ptr<llvm::Module> g_rt_module;
27 extern std::unique_ptr<llvm::Module> rt_udf_cpu_module;
28 
29 namespace {
30 
31 llvm::Function* generate_entry_point(const CgenState* cgen_state) {
32  auto& ctx = cgen_state->context_;
33  const auto pi8_type = llvm::PointerType::get(get_int_type(8, ctx), 0);
34  const auto ppi8_type = llvm::PointerType::get(pi8_type, 0);
35  const auto pi64_type = llvm::PointerType::get(get_int_type(64, ctx), 0);
36  const auto ppi64_type = llvm::PointerType::get(pi64_type, 0);
37  const auto i32_type = get_int_type(32, ctx);
38 
39  const auto func_type = llvm::FunctionType::get(
40  i32_type, {ppi8_type, pi64_type, ppi64_type, pi64_type}, false);
41 
42  auto func = llvm::Function::Create(func_type,
43  llvm::Function::ExternalLinkage,
44  "call_table_function",
45  cgen_state->module_);
46  auto arg_it = func->arg_begin();
47  const auto input_cols_arg = &*arg_it;
48  input_cols_arg->setName("input_col_buffers");
49  const auto input_row_count = &*(++arg_it);
50  input_row_count->setName("input_row_count");
51  const auto output_buffers = &*(++arg_it);
52  output_buffers->setName("output_buffers");
53  const auto output_row_count = &*(++arg_it);
54  output_row_count->setName("output_row_count");
55  return func;
56 }
57 
59  llvm::LLVMContext& ctx) {
60  CHECK(ti.is_column());
61  const auto& elem_ti = ti.get_elem_type();
62  if (elem_ti.is_fp()) {
63  switch (elem_ti.get_size()) {
64  case 4:
65  return llvm::Type::getFloatPtrTy(ctx);
66  case 8:
67  return llvm::Type::getDoublePtrTy(ctx);
68  }
69  }
70  CHECK(elem_ti.is_integer());
71  switch (elem_ti.get_size()) {
72  case 1:
73  return llvm::Type::getInt8PtrTy(ctx);
74  case 2:
75  return llvm::Type::getInt16PtrTy(ctx);
76  case 4:
77  return llvm::Type::getInt32PtrTy(ctx);
78  case 8:
79  return llvm::Type::getInt64PtrTy(ctx);
80  }
81 
82  UNREACHABLE();
83  return nullptr;
84 }
85 
86 llvm::Value* alloc_column(std::string col_name,
87  const SQLTypeInfo& data_target_info,
88  llvm::Value* data_ptr,
89  llvm::Value* data_size,
90  llvm::LLVMContext& ctx,
91  llvm::IRBuilder<>& ir_builder) {
92  /*
93  Creates a new Column instance of given element type and initialize
94  its data ptr and sz members. If data ptr or sz are unspecified
95  (have nullptr values) then the corresponding members are
96  initialized with NULL and -1, respectively.
97  */
98  llvm::Type* data_ptr_llvm_type =
99  get_llvm_type_from_sql_column_type(data_target_info, ctx);
100  llvm::StructType* col_struct_type =
101  llvm::StructType::get(ctx,
102  {
103  data_ptr_llvm_type, /* T* ptr */
104  llvm::Type::getInt64Ty(ctx) /* int64_t sz */
105  });
106  auto col = ir_builder.CreateAlloca(col_struct_type);
107  col->setName(col_name);
108  auto col_ptr_ptr = ir_builder.CreateStructGEP(col_struct_type, col, 0);
109  auto col_sz_ptr = ir_builder.CreateStructGEP(col_struct_type, col, 1);
110  col_ptr_ptr->setName(col_name + ".ptr");
111  col_sz_ptr->setName(col_name + ".sz");
112 
113  if (data_ptr != nullptr) {
114  if (data_ptr->getType() == data_ptr_llvm_type->getPointerElementType()) {
115  ir_builder.CreateStore(data_ptr, col_ptr_ptr);
116  } else {
117  auto tmp = ir_builder.CreateBitCast(data_ptr, data_ptr_llvm_type);
118  ir_builder.CreateStore(tmp, col_ptr_ptr);
119  }
120  } else {
121  ir_builder.CreateStore(llvm::Constant::getNullValue(data_ptr_llvm_type), col_ptr_ptr);
122  }
123  if (data_size != nullptr) {
124  auto data_size_type = data_size->getType();
125  if (data_size_type->isPointerTy()) {
126  CHECK(data_size_type->getPointerElementType()->isIntegerTy(64));
127  auto val = ir_builder.CreateLoad(data_size);
128  ir_builder.CreateStore(val, col_sz_ptr);
129  } else {
130  CHECK(data_size_type->isIntegerTy(64));
131  ir_builder.CreateStore(data_size, col_sz_ptr);
132  }
133  } else {
134  auto const_minus1 = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), -1, true);
135  ir_builder.CreateStore(const_minus1, col_sz_ptr);
136  }
137  return col;
138 }
139 
140 } // namespace
141 
143  : cgen_state_(std::make_unique<CgenState>(std::vector<InputTableInfo>{}, false)) {
144  auto cgen_state = cgen_state_.get();
145  CHECK(cgen_state);
146 
147  std::unique_ptr<llvm::Module> module(runtime_module_shallow_copy(cgen_state));
148  cgen_state->module_ = module.get();
149 
151  module_ = std::move(module);
152 }
153 
155  const CompilationOptions& co,
156  Executor* executor) {
157  generateEntryPoint(exe_unit);
160  }
161  finalize(co, executor);
162 }
163 
165  const TableFunctionExecutionUnit& exe_unit) {
167  auto arg_it = entry_point_func_->arg_begin();
168  const auto input_cols_arg = &*arg_it;
169  const auto input_row_count = &*(++arg_it);
170  const auto output_buffers_arg = &*(++arg_it);
171  const auto output_row_count_ptr = &*(++arg_it);
172 
173  auto cgen_state = cgen_state_.get();
174  CHECK(cgen_state);
175  auto& ctx = cgen_state->context_;
176 
177  const auto bb_entry = llvm::BasicBlock::Create(ctx, ".entry", entry_point_func_, 0);
178  cgen_state->ir_builder_.SetInsertPoint(bb_entry);
179 
180  const auto bb_exit = llvm::BasicBlock::Create(ctx, ".exit", entry_point_func_);
181 
182  const auto func_body_bb = llvm::BasicBlock::Create(
183  ctx, ".func_body", cgen_state->ir_builder_.GetInsertBlock()->getParent());
184  cgen_state->ir_builder_.SetInsertPoint(func_body_bb);
185 
186  auto col_heads = generate_column_heads_load(
187  exe_unit.input_exprs.size(), input_cols_arg, cgen_state->ir_builder_, ctx);
188  CHECK_EQ(exe_unit.input_exprs.size(), col_heads.size());
189 
190  std::vector<llvm::Value*> func_args;
191  for (size_t i = 0; i < exe_unit.input_exprs.size(); i++) {
192  const auto& expr = exe_unit.input_exprs[i];
193  const auto& ti = expr->get_type_info();
194  if (ti.is_fp()) {
195  auto r = cgen_state->ir_builder_.CreateBitCast(
196  col_heads[i], llvm::PointerType::get(get_fp_type(get_bit_width(ti), ctx), 0));
197  func_args.push_back(cgen_state->ir_builder_.CreateLoad(r));
198  } else if (ti.is_integer()) {
199  auto r = cgen_state->ir_builder_.CreateBitCast(
200  col_heads[i], llvm::PointerType::get(get_int_type(get_bit_width(ti), ctx), 0));
201  func_args.push_back(cgen_state->ir_builder_.CreateLoad(r));
202  } else if (ti.is_column()) {
203  auto col = alloc_column(std::string("input_col.") + std::to_string(i),
204  ti,
205  col_heads[i],
206  input_row_count,
207  ctx,
208  cgen_state_->ir_builder_);
209  func_args.push_back(cgen_state->ir_builder_.CreateLoad(col));
210  } else {
211  throw std::runtime_error(
212  "Only integer and floating point columns or scalars are supported as inputs to "
213  "table "
214  "functions.");
215  }
216  }
217 
218  std::vector<llvm::Value*> output_col_args;
219  for (size_t i = 0; i < exe_unit.target_exprs.size(); i++) {
220  auto output_load = cgen_state->ir_builder_.CreateLoad(
221  cgen_state->ir_builder_.CreateGEP(output_buffers_arg, cgen_state_->llInt(i)));
222  const auto& expr = exe_unit.target_exprs[i];
223  const auto& ti = expr->get_type_info();
224  /*
225  How to specify output scalars in C++ code?
226  */
227  if (ti.is_fp()) {
228  func_args.push_back(cgen_state->ir_builder_.CreateBitCast(
229  output_load, llvm::PointerType::get(get_fp_type(get_bit_width(ti), ctx), 0)));
230  } else if (ti.is_integer()) {
231  func_args.push_back(cgen_state->ir_builder_.CreateBitCast(
232  output_load, llvm::PointerType::get(get_int_type(get_bit_width(ti), ctx), 0)));
233  } else if (ti.is_column()) {
234  auto col = alloc_column(std::string("output_col.") + std::to_string(i),
235  ti,
236  output_load,
237  output_row_count_ptr,
238  ctx,
239  cgen_state_->ir_builder_);
240  func_args.push_back(cgen_state->ir_builder_.CreateLoad(col));
241  } else {
242  throw std::runtime_error(
243  "Only integer and floating point columns are supported as outputs to table "
244  "functions.");
245  }
246  }
247 
248  auto func_name = exe_unit.table_func_name;
249  boost::algorithm::to_lower(func_name);
250  const auto table_func_return =
251  cgen_state->emitExternalCall(func_name, get_int_type(32, ctx), func_args);
252  table_func_return->setName("table_func_ret");
253 
254  // If table_func_return is non-negative then store the value in
255  // output_row_count and return zero. Otherwise, return
256  // table_func_return that negative value contains the error code.
257  const auto bb_exit_0 = llvm::BasicBlock::Create(ctx, ".exit0", entry_point_func_);
258 
259  auto const_zero = llvm::ConstantInt::get(table_func_return->getType(), 0, true);
260  auto is_ok = cgen_state_->ir_builder_.CreateICmpSGE(table_func_return, const_zero);
261  cgen_state_->ir_builder_.CreateCondBr(is_ok, bb_exit_0, bb_exit);
262 
263  cgen_state_->ir_builder_.SetInsertPoint(bb_exit_0);
264  auto r = cgen_state->ir_builder_.CreateIntCast(
265  table_func_return, get_int_type(64, ctx), true);
266  cgen_state->ir_builder_.CreateStore(r, output_row_count_ptr);
267  cgen_state->ir_builder_.CreateRet(const_zero);
268 
269  cgen_state->ir_builder_.SetInsertPoint(bb_exit);
270  cgen_state->ir_builder_.CreateRet(table_func_return);
271 
272  cgen_state->ir_builder_.SetInsertPoint(bb_entry);
273  cgen_state->ir_builder_.CreateBr(func_body_bb);
274 
275  /*
276  std::cout << "=================================" << std::endl;
277  entry_point_func_->print(llvm::outs());
278  std::cout << "=================================" << std::endl;
279  */
280 
282 }
283 
286  std::vector<llvm::Type*> arg_types;
287  arg_types.reserve(entry_point_func_->arg_size());
288  std::for_each(entry_point_func_->arg_begin(),
289  entry_point_func_->arg_end(),
290  [&arg_types](const auto& arg) { arg_types.push_back(arg.getType()); });
291  CHECK_EQ(arg_types.size(), entry_point_func_->arg_size());
292 
293  auto cgen_state = cgen_state_.get();
294  CHECK(cgen_state);
295  auto& ctx = cgen_state->context_;
296 
297  std::vector<llvm::Type*> wrapper_arg_types(arg_types.size() + 1);
298  wrapper_arg_types[0] = llvm::PointerType::get(get_int_type(32, ctx), 0);
299  wrapper_arg_types[1] = arg_types[0];
300 
301  for (size_t i = 1; i < arg_types.size(); ++i) {
302  wrapper_arg_types[i + 1] = arg_types[i];
303  }
304 
305  auto wrapper_ft =
306  llvm::FunctionType::get(llvm::Type::getVoidTy(ctx), wrapper_arg_types, false);
307  kernel_func_ = llvm::Function::Create(wrapper_ft,
308  llvm::Function::ExternalLinkage,
309  "table_func_kernel",
310  cgen_state->module_);
311 
312  auto wrapper_bb_entry = llvm::BasicBlock::Create(ctx, ".entry", kernel_func_, 0);
313  llvm::IRBuilder<> b(ctx);
314  b.SetInsertPoint(wrapper_bb_entry);
315  std::vector<llvm::Value*> loaded_args = {kernel_func_->arg_begin() + 1};
316  for (size_t i = 2; i < wrapper_arg_types.size(); ++i) {
317  loaded_args.push_back(kernel_func_->arg_begin() + i);
318  }
319  auto error_lv = b.CreateCall(entry_point_func_, loaded_args);
320  b.CreateStore(error_lv, kernel_func_->arg_begin());
321  b.CreateRetVoid();
322 }
323 
325  Executor* executor) {
326  if (rt_udf_cpu_module != nullptr) {
327  /*
328  TODO 1: eliminate need for OverrideFromSrc
329  TODO 2: detect and link only the udf's that are needed
330  */
332  *module_,
333  cgen_state_.get(),
334  llvm::Linker::Flags::OverrideFromSrc);
335  }
336 
337  module_.release();
338  // Add code to cache?
339 
340  LOG(IR) << "Table Function Entry Point IR\n"
342 
344  LOG(IR) << "Table Function Kernel IR\n" << serialize_llvm_object(kernel_func_);
345 
346  CHECK(executor);
347  executor->initializeNVPTXBackend();
348  const auto cuda_mgr = executor->catalog_->getDataMgr().getCudaMgr();
349  CHECK(cuda_mgr);
350 
351  CodeGenerator::GPUTarget gpu_target{executor->nvptx_target_machine_.get(),
352  cuda_mgr,
353  executor->blockSize(),
354  cgen_state_.get(),
355  false};
357  kernel_func_,
359  co,
360  gpu_target);
361  } else {
362  auto ee =
364  func_ptr = reinterpret_cast<FuncPtr>(ee->getPointerToFunction(entry_point_func_));
365  own_execution_engine_ = std::move(ee);
366  }
367 
368  LOG(IR) << "End of IR";
369 }
std::string to_lower(const std::string &str)
#define CHECK_EQ(x, y)
Definition: Logger.h:205
std::unique_ptr< llvm::Module > g_rt_module
const std::string table_func_name
std::unique_ptr< llvm::Module > runtime_module_shallow_copy(CgenState *cgen_state)
static ExecutionEngineWrapper generateNativeCPUCode(llvm::Function *func, const std::unordered_set< llvm::Function *> &live_funcs, const CompilationOptions &co)
std::vector< Analyzer::Expr * > input_exprs
void generateEntryPoint(const TableFunctionExecutionUnit &exe_unit)
#define LOG(tag)
Definition: Logger.h:188
std::shared_ptr< GpuCompilationContext > gpu_code_
llvm::Function * generate_entry_point(const CgenState *cgen_state)
#define UNREACHABLE()
Definition: Logger.h:241
int32_t(*)(const int8_t **input_cols, const int64_t *input_row_count, int64_t **out, int64_t *output_row_count) FuncPtr
std::unique_ptr< llvm::Module > module_
static std::shared_ptr< GpuCompilationContext > generateNativeGPUCode(llvm::Function *func, llvm::Function *wrapper_func, const std::unordered_set< llvm::Function *> &live_funcs, const CompilationOptions &co, const GPUTarget &gpu_target)
llvm::Type * get_fp_type(const int width, llvm::LLVMContext &context)
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
std::string to_string(char const *&&v)
llvm::Module * module_
Definition: CgenState.h:314
void verify_function_ir(const llvm::Function *func)
size_t get_bit_width(const SQLTypeInfo &ti)
llvm::LLVMContext & context_
Definition: CgenState.h:317
bool is_column() const
Definition: sqltypes.h:429
static void link_udf_module(const std::unique_ptr< llvm::Module > &udf_module, llvm::Module &module, CgenState *cgen_state, llvm::Linker::Flags flags=llvm::Linker::Flags::None)
std::unique_ptr< llvm::Module > rt_udf_cpu_module
ExecutorDeviceType device_type
void finalize(const CompilationOptions &co, Executor *executor)
std::string serialize_llvm_object(const T *llvm_obj)
llvm::Type * get_llvm_type_from_sql_column_type(const SQLTypeInfo ti, llvm::LLVMContext &ctx)
std::vector< llvm::Value * > generate_column_heads_load(const int num_columns, llvm::Value *byte_stream_arg, llvm::IRBuilder<> &ir_builder, llvm::LLVMContext &ctx)
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:623
#define CHECK(condition)
Definition: Logger.h:197
std::vector< Analyzer::Expr * > target_exprs
llvm::Value * alloc_column(std::string col_name, const SQLTypeInfo &data_target_info, llvm::Value *data_ptr, llvm::Value *data_size, llvm::LLVMContext &ctx, llvm::IRBuilder<> &ir_builder)
void compile(const TableFunctionExecutionUnit &exe_unit, const CompilationOptions &co, Executor *executor)