OmniSciDB  c1a53651b2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TableFunctionCompilationContext.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <llvm/IR/InstIterator.h>
20 #include <llvm/IR/Verifier.h>
21 #include <llvm/Support/raw_os_ostream.h>
22 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
23 #include <algorithm>
24 #include <boost/algorithm/string.hpp>
25 
28 
29 namespace {
30 
31 llvm::Function* generate_entry_point(const CgenState* cgen_state) {
32  auto& ctx = cgen_state->context_;
33  const auto pi8_type = llvm::PointerType::get(get_int_type(8, ctx), 0);
34  const auto ppi8_type = llvm::PointerType::get(pi8_type, 0);
35  const auto pi64_type = llvm::PointerType::get(get_int_type(64, ctx), 0);
36  const auto ppi64_type = llvm::PointerType::get(pi64_type, 0);
37  const auto i32_type = get_int_type(32, ctx);
38 
39  const auto func_type = llvm::FunctionType::get(
40  i32_type,
41  {pi8_type, ppi8_type, pi64_type, ppi8_type, ppi64_type, ppi8_type, pi64_type},
42  false);
43 
44  auto func = llvm::Function::Create(func_type,
45  llvm::Function::ExternalLinkage,
46  "call_table_function",
47  cgen_state->module_);
48  auto arg_it = func->arg_begin();
49  const auto mgr_arg = &*arg_it;
50  mgr_arg->setName("mgr_ptr");
51  const auto input_cols_arg = &*(++arg_it);
52  input_cols_arg->setName("input_col_buffers");
53  const auto input_row_counts = &*(++arg_it);
54  input_row_counts->setName("input_row_counts");
55  const auto input_str_dict_proxies = &*(++arg_it);
56  input_str_dict_proxies->setName("input_str_dict_proxies");
57  const auto output_buffers = &*(++arg_it);
58  output_buffers->setName("output_buffers");
59  const auto output_str_dict_proxies = &*(++arg_it);
60  output_str_dict_proxies->setName("output_str_dict_proxies");
61  const auto output_row_count = &*(++arg_it);
62  output_row_count->setName("output_row_count");
63  return func;
64 }
65 
67  llvm::LLVMContext& ctx) {
68  if (elem_ti.is_fp()) {
69  return get_fp_ptr_type(elem_ti.get_size() * 8, ctx);
70  } else if (elem_ti.is_boolean()) {
71  return get_int_ptr_type(8, ctx);
72  } else if (elem_ti.is_integer()) {
73  return get_int_ptr_type(elem_ti.get_size() * 8, ctx);
74  } else if (elem_ti.is_string()) {
75  if (elem_ti.get_compression() == kENCODING_DICT) {
76  return get_int_ptr_type(elem_ti.get_size() * 8, ctx);
77  }
78  CHECK(elem_ti.is_bytes()); // None encoded string
79  return get_int_ptr_type(8, ctx);
80  } else if (elem_ti.is_timestamp()) {
81  return get_int_ptr_type(elem_ti.get_size() * 8, ctx);
82  } else if (elem_ti.is_array()) {
83  return get_int_ptr_type(8, ctx);
84  }
85  LOG(FATAL) << "get_llvm_type_from_sql_column_type: not implemented for "
86  << ::toString(elem_ti);
87  return nullptr;
88 }
89 
90 void initialize_ptr_member(llvm::Value* member_ptr,
91  llvm::Type* member_llvm_type,
92  llvm::Value* value_ptr,
93  llvm::IRBuilder<>& ir_builder) {
94  if (value_ptr != nullptr) {
95  if (value_ptr->getType() == member_llvm_type->getPointerElementType()) {
96  ir_builder.CreateStore(value_ptr, member_ptr);
97  } else {
98  auto tmp = ir_builder.CreateBitCast(value_ptr, member_llvm_type);
99  ir_builder.CreateStore(tmp, member_ptr);
100  }
101  } else {
102  ir_builder.CreateStore(llvm::Constant::getNullValue(member_llvm_type), member_ptr);
103  }
104 }
105 
106 void initialize_int64_member(llvm::Value* member_ptr,
107  llvm::Value* value,
108  int64_t default_value,
109  llvm::LLVMContext& ctx,
110  llvm::IRBuilder<>& ir_builder) {
111  llvm::Value* val = nullptr;
112  if (value != nullptr) {
113  auto value_type = value->getType();
114  if (value_type->isPointerTy()) {
115  CHECK(value_type->getPointerElementType()->isIntegerTy(64));
116  val = ir_builder.CreateLoad(value->getType()->getPointerElementType(), value);
117  } else {
118  CHECK(value_type->isIntegerTy(64));
119  val = value;
120  }
121  ir_builder.CreateStore(val, member_ptr);
122  } else {
123  auto const_default =
124  llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), default_value, true);
125  ir_builder.CreateStore(const_default, member_ptr);
126  }
127 }
128 
129 std::tuple<llvm::Value*, llvm::Value*> alloc_column(std::string col_name,
130  const size_t index,
131  const SQLTypeInfo& data_target_info,
132  llvm::Value* data_ptr,
133  llvm::Value* data_size,
134  llvm::Value* data_str_dict_proxy_ptr,
135  llvm::LLVMContext& ctx,
136  llvm::IRBuilder<>& ir_builder) {
137  /*
138  Creates a new Column instance of given element type and initialize
139  its data ptr and sz members when specified. If data ptr or sz are
140  unspecified (have nullptr values) then the corresponding members
141  are initialized with NULL and -1, respectively.
142 
143  If we are allocating a TextEncodingDict Column type, this function
144  adds and populates a int8* pointer to a StringDictProxy object.
145 
146  Return a pair of Column allocation (caller should apply
147  builder.CreateLoad to it in order to construct a Column instance
148  as a value) and a pointer to the Column instance.
149  */
150  const bool is_text_encoding_dict_type =
151  data_target_info.is_string() &&
152  data_target_info.get_compression() == kENCODING_DICT;
153  llvm::StructType* col_struct_type;
154  llvm::Type* data_ptr_llvm_type =
155  get_llvm_type_from_sql_column_type(data_target_info, ctx);
156  if (is_text_encoding_dict_type) {
157  col_struct_type = llvm::StructType::get(
158  ctx,
159  {
160  data_ptr_llvm_type, /* T* ptr */
161  llvm::Type::getInt64Ty(ctx), /* int64_t sz */
162  llvm::Type::getInt8PtrTy(ctx) /* int8_t* string_dictionary_ptr */
163  });
164  } else {
165  col_struct_type =
166  llvm::StructType::get(ctx,
167  {
168  data_ptr_llvm_type, /* T* ptr */
169  llvm::Type::getInt64Ty(ctx) /* int64_t sz */
170  });
171  }
172 
173  auto col = ir_builder.CreateAlloca(col_struct_type);
174  col->setName(col_name);
175  auto col_ptr_ptr = ir_builder.CreateStructGEP(col_struct_type, col, 0);
176  auto col_sz_ptr = ir_builder.CreateStructGEP(col_struct_type, col, 1);
177  auto col_str_dict_ptr = is_text_encoding_dict_type
178  ? ir_builder.CreateStructGEP(col_struct_type, col, 2)
179  : nullptr;
180  col_ptr_ptr->setName(col_name + ".ptr");
181  col_sz_ptr->setName(col_name + ".sz");
182  if (is_text_encoding_dict_type) {
183  col_str_dict_ptr->setName(col_name + ".string_dict_proxy");
184  }
185 
186  initialize_ptr_member(col_ptr_ptr, data_ptr_llvm_type, data_ptr, ir_builder);
187  initialize_int64_member(col_sz_ptr, data_size, -1, ctx, ir_builder);
188  if (is_text_encoding_dict_type) {
189  initialize_ptr_member(col_str_dict_ptr,
190  llvm::Type::getInt8PtrTy(ctx),
191  data_str_dict_proxy_ptr,
192  ir_builder);
193  }
194  auto col_ptr = ir_builder.CreatePointerCast(
195  col_ptr_ptr, llvm::PointerType::get(llvm::Type::getInt8Ty(ctx), 0));
196  col_ptr->setName(col_name + "_ptr");
197  return {col, col_ptr};
198 }
199 
200 llvm::Value* alloc_column_list(std::string col_list_name,
201  const SQLTypeInfo& data_target_info,
202  llvm::Value* data_ptrs,
203  int length,
204  llvm::Value* data_size,
205  llvm::Value* data_str_dict_proxy_ptrs,
206  llvm::LLVMContext& ctx,
207  llvm::IRBuilder<>& ir_builder) {
208  /*
209  Creates a new ColumnList instance of given element type and initialize
210  its members. If data ptr or size are unspecified (have nullptr
211  values) then the corresponding members are initialized with NULL
212  and -1, respectively.
213  */
214  llvm::Type* data_ptrs_llvm_type = llvm::Type::getInt8PtrTy(ctx);
215  const bool is_text_encoding_dict_type =
216  data_target_info.is_string() &&
217  data_target_info.get_compression() == kENCODING_DICT;
218 
219  llvm::StructType* col_list_struct_type =
220  is_text_encoding_dict_type
221  ? llvm::StructType::get(
222  ctx,
223  {
224  data_ptrs_llvm_type, /* int8_t* ptrs */
225  llvm::Type::getInt64Ty(ctx), /* int64_t length */
226  llvm::Type::getInt64Ty(ctx), /* int64_t size */
227  data_ptrs_llvm_type /* int8_t* str_dict_proxy_ptrs */
228  })
229  : llvm::StructType::get(ctx,
230  {
231  data_ptrs_llvm_type, /* int8_t* ptrs */
232  llvm::Type::getInt64Ty(ctx), /* int64_t length */
233  llvm::Type::getInt64Ty(ctx) /* int64_t size */
234  });
235 
236  auto col_list = ir_builder.CreateAlloca(col_list_struct_type);
237  col_list->setName(col_list_name);
238  auto col_list_ptr_ptr = ir_builder.CreateStructGEP(col_list_struct_type, col_list, 0);
239  auto col_list_length_ptr =
240  ir_builder.CreateStructGEP(col_list_struct_type, col_list, 1);
241  auto col_list_size_ptr = ir_builder.CreateStructGEP(col_list_struct_type, col_list, 2);
242  auto col_str_dict_ptr_ptr =
243  is_text_encoding_dict_type
244  ? ir_builder.CreateStructGEP(col_list_struct_type, col_list, 3)
245  : nullptr;
246 
247  col_list_ptr_ptr->setName(col_list_name + ".ptrs");
248  col_list_length_ptr->setName(col_list_name + ".length");
249  col_list_size_ptr->setName(col_list_name + ".size");
250  if (is_text_encoding_dict_type) {
251  col_str_dict_ptr_ptr->setName(col_list_name + ".string_dict_proxies");
252  }
253 
254  initialize_ptr_member(col_list_ptr_ptr, data_ptrs_llvm_type, data_ptrs, ir_builder);
255 
256  CHECK(length >= 0);
257  auto const_length = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), length, true);
258  ir_builder.CreateStore(const_length, col_list_length_ptr);
259 
260  initialize_int64_member(col_list_size_ptr, data_size, -1, ctx, ir_builder);
261 
262  if (is_text_encoding_dict_type) {
263  initialize_ptr_member(col_str_dict_ptr_ptr,
264  data_str_dict_proxy_ptrs->getType(),
265  data_str_dict_proxy_ptrs,
266  ir_builder);
267  }
268 
269  auto col_list_ptr = ir_builder.CreatePointerCast(
270  col_list_ptr_ptr, llvm::PointerType::get(llvm::Type::getInt8Ty(ctx), 0));
271  col_list_ptr->setName(col_list_name + "_ptrs");
272  return col_list_ptr;
273 }
274 
275 static bool columnTypeRequiresCasting(const SQLTypeInfo& ti) {
276  /*
277  Returns whether a column requires casting before table function execution based on its
278  underlying SQL type
279  */
280 
281  if (!ti.is_column()) {
282  return false;
283  }
284 
285  // TIMESTAMP columns should always have nanosecond precision
286  if (ti.get_subtype() == kTIMESTAMP && ti.get_precision() != 9) {
287  return true;
288  }
289 
290  return false;
291 }
292 
293 llvm::Value* cast_value(llvm::Value* value,
294  SQLTypeInfo& orig_ti,
295  SQLTypeInfo& dest_ti,
296  bool nullable,
297  CodeGenerator& codeGenerator) {
298  /*
299  Codegens a cast of a value from a given origin type to a given destination type, if such
300  implementation is available. Errors for unsupported casts.
301  */
302  if (orig_ti.is_timestamp() && dest_ti.is_timestamp()) {
303  return codeGenerator.codegenCastBetweenTimestamps(
304  value, orig_ti, dest_ti, !dest_ti.get_notnull());
305  } else {
306  throw std::runtime_error("Unsupported cast from " + orig_ti.get_type_name() + " to " +
307  dest_ti.get_type_name() + " during UDTF code generation.");
308  }
309 }
310 
311 void cast_column(llvm::Value* col_base_ptr,
312  llvm::Function* func,
313  SQLTypeInfo& orig_ti,
314  SQLTypeInfo& dest_ti,
315  std::string index,
316  llvm::IRBuilder<>& ir_builder,
317  llvm::LLVMContext& ctx,
318  CodeGenerator& codeGenerator) {
319  /*
320  Generates code to cast a Column instance from a given origin
321  SQLType to a new destinaton SQLType. To do so, it generates a
322  loop with the following overall structure:
323 
324  --------------
325  | pre_header |
326  | i = 0 |
327  --------------
328  |
329  v
330  ---------------- ----------------
331  | cond | | body |
332  | i < col.size | (True) -> | cast(col[i]) |
333  ---------------- | i++ |
334  (False) ^ ----------------
335  | \____________________/
336  |
337  v
338  ---------------
339  | end |
340  ---------------
341 
342  The correctness of the cast as well as error handling/early
343  exiting in case of cast failures are left to the CodeGenerator
344  which generates the code for the cast operation itself.
345  */
346 
347  llvm::BasicBlock* for_pre =
348  llvm::BasicBlock::Create(ctx, "for_pre_cast." + index, func);
349  llvm::BasicBlock* for_cond =
350  llvm::BasicBlock::Create(ctx, "for_cond_cast." + index, func);
351  llvm::BasicBlock* for_body =
352  llvm::BasicBlock::Create(ctx, "for_body_cast." + index, func);
353  llvm::BasicBlock* for_end =
354  llvm::BasicBlock::Create(ctx, "for_end_cast." + index, func);
355  ir_builder.CreateBr(for_pre);
356 
357  // pre-header: load column ptr and size
358  ir_builder.SetInsertPoint(for_pre);
359  llvm::Type* data_type = get_llvm_type_from_sql_column_type(orig_ti, ctx);
360  llvm::StructType* col_struct_type =
361  llvm::StructType::get(ctx, {data_type, ir_builder.getInt64Ty()});
362  llvm::Value* struct_cast = ir_builder.CreateBitCast(
363  col_base_ptr, col_struct_type->getPointerTo(), "col_struct." + index);
364  llvm::Value* gep_ptr = ir_builder.CreateStructGEP(
365  col_struct_type, struct_cast, 0, "col_ptr_addr." + index);
366  llvm::Value* col_ptr = ir_builder.CreateLoad(data_type, gep_ptr, "col_ptr." + index);
367  llvm::Value* gep_sz =
368  ir_builder.CreateStructGEP(col_struct_type, struct_cast, 1, "col_sz_addr." + index);
369  llvm::Value* col_sz =
370  ir_builder.CreateLoad(ir_builder.getInt64Ty(), gep_sz, "col_sz." + index);
371  ir_builder.CreateBr(for_cond);
372 
373  // condition: check induction variable against loop predicate
374  ir_builder.SetInsertPoint(for_cond);
375  llvm::PHINode* for_ind_var =
376  ir_builder.CreatePHI(ir_builder.getInt64Ty(), 2, "for_ind_var." + index);
377  for_ind_var->addIncoming(ir_builder.getInt64(0), for_pre);
378  llvm::Value* for_pred =
379  ir_builder.CreateICmpSLT(for_ind_var, col_sz, "for_pred." + index);
380  ir_builder.CreateCondBr(for_pred, for_body, for_end);
381 
382  // body: perform value cast, increment induction variable
383  ir_builder.SetInsertPoint(for_body);
384  ;
385  llvm::Value* val_gep = ir_builder.CreateInBoundsGEP(
386  ir_builder.getInt64Ty(), col_ptr, for_ind_var, "val_gep." + index);
387  llvm::Value* val_load =
388  ir_builder.CreateLoad(ir_builder.getInt64Ty(), val_gep, "val_load." + index);
389  llvm::Value* cast_result = cast_value(val_load, orig_ti, dest_ti, false, codeGenerator);
390  cast_result->setName("cast_result." + index);
391  ir_builder.CreateStore(cast_result, val_gep);
392  llvm::Value* for_inc =
393  ir_builder.CreateAdd(for_ind_var, ir_builder.getInt64(1), "for_inc." + index);
394  ir_builder.CreateBr(for_cond);
395  // the cast codegening may have generated extra blocks, so for_body does not necessarily
396  // jump to for_cond directly
397  llvm::Instruction* inc_as_inst = llvm::cast<llvm::Instruction>(for_inc);
398  for_ind_var->addIncoming(for_inc, inc_as_inst->getParent());
399  ir_builder.SetInsertPoint(for_end);
400 }
401 
402 std::string exprsKey(const std::vector<Analyzer::Expr*>& exprs) {
403  std::string result;
404  for (const auto& expr : exprs) {
405  const auto& ti = expr->get_type_info();
406  result += ti.to_string() + ", ";
407  }
408  return result;
409 }
410 
411 } // namespace
412 
413 std::shared_ptr<CompilationContext> TableFunctionCompilationContext::compile(
414  const TableFunctionExecutionUnit& exe_unit,
415  bool emit_only_preflight_fn) {
416  auto timer = DEBUG_TIMER(__func__);
417 
418  // Here we assume that Executor::tf_code_accessor is cleared when a
419  // UDTF implementation is changed. TODO: Ideally, the key should
420  // contain a hash of an UDTF implementation string. This could be
421  // achieved by including the hash value to the prefix of the UDTF
422  // name, for instance.
423  CodeCacheKey key{exe_unit.table_func.getName(),
424  exprsKey(exe_unit.input_exprs),
425  exprsKey(exe_unit.target_exprs),
426  std::to_string(emit_only_preflight_fn),
428 
429  auto cached_code = QueryEngine::getInstance()->tf_code_accessor->get_or_wait(key);
430  if (cached_code) {
431  return *cached_code;
432  }
433 
434  auto cgen_state = executor_->getCgenStatePtr();
435  CHECK(cgen_state);
436  CHECK(cgen_state->module_ == nullptr);
437  cgen_state->set_module_shallow_copy(executor_->get_rt_module());
438 
440 
441  generateEntryPoint(exe_unit, emit_only_preflight_fn);
442 
444  CHECK(!emit_only_preflight_fn);
446  }
447 
448  QueryEngine::getInstance()->tf_code_accessor->swap(key,
449  finalize(emit_only_preflight_fn));
450  return QueryEngine::getInstance()->tf_code_accessor->get_value(key);
451 }
452 
454  const TableFunctionExecutionUnit& exe_unit) {
455  bool is_gpu = co_.device_type == ExecutorDeviceType::GPU;
456  auto mod = executor_->get_rt_udf_module(is_gpu).get();
457  if (mod != nullptr) {
458  auto* flag = mod->getModuleFlag("pass_column_arguments_by_value");
459  if (auto* cnt = llvm::mdconst::extract_or_null<llvm::ConstantInt>(flag)) {
460  return cnt->getZExtValue();
461  }
462  }
463 
464  // fallback to original behavior
465  return exe_unit.table_func.isRuntime();
466 }
467 
469  const TableFunctionExecutionUnit& exe_unit,
470  const std::vector<llvm::Value*>& func_args,
471  llvm::BasicBlock* bb_exit,
472  llvm::Value* output_row_count_ptr,
473  bool emit_only_preflight_fn) {
474  auto cgen_state = executor_->getCgenStatePtr();
475  // Emit llvm IR code to call the table function
476  llvm::LLVMContext& ctx = cgen_state->context_;
477  llvm::IRBuilder<>* ir_builder = &cgen_state->ir_builder_;
478 
479  std::string func_name =
480  (emit_only_preflight_fn ? exe_unit.table_func.getPreFlightFnName()
481  : exe_unit.table_func.getName(false, true));
482  llvm::Value* table_func_return =
483  cgen_state->emitExternalCall(func_name, get_int_type(32, ctx), func_args);
484 
485  table_func_return->setName(emit_only_preflight_fn ? "preflight_check_func_ret"
486  : "table_func_ret");
487 
488  // If table_func_return is non-negative then store the value in
489  // output_row_count and return zero. Otherwise, return
490  // table_func_return that negative value contains the error code.
491  llvm::BasicBlock* bb_exit_0 =
492  llvm::BasicBlock::Create(ctx, ".exit0", entry_point_func_);
493 
494  llvm::Constant* const_zero =
495  llvm::ConstantInt::get(table_func_return->getType(), 0, true);
496  llvm::Value* is_ok = ir_builder->CreateICmpSGE(table_func_return, const_zero);
497  ir_builder->CreateCondBr(is_ok, bb_exit_0, bb_exit);
498 
499  ir_builder->SetInsertPoint(bb_exit_0);
500  llvm::Value* r =
501  ir_builder->CreateIntCast(table_func_return, get_int_type(64, ctx), true);
502  ir_builder->CreateStore(r, output_row_count_ptr);
503  ir_builder->CreateRet(const_zero);
504 
505  ir_builder->SetInsertPoint(bb_exit);
506  ir_builder->CreateRet(table_func_return);
507 }
508 
510  const TableFunctionExecutionUnit& exe_unit,
511  bool emit_only_preflight_fn) {
512  auto timer = DEBUG_TIMER(__func__);
514  CHECK_EQ(entry_point_func_->arg_size(), 7);
515  auto arg_it = entry_point_func_->arg_begin();
516  const auto mgr_ptr = &*arg_it;
517  const auto input_cols_arg = &*(++arg_it);
518  const auto input_row_counts_arg = &*(++arg_it);
519  const auto input_str_dict_proxies_arg = &*(++arg_it);
520  const auto output_buffers_arg = &*(++arg_it);
521  const auto output_str_dict_proxies_arg = &*(++arg_it);
522  const auto output_row_count_ptr = &*(++arg_it);
523  auto cgen_state = executor_->getCgenStatePtr();
524  CHECK(cgen_state);
525  auto& ctx = cgen_state->context_;
526 
527  llvm::BasicBlock* bb_entry =
528  llvm::BasicBlock::Create(ctx, ".entry", entry_point_func_, 0);
529  cgen_state->ir_builder_.SetInsertPoint(bb_entry);
530 
531  llvm::BasicBlock* bb_exit = llvm::BasicBlock::Create(ctx, ".exit", entry_point_func_);
532 
533  llvm::BasicBlock* func_body_bb = llvm::BasicBlock::Create(
534  ctx, ".func_body0", cgen_state->ir_builder_.GetInsertBlock()->getParent());
535 
536  cgen_state->ir_builder_.SetInsertPoint(bb_entry);
537  cgen_state->ir_builder_.CreateBr(func_body_bb);
538 
539  cgen_state->ir_builder_.SetInsertPoint(func_body_bb);
540  auto col_heads = generate_column_heads_load(
541  exe_unit.input_exprs.size(), input_cols_arg, cgen_state->ir_builder_, ctx);
542  CHECK_EQ(exe_unit.input_exprs.size(), col_heads.size());
543  auto row_count_heads = generate_column_heads_load(
544  exe_unit.input_exprs.size(), input_row_counts_arg, cgen_state->ir_builder_, ctx);
545 
546  auto input_str_dict_proxy_heads = std::vector<llvm::Value*>();
547  if (!emit_only_preflight_fn and co_.device_type == ExecutorDeviceType::CPU) {
548  input_str_dict_proxy_heads = generate_column_heads_load(exe_unit.input_exprs.size(),
549  input_str_dict_proxies_arg,
550  cgen_state->ir_builder_,
551  ctx);
552  }
553  // The column arguments of C++ UDTFs processed by clang must be
554  // passed by reference, see rbc issues 200 and 289.
555  auto pass_column_by_value = passColumnsByValue(exe_unit);
556  std::vector<llvm::Value*> func_args;
557  std::vector<std::pair<llvm::Value*, const SQLTypeInfo>> columns_to_cast;
558  size_t func_arg_index = 0;
559  if (exe_unit.table_func.usesManager()) {
560  func_args.push_back(mgr_ptr);
561  func_arg_index++;
562  }
563  int col_index = -1;
564 
565  for (size_t i = 0; i < exe_unit.input_exprs.size(); i++) {
566  const auto& expr = exe_unit.input_exprs[i];
567  const auto& ti = expr->get_type_info();
568  if (col_index == -1) {
569  func_arg_index += 1;
570  }
571  if (ti.is_fp()) {
572  auto r = cgen_state->ir_builder_.CreateBitCast(
573  col_heads[i], get_fp_ptr_type(get_bit_width(ti), ctx));
574  llvm::LoadInst* scalar_fp = cgen_state->ir_builder_.CreateLoad(
575  r->getType()->getPointerElementType(),
576  r,
577  "input_scalar_fp." + std::to_string(func_arg_index));
578  func_args.push_back(scalar_fp);
579  CHECK_EQ(col_index, -1);
580  } else if (ti.is_integer() || ti.is_boolean() || ti.is_timestamp() ||
581  ti.is_timeinterval()) {
582  auto r = cgen_state->ir_builder_.CreateBitCast(
583  col_heads[i], get_int_ptr_type(get_bit_width(ti), ctx));
584  llvm::LoadInst* scalar_int = cgen_state->ir_builder_.CreateLoad(
585  r->getType()->getPointerElementType(),
586  r,
587  "input_scalar_int." + std::to_string(func_arg_index));
588  func_args.push_back(scalar_int);
589  CHECK_EQ(col_index, -1);
590  } else if (ti.is_bytes()) {
591  auto varchar_size =
592  cgen_state->ir_builder_.CreateBitCast(col_heads[i], get_int_ptr_type(64, ctx));
593  auto varchar_ptr = cgen_state->ir_builder_.CreateGEP(
594  col_heads[i]->getType()->getScalarType()->getPointerElementType(),
595  col_heads[i],
596  cgen_state->llInt(8));
597  auto [varchar_struct, varchar_struct_ptr] = alloc_column(
598  std::string("input_varchar_literal.") + std::to_string(func_arg_index),
599  i,
600  ti,
601  varchar_ptr,
602  varchar_size,
603  nullptr,
604  ctx,
605  cgen_state->ir_builder_);
606  func_args.push_back(
607  (pass_column_by_value
608  ? cgen_state->ir_builder_.CreateLoad(
609  varchar_struct->getType()->getPointerElementType(), varchar_struct)
610  : varchar_struct_ptr));
611  CHECK_EQ(col_index, -1);
612  } else if (ti.is_column()) {
613  auto [col, col_ptr] = alloc_column(
614  std::string("input_col.") + std::to_string(func_arg_index),
615  i,
616  ti.get_elem_type(),
617  col_heads[i],
618  row_count_heads[i],
619  (co_.device_type != ExecutorDeviceType::CPU || emit_only_preflight_fn)
620  ? nullptr
621  : input_str_dict_proxy_heads[i],
622  ctx,
623  cgen_state->ir_builder_);
624  func_args.push_back((pass_column_by_value
625  ? cgen_state->ir_builder_.CreateLoad(
626  col->getType()->getPointerElementType(), col)
627  : col_ptr));
628 
629  if (columnTypeRequiresCasting(ti) &&
631  columns_to_cast.push_back(std::make_pair(col_ptr, ti));
632  }
633  CHECK_EQ(col_index, -1);
634  } else if (ti.is_column_list()) {
635  if (col_index == -1) {
636  auto col_list = alloc_column_list(
637  std::string("input_col_list.") + std::to_string(func_arg_index),
638  ti.get_elem_type(),
639  col_heads[i],
640  ti.get_dimension(),
641  row_count_heads[i],
642  (emit_only_preflight_fn) ? nullptr : input_str_dict_proxy_heads[i],
643  ctx,
644  cgen_state->ir_builder_);
645  func_args.push_back(col_list);
646  }
647  col_index++;
648  if (col_index + 1 == ti.get_dimension()) {
649  col_index = -1;
650  }
651  } else {
652  throw std::runtime_error(
653  "Only integer and floating point columns or scalars are supported as inputs to "
654  "table "
655  "functions, got " +
656  ti.get_type_name());
657  }
658  }
659  auto output_str_dict_proxy_heads =
661  ? (generate_column_heads_load(exe_unit.target_exprs.size(),
662  output_str_dict_proxies_arg,
663  cgen_state->ir_builder_,
664  ctx))
665  : std::vector<llvm::Value*>();
666 
667  std::vector<llvm::Value*> output_col_args;
668  for (size_t i = 0; i < exe_unit.target_exprs.size(); i++) {
669  auto* gep = cgen_state->ir_builder_.CreateGEP(
670  output_buffers_arg->getType()->getScalarType()->getPointerElementType(),
671  output_buffers_arg,
672  cgen_state->llInt(i));
673  auto output_load =
674  cgen_state->ir_builder_.CreateLoad(gep->getType()->getPointerElementType(), gep);
675  const auto& expr = exe_unit.target_exprs[i];
676  const auto& ti = expr->get_type_info();
677  CHECK(!ti.is_column()); // UDTF output column type is its data type
678  CHECK(!ti.is_column_list()); // TODO: when UDTF outputs column_list, convert it to
679  // output columns
680  auto [col, col_ptr] = alloc_column(
681  std::string("output_col.") + std::to_string(i),
682  i,
683  ti,
685  ? output_load
686  : nullptr), // CPU: set_output_row_size will set the output
687  // Column ptr member
688  output_row_count_ptr,
689  co_.device_type == ExecutorDeviceType::CPU ? output_str_dict_proxy_heads[i]
690  : nullptr,
691  ctx,
692  cgen_state->ir_builder_);
693  if (co_.device_type == ExecutorDeviceType::CPU && !emit_only_preflight_fn) {
694  cgen_state->emitExternalCall(
695  "TableFunctionManager_register_output_column",
696  llvm::Type::getVoidTy(ctx),
697  {mgr_ptr, llvm::ConstantInt::get(get_int_type(32, ctx), i, true), col_ptr});
698  }
699  output_col_args.push_back((pass_column_by_value ? col : col_ptr));
700  }
701 
702  // output column members must be set before loading column when
703  // column instances are passed by value
704  if ((exe_unit.table_func.hasOutputSizeKnownPreLaunch() ||
705  exe_unit.table_func.hasPreFlightOutputSizer()) &&
706  (co_.device_type == ExecutorDeviceType::CPU) && !emit_only_preflight_fn) {
707  cgen_state->emitExternalCall(
708  "TableFunctionManager_set_output_row_size",
709  llvm::Type::getVoidTy(ctx),
710  {mgr_ptr,
711  cgen_state->ir_builder_.CreateLoad(
712  output_row_count_ptr->getType()->getPointerElementType(),
713  output_row_count_ptr)});
714  }
715 
716  if (!emit_only_preflight_fn) {
717  for (auto& col : output_col_args) {
718  func_args.push_back((pass_column_by_value
719  ? cgen_state->ir_builder_.CreateLoad(
720  col->getType()->getPointerElementType(), col)
721  : col));
722  }
723  }
724 
725  if (exe_unit.table_func.mayRequireCastingInputTypes() && !emit_only_preflight_fn) {
726  generateCastsForInputTypes(exe_unit, columns_to_cast, mgr_ptr);
727  }
728 
730  exe_unit, func_args, bb_exit, output_row_count_ptr, emit_only_preflight_fn);
731 
732  // std::cout << "=================================" << std::endl;
733  // entry_point_func_->print(llvm::outs());
734  // std::cout << "=================================" << std::endl;
735 
737 }
738 
740  const TableFunctionExecutionUnit& exe_unit,
741  const std::vector<std::pair<llvm::Value*, const SQLTypeInfo>>& columns_to_cast,
742  llvm::Value* mgr_ptr) {
743  auto* cgen_state = executor_->getCgenStatePtr();
744  llvm::LLVMContext& ctx = cgen_state->context_;
745  llvm::IRBuilder<>* ir_builder = &cgen_state->ir_builder_;
746  CodeGenerator codeGenerator = CodeGenerator(cgen_state, executor_->getPlanStatePtr());
747  llvm::Function* old_func = cgen_state->current_func_;
748  cgen_state->current_func_ =
749  entry_point_func_; // update cgen_state current func for CodeGenerator
750 
751  for (unsigned i = 0; i < columns_to_cast.size(); ++i) {
752  auto [col_ptr, ti] = columns_to_cast[i];
753 
754  if (ti.is_column() && ti.get_subtype() == kTIMESTAMP && ti.get_precision() != 9) {
755  // TIMESTAMP columns should always have nanosecond precision
756  SQLTypeInfo orig_ti = SQLTypeInfo(
757  ti.get_subtype(), ti.get_precision(), ti.get_dimension(), ti.get_notnull());
758  SQLTypeInfo dest_ti =
759  SQLTypeInfo(kTIMESTAMP, 9, ti.get_dimension(), ti.get_notnull());
760  cast_column(col_ptr,
762  orig_ti,
763  dest_ti,
764  std::to_string(i + 1),
765  *ir_builder,
766  ctx,
767  codeGenerator);
768  }
769  }
770 
771  // The QueryEngine CodeGenerator will codegen return values corresponding to
772  // QueryExecutionError codes. Since at the table function level we'd like error handling
773  // to be done by the TableFunctionManager, we replace the codegen'd returns by calls to
774  // the appropriate Manager functions.
776  // TableFunctionManager is not supported on GPU, so leave the QueryExecutionError code
777  return;
778  }
779 
780  std::vector<llvm::ReturnInst*> rets_to_replace;
781  for (llvm::BasicBlock& BB : *entry_point_func_) {
782  for (llvm::Instruction& I : BB) {
783  if (!llvm::isa<llvm::ReturnInst>(&I)) {
784  continue;
785  }
786  llvm::ReturnInst* RI = llvm::cast<llvm::ReturnInst>(&I);
787  llvm::Value* retValue = RI->getReturnValue();
788  if (!retValue || !llvm::isa<llvm::ConstantInt>(retValue)) {
789  continue;
790  }
791  llvm::ConstantInt* retConst = llvm::cast<llvm::ConstantInt>(retValue);
792  if (retConst->getValue() == 7) {
793  // ret 7 = underflow/overflow during casting attempt
794  rets_to_replace.push_back(RI);
795  }
796  }
797  }
798 
799  auto prev_insert_point = ir_builder->saveIP();
800  for (llvm::ReturnInst* RI : rets_to_replace) {
801  ir_builder->SetInsertPoint(RI);
802  llvm::Value* err_msg = ir_builder->CreateGlobalStringPtr(
803  "Underflow or overflow during casting of input types!", "cast_err_str");
804  llvm::Value* error_call;
805  if (exe_unit.table_func.usesManager()) {
806  error_call = cgen_state->emitExternalCall("TableFunctionManager_error_message",
807  ir_builder->getInt32Ty(),
808  {mgr_ptr, err_msg});
809  } else {
810  error_call = cgen_state->emitExternalCall(
811  "table_function_error", ir_builder->getInt32Ty(), {err_msg});
812  }
813  llvm::ReplaceInstWithInst(RI, llvm::ReturnInst::Create(ctx, error_call));
814  }
815  ir_builder->restoreIP(prev_insert_point);
816 
817  cgen_state->current_func_ = old_func;
818 }
819 
821  auto timer = DEBUG_TIMER(__func__);
823  std::vector<llvm::Type*> arg_types;
824  arg_types.reserve(entry_point_func_->arg_size());
825  std::for_each(entry_point_func_->arg_begin(),
826  entry_point_func_->arg_end(),
827  [&arg_types](const auto& arg) { arg_types.push_back(arg.getType()); });
828  CHECK_EQ(arg_types.size(), entry_point_func_->arg_size());
829 
830  auto cgen_state = executor_->getCgenStatePtr();
831  CHECK(cgen_state);
832  auto& ctx = cgen_state->context_;
833 
834  std::vector<llvm::Type*> wrapper_arg_types(arg_types.size() + 1);
835  wrapper_arg_types[0] = llvm::PointerType::get(get_int_type(32, ctx), 0);
836  wrapper_arg_types[1] = arg_types[0];
837 
838  for (size_t i = 1; i < arg_types.size(); ++i) {
839  wrapper_arg_types[i + 1] = arg_types[i];
840  }
841 
842  auto wrapper_ft =
843  llvm::FunctionType::get(llvm::Type::getVoidTy(ctx), wrapper_arg_types, false);
844  kernel_func_ = llvm::Function::Create(wrapper_ft,
845  llvm::Function::ExternalLinkage,
846  "table_func_kernel",
847  cgen_state->module_);
848 
849  auto wrapper_bb_entry = llvm::BasicBlock::Create(ctx, ".entry", kernel_func_, 0);
850  llvm::IRBuilder<> b(ctx);
851  b.SetInsertPoint(wrapper_bb_entry);
852  std::vector<llvm::Value*> loaded_args = {kernel_func_->arg_begin() + 1};
853  for (size_t i = 2; i < wrapper_arg_types.size(); ++i) {
854  loaded_args.push_back(kernel_func_->arg_begin() + i);
855  }
856  auto error_lv = b.CreateCall(entry_point_func_, loaded_args);
857  b.CreateStore(error_lv, kernel_func_->arg_begin());
858  b.CreateRetVoid();
859 }
860 
861 std::shared_ptr<CompilationContext> TableFunctionCompilationContext::finalize(
862  bool emit_only_preflight_fn) {
863  auto timer = DEBUG_TIMER(__func__);
864  /*
865  TODO 1: eliminate need for OverrideFromSrc
866  TODO 2: detect and link only the udf's that are needed
867  */
868  auto cgen_state = executor_->getCgenStatePtr();
869  auto is_gpu = co_.device_type == ExecutorDeviceType::GPU;
870  if (executor_->has_rt_udf_module(is_gpu)) {
871  CodeGenerator::link_udf_module(executor_->get_rt_udf_module(is_gpu),
872  *(cgen_state->module_),
873  cgen_state,
874  llvm::Linker::Flags::OverrideFromSrc);
875  }
876 
877  LOG(IR) << (emit_only_preflight_fn ? "Pre Flight Function Entry Point IR\n"
878  : "Table Function Entry Point IR\n")
880  std::shared_ptr<CompilationContext> code;
881  if (is_gpu) {
882  LOG(IR) << "Table Function Kernel IR\n" << serialize_llvm_object(kernel_func_);
883 
884  CHECK(executor_);
885  executor_->initializeNVPTXBackend();
886 
887  CodeGenerator::GPUTarget gpu_target{
888  executor_->nvptx_target_machine_.get(), executor_->cudaMgr(), cgen_state, false};
891  kernel_func_,
893  /*is_gpu_smem_used=*/false,
894  co_,
895  gpu_target);
896  } else {
897  auto ee =
899  auto cpu_code = std::make_shared<CpuCompilationContext>(std::move(ee));
900  cpu_code->setFunctionPointer(entry_point_func_);
901  code = cpu_code;
902  }
903  LOG(IR) << "End of IR";
904 
905  return code;
906 }
llvm::Type * get_fp_ptr_type(const int width, llvm::LLVMContext &context)
HOST DEVICE SQLTypes get_subtype() const
Definition: sqltypes.h:382
#define CHECK_EQ(x, y)
Definition: Logger.h:301
std::string exprsKey(const std::vector< Analyzer::Expr * > &exprs)
HOST DEVICE int get_size() const
Definition: sqltypes.h:393
bool is_timestamp() const
Definition: sqltypes.h:1014
void generateEntryPoint(const TableFunctionExecutionUnit &exe_unit, bool emit_only_preflight_fn)
void initialize_ptr_member(llvm::Value *member_ptr, llvm::Type *member_llvm_type, llvm::Value *value_ptr, llvm::IRBuilder<> &ir_builder)
bool passColumnsByValue(const TableFunctionExecutionUnit &exe_unit)
std::vector< Analyzer::Expr * > input_exprs
const table_functions::TableFunction table_func
#define LOG(tag)
Definition: Logger.h:285
bool is_fp() const
Definition: sqltypes.h:584
std::tuple< llvm::Value *, llvm::Value * > alloc_column(std::string col_name, const size_t index, const SQLTypeInfo &data_target_info, llvm::Value *data_ptr, llvm::Value *data_size, llvm::Value *data_str_dict_proxy_ptr, llvm::LLVMContext &ctx, llvm::IRBuilder<> &ir_builder)
llvm::Function * generate_entry_point(const CgenState *cgen_state)
std::shared_ptr< CompilationContext > finalize(bool emit_only_preflight_fn)
std::vector< std::string > CodeCacheKey
Definition: CodeCache.h:25
void generateCastsForInputTypes(const TableFunctionExecutionUnit &exe_unit, const std::vector< std::pair< llvm::Value *, const SQLTypeInfo >> &columns_to_cast, llvm::Value *mgr_ptr)
static ExecutionEngineWrapper generateNativeCPUCode(llvm::Function *func, const std::unordered_set< llvm::Function * > &live_funcs, const CompilationOptions &co)
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
llvm::Value * alloc_column_list(std::string col_list_name, const SQLTypeInfo &data_target_info, llvm::Value *data_ptrs, int length, llvm::Value *data_size, llvm::Value *data_str_dict_proxy_ptrs, llvm::LLVMContext &ctx, llvm::IRBuilder<> &ir_builder)
std::string to_string(char const *&&v)
llvm::Module * module_
Definition: CgenState.h:366
void verify_function_ir(const llvm::Function *func)
size_t get_bit_width(const SQLTypeInfo &ti)
llvm::LLVMContext & context_
Definition: CgenState.h:375
bool is_integer() const
Definition: sqltypes.h:582
std::string toString(const ExecutorDeviceType &device_type)
bool is_boolean() const
Definition: sqltypes.h:587
static void link_udf_module(const std::unique_ptr< llvm::Module > &udf_module, llvm::Module &module, CgenState *cgen_state, llvm::Linker::Flags flags=llvm::Linker::Flags::None)
void generateTableFunctionCall(const TableFunctionExecutionUnit &exe_unit, const std::vector< llvm::Value * > &func_args, llvm::BasicBlock *bb_exit, llvm::Value *output_row_count_ptr, bool emit_only_preflight_fn)
std::string getName(const bool drop_suffix=false, const bool lower=false) const
int get_precision() const
Definition: sqltypes.h:384
ExecutorDeviceType device_type
bool is_column() const
Definition: sqltypes.h:593
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:389
std::string serialize_llvm_object(const T *llvm_obj)
static std::shared_ptr< GpuCompilationContext > generateNativeGPUCode(Executor *executor, llvm::Function *func, llvm::Function *wrapper_func, const std::unordered_set< llvm::Function * > &live_funcs, const bool is_gpu_smem_used, const CompilationOptions &co, const GPUTarget &gpu_target)
std::vector< llvm::Value * > generate_column_heads_load(const int num_columns, llvm::Value *byte_stream_arg, llvm::IRBuilder<> &ir_builder, llvm::LLVMContext &ctx)
std::string get_type_name() const
Definition: sqltypes.h:507
void cast_column(llvm::Value *col_base_ptr, llvm::Function *func, SQLTypeInfo &orig_ti, SQLTypeInfo &dest_ti, std::string index, llvm::IRBuilder<> &ir_builder, llvm::LLVMContext &ctx, CodeGenerator &codeGenerator)
llvm::Value * codegenCastBetweenTimestamps(llvm::Value *ts_lv, const SQLTypeInfo &operand_dimen, const SQLTypeInfo &target_dimen, const bool nullable)
Definition: CastIR.cpp:199
llvm::Type * get_llvm_type_from_sql_column_type(const SQLTypeInfo elem_ti, llvm::LLVMContext &ctx)
bool is_bytes() const
Definition: sqltypes.h:603
#define CHECK(condition)
Definition: Logger.h:291
#define DEBUG_TIMER(name)
Definition: Logger.h:411
void initialize_int64_member(llvm::Value *member_ptr, llvm::Value *value, int64_t default_value, llvm::LLVMContext &ctx, llvm::IRBuilder<> &ir_builder)
static std::shared_ptr< QueryEngine > getInstance()
Definition: QueryEngine.h:81
std::shared_ptr< CompilationContext > compile(const TableFunctionExecutionUnit &exe_unit, bool emit_only_preflight_fn)
std::vector< Analyzer::Expr * > target_exprs
bool is_string() const
Definition: sqltypes.h:580
HOST DEVICE bool get_notnull() const
Definition: sqltypes.h:388
llvm::Type * get_int_ptr_type(const int width, llvm::LLVMContext &context)
bool is_array() const
Definition: sqltypes.h:588
llvm::Value * cast_value(llvm::Value *value, SQLTypeInfo &orig_ti, SQLTypeInfo &dest_ti, bool nullable, CodeGenerator &codeGenerator)