OmniSciDB  c0231cc57d
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TableFunctionCompilationContext.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <llvm/IR/InstIterator.h>
20 #include <llvm/IR/Verifier.h>
21 #include <llvm/Support/raw_os_ostream.h>
22 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
23 #include <algorithm>
24 #include <boost/algorithm/string.hpp>
25 
28 
29 namespace {
30 
31 llvm::Function* generate_entry_point(const CgenState* cgen_state) {
32  auto& ctx = cgen_state->context_;
33  const auto pi8_type = llvm::PointerType::get(get_int_type(8, ctx), 0);
34  const auto ppi8_type = llvm::PointerType::get(pi8_type, 0);
35  const auto pi64_type = llvm::PointerType::get(get_int_type(64, ctx), 0);
36  const auto ppi64_type = llvm::PointerType::get(pi64_type, 0);
37  const auto i32_type = get_int_type(32, ctx);
38 
39  const auto func_type = llvm::FunctionType::get(
40  i32_type,
41  {pi8_type, ppi8_type, pi64_type, ppi8_type, ppi64_type, ppi8_type, pi64_type},
42  false);
43 
44  auto func = llvm::Function::Create(func_type,
45  llvm::Function::ExternalLinkage,
46  "call_table_function",
47  cgen_state->module_);
48  auto arg_it = func->arg_begin();
49  const auto mgr_arg = &*arg_it;
50  mgr_arg->setName("mgr_ptr");
51  const auto input_cols_arg = &*(++arg_it);
52  input_cols_arg->setName("input_col_buffers");
53  const auto input_row_counts = &*(++arg_it);
54  input_row_counts->setName("input_row_counts");
55  const auto input_str_dict_proxies = &*(++arg_it);
56  input_str_dict_proxies->setName("input_str_dict_proxies");
57  const auto output_buffers = &*(++arg_it);
58  output_buffers->setName("output_buffers");
59  const auto output_str_dict_proxies = &*(++arg_it);
60  output_str_dict_proxies->setName("output_str_dict_proxies");
61  const auto output_row_count = &*(++arg_it);
62  output_row_count->setName("output_row_count");
63  return func;
64 }
65 
67  llvm::LLVMContext& ctx) {
68  if (elem_ti.is_fp()) {
69  return get_fp_ptr_type(elem_ti.get_size() * 8, ctx);
70  } else if (elem_ti.is_boolean()) {
71  return get_int_ptr_type(8, ctx);
72  } else if (elem_ti.is_integer()) {
73  return get_int_ptr_type(elem_ti.get_size() * 8, ctx);
74  } else if (elem_ti.is_string()) {
75  if (elem_ti.get_compression() == kENCODING_DICT) {
76  return get_int_ptr_type(elem_ti.get_size() * 8, ctx);
77  }
78  CHECK(elem_ti.is_bytes()); // None encoded string
79  return get_int_ptr_type(8, ctx);
80  } else if (elem_ti.is_timestamp()) {
81  return get_int_ptr_type(elem_ti.get_size() * 8, ctx);
82  } else if (elem_ti.is_array()) {
83  return get_int_ptr_type(8, ctx);
84  }
85  LOG(FATAL) << "get_llvm_type_from_sql_column_type: not implemented for "
86  << ::toString(elem_ti);
87  return nullptr;
88 }
89 
90 void initialize_ptr_member(llvm::Value* member_ptr,
91  llvm::Type* member_llvm_type,
92  llvm::Value* value_ptr,
93  llvm::IRBuilder<>& ir_builder) {
94  if (value_ptr != nullptr) {
95  if (value_ptr->getType() == member_llvm_type->getPointerElementType()) {
96  ir_builder.CreateStore(value_ptr, member_ptr);
97  } else {
98  auto tmp = ir_builder.CreateBitCast(value_ptr, member_llvm_type);
99  ir_builder.CreateStore(tmp, member_ptr);
100  }
101  } else {
102  ir_builder.CreateStore(llvm::Constant::getNullValue(member_llvm_type), member_ptr);
103  }
104 }
105 
106 void initialize_int64_member(llvm::Value* member_ptr,
107  llvm::Value* value,
108  int64_t default_value,
109  llvm::LLVMContext& ctx,
110  llvm::IRBuilder<>& ir_builder) {
111  llvm::Value* val = nullptr;
112  if (value != nullptr) {
113  auto value_type = value->getType();
114  if (value_type->isPointerTy()) {
115  CHECK(value_type->getPointerElementType()->isIntegerTy(64));
116  val = ir_builder.CreateLoad(value->getType()->getPointerElementType(), value);
117  } else {
118  CHECK(value_type->isIntegerTy(64));
119  val = value;
120  }
121  ir_builder.CreateStore(val, member_ptr);
122  } else {
123  auto const_default =
124  llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), default_value, true);
125  ir_builder.CreateStore(const_default, member_ptr);
126  }
127 }
128 
129 std::tuple<llvm::Value*, llvm::Value*> alloc_column(std::string col_name,
130  const size_t index,
131  const SQLTypeInfo& data_target_info,
132  llvm::Value* data_ptr,
133  llvm::Value* data_size,
134  llvm::Value* data_str_dict_proxy_ptr,
135  llvm::LLVMContext& ctx,
136  llvm::IRBuilder<>& ir_builder) {
137  /*
138  Creates a new Column instance of given element type and initialize
139  its data ptr and sz members when specified. If data ptr or sz are
140  unspecified (have nullptr values) then the corresponding members
141  are initialized with NULL and -1, respectively.
142 
143  If we are allocating a TextEncodingDict Column type, this function
144  adds and populates a int8* pointer to a StringDictProxy object.
145 
146  Return a pair of Column allocation (caller should apply
147  builder.CreateLoad to it in order to construct a Column instance
148  as a value) and a pointer to the Column instance.
149  */
150  const bool is_text_encoding_dict_type =
151  data_target_info.is_string() &&
152  data_target_info.get_compression() == kENCODING_DICT;
153  llvm::StructType* col_struct_type;
154  llvm::Type* data_ptr_llvm_type =
155  get_llvm_type_from_sql_column_type(data_target_info, ctx);
156  if (is_text_encoding_dict_type) {
157  col_struct_type = llvm::StructType::get(
158  ctx,
159  {
160  data_ptr_llvm_type, /* T* ptr */
161  llvm::Type::getInt64Ty(ctx), /* int64_t sz */
162  llvm::Type::getInt8PtrTy(ctx) /* int8_t* string_dictionary_ptr */
163  });
164  } else {
165  col_struct_type =
166  llvm::StructType::get(ctx,
167  {
168  data_ptr_llvm_type, /* T* ptr */
169  llvm::Type::getInt64Ty(ctx) /* int64_t sz */
170  });
171  }
172 
173  auto col = ir_builder.CreateAlloca(col_struct_type);
174  col->setName(col_name);
175  auto col_ptr_ptr = ir_builder.CreateStructGEP(col_struct_type, col, 0);
176  auto col_sz_ptr = ir_builder.CreateStructGEP(col_struct_type, col, 1);
177  auto col_str_dict_ptr = is_text_encoding_dict_type
178  ? ir_builder.CreateStructGEP(col_struct_type, col, 2)
179  : nullptr;
180  col_ptr_ptr->setName(col_name + ".ptr");
181  col_sz_ptr->setName(col_name + ".sz");
182  if (is_text_encoding_dict_type) {
183  col_str_dict_ptr->setName(col_name + ".string_dict_proxy");
184  }
185 
186  initialize_ptr_member(col_ptr_ptr, data_ptr_llvm_type, data_ptr, ir_builder);
187  initialize_int64_member(col_sz_ptr, data_size, -1, ctx, ir_builder);
188  if (is_text_encoding_dict_type) {
189  initialize_ptr_member(col_str_dict_ptr,
190  llvm::Type::getInt8PtrTy(ctx),
191  data_str_dict_proxy_ptr,
192  ir_builder);
193  }
194  auto col_ptr = ir_builder.CreatePointerCast(
195  col_ptr_ptr, llvm::PointerType::get(llvm::Type::getInt8Ty(ctx), 0));
196  col_ptr->setName(col_name + "_ptr");
197  return {col, col_ptr};
198 }
199 
200 llvm::Value* alloc_column_list(std::string col_list_name,
201  const SQLTypeInfo& data_target_info,
202  llvm::Value* data_ptrs,
203  int length,
204  llvm::Value* data_size,
205  llvm::Value* data_str_dict_proxy_ptrs,
206  llvm::LLVMContext& ctx,
207  llvm::IRBuilder<>& ir_builder) {
208  /*
209  Creates a new ColumnList instance of given element type and initialize
210  its members. If data ptr or size are unspecified (have nullptr
211  values) then the corresponding members are initialized with NULL
212  and -1, respectively.
213  */
214  llvm::Type* data_ptrs_llvm_type = llvm::Type::getInt8PtrTy(ctx);
215  const bool is_text_encoding_dict_type =
216  data_target_info.is_string() &&
217  data_target_info.get_compression() == kENCODING_DICT;
218 
219  llvm::StructType* col_list_struct_type =
220  is_text_encoding_dict_type
221  ? llvm::StructType::get(
222  ctx,
223  {
224  data_ptrs_llvm_type, /* int8_t* ptrs */
225  llvm::Type::getInt64Ty(ctx), /* int64_t length */
226  llvm::Type::getInt64Ty(ctx), /* int64_t size */
227  data_ptrs_llvm_type /* int8_t* str_dict_proxy_ptrs */
228  })
229  : llvm::StructType::get(ctx,
230  {
231  data_ptrs_llvm_type, /* int8_t* ptrs */
232  llvm::Type::getInt64Ty(ctx), /* int64_t length */
233  llvm::Type::getInt64Ty(ctx) /* int64_t size */
234  });
235 
236  auto col_list = ir_builder.CreateAlloca(col_list_struct_type);
237  col_list->setName(col_list_name);
238  auto col_list_ptr_ptr = ir_builder.CreateStructGEP(col_list_struct_type, col_list, 0);
239  auto col_list_length_ptr =
240  ir_builder.CreateStructGEP(col_list_struct_type, col_list, 1);
241  auto col_list_size_ptr = ir_builder.CreateStructGEP(col_list_struct_type, col_list, 2);
242  auto col_str_dict_ptr_ptr =
243  is_text_encoding_dict_type
244  ? ir_builder.CreateStructGEP(col_list_struct_type, col_list, 3)
245  : nullptr;
246 
247  col_list_ptr_ptr->setName(col_list_name + ".ptrs");
248  col_list_length_ptr->setName(col_list_name + ".length");
249  col_list_size_ptr->setName(col_list_name + ".size");
250  if (is_text_encoding_dict_type) {
251  col_str_dict_ptr_ptr->setName(col_list_name + ".string_dict_proxies");
252  }
253 
254  initialize_ptr_member(col_list_ptr_ptr, data_ptrs_llvm_type, data_ptrs, ir_builder);
255 
256  CHECK(length >= 0);
257  auto const_length = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), length, true);
258  ir_builder.CreateStore(const_length, col_list_length_ptr);
259 
260  initialize_int64_member(col_list_size_ptr, data_size, -1, ctx, ir_builder);
261 
262  if (is_text_encoding_dict_type) {
263  initialize_ptr_member(col_str_dict_ptr_ptr,
264  data_str_dict_proxy_ptrs->getType(),
265  data_str_dict_proxy_ptrs,
266  ir_builder);
267  }
268 
269  auto col_list_ptr = ir_builder.CreatePointerCast(
270  col_list_ptr_ptr, llvm::PointerType::get(llvm::Type::getInt8Ty(ctx), 0));
271  col_list_ptr->setName(col_list_name + "_ptrs");
272  return col_list_ptr;
273 }
274 
275 static bool columnTypeRequiresCasting(const SQLTypeInfo& ti) {
276  /*
277  Returns whether a column requires casting before table function execution based on its
278  underlying SQL type
279  */
280 
281  if (!ti.is_column()) {
282  return false;
283  }
284 
285  // TIMESTAMP columns should always have nanosecond precision
286  if (ti.get_subtype() == kTIMESTAMP && ti.get_precision() != 9) {
287  return true;
288  }
289 
290  return false;
291 }
292 
293 llvm::Value* cast_value(llvm::Value* value,
294  SQLTypeInfo& orig_ti,
295  SQLTypeInfo& dest_ti,
296  bool nullable,
297  CodeGenerator& codeGenerator) {
298  /*
299  Codegens a cast of a value from a given origin type to a given destination type, if such
300  implementation is available. Errors for unsupported casts.
301  */
302  if (orig_ti.is_timestamp() && dest_ti.is_timestamp()) {
303  return codeGenerator.codegenCastBetweenTimestamps(
304  value, orig_ti, dest_ti, !dest_ti.get_notnull());
305  } else {
306  throw std::runtime_error("Unsupported cast from " + orig_ti.get_type_name() + " to " +
307  dest_ti.get_type_name() + " during UDTF code generation.");
308  }
309 }
310 
311 void cast_column(llvm::Value* col_base_ptr,
312  llvm::Function* func,
313  SQLTypeInfo& orig_ti,
314  SQLTypeInfo& dest_ti,
315  std::string index,
316  llvm::IRBuilder<>& ir_builder,
317  llvm::LLVMContext& ctx,
318  CodeGenerator& codeGenerator) {
319  /*
320  Generates code to cast a Column instance from a given origin
321  SQLType to a new destinaton SQLType. To do so, it generates a
322  loop with the following overall structure:
323 
324  --------------
325  | pre_header |
326  | i = 0 |
327  --------------
328  |
329  v
330  ---------------- ----------------
331  | cond | | body |
332  | i < col.size | (True) -> | cast(col[i]) |
333  ---------------- | i++ |
334  (False) ^ ----------------
335  | \____________________/
336  |
337  v
338  ---------------
339  | end |
340  ---------------
341 
342  The correctness of the cast as well as error handling/early
343  exiting in case of cast failures are left to the CodeGenerator
344  which generates the code for the cast operation itself.
345  */
346 
347  llvm::BasicBlock* for_pre =
348  llvm::BasicBlock::Create(ctx, "for_pre_cast." + index, func);
349  llvm::BasicBlock* for_cond =
350  llvm::BasicBlock::Create(ctx, "for_cond_cast." + index, func);
351  llvm::BasicBlock* for_body =
352  llvm::BasicBlock::Create(ctx, "for_body_cast." + index, func);
353  llvm::BasicBlock* for_end =
354  llvm::BasicBlock::Create(ctx, "for_end_cast." + index, func);
355  ir_builder.CreateBr(for_pre);
356 
357  // pre-header: load column ptr and size
358  ir_builder.SetInsertPoint(for_pre);
359  llvm::Type* data_type = get_llvm_type_from_sql_column_type(orig_ti, ctx);
360  llvm::StructType* col_struct_type =
361  llvm::StructType::get(ctx, {data_type, ir_builder.getInt64Ty()});
362  llvm::Value* struct_cast = ir_builder.CreateBitCast(
363  col_base_ptr, col_struct_type->getPointerTo(), "col_struct." + index);
364  llvm::Value* gep_ptr = ir_builder.CreateStructGEP(
365  col_struct_type, struct_cast, 0, "col_ptr_addr." + index);
366  llvm::Value* col_ptr = ir_builder.CreateLoad(data_type, gep_ptr, "col_ptr." + index);
367  llvm::Value* gep_sz =
368  ir_builder.CreateStructGEP(col_struct_type, struct_cast, 1, "col_sz_addr." + index);
369  llvm::Value* col_sz =
370  ir_builder.CreateLoad(ir_builder.getInt64Ty(), gep_sz, "col_sz." + index);
371  ir_builder.CreateBr(for_cond);
372 
373  // condition: check induction variable against loop predicate
374  ir_builder.SetInsertPoint(for_cond);
375  llvm::PHINode* for_ind_var =
376  ir_builder.CreatePHI(ir_builder.getInt64Ty(), 2, "for_ind_var." + index);
377  for_ind_var->addIncoming(ir_builder.getInt64(0), for_pre);
378  llvm::Value* for_pred =
379  ir_builder.CreateICmpSLT(for_ind_var, col_sz, "for_pred." + index);
380  ir_builder.CreateCondBr(for_pred, for_body, for_end);
381 
382  // body: perform value cast, increment induction variable
383  ir_builder.SetInsertPoint(for_body);
384  ;
385  llvm::Value* val_gep = ir_builder.CreateInBoundsGEP(
386  ir_builder.getInt64Ty(), col_ptr, for_ind_var, "val_gep." + index);
387  llvm::Value* val_load =
388  ir_builder.CreateLoad(ir_builder.getInt64Ty(), val_gep, "val_load." + index);
389  llvm::Value* cast_result = cast_value(val_load, orig_ti, dest_ti, false, codeGenerator);
390  cast_result->setName("cast_result." + index);
391  ir_builder.CreateStore(cast_result, val_gep);
392  llvm::Value* for_inc =
393  ir_builder.CreateAdd(for_ind_var, ir_builder.getInt64(1), "for_inc." + index);
394  ir_builder.CreateBr(for_cond);
395  // the cast codegening may have generated extra blocks, so for_body does not necessarily
396  // jump to for_cond directly
397  llvm::Instruction* inc_as_inst = llvm::cast<llvm::Instruction>(for_inc);
398  for_ind_var->addIncoming(for_inc, inc_as_inst->getParent());
399  ir_builder.SetInsertPoint(for_end);
400 }
401 
402 std::string exprsKey(const std::vector<Analyzer::Expr*>& exprs) {
403  std::string result;
404  for (const auto& expr : exprs) {
405  const auto& ti = expr->get_type_info();
406  result += ti.to_string() + ", ";
407  }
408  return result;
409 }
410 
411 } // namespace
412 
413 std::shared_ptr<CompilationContext> TableFunctionCompilationContext::compile(
414  const TableFunctionExecutionUnit& exe_unit,
415  bool emit_only_preflight_fn) {
416  auto timer = DEBUG_TIMER(__func__);
417 
418  // Here we assume that Executor::tf_code_accessor is cleared when a
419  // UDTF implementation is changed. TODO: Ideally, the key should
420  // contain a hash of an UDTF implementation string. This could be
421  // achieved by including the hash value to the prefix of the UDTF
422  // name, for instance.
423  CodeCacheKey key{exe_unit.table_func.getName(),
424  exprsKey(exe_unit.input_exprs),
425  exprsKey(exe_unit.target_exprs),
426  std::to_string(emit_only_preflight_fn),
428 
429  auto cached_code = QueryEngine::getInstance()->tf_code_accessor->get_or_wait(key);
430  if (cached_code) {
431  return *cached_code;
432  }
433 
434  auto cgen_state = executor_->getCgenStatePtr();
435  CHECK(cgen_state);
436  CHECK(cgen_state->module_ == nullptr);
437  cgen_state->set_module_shallow_copy(executor_->get_rt_module());
438 
440 
441  generateEntryPoint(exe_unit, emit_only_preflight_fn);
442 
444  CHECK(!emit_only_preflight_fn);
446  }
447 
448  QueryEngine::getInstance()->tf_code_accessor->swap(key,
449  finalize(emit_only_preflight_fn));
450  return QueryEngine::getInstance()->tf_code_accessor->get_value(key);
451 }
452 
454  const TableFunctionExecutionUnit& exe_unit) {
455  bool is_gpu = co_.device_type == ExecutorDeviceType::GPU;
456  auto mod = executor_->get_rt_udf_module(is_gpu).get();
457  if (mod != nullptr) {
458  auto* flag = mod->getModuleFlag("pass_column_arguments_by_value");
459  if (auto* cnt = llvm::mdconst::extract_or_null<llvm::ConstantInt>(flag)) {
460  return cnt->getZExtValue();
461  }
462  }
463 
464  // fallback to original behavior
465  return exe_unit.table_func.isRuntime();
466 }
467 
469  const TableFunctionExecutionUnit& exe_unit,
470  const std::vector<llvm::Value*>& func_args,
471  llvm::BasicBlock* bb_exit,
472  llvm::Value* output_row_count_ptr,
473  bool emit_only_preflight_fn) {
474  auto cgen_state = executor_->getCgenStatePtr();
475  // Emit llvm IR code to call the table function
476  llvm::LLVMContext& ctx = cgen_state->context_;
477  llvm::IRBuilder<>* ir_builder = &cgen_state->ir_builder_;
478 
479  std::string func_name =
480  (emit_only_preflight_fn ? exe_unit.table_func.getPreFlightFnName()
481  : exe_unit.table_func.getName(false, true));
482  llvm::Value* table_func_return =
483  cgen_state->emitExternalCall(func_name, get_int_type(32, ctx), func_args);
484 
485  table_func_return->setName(emit_only_preflight_fn ? "preflight_check_func_ret"
486  : "table_func_ret");
487 
488  // If table_func_return is non-negative then store the value in
489  // output_row_count and return zero. Otherwise, return
490  // table_func_return that negative value contains the error code.
491  llvm::BasicBlock* bb_exit_0 =
492  llvm::BasicBlock::Create(ctx, ".exit0", entry_point_func_);
493 
494  llvm::Constant* const_zero =
495  llvm::ConstantInt::get(table_func_return->getType(), 0, true);
496  llvm::Value* is_ok = ir_builder->CreateICmpSGE(table_func_return, const_zero);
497  ir_builder->CreateCondBr(is_ok, bb_exit_0, bb_exit);
498 
499  ir_builder->SetInsertPoint(bb_exit_0);
500  llvm::Value* r =
501  ir_builder->CreateIntCast(table_func_return, get_int_type(64, ctx), true);
502  ir_builder->CreateStore(r, output_row_count_ptr);
503  ir_builder->CreateRet(const_zero);
504 
505  ir_builder->SetInsertPoint(bb_exit);
506  ir_builder->CreateRet(table_func_return);
507 }
508 
510  const TableFunctionExecutionUnit& exe_unit,
511  bool emit_only_preflight_fn) {
512  auto timer = DEBUG_TIMER(__func__);
514  CHECK_EQ(entry_point_func_->arg_size(), 7);
515  auto arg_it = entry_point_func_->arg_begin();
516  const auto mgr_ptr = &*arg_it;
517  const auto input_cols_arg = &*(++arg_it);
518  const auto input_row_counts_arg = &*(++arg_it);
519  const auto input_str_dict_proxies_arg = &*(++arg_it);
520  const auto output_buffers_arg = &*(++arg_it);
521  const auto output_str_dict_proxies_arg = &*(++arg_it);
522  const auto output_row_count_ptr = &*(++arg_it);
523  auto cgen_state = executor_->getCgenStatePtr();
524  CHECK(cgen_state);
525  auto& ctx = cgen_state->context_;
526 
527  llvm::BasicBlock* bb_entry =
528  llvm::BasicBlock::Create(ctx, ".entry", entry_point_func_, 0);
529  cgen_state->ir_builder_.SetInsertPoint(bb_entry);
530 
531  llvm::BasicBlock* bb_exit = llvm::BasicBlock::Create(ctx, ".exit", entry_point_func_);
532 
533  llvm::BasicBlock* func_body_bb = llvm::BasicBlock::Create(
534  ctx, ".func_body0", cgen_state->ir_builder_.GetInsertBlock()->getParent());
535 
536  cgen_state->ir_builder_.SetInsertPoint(bb_entry);
537  cgen_state->ir_builder_.CreateBr(func_body_bb);
538 
539  cgen_state->ir_builder_.SetInsertPoint(func_body_bb);
540  auto col_heads = generate_column_heads_load(
541  exe_unit.input_exprs.size(), input_cols_arg, cgen_state->ir_builder_, ctx);
542  CHECK_EQ(exe_unit.input_exprs.size(), col_heads.size());
543  auto row_count_heads = generate_column_heads_load(
544  exe_unit.input_exprs.size(), input_row_counts_arg, cgen_state->ir_builder_, ctx);
545 
546  auto input_str_dict_proxy_heads = std::vector<llvm::Value*>();
547  if (!emit_only_preflight_fn and co_.device_type == ExecutorDeviceType::CPU) {
548  input_str_dict_proxy_heads = generate_column_heads_load(exe_unit.input_exprs.size(),
549  input_str_dict_proxies_arg,
550  cgen_state->ir_builder_,
551  ctx);
552  }
553  // The column arguments of C++ UDTFs processed by clang must be
554  // passed by reference, see rbc issues 200 and 289.
555  auto pass_column_by_value = passColumnsByValue(exe_unit);
556  std::vector<llvm::Value*> func_args;
557  std::vector<std::pair<llvm::Value*, const SQLTypeInfo>> columns_to_cast;
558  size_t func_arg_index = 0;
559  if (exe_unit.table_func.usesManager()) {
560  func_args.push_back(mgr_ptr);
561  func_arg_index++;
562  }
563  int col_index = -1;
564 
565  for (size_t i = 0; i < exe_unit.input_exprs.size(); i++) {
566  const auto& expr = exe_unit.input_exprs[i];
567  const auto& ti = expr->get_type_info();
568  if (col_index == -1) {
569  func_arg_index += 1;
570  }
571  if (ti.is_fp()) {
572  auto r = cgen_state->ir_builder_.CreateBitCast(
573  col_heads[i], get_fp_ptr_type(get_bit_width(ti), ctx));
574  llvm::LoadInst* scalar_fp = cgen_state->ir_builder_.CreateLoad(
575  r->getType()->getPointerElementType(),
576  r,
577  "input_scalar_fp." + std::to_string(func_arg_index));
578  func_args.push_back(scalar_fp);
579  CHECK_EQ(col_index, -1);
580  } else if (ti.is_integer() || ti.is_boolean() || ti.is_timestamp()) {
581  auto r = cgen_state->ir_builder_.CreateBitCast(
582  col_heads[i], get_int_ptr_type(get_bit_width(ti), ctx));
583  llvm::LoadInst* scalar_int = cgen_state->ir_builder_.CreateLoad(
584  r->getType()->getPointerElementType(),
585  r,
586  "input_scalar_int." + std::to_string(func_arg_index));
587  func_args.push_back(scalar_int);
588  CHECK_EQ(col_index, -1);
589  } else if (ti.is_bytes()) {
590  auto varchar_size =
591  cgen_state->ir_builder_.CreateBitCast(col_heads[i], get_int_ptr_type(64, ctx));
592  auto varchar_ptr = cgen_state->ir_builder_.CreateGEP(
593  col_heads[i]->getType()->getScalarType()->getPointerElementType(),
594  col_heads[i],
595  cgen_state->llInt(8));
596  auto [varchar_struct, varchar_struct_ptr] = alloc_column(
597  std::string("input_varchar_literal.") + std::to_string(func_arg_index),
598  i,
599  ti,
600  varchar_ptr,
601  varchar_size,
602  nullptr,
603  ctx,
604  cgen_state->ir_builder_);
605  func_args.push_back(
606  (pass_column_by_value
607  ? cgen_state->ir_builder_.CreateLoad(
608  varchar_struct->getType()->getPointerElementType(), varchar_struct)
609  : varchar_struct_ptr));
610  CHECK_EQ(col_index, -1);
611  } else if (ti.is_column()) {
612  auto [col, col_ptr] = alloc_column(
613  std::string("input_col.") + std::to_string(func_arg_index),
614  i,
615  ti.get_elem_type(),
616  col_heads[i],
617  row_count_heads[i],
618  (co_.device_type != ExecutorDeviceType::CPU || emit_only_preflight_fn)
619  ? nullptr
620  : input_str_dict_proxy_heads[i],
621  ctx,
622  cgen_state->ir_builder_);
623  func_args.push_back((pass_column_by_value
624  ? cgen_state->ir_builder_.CreateLoad(
625  col->getType()->getPointerElementType(), col)
626  : col_ptr));
627 
628  if (columnTypeRequiresCasting(ti) &&
630  columns_to_cast.push_back(std::make_pair(col_ptr, ti));
631  }
632  CHECK_EQ(col_index, -1);
633  } else if (ti.is_column_list()) {
634  if (col_index == -1) {
635  auto col_list = alloc_column_list(
636  std::string("input_col_list.") + std::to_string(func_arg_index),
637  ti.get_elem_type(),
638  col_heads[i],
639  ti.get_dimension(),
640  row_count_heads[i],
641  (emit_only_preflight_fn) ? nullptr : input_str_dict_proxy_heads[i],
642  ctx,
643  cgen_state->ir_builder_);
644  func_args.push_back(col_list);
645  }
646  col_index++;
647  if (col_index + 1 == ti.get_dimension()) {
648  col_index = -1;
649  }
650  } else {
651  throw std::runtime_error(
652  "Only integer and floating point columns or scalars are supported as inputs to "
653  "table "
654  "functions, got " +
655  ti.get_type_name());
656  }
657  }
658  auto output_str_dict_proxy_heads =
660  ? (generate_column_heads_load(exe_unit.target_exprs.size(),
661  output_str_dict_proxies_arg,
662  cgen_state->ir_builder_,
663  ctx))
664  : std::vector<llvm::Value*>();
665 
666  std::vector<llvm::Value*> output_col_args;
667  for (size_t i = 0; i < exe_unit.target_exprs.size(); i++) {
668  auto* gep = cgen_state->ir_builder_.CreateGEP(
669  output_buffers_arg->getType()->getScalarType()->getPointerElementType(),
670  output_buffers_arg,
671  cgen_state->llInt(i));
672  auto output_load =
673  cgen_state->ir_builder_.CreateLoad(gep->getType()->getPointerElementType(), gep);
674  const auto& expr = exe_unit.target_exprs[i];
675  const auto& ti = expr->get_type_info();
676  CHECK(!ti.is_column()); // UDTF output column type is its data type
677  CHECK(!ti.is_column_list()); // TODO: when UDTF outputs column_list, convert it to
678  // output columns
679  auto [col, col_ptr] = alloc_column(
680  std::string("output_col.") + std::to_string(i),
681  i,
682  ti,
684  ? output_load
685  : nullptr), // CPU: set_output_row_size will set the output
686  // Column ptr member
687  output_row_count_ptr,
688  co_.device_type == ExecutorDeviceType::CPU ? output_str_dict_proxy_heads[i]
689  : nullptr,
690  ctx,
691  cgen_state->ir_builder_);
692  if (co_.device_type == ExecutorDeviceType::CPU && !emit_only_preflight_fn) {
693  cgen_state->emitExternalCall(
694  "TableFunctionManager_register_output_column",
695  llvm::Type::getVoidTy(ctx),
696  {mgr_ptr, llvm::ConstantInt::get(get_int_type(32, ctx), i, true), col_ptr});
697  }
698  output_col_args.push_back((pass_column_by_value ? col : col_ptr));
699  }
700 
701  // output column members must be set before loading column when
702  // column instances are passed by value
703  if ((exe_unit.table_func.hasOutputSizeKnownPreLaunch() ||
704  exe_unit.table_func.hasPreFlightOutputSizer()) &&
705  (co_.device_type == ExecutorDeviceType::CPU) && !emit_only_preflight_fn) {
706  cgen_state->emitExternalCall(
707  "TableFunctionManager_set_output_row_size",
708  llvm::Type::getVoidTy(ctx),
709  {mgr_ptr,
710  cgen_state->ir_builder_.CreateLoad(
711  output_row_count_ptr->getType()->getPointerElementType(),
712  output_row_count_ptr)});
713  }
714 
715  if (!emit_only_preflight_fn) {
716  for (auto& col : output_col_args) {
717  func_args.push_back((pass_column_by_value
718  ? cgen_state->ir_builder_.CreateLoad(
719  col->getType()->getPointerElementType(), col)
720  : col));
721  }
722  }
723 
724  if (exe_unit.table_func.mayRequireCastingInputTypes() && !emit_only_preflight_fn) {
725  generateCastsForInputTypes(exe_unit, columns_to_cast, mgr_ptr);
726  }
727 
729  exe_unit, func_args, bb_exit, output_row_count_ptr, emit_only_preflight_fn);
730 
731  // std::cout << "=================================" << std::endl;
732  // entry_point_func_->print(llvm::outs());
733  // std::cout << "=================================" << std::endl;
734 
736 }
737 
739  const TableFunctionExecutionUnit& exe_unit,
740  const std::vector<std::pair<llvm::Value*, const SQLTypeInfo>>& columns_to_cast,
741  llvm::Value* mgr_ptr) {
742  auto* cgen_state = executor_->getCgenStatePtr();
743  llvm::LLVMContext& ctx = cgen_state->context_;
744  llvm::IRBuilder<>* ir_builder = &cgen_state->ir_builder_;
745  CodeGenerator codeGenerator = CodeGenerator(cgen_state, executor_->getPlanStatePtr());
746  llvm::Function* old_func = cgen_state->current_func_;
747  cgen_state->current_func_ =
748  entry_point_func_; // update cgen_state current func for CodeGenerator
749 
750  for (unsigned i = 0; i < columns_to_cast.size(); ++i) {
751  auto [col_ptr, ti] = columns_to_cast[i];
752 
753  if (ti.is_column() && ti.get_subtype() == kTIMESTAMP && ti.get_precision() != 9) {
754  // TIMESTAMP columns should always have nanosecond precision
755  SQLTypeInfo orig_ti = SQLTypeInfo(
756  ti.get_subtype(), ti.get_precision(), ti.get_dimension(), ti.get_notnull());
757  SQLTypeInfo dest_ti =
758  SQLTypeInfo(kTIMESTAMP, 9, ti.get_dimension(), ti.get_notnull());
759  cast_column(col_ptr,
761  orig_ti,
762  dest_ti,
763  std::to_string(i + 1),
764  *ir_builder,
765  ctx,
766  codeGenerator);
767  }
768  }
769 
770  // The QueryEngine CodeGenerator will codegen return values corresponding to
771  // QueryExecutionError codes. Since at the table function level we'd like error handling
772  // to be done by the TableFunctionManager, we replace the codegen'd returns by calls to
773  // the appropriate Manager functions.
775  // TableFunctionManager is not supported on GPU, so leave the QueryExecutionError code
776  return;
777  }
778 
779  std::vector<llvm::ReturnInst*> rets_to_replace;
780  for (llvm::BasicBlock& BB : *entry_point_func_) {
781  for (llvm::Instruction& I : BB) {
782  if (!llvm::isa<llvm::ReturnInst>(&I)) {
783  continue;
784  }
785  llvm::ReturnInst* RI = llvm::cast<llvm::ReturnInst>(&I);
786  llvm::Value* retValue = RI->getReturnValue();
787  if (!retValue || !llvm::isa<llvm::ConstantInt>(retValue)) {
788  continue;
789  }
790  llvm::ConstantInt* retConst = llvm::cast<llvm::ConstantInt>(retValue);
791  if (retConst->getValue() == 7) {
792  // ret 7 = underflow/overflow during casting attempt
793  rets_to_replace.push_back(RI);
794  }
795  }
796  }
797 
798  auto prev_insert_point = ir_builder->saveIP();
799  for (llvm::ReturnInst* RI : rets_to_replace) {
800  ir_builder->SetInsertPoint(RI);
801  llvm::Value* err_msg = ir_builder->CreateGlobalStringPtr(
802  "Underflow or overflow during casting of input types!", "cast_err_str");
803  llvm::Value* error_call;
804  if (exe_unit.table_func.usesManager()) {
805  error_call = cgen_state->emitExternalCall("TableFunctionManager_error_message",
806  ir_builder->getInt32Ty(),
807  {mgr_ptr, err_msg});
808  } else {
809  error_call = cgen_state->emitExternalCall(
810  "table_function_error", ir_builder->getInt32Ty(), {err_msg});
811  }
812  llvm::ReplaceInstWithInst(RI, llvm::ReturnInst::Create(ctx, error_call));
813  }
814  ir_builder->restoreIP(prev_insert_point);
815 
816  cgen_state->current_func_ = old_func;
817 }
818 
820  auto timer = DEBUG_TIMER(__func__);
822  std::vector<llvm::Type*> arg_types;
823  arg_types.reserve(entry_point_func_->arg_size());
824  std::for_each(entry_point_func_->arg_begin(),
825  entry_point_func_->arg_end(),
826  [&arg_types](const auto& arg) { arg_types.push_back(arg.getType()); });
827  CHECK_EQ(arg_types.size(), entry_point_func_->arg_size());
828 
829  auto cgen_state = executor_->getCgenStatePtr();
830  CHECK(cgen_state);
831  auto& ctx = cgen_state->context_;
832 
833  std::vector<llvm::Type*> wrapper_arg_types(arg_types.size() + 1);
834  wrapper_arg_types[0] = llvm::PointerType::get(get_int_type(32, ctx), 0);
835  wrapper_arg_types[1] = arg_types[0];
836 
837  for (size_t i = 1; i < arg_types.size(); ++i) {
838  wrapper_arg_types[i + 1] = arg_types[i];
839  }
840 
841  auto wrapper_ft =
842  llvm::FunctionType::get(llvm::Type::getVoidTy(ctx), wrapper_arg_types, false);
843  kernel_func_ = llvm::Function::Create(wrapper_ft,
844  llvm::Function::ExternalLinkage,
845  "table_func_kernel",
846  cgen_state->module_);
847 
848  auto wrapper_bb_entry = llvm::BasicBlock::Create(ctx, ".entry", kernel_func_, 0);
849  llvm::IRBuilder<> b(ctx);
850  b.SetInsertPoint(wrapper_bb_entry);
851  std::vector<llvm::Value*> loaded_args = {kernel_func_->arg_begin() + 1};
852  for (size_t i = 2; i < wrapper_arg_types.size(); ++i) {
853  loaded_args.push_back(kernel_func_->arg_begin() + i);
854  }
855  auto error_lv = b.CreateCall(entry_point_func_, loaded_args);
856  b.CreateStore(error_lv, kernel_func_->arg_begin());
857  b.CreateRetVoid();
858 }
859 
860 std::shared_ptr<CompilationContext> TableFunctionCompilationContext::finalize(
861  bool emit_only_preflight_fn) {
862  auto timer = DEBUG_TIMER(__func__);
863  /*
864  TODO 1: eliminate need for OverrideFromSrc
865  TODO 2: detect and link only the udf's that are needed
866  */
867  auto cgen_state = executor_->getCgenStatePtr();
868  auto is_gpu = co_.device_type == ExecutorDeviceType::GPU;
869  if (executor_->has_rt_udf_module(is_gpu)) {
870  CodeGenerator::link_udf_module(executor_->get_rt_udf_module(is_gpu),
871  *(cgen_state->module_),
872  cgen_state,
873  llvm::Linker::Flags::OverrideFromSrc);
874  }
875 
876  LOG(IR) << (emit_only_preflight_fn ? "Pre Flight Function Entry Point IR\n"
877  : "Table Function Entry Point IR\n")
879  std::shared_ptr<CompilationContext> code;
880  if (is_gpu) {
881  LOG(IR) << "Table Function Kernel IR\n" << serialize_llvm_object(kernel_func_);
882 
883  CHECK(executor_);
884  executor_->initializeNVPTXBackend();
885 
886  CodeGenerator::GPUTarget gpu_target{executor_->nvptx_target_machine_.get(),
887  executor_->cudaMgr(),
888  executor_->blockSize(),
889  cgen_state,
890  false};
893  kernel_func_,
895  /*is_gpu_smem_used=*/false,
896  co_,
897  gpu_target);
898  } else {
899  auto ee =
901  auto cpu_code = std::make_shared<CpuCompilationContext>(std::move(ee));
902  cpu_code->setFunctionPointer(entry_point_func_);
903  code = cpu_code;
904  }
905  LOG(IR) << "End of IR";
906 
907  return code;
908 }
llvm::Type * get_fp_ptr_type(const int width, llvm::LLVMContext &context)
HOST DEVICE SQLTypes get_subtype() const
Definition: sqltypes.h:405
#define CHECK_EQ(x, y)
Definition: Logger.h:230
std::string exprsKey(const std::vector< Analyzer::Expr * > &exprs)
HOST DEVICE int get_size() const
Definition: sqltypes.h:414
bool is_timestamp() const
Definition: sqltypes.h:1020
void generateEntryPoint(const TableFunctionExecutionUnit &exe_unit, bool emit_only_preflight_fn)
void initialize_ptr_member(llvm::Value *member_ptr, llvm::Type *member_llvm_type, llvm::Value *value_ptr, llvm::IRBuilder<> &ir_builder)
bool passColumnsByValue(const TableFunctionExecutionUnit &exe_unit)
std::vector< Analyzer::Expr * > input_exprs
const table_functions::TableFunction table_func
#define LOG(tag)
Definition: Logger.h:216
bool is_fp() const
Definition: sqltypes.h:604
std::tuple< llvm::Value *, llvm::Value * > alloc_column(std::string col_name, const size_t index, const SQLTypeInfo &data_target_info, llvm::Value *data_ptr, llvm::Value *data_size, llvm::Value *data_str_dict_proxy_ptr, llvm::LLVMContext &ctx, llvm::IRBuilder<> &ir_builder)
llvm::Function * generate_entry_point(const CgenState *cgen_state)
std::shared_ptr< CompilationContext > finalize(bool emit_only_preflight_fn)
std::vector< std::string > CodeCacheKey
Definition: CodeCache.h:25
void generateCastsForInputTypes(const TableFunctionExecutionUnit &exe_unit, const std::vector< std::pair< llvm::Value *, const SQLTypeInfo >> &columns_to_cast, llvm::Value *mgr_ptr)
std::string toString(const QueryDescriptionType &type)
Definition: Types.h:64
static ExecutionEngineWrapper generateNativeCPUCode(llvm::Function *func, const std::unordered_set< llvm::Function * > &live_funcs, const CompilationOptions &co)
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
llvm::Value * alloc_column_list(std::string col_list_name, const SQLTypeInfo &data_target_info, llvm::Value *data_ptrs, int length, llvm::Value *data_size, llvm::Value *data_str_dict_proxy_ptrs, llvm::LLVMContext &ctx, llvm::IRBuilder<> &ir_builder)
std::string to_string(char const *&&v)
llvm::Module * module_
Definition: CgenState.h:430
void verify_function_ir(const llvm::Function *func)
size_t get_bit_width(const SQLTypeInfo &ti)
llvm::LLVMContext & context_
Definition: CgenState.h:439
bool is_integer() const
Definition: sqltypes.h:602
bool is_boolean() const
Definition: sqltypes.h:607
static void link_udf_module(const std::unique_ptr< llvm::Module > &udf_module, llvm::Module &module, CgenState *cgen_state, llvm::Linker::Flags flags=llvm::Linker::Flags::None)
void generateTableFunctionCall(const TableFunctionExecutionUnit &exe_unit, const std::vector< llvm::Value * > &func_args, llvm::BasicBlock *bb_exit, llvm::Value *output_row_count_ptr, bool emit_only_preflight_fn)
std::string getName(const bool drop_suffix=false, const bool lower=false) const
int get_precision() const
Definition: sqltypes.h:407
ExecutorDeviceType device_type
bool is_column() const
Definition: sqltypes.h:613
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:412
std::string serialize_llvm_object(const T *llvm_obj)
static std::shared_ptr< GpuCompilationContext > generateNativeGPUCode(Executor *executor, llvm::Function *func, llvm::Function *wrapper_func, const std::unordered_set< llvm::Function * > &live_funcs, const bool is_gpu_smem_used, const CompilationOptions &co, const GPUTarget &gpu_target)
std::vector< llvm::Value * > generate_column_heads_load(const int num_columns, llvm::Value *byte_stream_arg, llvm::IRBuilder<> &ir_builder, llvm::LLVMContext &ctx)
std::string get_type_name() const
Definition: sqltypes.h:528
void cast_column(llvm::Value *col_base_ptr, llvm::Function *func, SQLTypeInfo &orig_ti, SQLTypeInfo &dest_ti, std::string index, llvm::IRBuilder<> &ir_builder, llvm::LLVMContext &ctx, CodeGenerator &codeGenerator)
llvm::Value * codegenCastBetweenTimestamps(llvm::Value *ts_lv, const SQLTypeInfo &operand_dimen, const SQLTypeInfo &target_dimen, const bool nullable)
Definition: CastIR.cpp:194
llvm::Type * get_llvm_type_from_sql_column_type(const SQLTypeInfo elem_ti, llvm::LLVMContext &ctx)
bool is_bytes() const
Definition: sqltypes.h:623
#define CHECK(condition)
Definition: Logger.h:222
#define DEBUG_TIMER(name)
Definition: Logger.h:371
void initialize_int64_member(llvm::Value *member_ptr, llvm::Value *value, int64_t default_value, llvm::LLVMContext &ctx, llvm::IRBuilder<> &ir_builder)
static std::shared_ptr< QueryEngine > getInstance()
Definition: QueryEngine.h:81
std::shared_ptr< CompilationContext > compile(const TableFunctionExecutionUnit &exe_unit, bool emit_only_preflight_fn)
std::vector< Analyzer::Expr * > target_exprs
bool is_string() const
Definition: sqltypes.h:600
HOST DEVICE bool get_notnull() const
Definition: sqltypes.h:411
llvm::Type * get_int_ptr_type(const int width, llvm::LLVMContext &context)
bool is_array() const
Definition: sqltypes.h:608
llvm::Value * cast_value(llvm::Value *value, SQLTypeInfo &orig_ti, SQLTypeInfo &dest_ti, bool nullable, CodeGenerator &codeGenerator)