OmniSciDB  1dac507f6e
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ExtensionsIR.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CodeGenerator.h"
18 #include "Execute.h"
19 #include "ExtensionFunctions.hpp"
23 
24 extern std::unique_ptr<llvm::Module> udf_gpu_module;
25 extern std::unique_ptr<llvm::Module> udf_cpu_module;
26 
27 namespace {
28 
30  llvm::LLVMContext& ctx) {
31  switch (ext_arg_type) {
32  case ExtArgumentType::Bool: // pass thru to Int8
34  return get_int_type(8, ctx);
36  return get_int_type(16, ctx);
38  return get_int_type(32, ctx);
40  return get_int_type(64, ctx);
42  return llvm::Type::getFloatTy(ctx);
44  return llvm::Type::getDoubleTy(ctx);
45  default:
46  CHECK(false);
47  }
48  CHECK(false);
49  return nullptr;
50 }
51 
53  CHECK(ll_type);
54  const auto bits = ll_type->getPrimitiveSizeInBits();
55 
56  if (ll_type->isFloatingPointTy()) {
57  switch (bits) {
58  case 32:
59  return SQLTypeInfo(kFLOAT, false);
60  case 64:
61  return SQLTypeInfo(kDOUBLE, false);
62  default:
63  LOG(FATAL) << "Unsupported llvm floating point type: " << bits
64  << ", only 32 and 64 bit floating point is supported.";
65  }
66  } else {
67  switch (bits) {
68  case 1:
69  return SQLTypeInfo(kBOOLEAN, false);
70  case 8:
71  return SQLTypeInfo(kTINYINT, false);
72  case 16:
73  return SQLTypeInfo(kSMALLINT, false);
74  case 32:
75  return SQLTypeInfo(kINT, false);
76  case 64:
77  return SQLTypeInfo(kBIGINT, false);
78  default:
79  LOG(FATAL) << "Unrecognized llvm type for SQL type: "
80  << bits; // TODO let's get the real name here
81  }
82  }
83  UNREACHABLE();
84  return SQLTypeInfo();
85 }
86 
88  for (size_t i = 0; i < function_oper->getArity(); ++i) {
89  const auto arg = function_oper->getArg(i);
90  const auto& arg_ti = arg->get_type_info();
91  if (!arg_ti.get_notnull() && !arg_ti.is_array() && !arg_ti.is_geometry()) {
92  return true;
93  }
94  }
95  return false;
96 }
97 
98 } // namespace
99 
100 #include "../Shared/sql_type_to_string.h"
101 
102 extern "C" void register_buffer_with_executor_rsm(int64_t exec, int8_t* buffer) {
103  Executor* exec_ptr = reinterpret_cast<Executor*>(exec);
104  if (buffer != nullptr) {
105  exec_ptr->getRowSetMemoryOwner()->addVarlenBuffer(buffer);
106  }
107 }
108 
110  const Analyzer::FunctionOper* function_oper,
111  const CompilationOptions& co) {
112  auto ext_func_sig = bind_function(function_oper);
113 
114  const auto& ret_ti = function_oper->get_type_info();
115  CHECK(ret_ti.is_integer() || ret_ti.is_fp() || ret_ti.is_boolean());
116  const auto ret_ty =
117  ext_arg_type_to_llvm_type(ext_func_sig.getRet(), cgen_state_->context_);
118  const auto current_bb = cgen_state_->ir_builder_.GetInsertBlock();
119  for (auto it : cgen_state_->ext_call_cache_) {
120  if (*it.foper == *function_oper) {
121  auto inst = llvm::dyn_cast<llvm::Instruction>(it.lv);
122  if (inst && inst->getParent() == current_bb) {
123  return it.lv;
124  }
125  }
126  }
127  std::vector<llvm::Value*> orig_arg_lvs;
128  std::unordered_map<llvm::Value*, llvm::Value*> const_arr_size;
129  for (size_t i = 0; i < function_oper->getArity(); ++i) {
130  const auto arg = function_oper->getArg(i);
131  const auto arg_cast = dynamic_cast<const Analyzer::UOper*>(arg);
132  const auto arg0 =
133  (arg_cast && arg_cast->get_optype() == kCAST) ? arg_cast->get_operand() : arg;
134  const auto array_expr_arg = dynamic_cast<const Analyzer::ArrayExpr*>(arg0);
135  auto is_local_alloc = (array_expr_arg && array_expr_arg->isLocalAlloc());
136  const auto& arg_ti = arg->get_type_info();
137  const auto arg_lvs = codegen(arg, true, co);
138  // TODO(adb / d): Assuming no const array cols for geo (for now)
139  if (arg_ti.is_geometry()) {
140  CHECK_EQ(static_cast<size_t>(arg_ti.get_physical_coord_cols()), arg_lvs.size());
141  for (size_t i = 0; i < arg_lvs.size(); i++) {
142  orig_arg_lvs.push_back(arg_lvs[i]);
143  }
144  } else {
145  if (arg_lvs.size() > 1) {
146  CHECK(arg_ti.is_array());
147  CHECK_EQ(size_t(2), arg_lvs.size());
148  const_arr_size[arg_lvs.front()] = arg_lvs.back();
149  } else {
150  CHECK_EQ(size_t(1), arg_lvs.size());
151  if (is_local_alloc && arg_ti.get_size() > 0) {
152  const_arr_size[arg_lvs.front()] = cgen_state_->llInt(arg_ti.get_size());
153  }
154  }
155  orig_arg_lvs.push_back(arg_lvs.front());
156  }
157  }
158  // The extension function implementations don't handle NULL, they work under
159  // the assumption that the inputs are validated before calling them. Generate
160  // code to do the check at the call site: if any argument is NULL, return NULL
161  // without calling the function at all.
162  const auto bbs = beginArgsNullcheck(function_oper, orig_arg_lvs);
163  CHECK_GE(orig_arg_lvs.size(), function_oper->getArity());
164  // Arguments must be converted to the types the extension function can handle.
165  const auto args = codegenFunctionOperCastArgs(
166  function_oper, &ext_func_sig, orig_arg_lvs, const_arr_size, co);
167  const auto ext_call =
168  cgen_state_->emitExternalCall(ext_func_sig.getName(), ret_ty, args);
169  auto ext_call_nullcheck = endArgsNullcheck(bbs, ext_call, function_oper);
170 
171  // Cast the return of the extension function to match the FunctionOper
172  const auto extension_ret_ti = get_sql_type_from_llvm_type(ret_ty);
173  if (bbs.args_null_bb &&
174  extension_ret_ti.get_type() != function_oper->get_type_info().get_type()) {
175  ext_call_nullcheck = codegenCast(
176  ext_call_nullcheck, extension_ret_ti, function_oper->get_type_info(), false, co);
177  }
178 
179  cgen_state_->ext_call_cache_.push_back({function_oper, ext_call_nullcheck});
180 
181  return ext_call_nullcheck;
182 }
183 
184 // Start the control flow needed for a call site check of NULL arguments.
186  const Analyzer::FunctionOper* function_oper,
187  const std::vector<llvm::Value*>& orig_arg_lvs) {
188  llvm::BasicBlock* args_null_bb{nullptr};
189  llvm::BasicBlock* args_notnull_bb{nullptr};
190  llvm::BasicBlock* orig_bb = cgen_state_->ir_builder_.GetInsertBlock();
191  // Only generate the check if required (at least one argument must be nullable).
192  if (ext_func_call_requires_nullcheck(function_oper)) {
193  const auto args_notnull_lv = cgen_state_->ir_builder_.CreateNot(
194  codegenFunctionOperNullArg(function_oper, orig_arg_lvs));
195  args_notnull_bb = llvm::BasicBlock::Create(
196  cgen_state_->context_, "args_notnull", cgen_state_->row_func_);
197  args_null_bb = llvm::BasicBlock::Create(
198  cgen_state_->context_, "args_null", cgen_state_->row_func_);
199  cgen_state_->ir_builder_.CreateCondBr(args_notnull_lv, args_notnull_bb, args_null_bb);
200  cgen_state_->ir_builder_.SetInsertPoint(args_notnull_bb);
201  }
202  return {args_null_bb, args_notnull_bb, orig_bb};
203 }
204 
205 // Wrap up the control flow needed for NULL argument handling.
207  const ArgNullcheckBBs& bbs,
208  llvm::Value* fn_ret_lv,
209  const Analyzer::FunctionOper* function_oper) {
210  if (bbs.args_null_bb) {
211  CHECK(bbs.args_notnull_bb);
212  cgen_state_->ir_builder_.CreateBr(bbs.args_null_bb);
213  cgen_state_->ir_builder_.SetInsertPoint(bbs.args_null_bb);
214 
215  // The pre-cast SQL equivalent of the type returned by the extension function.
216  const auto extension_ret_ti = get_sql_type_from_llvm_type(fn_ret_lv->getType());
217 
218  auto ext_call_phi = cgen_state_->ir_builder_.CreatePHI(
219  extension_ret_ti.is_fp()
220  ? get_fp_type(extension_ret_ti.get_size() * 8, cgen_state_->context_)
221  : get_int_type(extension_ret_ti.get_size() * 8, cgen_state_->context_),
222  2);
223 
224  ext_call_phi->addIncoming(fn_ret_lv, bbs.args_notnull_bb);
225 
226  const auto null_lv =
227  extension_ret_ti.is_fp()
228  ? static_cast<llvm::Value*>(cgen_state_->inlineFpNull(extension_ret_ti))
229  : static_cast<llvm::Value*>(cgen_state_->inlineIntNull(extension_ret_ti));
230  ext_call_phi->addIncoming(null_lv, bbs.orig_bb);
231  return ext_call_phi;
232  }
233  return fn_ret_lv;
234 }
235 
236 namespace {
237 
239  const auto& ret_ti = function_oper->get_type_info();
240  if (!ret_ti.is_integer() && !ret_ti.is_fp()) {
241  return true;
242  }
243  for (size_t i = 0; i < function_oper->getArity(); ++i) {
244  const auto arg = function_oper->getArg(i);
245  const auto& arg_ti = arg->get_type_info();
246  if (!arg_ti.is_integer() && !arg_ti.is_fp()) {
247  return true;
248  }
249  }
250  return false;
251 }
252 
253 } // namespace
254 
257  const CompilationOptions& co) {
258  if (call_requires_custom_type_handling(function_oper)) {
259  // Some functions need the return type to be the same as the input type.
260  if (function_oper->getName() == "FLOOR" || function_oper->getName() == "CEIL") {
261  CHECK_EQ(size_t(1), function_oper->getArity());
262  const auto arg = function_oper->getArg(0);
263  const auto& arg_ti = arg->get_type_info();
264  CHECK(arg_ti.is_decimal());
265  const auto arg_lvs = codegen(arg, true, co);
266  CHECK_EQ(size_t(1), arg_lvs.size());
267  const auto arg_lv = arg_lvs.front();
268  CHECK(arg_lv->getType()->isIntegerTy(64));
269  const auto bbs = beginArgsNullcheck(function_oper, {arg_lvs});
270  const std::string func_name =
271  (function_oper->getName() == "FLOOR") ? "decimal_floor" : "decimal_ceil";
272  const auto covar_result_lv = cgen_state_->emitCall(
273  func_name, {arg_lv, cgen_state_->llInt(exp_to_scale(arg_ti.get_scale()))});
274  const auto ret_ti = function_oper->get_type_info();
275  CHECK(ret_ti.is_decimal());
276  CHECK_EQ(0, ret_ti.get_scale());
277  const auto result_lv = cgen_state_->ir_builder_.CreateSDiv(
278  covar_result_lv, cgen_state_->llInt(exp_to_scale(arg_ti.get_scale())));
279  return endArgsNullcheck(bbs, result_lv, function_oper);
280  } else if (function_oper->getName() == "ROUND" &&
281  function_oper->getArg(0)->get_type_info().is_decimal()) {
282  CHECK_EQ(size_t(2), function_oper->getArity());
283 
284  const auto arg0 = function_oper->getArg(0);
285  const auto& arg0_ti = arg0->get_type_info();
286  const auto arg0_lvs = codegen(arg0, true, co);
287  CHECK_EQ(size_t(1), arg0_lvs.size());
288  const auto arg0_lv = arg0_lvs.front();
289  CHECK(arg0_lv->getType()->isIntegerTy(64));
290 
291  const auto arg1 = function_oper->getArg(1);
292  const auto& arg1_ti = arg1->get_type_info();
293  CHECK(arg1_ti.is_integer());
294  const auto arg1_lvs = codegen(arg1, true, co);
295  auto arg1_lv = arg1_lvs.front();
296  if (arg1_ti.get_type() != kINT) {
297  arg1_lv = codegenCast(arg1_lv, arg1_ti, SQLTypeInfo(kINT, true), false, co);
298  }
299 
300  const auto bbs0 = beginArgsNullcheck(function_oper, {arg0_lv, arg1_lvs.front()});
301 
302  const std::string func_name = "Round__4";
303  const auto ret_ti = function_oper->get_type_info();
304  CHECK(ret_ti.is_decimal());
305  const auto result_lv = cgen_state_->emitExternalCall(
306  func_name,
308  {arg0_lv, arg1_lv, cgen_state_->llInt(arg0_ti.get_scale())});
309 
310  return endArgsNullcheck(bbs0, result_lv, function_oper);
311  }
312  throw std::runtime_error("Type combination not supported for function " +
313  function_oper->getName());
314  }
315  return codegenFunctionOper(function_oper, co);
316 }
317 
318 // Generates code which returns true iff at least one of the arguments is NULL.
320  const Analyzer::FunctionOper* function_oper,
321  const std::vector<llvm::Value*>& orig_arg_lvs) {
322  llvm::Value* one_arg_null =
323  llvm::ConstantInt::get(llvm::IntegerType::getInt1Ty(cgen_state_->context_), false);
324  for (size_t i = 0; i < function_oper->getArity(); ++i) {
325  const auto arg = function_oper->getArg(i);
326  const auto& arg_ti = arg->get_type_info();
327  if (arg_ti.get_notnull() || arg_ti.is_array() || arg_ti.is_geometry()) {
328  continue;
329  }
330  CHECK(arg_ti.is_number());
331  one_arg_null = cgen_state_->ir_builder_.CreateOr(
332  one_arg_null, codegenIsNullNumber(orig_arg_lvs[i], arg_ti));
333  }
334  return one_arg_null;
335 }
336 
337 llvm::StructType* CodeGenerator::createArrayStructType(const std::string& udf_func_name,
338  size_t param_num) {
339  llvm::Function* udf_func = cgen_state_->module_->getFunction(udf_func_name);
340  llvm::Module* module_for_lookup = cgen_state_->module_;
341 
342  CHECK(udf_func);
343 
344  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
345  CHECK(param_num < udf_func_type->getNumParams());
346  llvm::Type* param_type = udf_func_type->getParamType(param_num);
347  CHECK(param_type->isPointerTy());
348  llvm::Type* struct_type = param_type->getPointerElementType();
349  CHECK(struct_type->isStructTy());
350  CHECK(struct_type->getStructNumElements() == 3);
351 
352  if (llvm::cast<llvm::StructType>(struct_type)->isLiteral()) {
353  return llvm::cast<llvm::StructType>(struct_type);
354  }
355 
356  llvm::StringRef struct_name = struct_type->getStructName();
357 
358  llvm::StructType* array_type = module_for_lookup->getTypeByName(struct_name);
359  CHECK(array_type);
360 
361  return (array_type);
362 }
363 
364 void CodeGenerator::codegenArrayArgs(const std::string& udf_func_name,
365  size_t param_num,
366  llvm::Value* array_buf,
367  llvm::Value* array_size,
368  llvm::Value* array_null,
369  std::vector<llvm::Value*>& output_args) {
370  CHECK(array_buf);
371  CHECK(array_size);
372  CHECK(array_null);
373 
374  auto array_abstraction = createArrayStructType(udf_func_name, param_num);
375  auto alloc_mem = cgen_state_->ir_builder_.CreateAlloca(array_abstraction, nullptr);
376 
377  auto array_buf_ptr =
378  cgen_state_->ir_builder_.CreateStructGEP(array_abstraction, alloc_mem, 0);
379  cgen_state_->ir_builder_.CreateStore(array_buf, array_buf_ptr);
380 
381  auto array_size_ptr =
382  cgen_state_->ir_builder_.CreateStructGEP(array_abstraction, alloc_mem, 1);
383  cgen_state_->ir_builder_.CreateStore(array_size, array_size_ptr);
384 
385  auto bool_extended_type = llvm::Type::getInt8Ty(cgen_state_->context_);
386  auto array_null_extended =
387  cgen_state_->ir_builder_.CreateZExt(array_null, bool_extended_type);
388  auto array_is_null_ptr =
389  cgen_state_->ir_builder_.CreateStructGEP(array_abstraction, alloc_mem, 2);
390  cgen_state_->ir_builder_.CreateStore(array_null_extended, array_is_null_ptr);
391  output_args.push_back(alloc_mem);
392 }
393 
394 llvm::StructType* CodeGenerator::createPointStructType(const std::string& udf_func_name,
395  size_t param_num) {
396  llvm::Function* udf_func = cgen_state_->module_->getFunction(udf_func_name);
397  llvm::Module* module_for_lookup = cgen_state_->module_;
398 
399  CHECK(udf_func);
400 
401  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
402  CHECK(param_num < udf_func_type->getNumParams());
403  llvm::Type* param_type = udf_func_type->getParamType(param_num);
404  CHECK(param_type->isPointerTy());
405  llvm::Type* struct_type = param_type->getPointerElementType();
406  CHECK(struct_type->isStructTy());
407  CHECK(struct_type->getStructNumElements() == 5);
408 
409  llvm::StringRef struct_name = struct_type->getStructName();
410 
411  llvm::StructType* point_type = module_for_lookup->getTypeByName(struct_name);
412  CHECK(point_type);
413 
414  return (point_type);
415 }
416 
417 void CodeGenerator::codegenGeoPointArgs(const std::string& udf_func_name,
418  size_t param_num,
419  llvm::Value* point_buf,
420  llvm::Value* point_size,
421  llvm::Value* compression,
422  llvm::Value* input_srid,
423  llvm::Value* output_srid,
424  std::vector<llvm::Value*>& output_args) {
425  CHECK(point_buf);
426  CHECK(point_size);
427  CHECK(compression);
428  CHECK(input_srid);
429  CHECK(output_srid);
430 
431  auto point_abstraction = createPointStructType(udf_func_name, param_num);
432  auto alloc_mem = cgen_state_->ir_builder_.CreateAlloca(point_abstraction, nullptr);
433 
434  auto point_buf_ptr =
435  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 0);
436  cgen_state_->ir_builder_.CreateStore(point_buf, point_buf_ptr);
437 
438  auto point_size_ptr =
439  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 1);
440  cgen_state_->ir_builder_.CreateStore(point_size, point_size_ptr);
441 
442  auto point_compression_ptr =
443  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 2);
444  cgen_state_->ir_builder_.CreateStore(compression, point_compression_ptr);
445 
446  auto input_srid_ptr =
447  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 3);
448  cgen_state_->ir_builder_.CreateStore(input_srid, input_srid_ptr);
449 
450  auto output_srid_ptr =
451  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 4);
452  cgen_state_->ir_builder_.CreateStore(output_srid, output_srid_ptr);
453 
454  output_args.push_back(alloc_mem);
455 }
456 
458  const std::string& udf_func_name,
459  size_t param_num) {
460  llvm::Function* udf_func = cgen_state_->module_->getFunction(udf_func_name);
461  llvm::Module* module_for_lookup = cgen_state_->module_;
462 
463  CHECK(udf_func);
464 
465  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
466  CHECK(param_num < udf_func_type->getNumParams());
467  llvm::Type* param_type = udf_func_type->getParamType(param_num);
468  CHECK(param_type->isPointerTy());
469  llvm::Type* struct_type = param_type->getPointerElementType();
470  CHECK(struct_type->isStructTy());
471  CHECK(struct_type->getStructNumElements() == 5);
472 
473  llvm::StringRef struct_name = struct_type->getStructName();
474 
475  llvm::StructType* line_string_type = module_for_lookup->getTypeByName(struct_name);
476  CHECK(line_string_type);
477 
478  return (line_string_type);
479 }
480 
481 void CodeGenerator::codegenGeoLineStringArgs(const std::string& udf_func_name,
482  size_t param_num,
483  llvm::Value* line_string_buf,
484  llvm::Value* line_string_size,
485  llvm::Value* compression,
486  llvm::Value* input_srid,
487  llvm::Value* output_srid,
488  std::vector<llvm::Value*>& output_args) {
489  CHECK(line_string_buf);
490  CHECK(line_string_size);
491  CHECK(compression);
492  CHECK(input_srid);
493  CHECK(output_srid);
494 
495  auto line_string_abstraction = createLineStringStructType(udf_func_name, param_num);
496  auto alloc_mem =
497  cgen_state_->ir_builder_.CreateAlloca(line_string_abstraction, nullptr);
498 
499  auto line_string_buf_ptr =
500  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 0);
501  cgen_state_->ir_builder_.CreateStore(line_string_buf, line_string_buf_ptr);
502 
503  auto line_string_size_ptr =
504  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 1);
505  cgen_state_->ir_builder_.CreateStore(line_string_size, line_string_size_ptr);
506 
507  auto line_string_compression_ptr =
508  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 2);
509  cgen_state_->ir_builder_.CreateStore(compression, line_string_compression_ptr);
510 
511  auto input_srid_ptr =
512  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 3);
513  cgen_state_->ir_builder_.CreateStore(input_srid, input_srid_ptr);
514 
515  auto output_srid_ptr =
516  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 4);
517  cgen_state_->ir_builder_.CreateStore(output_srid, output_srid_ptr);
518 
519  output_args.push_back(alloc_mem);
520 }
521 
522 llvm::StructType* CodeGenerator::createPolygonStructType(const std::string& udf_func_name,
523  size_t param_num) {
524  llvm::Function* udf_func = cgen_state_->module_->getFunction(udf_func_name);
525  llvm::Module* module_for_lookup = cgen_state_->module_;
526 
527  CHECK(udf_func);
528 
529  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
530  CHECK(param_num < udf_func_type->getNumParams());
531  llvm::Type* param_type = udf_func_type->getParamType(param_num);
532  CHECK(param_type->isPointerTy());
533  llvm::Type* struct_type = param_type->getPointerElementType();
534  CHECK(struct_type->isStructTy());
535  CHECK(struct_type->getStructNumElements() == 7);
536 
537  llvm::StringRef struct_name = struct_type->getStructName();
538 
539  llvm::StructType* polygon_type = module_for_lookup->getTypeByName(struct_name);
540  CHECK(polygon_type);
541 
542  return (polygon_type);
543 }
544 
545 void CodeGenerator::codegenGeoPolygonArgs(const std::string& udf_func_name,
546  size_t param_num,
547  llvm::Value* polygon_buf,
548  llvm::Value* polygon_size,
549  llvm::Value* ring_sizes_buf,
550  llvm::Value* num_rings,
551  llvm::Value* compression,
552  llvm::Value* input_srid,
553  llvm::Value* output_srid,
554  std::vector<llvm::Value*>& output_args) {
555  CHECK(polygon_buf);
556  CHECK(polygon_size);
557  CHECK(ring_sizes_buf);
558  CHECK(num_rings);
559  CHECK(compression);
560  CHECK(input_srid);
561  CHECK(output_srid);
562 
563  auto polygon_abstraction = createPolygonStructType(udf_func_name, param_num);
564  auto alloc_mem = cgen_state_->ir_builder_.CreateAlloca(polygon_abstraction, nullptr);
565 
566  auto polygon_buf_ptr =
567  cgen_state_->ir_builder_.CreateStructGEP(polygon_abstraction, alloc_mem, 0);
568  cgen_state_->ir_builder_.CreateStore(polygon_buf, polygon_buf_ptr);
569 
570  auto polygon_size_ptr =
571  cgen_state_->ir_builder_.CreateStructGEP(polygon_abstraction, alloc_mem, 1);
572  cgen_state_->ir_builder_.CreateStore(polygon_size, polygon_size_ptr);
573 
574  auto ring_sizes_buf_ptr =
575  cgen_state_->ir_builder_.CreateStructGEP(polygon_abstraction, alloc_mem, 2);
576  cgen_state_->ir_builder_.CreateStore(ring_sizes_buf, ring_sizes_buf_ptr);
577 
578  auto ring_size_ptr =
579  cgen_state_->ir_builder_.CreateStructGEP(polygon_abstraction, alloc_mem, 3);
580  cgen_state_->ir_builder_.CreateStore(num_rings, ring_size_ptr);
581 
582  auto polygon_compression_ptr =
583  cgen_state_->ir_builder_.CreateStructGEP(polygon_abstraction, alloc_mem, 4);
584  cgen_state_->ir_builder_.CreateStore(compression, polygon_compression_ptr);
585 
586  auto input_srid_ptr =
587  cgen_state_->ir_builder_.CreateStructGEP(polygon_abstraction, alloc_mem, 5);
588  cgen_state_->ir_builder_.CreateStore(input_srid, input_srid_ptr);
589 
590  auto output_srid_ptr =
591  cgen_state_->ir_builder_.CreateStructGEP(polygon_abstraction, alloc_mem, 6);
592  cgen_state_->ir_builder_.CreateStore(output_srid, output_srid_ptr);
593 
594  output_args.push_back(alloc_mem);
595 }
596 
597 // Generate CAST operations for arguments in `orig_arg_lvs` to the types required by
598 // `ext_func_sig`.
600  const Analyzer::FunctionOper* function_oper,
601  const ExtensionFunction* ext_func_sig,
602  const std::vector<llvm::Value*>& orig_arg_lvs,
603  const std::unordered_map<llvm::Value*, llvm::Value*>& const_arr_size,
604  const CompilationOptions& co) {
605  CHECK(ext_func_sig);
606  const auto& ext_func_args = ext_func_sig->getArgs();
607  CHECK_LE(function_oper->getArity(), ext_func_args.size());
608  std::vector<llvm::Value*> args;
609  // i: argument in RA for the function op
610  // j: extra offset in orig_arg_lvs (to account for additional values required for a col,
611  // e.g. array cols) k: origin_arg_lvs counter
612  for (size_t i = 0, j = 0, k = 0; i < function_oper->getArity(); ++i, ++k) {
613  const auto arg = function_oper->getArg(i);
614  const auto& arg_ti = arg->get_type_info();
615  llvm::Value* arg_lv{nullptr};
616  if (arg_ti.is_array()) {
617  bool const_arr = (const_arr_size.count(orig_arg_lvs[k]) > 0);
618  const auto elem_ti = arg_ti.get_elem_type();
619  // TODO: switch to fast fixlen variants
620  const auto ptr_lv = (const_arr)
621  ? orig_arg_lvs[k]
623  "array_buff",
624  llvm::Type::getInt8PtrTy(cgen_state_->context_),
625  {orig_arg_lvs[k], posArg(arg)});
626  const auto len_lv =
627  (const_arr) ? const_arr_size.at(orig_arg_lvs[k])
629  "array_size",
631  {orig_arg_lvs[k],
632  posArg(arg),
633  cgen_state_->llInt(log2_bytes(elem_ti.get_logical_size()))});
634 
635  if (!is_ext_arg_type_array(ext_func_args[i])) {
636  args.push_back(castArrayPointer(ptr_lv, elem_ti));
637  args.push_back(cgen_state_->ir_builder_.CreateZExt(
638  len_lv, get_int_type(64, cgen_state_->context_)));
639  j++;
640  } else {
641  auto array_buf_arg = castArrayPointer(ptr_lv, elem_ti);
642  auto builder = cgen_state_->ir_builder_;
643  auto array_size_arg =
644  builder.CreateZExt(len_lv, get_int_type(64, cgen_state_->context_));
645  auto array_null_arg =
646  cgen_state_->emitExternalCall("array_is_null",
648  {orig_arg_lvs[k], posArg(arg)});
649  codegenArrayArgs(ext_func_sig->getName(),
650  k,
651  array_buf_arg,
652  array_size_arg,
653  array_null_arg,
654  args);
655  }
656 
657  } else if (arg_ti.is_geometry()) {
658  // Coords
659  bool const_arr = (const_arr_size.count(orig_arg_lvs[k]) > 0);
660  // NOTE(adb): We're generating code to handle the TINYINT array only -- the actual
661  // geo encoding (or lack thereof) does not matter here
662  const auto elem_ti = SQLTypeInfo(SQLTypes::kARRAY,
663  0,
664  0,
665  false,
667  0,
669  .get_elem_type();
670  llvm::Value* ptr_lv;
671  llvm::Value* len_lv;
672  int32_t fixlen = -1;
673  if (arg_ti.get_type() == kPOINT) {
674  const auto col_var = dynamic_cast<const Analyzer::ColumnVar*>(arg);
675  if (col_var) {
676  const auto coords_cd = executor()->getPhysicalColumnDescriptor(col_var, 1);
677  if (coords_cd && coords_cd->columnType.get_type() == kARRAY) {
678  fixlen = coords_cd->columnType.get_size();
679  }
680  }
681  }
682  if (fixlen > 0) {
683  ptr_lv =
684  cgen_state_->emitExternalCall("fast_fixlen_array_buff",
685  llvm::Type::getInt8PtrTy(cgen_state_->context_),
686  {orig_arg_lvs[k], posArg(arg)});
687  len_lv = cgen_state_->llInt(int64_t(fixlen));
688  } else {
689  // TODO: remove const_arr and related code if it's not needed
690  ptr_lv = (const_arr) ? orig_arg_lvs[k]
692  "array_buff",
693  llvm::Type::getInt8PtrTy(cgen_state_->context_),
694  {orig_arg_lvs[k], posArg(arg)});
695  len_lv = (const_arr)
696  ? const_arr_size.at(orig_arg_lvs[k])
698  "array_size",
700  {orig_arg_lvs[k],
701  posArg(arg),
702  cgen_state_->llInt(log2_bytes(elem_ti.get_logical_size()))});
703  }
704 
705  if (is_ext_arg_type_geo(ext_func_args[i])) {
706  if (arg_ti.get_type() == kPOINT || arg_ti.get_type() == kLINESTRING) {
707  auto array_buf_arg = castArrayPointer(ptr_lv, elem_ti);
708  auto builder = cgen_state_->ir_builder_;
709  auto array_size_arg =
710  builder.CreateZExt(len_lv, get_int_type(64, cgen_state_->context_));
711  int32_t compression = (arg_ti.get_compression() == kENCODING_GEOINT &&
712  arg_ti.get_comp_param() == 32)
713  ? 1
714  : 0;
715  auto compression_val = cgen_state_->llInt(compression);
716  auto input_srid_val = cgen_state_->llInt(arg_ti.get_input_srid());
717  auto output_srid_val = cgen_state_->llInt(arg_ti.get_output_srid());
718 
719  if (arg_ti.get_type() == kPOINT) {
720  codegenGeoPointArgs(ext_func_sig->getName(),
721  k,
722  array_buf_arg,
723  array_size_arg,
724  compression_val,
725  input_srid_val,
726  output_srid_val,
727  args);
728  } else {
729  codegenGeoLineStringArgs(ext_func_sig->getName(),
730  k,
731  array_buf_arg,
732  array_size_arg,
733  compression_val,
734  input_srid_val,
735  output_srid_val,
736  args);
737  }
738  }
739  } else {
740  args.push_back(castArrayPointer(ptr_lv, elem_ti));
741  args.push_back(cgen_state_->ir_builder_.CreateZExt(
742  len_lv, get_int_type(64, cgen_state_->context_)));
743  j++;
744  }
745 
746  switch (arg_ti.get_type()) {
747  case kPOINT:
748  case kLINESTRING:
749  break;
750  case kPOLYGON: {
751  if (ext_func_args[i] == ExtArgumentType::GeoPolygon) {
752  auto array_buf_arg = castArrayPointer(ptr_lv, elem_ti);
753  auto builder = cgen_state_->ir_builder_;
754  auto array_size_arg =
755  builder.CreateZExt(len_lv, get_int_type(64, cgen_state_->context_));
756  int32_t compression = (arg_ti.get_compression() == kENCODING_GEOINT &&
757  arg_ti.get_comp_param() == 32)
758  ? 1
759  : 0;
760  auto compression_val = cgen_state_->llInt(compression);
761  auto input_srid_val = cgen_state_->llInt(arg_ti.get_input_srid());
762  auto output_srid_val = cgen_state_->llInt(arg_ti.get_output_srid());
763  k++;
764  // Ring Sizes
765  const auto elem_ti = SQLTypeInfo(SQLTypes::kARRAY,
766  0,
767  0,
768  false,
770  0,
772  .get_elem_type();
773  const auto ptr_lv = cgen_state_->emitExternalCall(
774  "array_buff",
775  llvm::Type::getInt32PtrTy(cgen_state_->context_),
776  {orig_arg_lvs[k], posArg(arg)});
777  const auto len_lv = cgen_state_->emitExternalCall(
778  "array_size",
780  {orig_arg_lvs[k],
781  posArg(arg),
782  cgen_state_->llInt(log2_bytes(elem_ti.get_logical_size()))});
783  auto ring_size_buf_arg = castArrayPointer(ptr_lv, elem_ti);
784  auto ring_size_arg =
785  builder.CreateZExt(len_lv, get_int_type(64, cgen_state_->context_));
786 
787  codegenGeoPolygonArgs(ext_func_sig->getName(),
788  k - 1,
789  array_buf_arg,
790  array_size_arg,
791  ring_size_buf_arg,
792  ring_size_arg,
793  compression_val,
794  input_srid_val,
795  output_srid_val,
796  args);
797 
798  } else {
799  k++;
800  // Ring Sizes
801  const auto elem_ti = SQLTypeInfo(SQLTypes::kARRAY,
802  0,
803  0,
804  false,
806  0,
808  .get_elem_type();
809  const auto ptr_lv = cgen_state_->emitExternalCall(
810  "array_buff",
811  llvm::Type::getInt32PtrTy(cgen_state_->context_),
812  {orig_arg_lvs[k], posArg(arg)});
813  const auto len_lv = cgen_state_->emitExternalCall(
814  "array_size",
816  {orig_arg_lvs[k],
817  posArg(arg),
818  cgen_state_->llInt(log2_bytes(elem_ti.get_logical_size()))});
819 
820  args.push_back(castArrayPointer(ptr_lv, elem_ti));
821  args.push_back(cgen_state_->ir_builder_.CreateZExt(
822  len_lv, get_int_type(64, cgen_state_->context_)));
823  j++;
824  }
825  break;
826  }
827  case kMULTIPOLYGON: {
828  k++;
829  // Ring Sizes
830  {
831  const auto elem_ti = SQLTypeInfo(SQLTypes::kARRAY,
832  0,
833  0,
834  false,
836  0,
838  .get_elem_type();
839  const auto ptr_lv = cgen_state_->emitExternalCall(
840  "array_buff",
841  llvm::Type::getInt32PtrTy(cgen_state_->context_),
842  {orig_arg_lvs[k], posArg(arg)});
843  const auto len_lv = cgen_state_->emitExternalCall(
844  "array_size",
845  get_int_type(32, cgen_state_->context_),
846  {orig_arg_lvs[k],
847  posArg(arg),
848  cgen_state_->llInt(log2_bytes(elem_ti.get_logical_size()))});
849  args.push_back(castArrayPointer(ptr_lv, elem_ti));
850  args.push_back(cgen_state_->ir_builder_.CreateZExt(
851  len_lv, get_int_type(64, cgen_state_->context_)));
852  }
853  j++, k++;
854 
855  // Poly Rings
856  {
857  const auto elem_ti = SQLTypeInfo(SQLTypes::kARRAY,
858  0,
859  0,
860  false,
862  0,
864  .get_elem_type();
865  const auto ptr_lv = cgen_state_->emitExternalCall(
866  "array_buff",
867  llvm::Type::getInt32PtrTy(cgen_state_->context_),
868  {orig_arg_lvs[k], posArg(arg)});
869  const auto len_lv = cgen_state_->emitExternalCall(
870  "array_size",
871  get_int_type(32, cgen_state_->context_),
872  {orig_arg_lvs[k],
873  posArg(arg),
874  cgen_state_->llInt(log2_bytes(elem_ti.get_logical_size()))});
875  args.push_back(castArrayPointer(ptr_lv, elem_ti));
876  args.push_back(cgen_state_->ir_builder_.CreateZExt(
877  len_lv, get_int_type(64, cgen_state_->context_)));
878  }
879  j++;
880  break;
881  }
882  default:
883  CHECK(false);
884  }
885  } else {
886  const auto arg_target_ti = ext_arg_type_to_type_info(ext_func_args[k + j]);
887  if (arg_ti.get_type() != arg_target_ti.get_type()) {
888  arg_lv = codegenCast(orig_arg_lvs[k], arg_ti, arg_target_ti, false, co);
889  } else {
890  arg_lv = orig_arg_lvs[k];
891  }
892  CHECK_EQ(arg_lv->getType(),
893  ext_arg_type_to_llvm_type(ext_func_args[k + j], cgen_state_->context_));
894  args.push_back(arg_lv);
895  }
896  }
897  return args;
898 }
899 
900 llvm::Value* CodeGenerator::castArrayPointer(llvm::Value* ptr,
901  const SQLTypeInfo& elem_ti) {
902  if (elem_ti.get_type() == kFLOAT) {
903  return cgen_state_->ir_builder_.CreatePointerCast(
904  ptr, llvm::Type::getFloatPtrTy(cgen_state_->context_));
905  }
906  if (elem_ti.get_type() == kDOUBLE) {
907  return cgen_state_->ir_builder_.CreatePointerCast(
908  ptr, llvm::Type::getDoublePtrTy(cgen_state_->context_));
909  }
910  CHECK(elem_ti.is_integer() || elem_ti.is_boolean() ||
911  (elem_ti.is_string() && elem_ti.get_compression() == kENCODING_DICT));
912  switch (elem_ti.get_size()) {
913  case 1:
914  return cgen_state_->ir_builder_.CreatePointerCast(
915  ptr, llvm::Type::getInt8PtrTy(cgen_state_->context_));
916  case 2:
917  return cgen_state_->ir_builder_.CreatePointerCast(
918  ptr, llvm::Type::getInt16PtrTy(cgen_state_->context_));
919  case 4:
920  return cgen_state_->ir_builder_.CreatePointerCast(
921  ptr, llvm::Type::getInt32PtrTy(cgen_state_->context_));
922  case 8:
923  return cgen_state_->ir_builder_.CreatePointerCast(
924  ptr, llvm::Type::getInt64PtrTy(cgen_state_->context_));
925  default:
926  CHECK(false);
927  }
928  return nullptr;
929 }
llvm::StructType * createLineStringStructType(const std::string &udf_func_name, size_t param_num)
bool is_boolean() const
Definition: sqltypes.h:484
#define CHECK_EQ(x, y)
Definition: Logger.h:198
const std::vector< ExtArgumentType > & getArgs() const
const std::string & getName() const
llvm::BasicBlock * args_notnull_bb
size_t getArity() const
Definition: Analyzer.h:1309
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:334
std::unique_ptr< llvm::Module > udf_gpu_module
bool is_ext_arg_type_geo(const ExtArgumentType ext_arg_type)
CgenState * cgen_state_
void codegenGeoPolygonArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *polygon_buf, llvm::Value *polygon_size, llvm::Value *ring_sizes_buf, llvm::Value *num_rings, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
std::vector< llvm::Value * > codegenFunctionOperCastArgs(const Analyzer::FunctionOper *, const ExtensionFunction *, const std::vector< llvm::Value * > &, const std::unordered_map< llvm::Value *, llvm::Value * > &, const CompilationOptions &)
#define LOG(tag)
Definition: Logger.h:185
llvm::StructType * createArrayStructType(const std::string &udf_func_name, size_t param_num)
ExtensionFunction bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< ExtensionFunction > &ext_funcs)
llvm::Value * codegenFunctionOperNullArg(const Analyzer::FunctionOper *, const std::vector< llvm::Value * > &)
llvm::IRBuilder ir_builder_
Definition: CgenState.h:269
llvm::Value * posArg(const Analyzer::Expr *) const
Definition: ColumnIR.cpp:503
llvm::Value * castArrayPointer(llvm::Value *ptr, const SQLTypeInfo &elem_ti)
#define UNREACHABLE()
Definition: Logger.h:234
HOST DEVICE int get_size() const
Definition: sqltypes.h:336
#define CHECK_GE(x, y)
Definition: Logger.h:203
Definition: sqldefs.h:49
llvm::Type * get_fp_type(const int width, llvm::LLVMContext &context)
llvm::StructType * createPointStructType(const std::string &udf_func_name, size_t param_num)
bool call_requires_custom_type_handling(const Analyzer::FunctionOper *function_oper)
ArgNullcheckBBs beginArgsNullcheck(const Analyzer::FunctionOper *function_oper, const std::vector< llvm::Value * > &orig_arg_lvs)
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
bool ext_func_call_requires_nullcheck(const Analyzer::FunctionOper *function_oper)
SQLTypeInfo get_sql_type_from_llvm_type(const llvm::Type *ll_type)
std::vector< FunctionOperValue > ext_call_cache_
Definition: CgenState.h:275
llvm::Function * row_func_
Definition: CgenState.h:265
void codegenArrayArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *array_buf, llvm::Value *array_size, llvm::Value *array_is_null, std::vector< llvm::Value * > &output_args)
llvm::Module * module_
Definition: CgenState.h:264
llvm::LLVMContext & context_
Definition: CgenState.h:267
CHECK(cgen_state)
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:326
llvm::ConstantInt * inlineIntNull(const SQLTypeInfo &)
Definition: CgenState.cpp:24
llvm::Value * codegenFunctionOper(const Analyzer::FunctionOper *, const CompilationOptions &)
llvm::BasicBlock * args_null_bb
llvm::Type * ext_arg_type_to_llvm_type(const ExtArgumentType ext_arg_type, llvm::LLVMContext &ctx)
SQLTypeInfoCore< ArrayContextTypeSizer, ExecutorTypePackaging, DateTimeFacilities > SQLTypeInfo
Definition: sqltypes.h:852
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:78
llvm::Value * emitCall(const std::string &fname, const std::vector< llvm::Value * > &args)
Definition: CgenState.cpp:134
void codegenGeoPointArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *point_buf, llvm::Value *point_size, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
std::vector< llvm::Value * > codegen(const Analyzer::Expr *, const bool fetch_columns, const CompilationOptions &)
Definition: IRCodegen.cpp:25
llvm::Value * emitExternalCall(const std::string &fname, llvm::Type *ret_type, const std::vector< llvm::Value * > args, const std::vector< llvm::Attribute::AttrKind > &fnattrs={})
Definition: CgenState.h:205
#define CHECK_LE(x, y)
Definition: Logger.h:201
llvm::StructType * createPolygonStructType(const std::string &udf_func_name, size_t param_num)
const Analyzer::Expr * getArg(const size_t i) const
Definition: Analyzer.h:1311
const Expr * get_operand() const
Definition: Analyzer.h:365
bool is_integer() const
Definition: sqltypes.h:479
llvm::Value * endArgsNullcheck(const ArgNullcheckBBs &, llvm::Value *, const Analyzer::FunctionOper *)
std::unique_ptr< llvm::Module > udf_cpu_module
llvm::Value * codegenFunctionOperWithCustomTypeHandling(const Analyzer::FunctionOperWithCustomTypeHandling *, const CompilationOptions &)
llvm::ConstantInt * llInt(const T v) const
Definition: CgenState.h:248
bool is_string() const
Definition: sqltypes.h:477
llvm::Value * codegenIsNullNumber(llvm::Value *, const SQLTypeInfo &)
Definition: LogicalIR.cpp:397
uint64_t exp_to_scale(const unsigned exp)
llvm::Value * codegenCast(const Analyzer::UOper *, const CompilationOptions &)
Definition: CastIR.cpp:20
uint32_t log2_bytes(const uint32_t bytes)
Definition: Execute.h:127
Definition: sqltypes.h:48
SQLTypeInfoCore get_elem_type() const
Definition: sqltypes.h:659
std::string getName() const
Definition: Analyzer.h:1307
bool is_decimal() const
Definition: sqltypes.h:480
void codegenGeoLineStringArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *line_string_buf, llvm::Value *line_string_size, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
void register_buffer_with_executor_rsm(int64_t exec, int8_t *buffer)
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
llvm::ConstantFP * inlineFpNull(const SQLTypeInfo &)
Definition: CgenState.cpp:62
Executor * executor() const