OmniSciDB  a987f07e93
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionsIR.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CodeGenerator.h"
18 #include "Execute.h"
19 #include "ExtensionFunctions.hpp"
22 
23 #include <tuple>
24 
25 extern std::unique_ptr<llvm::Module> udf_gpu_module;
26 extern std::unique_ptr<llvm::Module> udf_cpu_module;
27 
28 namespace {
29 
30 llvm::StructType* get_buffer_struct_type(CgenState* cgen_state,
31  const std::string& ext_func_name,
32  size_t param_num,
33  llvm::Type* elem_type) {
34  CHECK(elem_type);
35  CHECK(elem_type->isPointerTy());
36  llvm::StructType* generated_struct_type =
37  llvm::StructType::get(cgen_state->context_,
38  {elem_type,
39  llvm::Type::getInt64Ty(cgen_state->context_),
40  llvm::Type::getInt8Ty(cgen_state->context_)},
41  false);
42  llvm::Function* udf_func = cgen_state->module_->getFunction(ext_func_name);
43  if (udf_func) {
44  // Compare expected array struct type with type from the function
45  // definition from the UDF module, but use the type from the
46  // module
47  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
48  CHECK_LE(param_num, udf_func_type->getNumParams());
49  llvm::Type* param_pointer_type = udf_func_type->getParamType(param_num);
50  CHECK(param_pointer_type->isPointerTy());
51  llvm::Type* param_type = param_pointer_type->getPointerElementType();
52  CHECK(param_type->isStructTy());
53  llvm::StructType* struct_type = llvm::cast<llvm::StructType>(param_type);
54  CHECK_GE(struct_type->getStructNumElements(),
55  generated_struct_type->getStructNumElements())
56  << serialize_llvm_object(struct_type);
57 
58  const auto expected_elems = generated_struct_type->elements();
59  const auto current_elems = struct_type->elements();
60  for (size_t i = 0; i < expected_elems.size(); i++) {
61  CHECK_EQ(expected_elems[i], current_elems[i])
62  << "[" << ::toString(expected_elems[i]) << ", " << ::toString(current_elems[i])
63  << "]";
64  }
65 
66  if (struct_type->isLiteral()) {
67  return struct_type;
68  }
69 
70  llvm::StringRef struct_name = struct_type->getStructName();
71 #if LLVM_VERSION_MAJOR >= 12
72  return struct_type->getTypeByName(cgen_state->context_, struct_name);
73 #else
74  return cgen_state->module_->getTypeByName(struct_name);
75 #endif
76  }
77  return generated_struct_type;
78 }
79 
81  llvm::LLVMContext& ctx) {
82  switch (ext_arg_type) {
83  case ExtArgumentType::Bool: // pass thru to Int8
85  return get_int_type(8, ctx);
87  return get_int_type(16, ctx);
89  return get_int_type(32, ctx);
91  return get_int_type(64, ctx);
93  return llvm::Type::getFloatTy(ctx);
95  return llvm::Type::getDoubleTy(ctx);
97  return get_int_type(32, ctx);
122  return llvm::Type::getVoidTy(ctx);
123  default:
124  CHECK(false);
125  }
126  CHECK(false);
127  return nullptr;
128 }
129 
131  CHECK(ll_type);
132  const auto bits = ll_type->getPrimitiveSizeInBits();
133 
134  if (ll_type->isFloatingPointTy()) {
135  switch (bits) {
136  case 32:
137  return SQLTypeInfo(kFLOAT, false);
138  case 64:
139  return SQLTypeInfo(kDOUBLE, false);
140  default:
141  LOG(FATAL) << "Unsupported llvm floating point type: " << bits
142  << ", only 32 and 64 bit floating point is supported.";
143  }
144  } else {
145  switch (bits) {
146  case 1:
147  return SQLTypeInfo(kBOOLEAN, false);
148  case 8:
149  return SQLTypeInfo(kTINYINT, false);
150  case 16:
151  return SQLTypeInfo(kSMALLINT, false);
152  case 32:
153  return SQLTypeInfo(kINT, false);
154  case 64:
155  return SQLTypeInfo(kBIGINT, false);
156  default:
157  LOG(FATAL) << "Unrecognized llvm type for SQL type: "
158  << bits; // TODO let's get the real name here
159  }
160  }
161  UNREACHABLE();
162  return SQLTypeInfo();
163 }
164 
166  llvm::LLVMContext& ctx) {
167  CHECK(ti.is_buffer());
168  if (ti.is_bytes()) {
169  return llvm::Type::getInt8PtrTy(ctx);
170  }
171 
172  const auto& elem_ti = ti.get_elem_type();
173  if (elem_ti.is_fp()) {
174  switch (elem_ti.get_size()) {
175  case 4:
176  return llvm::Type::getFloatPtrTy(ctx);
177  case 8:
178  return llvm::Type::getDoublePtrTy(ctx);
179  }
180  }
181 
182  if (elem_ti.is_text_encoding_dict()) {
183  return llvm::Type::getInt32PtrTy(ctx);
184  }
185 
186  if (elem_ti.is_boolean()) {
187  return llvm::Type::getInt8PtrTy(ctx);
188  }
189 
190  CHECK(elem_ti.is_integer());
191  switch (elem_ti.get_size()) {
192  case 1:
193  return llvm::Type::getInt8PtrTy(ctx);
194  case 2:
195  return llvm::Type::getInt16PtrTy(ctx);
196  case 4:
197  return llvm::Type::getInt32PtrTy(ctx);
198  case 8:
199  return llvm::Type::getInt64PtrTy(ctx);
200  }
201 
202  UNREACHABLE();
203  return nullptr;
204 }
205 
207  const auto& func_ti = function_oper->get_type_info();
208  for (size_t i = 0; i < function_oper->getArity(); ++i) {
209  const auto arg = function_oper->getArg(i);
210  const auto& arg_ti = arg->get_type_info();
211  if ((func_ti.is_array() && arg_ti.is_array()) ||
212  (func_ti.is_bytes() && arg_ti.is_bytes()) ||
213  (func_ti.is_text_encoding_dict() && arg_ti.is_text_encoding_dict()) ||
214  (func_ti.is_text_encoding_dict_array() && arg_ti.is_text_encoding_dict())) {
215  // If the function returns an array and any of the arguments are arrays, allow NULL
216  // scalars.
217  // TODO: Make this a property of the FunctionOper following `RETURN NULL ON NULL`
218  // semantics.
219  return false;
220  } else if (!arg_ti.get_notnull() && !arg_ti.is_buffer()) {
221  // Nullable geometry args will trigger a null check
222  return true;
223  } else {
224  continue;
225  }
226  }
227  return false;
228 }
229 
230 } // namespace
231 
233  int8_t* buffer) {
234  Executor* exec_ptr = reinterpret_cast<Executor*>(exec);
235  if (buffer != nullptr) {
236  exec_ptr->getRowSetMemoryOwner()->addVarlenBuffer(buffer);
237  }
238 }
239 
241  const Analyzer::FunctionOper* function_oper,
242  const CompilationOptions& co) {
244  ExtensionFunction ext_func_sig = [=]() {
246  try {
247  return bind_function(function_oper, /* is_gpu= */ true);
248  } catch (ExtensionFunctionBindingError& e) {
249  LOG(WARNING) << "codegenFunctionOper[GPU]: " << e.what() << " Redirecting "
250  << function_oper->getName() << " to run on CPU.";
251  throw QueryMustRunOnCpu();
252  }
253  } else {
254  try {
255  return bind_function(function_oper, /* is_gpu= */ false);
256  } catch (ExtensionFunctionBindingError& e) {
257  LOG(WARNING) << "codegenFunctionOper[CPU]: " << e.what();
258  throw;
259  }
260  }
261  }();
262 
263  const auto& ret_ti = function_oper->get_type_info();
264  CHECK(ret_ti.is_integer() || ret_ti.is_fp() || ret_ti.is_boolean() ||
265  ret_ti.is_buffer() || ret_ti.is_text_encoding_dict());
266  if (ret_ti.is_buffer() && co.device_type == ExecutorDeviceType::GPU) {
267  // TODO: This is not necessary for runtime UDFs because RBC does
268  // not generated GPU LLVM IR when the UDF is using Buffer objects.
269  // However, we cannot remove it until C++ UDFs can be defined for
270  // different devices independently.
271  throw QueryMustRunOnCpu();
272  }
273 
274  auto ret_ty = ext_arg_type_to_llvm_type(ext_func_sig.getRet(), cgen_state_->context_);
275  const auto current_bb = cgen_state_->ir_builder_.GetInsertBlock();
276  for (auto it : cgen_state_->ext_call_cache_) {
277  if (*it.foper == *function_oper) {
278  auto inst = llvm::dyn_cast<llvm::Instruction>(it.lv);
279  if (inst && inst->getParent() == current_bb) {
280  return it.lv;
281  }
282  }
283  }
284  std::vector<llvm::Value*> orig_arg_lvs;
285  std::vector<size_t> orig_arg_lvs_index;
286  std::unordered_map<llvm::Value*, llvm::Value*> const_arr_size;
287 
288  for (size_t i = 0; i < function_oper->getArity(); ++i) {
289  orig_arg_lvs_index.push_back(orig_arg_lvs.size());
290  const auto arg = function_oper->getArg(i);
291  const auto arg_cast = dynamic_cast<const Analyzer::UOper*>(arg);
292  const auto arg0 =
293  (arg_cast && arg_cast->get_optype() == kCAST) ? arg_cast->get_operand() : arg;
294  const auto array_expr_arg = dynamic_cast<const Analyzer::ArrayExpr*>(arg0);
295  auto is_local_alloc = array_expr_arg && array_expr_arg->isLocalAlloc();
296  const auto& arg_ti = arg->get_type_info();
297  const auto arg_lvs = codegen(arg, true, co);
298  auto geo_uoper_arg = dynamic_cast<const Analyzer::GeoUOper*>(arg);
299  auto geo_binoper_arg = dynamic_cast<const Analyzer::GeoBinOper*>(arg);
300  auto geo_expr_arg = dynamic_cast<const Analyzer::GeoExpr*>(arg);
301  // TODO(adb / d): Assuming no const array cols for geo (for now)
302  if ((geo_uoper_arg || geo_binoper_arg) && arg_ti.is_geometry()) {
303  // Extract arr sizes and put them in the map, forward arr pointers
304  CHECK_EQ(2 * static_cast<size_t>(arg_ti.get_physical_coord_cols()), arg_lvs.size());
305  for (size_t i = 0; i < arg_lvs.size(); i++) {
306  auto arr = arg_lvs[i++];
307  auto size = arg_lvs[i];
308  orig_arg_lvs.push_back(arr);
309  const_arr_size[arr] = size;
310  }
311  } else if (geo_expr_arg && geo_expr_arg->get_type_info().is_geometry()) {
312  CHECK(geo_expr_arg->get_type_info().get_type() == kPOINT);
313  CHECK_EQ(arg_lvs.size(), size_t(2));
314  for (size_t j = 0; j < arg_lvs.size(); j++) {
315  orig_arg_lvs.push_back(arg_lvs[j]);
316  }
317  } else if (arg_ti.is_geometry()) {
318  CHECK_EQ(static_cast<size_t>(arg_ti.get_physical_coord_cols()), arg_lvs.size());
319  for (size_t j = 0; j < arg_lvs.size(); j++) {
320  orig_arg_lvs.push_back(arg_lvs[j]);
321  }
322  } else if (arg_ti.is_bytes()) {
323  CHECK_EQ(size_t(3), arg_lvs.size());
324  // arg_lvs contains:
325  // arg_lvs[0] StringView struct { i8*, i64 }
326  // arg_lvs[1] i8* pointer
327  // arg_lvs[2] i32 string length (truncated from i64)
328  for (size_t j = 0; j < arg_lvs.size(); j++) {
329  orig_arg_lvs.push_back(arg_lvs[j]);
330  }
331  } else if (arg_ti.is_text_encoding_dict()) {
332  CHECK_EQ(size_t(1), arg_lvs.size());
333  orig_arg_lvs.push_back(arg_lvs[0]);
334  } else {
335  if (arg_lvs.size() > 1) {
336  CHECK(arg_ti.is_array());
337  CHECK_EQ(size_t(2), arg_lvs.size());
338  const_arr_size[arg_lvs.front()] = arg_lvs.back();
339  } else {
340  CHECK_EQ(size_t(1), arg_lvs.size());
341  /* arg_lvs contains:
342  &col_buf1
343  */
344  if (is_local_alloc && arg_ti.get_size() > 0) {
345  const_arr_size[arg_lvs.front()] = cgen_state_->llInt(arg_ti.get_size());
346  }
347  }
348  orig_arg_lvs.push_back(arg_lvs.front());
349  }
350  }
351  // The extension function implementations don't handle NULL, they work under
352  // the assumption that the inputs are validated before calling them. Generate
353  // code to do the check at the call site: if any argument is NULL, return NULL
354  // without calling the function at all.
355  const auto [bbs, null_buffer_ptr] = beginArgsNullcheck(function_oper, orig_arg_lvs);
356  CHECK_GE(orig_arg_lvs.size(), function_oper->getArity());
357  // Arguments must be converted to the types the extension function can handle.
359  function_oper, &ext_func_sig, orig_arg_lvs, orig_arg_lvs_index, const_arr_size, co);
360 
361  if (ext_func_sig.usesManager()) {
363  throw QueryMustRunOnCpu();
364  }
365  llvm::Value* row_func_mgr = get_arg_by_name(cgen_state_->row_func_, "row_func_mgr");
366  args.insert(args.begin(), row_func_mgr);
367  }
368 
369  llvm::Value* buffer_ret{nullptr};
370  if (ret_ti.is_buffer()) {
371  // codegen buffer return as first arg
372  CHECK(ret_ti.is_array() || ret_ti.is_bytes());
373  ret_ty = llvm::Type::getVoidTy(cgen_state_->context_);
374  const auto struct_ty = get_buffer_struct_type(
375  cgen_state_,
376  function_oper->getName(),
377  0,
379  buffer_ret = cgen_state_->ir_builder_.CreateAlloca(struct_ty);
380  args.insert(args.begin(), buffer_ret);
381  }
382 
383  const auto ext_call = cgen_state_->emitExternalCall(
384  ext_func_sig.getName(), ret_ty, args, {}, ret_ti.is_buffer());
385  auto ext_call_nullcheck = endArgsNullcheck(
386  bbs, ret_ti.is_buffer() ? buffer_ret : ext_call, null_buffer_ptr, function_oper);
387 
388  // Cast the return of the extension function to match the FunctionOper
389  if (!(ret_ti.is_buffer() || ret_ti.is_text_encoding_dict())) {
390  const auto extension_ret_ti = get_sql_type_from_llvm_type(ret_ty);
391  if (bbs.args_null_bb &&
392  extension_ret_ti.get_type() != function_oper->get_type_info().get_type() &&
393  // Skip i1-->i8 casts for ST_ functions.
394  // function_oper ret type is i1, extension ret type is 'upgraded' to i8
395  // during type deserialization to 'handle' NULL returns, hence i1-->i8.
396  // ST_ functions can't return NULLs, we just need to check arg nullness
397  // and if any args are NULL then ST_ function is not called
398  function_oper->getName().substr(0, 3) != std::string("ST_")) {
399  ext_call_nullcheck = codegenCast(ext_call_nullcheck,
400  extension_ret_ti,
401  function_oper->get_type_info(),
402  false,
403  co);
404  }
405  }
406 
407  cgen_state_->ext_call_cache_.push_back({function_oper, ext_call_nullcheck});
408  return ext_call_nullcheck;
409 }
410 
411 // Start the control flow needed for a call site check of NULL arguments.
412 std::tuple<CodeGenerator::ArgNullcheckBBs, llvm::Value*>
414  const std::vector<llvm::Value*>& orig_arg_lvs) {
416  llvm::BasicBlock* args_null_bb{nullptr};
417  llvm::BasicBlock* args_notnull_bb{nullptr};
418  llvm::BasicBlock* orig_bb = cgen_state_->ir_builder_.GetInsertBlock();
419  llvm::Value* null_array_alloca{nullptr};
420  // Only generate the check if required (at least one argument must be nullable).
421  if (ext_func_call_requires_nullcheck(function_oper)) {
422  const auto func_ti = function_oper->get_type_info();
423  if (func_ti.is_buffer()) {
424  const auto arr_struct_ty = get_buffer_struct_type(
425  cgen_state_,
426  function_oper->getName(),
427  0,
429  null_array_alloca = cgen_state_->ir_builder_.CreateAlloca(arr_struct_ty);
430  }
431  const auto args_notnull_lv = cgen_state_->ir_builder_.CreateNot(
432  codegenFunctionOperNullArg(function_oper, orig_arg_lvs));
433  args_notnull_bb = llvm::BasicBlock::Create(
434  cgen_state_->context_, "args_notnull", cgen_state_->current_func_);
435  args_null_bb = llvm::BasicBlock::Create(
437  cgen_state_->ir_builder_.CreateCondBr(args_notnull_lv, args_notnull_bb, args_null_bb);
438  cgen_state_->ir_builder_.SetInsertPoint(args_notnull_bb);
439  }
440  return std::make_tuple(
441  CodeGenerator::ArgNullcheckBBs{args_null_bb, args_notnull_bb, orig_bb},
442  null_array_alloca);
443 }
444 
445 // Wrap up the control flow needed for NULL argument handling.
447  const ArgNullcheckBBs& bbs,
448  llvm::Value* fn_ret_lv,
449  llvm::Value* null_array_ptr,
450  const Analyzer::FunctionOper* function_oper) {
452  if (bbs.args_null_bb) {
453  CHECK(bbs.args_notnull_bb);
454  cgen_state_->ir_builder_.CreateBr(bbs.args_null_bb);
455  cgen_state_->ir_builder_.SetInsertPoint(bbs.args_null_bb);
456 
457  llvm::PHINode* ext_call_phi{nullptr};
458  llvm::Value* null_lv{nullptr};
459  const auto func_ti = function_oper->get_type_info();
460  if (!func_ti.is_buffer()) {
461  // The pre-cast SQL equivalent of the type returned by the extension function.
462  const auto extension_ret_ti = get_sql_type_from_llvm_type(fn_ret_lv->getType());
463 
464  ext_call_phi = cgen_state_->ir_builder_.CreatePHI(
465  extension_ret_ti.is_fp()
466  ? get_fp_type(extension_ret_ti.get_size() * 8, cgen_state_->context_)
467  : get_int_type(extension_ret_ti.get_size() * 8, cgen_state_->context_),
468  2);
469 
470  null_lv =
471  extension_ret_ti.is_fp()
472  ? static_cast<llvm::Value*>(cgen_state_->inlineFpNull(extension_ret_ti))
473  : static_cast<llvm::Value*>(cgen_state_->inlineIntNull(extension_ret_ti));
474  } else {
475  const auto arr_struct_ty = get_buffer_struct_type(
476  cgen_state_,
477  function_oper->getName(),
478  0,
480  ext_call_phi =
481  cgen_state_->ir_builder_.CreatePHI(llvm::PointerType::get(arr_struct_ty, 0), 2);
482 
483  CHECK(null_array_ptr);
484  const auto arr_null_bool =
485  cgen_state_->ir_builder_.CreateStructGEP(arr_struct_ty, null_array_ptr, 2);
486  cgen_state_->ir_builder_.CreateStore(
487  llvm::ConstantInt::get(get_int_type(8, cgen_state_->context_), 1),
488  arr_null_bool);
489 
490  const auto arr_null_size =
491  cgen_state_->ir_builder_.CreateStructGEP(arr_struct_ty, null_array_ptr, 1);
492  cgen_state_->ir_builder_.CreateStore(
493  llvm::ConstantInt::get(get_int_type(64, cgen_state_->context_), 0),
494  arr_null_size);
495  }
496  ext_call_phi->addIncoming(fn_ret_lv, bbs.args_notnull_bb);
497  ext_call_phi->addIncoming(func_ti.is_buffer() ? null_array_ptr : null_lv,
498  bbs.orig_bb);
499 
500  return ext_call_phi;
501  }
502  return fn_ret_lv;
503 }
504 
505 namespace {
506 
508  const auto& ret_ti = function_oper->get_type_info();
509  if (!ret_ti.is_integer() && !ret_ti.is_fp()) {
510  return true;
511  }
512  for (size_t i = 0; i < function_oper->getArity(); ++i) {
513  const auto arg = function_oper->getArg(i);
514  const auto& arg_ti = arg->get_type_info();
515  if (!arg_ti.is_integer() && !arg_ti.is_fp()) {
516  return true;
517  }
518  }
519  return false;
520 }
521 
522 } // namespace
523 
526  const CompilationOptions& co) {
528  if (call_requires_custom_type_handling(function_oper)) {
529  // Some functions need the return type to be the same as the input type.
530  if (function_oper->getName() == "FLOOR" || function_oper->getName() == "CEIL") {
531  CHECK_EQ(size_t(1), function_oper->getArity());
532  const auto arg = function_oper->getArg(0);
533  const auto& arg_ti = arg->get_type_info();
534  CHECK(arg_ti.is_decimal());
535  const auto arg_lvs = codegen(arg, true, co);
536  CHECK_EQ(size_t(1), arg_lvs.size());
537  const auto arg_lv = arg_lvs.front();
538  CHECK(arg_lv->getType()->isIntegerTy(64));
540  std::tie(bbs, std::ignore) = beginArgsNullcheck(function_oper, {arg_lvs});
541  const std::string func_name =
542  (function_oper->getName() == "FLOOR") ? "decimal_floor" : "decimal_ceil";
543  const auto covar_result_lv = cgen_state_->emitCall(
544  func_name, {arg_lv, cgen_state_->llInt(exp_to_scale(arg_ti.get_scale()))});
545  const auto ret_ti = function_oper->get_type_info();
546  CHECK(ret_ti.is_decimal());
547  CHECK_EQ(0, ret_ti.get_scale());
548  const auto result_lv = cgen_state_->ir_builder_.CreateSDiv(
549  covar_result_lv, cgen_state_->llInt(exp_to_scale(arg_ti.get_scale())));
550  return endArgsNullcheck(bbs, result_lv, nullptr, function_oper);
551  } else if (function_oper->getName() == "ROUND" &&
552  function_oper->getArg(0)->get_type_info().is_decimal()) {
553  CHECK_EQ(size_t(2), function_oper->getArity());
554 
555  const auto arg0 = function_oper->getArg(0);
556  const auto& arg0_ti = arg0->get_type_info();
557  const auto arg0_lvs = codegen(arg0, true, co);
558  CHECK_EQ(size_t(1), arg0_lvs.size());
559  const auto arg0_lv = arg0_lvs.front();
560  CHECK(arg0_lv->getType()->isIntegerTy(64));
561 
562  const auto arg1 = function_oper->getArg(1);
563  const auto& arg1_ti = arg1->get_type_info();
564  CHECK(arg1_ti.is_integer());
565  const auto arg1_lvs = codegen(arg1, true, co);
566  auto arg1_lv = arg1_lvs.front();
567  if (arg1_ti.get_type() != kINT) {
568  arg1_lv = codegenCast(arg1_lv, arg1_ti, SQLTypeInfo(kINT, true), false, co);
569  }
570 
572  std::tie(bbs0, std::ignore) =
573  beginArgsNullcheck(function_oper, {arg0_lv, arg1_lvs.front()});
574 
575  const std::string func_name = "Round__4";
576  const auto ret_ti = function_oper->get_type_info();
577  CHECK(ret_ti.is_decimal());
578  const auto result_lv = cgen_state_->emitExternalCall(
579  func_name,
581  {arg0_lv, arg1_lv, cgen_state_->llInt(arg0_ti.get_scale())});
582 
583  return endArgsNullcheck(bbs0, result_lv, nullptr, function_oper);
584  }
585  throw std::runtime_error("Type combination not supported for function " +
586  function_oper->getName());
587  }
588  return codegenFunctionOper(function_oper, co);
589 }
590 
591 // Generates code which returns true iff at least one of the arguments is NULL.
593  const Analyzer::FunctionOper* function_oper,
594  const std::vector<llvm::Value*>& orig_arg_lvs) {
596  llvm::Value* one_arg_null =
597  llvm::ConstantInt::get(llvm::IntegerType::getInt1Ty(cgen_state_->context_), false);
598  size_t physical_coord_cols = 0;
599  for (size_t i = 0, j = 0; i < function_oper->getArity();
600  ++i, j += std::max(size_t(1), physical_coord_cols)) {
601  const auto arg = function_oper->getArg(i);
602  const auto& arg_ti = arg->get_type_info();
603  physical_coord_cols = arg_ti.get_physical_coord_cols();
604  if (arg_ti.get_notnull()) {
605  continue;
606  }
607  auto geo_expr_arg = dynamic_cast<const Analyzer::GeoExpr*>(arg);
608  if (geo_expr_arg && arg_ti.is_geometry()) {
609  CHECK(arg_ti.get_type() == kPOINT);
610  auto is_null_lv = cgen_state_->ir_builder_.CreateICmp(
611  llvm::CmpInst::ICMP_EQ,
612  orig_arg_lvs[j],
613  llvm::ConstantPointerNull::get( // TODO: centralize logic; in geo expr?
614  arg_ti.get_compression() == kENCODING_GEOINT
615  ? llvm::Type::getInt32PtrTy(cgen_state_->context_)
616  : llvm::Type::getDoublePtrTy(cgen_state_->context_)));
617  one_arg_null = cgen_state_->ir_builder_.CreateOr(one_arg_null, is_null_lv);
618  physical_coord_cols = 2; // number of lvs to advance
619  continue;
620  }
621 #ifdef ENABLE_GEOS
622  // If geo arg is coming from geos, skip the null check, assume it's a valid geo
623  if (arg_ti.is_geometry()) {
624  auto* coords_load = llvm::dyn_cast<llvm::LoadInst>(orig_arg_lvs[i]);
625  if (coords_load) {
626  continue;
627  }
628  }
629 #endif
630  if (arg_ti.is_geometry()) {
631  auto* coords_alloca = llvm::dyn_cast<llvm::AllocaInst>(orig_arg_lvs[j]);
632  auto* coords_phi = llvm::dyn_cast<llvm::PHINode>(orig_arg_lvs[j]);
633  if (coords_alloca || coords_phi) {
634  // TODO: null check dynamically generated geometries
635  continue;
636  }
637  }
638  if (arg_ti.is_text_encoding_dict()) {
639  one_arg_null = cgen_state_->ir_builder_.CreateOr(
640  one_arg_null, codegenIsNullNumber(orig_arg_lvs[j], arg_ti));
641  continue;
642  }
643  if (arg_ti.is_buffer() || arg_ti.is_geometry()) {
644  // POINT [un]compressed coord check requires custom checker and chunk iterator
645  // Non-POINT NULL geographies will have a normally encoded null coord array
646  auto fname =
647  (arg_ti.get_type() == kPOINT) ? "point_coord_array_is_null" : "array_is_null";
648  auto is_null_lv = cgen_state_->emitExternalCall(
649  fname, get_int_type(1, cgen_state_->context_), {orig_arg_lvs[j], posArg(arg)});
650  one_arg_null = cgen_state_->ir_builder_.CreateOr(one_arg_null, is_null_lv);
651  continue;
652  }
653  CHECK(arg_ti.is_number() or arg_ti.is_boolean());
654  one_arg_null = cgen_state_->ir_builder_.CreateOr(
655  one_arg_null, codegenIsNullNumber(orig_arg_lvs[j], arg_ti));
656  }
657  return one_arg_null;
658 }
659 
660 llvm::Value* CodeGenerator::codegenCompression(const SQLTypeInfo& type_info) {
662  int32_t compression = (type_info.get_compression() == kENCODING_GEOINT &&
663  type_info.get_comp_param() == 32)
664  ? 1
665  : 0;
666 
667  return cgen_state_->llInt(compression);
668 }
669 
670 std::pair<llvm::Value*, llvm::Value*> CodeGenerator::codegenArrayBuff(
671  llvm::Value* chunk,
672  llvm::Value* row_pos,
673  SQLTypes array_type,
674  bool cast_and_extend) {
676  const auto elem_ti =
677  SQLTypeInfo(
678  SQLTypes::kARRAY, 0, 0, false, EncodingType::kENCODING_NONE, 0, array_type)
679  .get_elem_type();
680 
681  auto buff = cgen_state_->emitExternalCall(
682  "array_buff", llvm::Type::getInt32PtrTy(cgen_state_->context_), {chunk, row_pos});
683 
684  auto len = cgen_state_->emitExternalCall(
685  "array_size",
686  get_int_type(32, cgen_state_->context_),
687  {chunk, row_pos, cgen_state_->llInt(log2_bytes(elem_ti.get_logical_size()))});
688 
689  if (cast_and_extend) {
690  buff = castArrayPointer(buff, elem_ti);
691  len =
692  cgen_state_->ir_builder_.CreateZExt(len, get_int_type(64, cgen_state_->context_));
693  }
694 
695  return std::make_pair(buff, len);
696 }
697 
698 void CodeGenerator::codegenBufferArgs(const std::string& ext_func_name,
699  size_t param_num,
700  llvm::Value* buffer_buf,
701  llvm::Value* buffer_size,
702  llvm::Value* buffer_null,
703  std::vector<llvm::Value*>& output_args) {
705  CHECK(buffer_buf);
706  CHECK(buffer_size);
707 
708  auto buffer_abstraction = get_buffer_struct_type(
709  cgen_state_, ext_func_name, param_num, buffer_buf->getType());
710  auto alloc_mem = cgen_state_->ir_builder_.CreateAlloca(buffer_abstraction);
711 
712  auto buffer_buf_ptr =
713  cgen_state_->ir_builder_.CreateStructGEP(buffer_abstraction, alloc_mem, 0);
714  cgen_state_->ir_builder_.CreateStore(buffer_buf, buffer_buf_ptr);
715 
716  auto buffer_size_ptr =
717  cgen_state_->ir_builder_.CreateStructGEP(buffer_abstraction, alloc_mem, 1);
718  cgen_state_->ir_builder_.CreateStore(buffer_size, buffer_size_ptr);
719 
720  auto bool_extended_type = llvm::Type::getInt8Ty(cgen_state_->context_);
721  auto buffer_null_extended =
722  cgen_state_->ir_builder_.CreateZExt(buffer_null, bool_extended_type);
723  auto buffer_is_null_ptr =
724  cgen_state_->ir_builder_.CreateStructGEP(buffer_abstraction, alloc_mem, 2);
725  cgen_state_->ir_builder_.CreateStore(buffer_null_extended, buffer_is_null_ptr);
726  output_args.push_back(alloc_mem);
727 }
728 
729 llvm::StructType* CodeGenerator::createPointStructType(const std::string& udf_func_name,
730  size_t param_num) {
731  llvm::Module* module_for_lookup = cgen_state_->module_;
732  llvm::Function* udf_func = module_for_lookup->getFunction(udf_func_name);
733 
734  llvm::StructType* generated_struct_type =
735  llvm::StructType::get(cgen_state_->context_,
736  {llvm::Type::getInt8PtrTy(cgen_state_->context_),
737  llvm::Type::getInt32Ty(cgen_state_->context_),
738  llvm::Type::getInt32Ty(cgen_state_->context_),
739  llvm::Type::getInt32Ty(cgen_state_->context_),
740  llvm::Type::getInt32Ty(cgen_state_->context_)},
741  false);
742 
743  if (udf_func) {
744  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
745  CHECK(param_num < udf_func_type->getNumParams());
746  llvm::Type* param_pointer_type = udf_func_type->getParamType(param_num);
747  CHECK(param_pointer_type->isPointerTy());
748  llvm::Type* param_type = param_pointer_type->getPointerElementType();
749  CHECK(param_type->isStructTy());
750  llvm::StructType* struct_type = llvm::cast<llvm::StructType>(param_type);
751  CHECK_EQ(struct_type->getStructNumElements(), 5u)
752  << serialize_llvm_object(struct_type);
753  const auto expected_elems = generated_struct_type->elements();
754  const auto current_elems = struct_type->elements();
755  for (size_t i = 0; i < expected_elems.size(); i++) {
756  CHECK_EQ(expected_elems[i], current_elems[i]);
757  }
758  if (struct_type->isLiteral()) {
759  return struct_type;
760  }
761 
762  llvm::StringRef struct_name = struct_type->getStructName();
763 #if LLVM_VERSION_MAJOR >= 12
764  llvm::StructType* point_type =
765  struct_type->getTypeByName(cgen_state_->context_, struct_name);
766 #else
767  llvm::StructType* point_type = module_for_lookup->getTypeByName(struct_name);
768 #endif
769  CHECK(point_type);
770 
771  return point_type;
772  }
773  return generated_struct_type;
774 }
775 
776 void CodeGenerator::codegenGeoPointArgs(const std::string& udf_func_name,
777  size_t param_num,
778  llvm::Value* point_buf,
779  llvm::Value* point_size,
780  llvm::Value* compression,
781  llvm::Value* input_srid,
782  llvm::Value* output_srid,
783  std::vector<llvm::Value*>& output_args) {
785  CHECK(point_buf);
786  CHECK(point_size);
787  CHECK(compression);
788  CHECK(input_srid);
789  CHECK(output_srid);
790 
791  auto point_abstraction = createPointStructType(udf_func_name, param_num);
792  auto alloc_mem = cgen_state_->ir_builder_.CreateAlloca(point_abstraction, nullptr);
793 
794  auto point_buf_ptr =
795  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 0);
796  cgen_state_->ir_builder_.CreateStore(point_buf, point_buf_ptr);
797 
798  auto point_size_ptr =
799  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 1);
800  cgen_state_->ir_builder_.CreateStore(point_size, point_size_ptr);
801 
802  auto point_compression_ptr =
803  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 2);
804  cgen_state_->ir_builder_.CreateStore(compression, point_compression_ptr);
805 
806  auto input_srid_ptr =
807  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 3);
808  cgen_state_->ir_builder_.CreateStore(input_srid, input_srid_ptr);
809 
810  auto output_srid_ptr =
811  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 4);
812  cgen_state_->ir_builder_.CreateStore(output_srid, output_srid_ptr);
813 
814  output_args.push_back(alloc_mem);
815 }
816 
818  const std::string& udf_func_name,
819  size_t param_num) {
820  llvm::Module* module_for_lookup = cgen_state_->module_;
821  llvm::Function* udf_func = module_for_lookup->getFunction(udf_func_name);
822 
823  llvm::StructType* generated_struct_type =
824  llvm::StructType::get(cgen_state_->context_,
825  {llvm::Type::getInt8PtrTy(cgen_state_->context_),
826  llvm::Type::getInt32Ty(cgen_state_->context_),
827  llvm::Type::getInt32Ty(cgen_state_->context_),
828  llvm::Type::getInt32Ty(cgen_state_->context_),
829  llvm::Type::getInt32Ty(cgen_state_->context_)},
830  false);
831 
832  if (udf_func) {
833  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
834  CHECK(param_num < udf_func_type->getNumParams());
835  llvm::Type* param_pointer_type = udf_func_type->getParamType(param_num);
836  CHECK(param_pointer_type->isPointerTy());
837  llvm::Type* param_type = param_pointer_type->getPointerElementType();
838  CHECK(param_type->isStructTy());
839  llvm::StructType* struct_type = llvm::cast<llvm::StructType>(param_type);
840  CHECK(struct_type->isStructTy());
841  CHECK_EQ(struct_type->getStructNumElements(), 5u);
842 
843  const auto expected_elems = generated_struct_type->elements();
844  const auto current_elems = struct_type->elements();
845  for (size_t i = 0; i < expected_elems.size(); i++) {
846  CHECK_EQ(expected_elems[i], current_elems[i]);
847  }
848  if (struct_type->isLiteral()) {
849  return struct_type;
850  }
851 
852  llvm::StringRef struct_name = struct_type->getStructName();
853 #if LLVM_VERSION_MAJOR >= 12
854  llvm::StructType* multi_point_type =
855  struct_type->getTypeByName(cgen_state_->context_, struct_name);
856 #else
857  llvm::StructType* multi_point_type = module_for_lookup->getTypeByName(struct_name);
858 #endif
859  CHECK(multi_point_type);
860 
861  return multi_point_type;
862  }
863  return generated_struct_type;
864 }
865 
866 void CodeGenerator::codegenGeoMultiPointArgs(const std::string& udf_func_name,
867  size_t param_num,
868  llvm::Value* multi_point_buf,
869  llvm::Value* multi_point_size,
870  llvm::Value* compression,
871  llvm::Value* input_srid,
872  llvm::Value* output_srid,
873  std::vector<llvm::Value*>& output_args) {
875  CHECK(multi_point_buf);
876  CHECK(multi_point_size);
877  CHECK(compression);
878  CHECK(input_srid);
879  CHECK(output_srid);
880 
881  auto multi_point_abstraction = createMultiPointStructType(udf_func_name, param_num);
882  auto alloc_mem =
883  cgen_state_->ir_builder_.CreateAlloca(multi_point_abstraction, nullptr);
884 
885  auto multi_point_buf_ptr =
886  cgen_state_->ir_builder_.CreateStructGEP(multi_point_abstraction, alloc_mem, 0);
887  cgen_state_->ir_builder_.CreateStore(multi_point_buf, multi_point_buf_ptr);
888 
889  auto multi_point_size_ptr =
890  cgen_state_->ir_builder_.CreateStructGEP(multi_point_abstraction, alloc_mem, 1);
891  cgen_state_->ir_builder_.CreateStore(multi_point_size, multi_point_size_ptr);
892 
893  auto compression_ptr =
894  cgen_state_->ir_builder_.CreateStructGEP(multi_point_abstraction, alloc_mem, 2);
895  cgen_state_->ir_builder_.CreateStore(compression, compression_ptr);
896 
897  auto input_srid_ptr =
898  cgen_state_->ir_builder_.CreateStructGEP(multi_point_abstraction, alloc_mem, 3);
899  cgen_state_->ir_builder_.CreateStore(input_srid, input_srid_ptr);
900 
901  auto output_srid_ptr =
902  cgen_state_->ir_builder_.CreateStructGEP(multi_point_abstraction, alloc_mem, 4);
903  cgen_state_->ir_builder_.CreateStore(output_srid, output_srid_ptr);
904 
905  output_args.push_back(alloc_mem);
906 }
907 
909  const std::string& udf_func_name,
910  size_t param_num) {
911  llvm::Module* module_for_lookup = cgen_state_->module_;
912  llvm::Function* udf_func = module_for_lookup->getFunction(udf_func_name);
913 
914  llvm::StructType* generated_struct_type =
915  llvm::StructType::get(cgen_state_->context_,
916  {llvm::Type::getInt8PtrTy(cgen_state_->context_),
917  llvm::Type::getInt32Ty(cgen_state_->context_),
918  llvm::Type::getInt32Ty(cgen_state_->context_),
919  llvm::Type::getInt32Ty(cgen_state_->context_),
920  llvm::Type::getInt32Ty(cgen_state_->context_)},
921  false);
922 
923  if (udf_func) {
924  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
925  CHECK(param_num < udf_func_type->getNumParams());
926  llvm::Type* param_pointer_type = udf_func_type->getParamType(param_num);
927  CHECK(param_pointer_type->isPointerTy());
928  llvm::Type* param_type = param_pointer_type->getPointerElementType();
929  CHECK(param_type->isStructTy());
930  llvm::StructType* struct_type = llvm::cast<llvm::StructType>(param_type);
931  CHECK(struct_type->isStructTy());
932  CHECK_EQ(struct_type->getStructNumElements(), 5u);
933 
934  const auto expected_elems = generated_struct_type->elements();
935  const auto current_elems = struct_type->elements();
936  for (size_t i = 0; i < expected_elems.size(); i++) {
937  CHECK_EQ(expected_elems[i], current_elems[i]);
938  }
939  if (struct_type->isLiteral()) {
940  return struct_type;
941  }
942 
943  llvm::StringRef struct_name = struct_type->getStructName();
944 #if LLVM_VERSION_MAJOR >= 12
945  llvm::StructType* line_string_type =
946  struct_type->getTypeByName(cgen_state_->context_, struct_name);
947 #else
948  llvm::StructType* line_string_type = module_for_lookup->getTypeByName(struct_name);
949 #endif
950  CHECK(line_string_type);
951 
952  return line_string_type;
953  }
954  return generated_struct_type;
955 }
956 
957 void CodeGenerator::codegenGeoLineStringArgs(const std::string& udf_func_name,
958  size_t param_num,
959  llvm::Value* line_string_buf,
960  llvm::Value* line_string_size,
961  llvm::Value* compression,
962  llvm::Value* input_srid,
963  llvm::Value* output_srid,
964  std::vector<llvm::Value*>& output_args) {
966  CHECK(line_string_buf);
967  CHECK(line_string_size);
968  CHECK(compression);
969  CHECK(input_srid);
970  CHECK(output_srid);
971 
972  auto line_string_abstraction = createLineStringStructType(udf_func_name, param_num);
973  auto alloc_mem =
974  cgen_state_->ir_builder_.CreateAlloca(line_string_abstraction, nullptr);
975 
976  auto line_string_buf_ptr =
977  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 0);
978  cgen_state_->ir_builder_.CreateStore(line_string_buf, line_string_buf_ptr);
979 
980  auto line_string_size_ptr =
981  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 1);
982  cgen_state_->ir_builder_.CreateStore(line_string_size, line_string_size_ptr);
983 
984  auto line_string_compression_ptr =
985  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 2);
986  cgen_state_->ir_builder_.CreateStore(compression, line_string_compression_ptr);
987 
988  auto input_srid_ptr =
989  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 3);
990  cgen_state_->ir_builder_.CreateStore(input_srid, input_srid_ptr);
991 
992  auto output_srid_ptr =
993  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 4);
994  cgen_state_->ir_builder_.CreateStore(output_srid, output_srid_ptr);
995 
996  output_args.push_back(alloc_mem);
997 }
998 
1000  const std::string& udf_func_name,
1001  size_t param_num) {
1002  llvm::Module* module_for_lookup = cgen_state_->module_;
1003  llvm::Function* udf_func = module_for_lookup->getFunction(udf_func_name);
1004 
1005  llvm::StructType* generated_struct_type =
1006  llvm::StructType::get(cgen_state_->context_,
1007  {llvm::Type::getInt8PtrTy(cgen_state_->context_),
1008  llvm::Type::getInt32Ty(cgen_state_->context_),
1009  llvm::Type::getInt8PtrTy(cgen_state_->context_),
1010  llvm::Type::getInt32Ty(cgen_state_->context_),
1011  llvm::Type::getInt32Ty(cgen_state_->context_),
1012  llvm::Type::getInt32Ty(cgen_state_->context_),
1013  llvm::Type::getInt32Ty(cgen_state_->context_)},
1014  false);
1015 
1016  if (udf_func) {
1017  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
1018  CHECK(param_num < udf_func_type->getNumParams());
1019  llvm::Type* param_pointer_type = udf_func_type->getParamType(param_num);
1020  CHECK(param_pointer_type->isPointerTy());
1021  llvm::Type* param_type = param_pointer_type->getPointerElementType();
1022  CHECK(param_type->isStructTy());
1023  llvm::StructType* struct_type = llvm::cast<llvm::StructType>(param_type);
1024  CHECK(struct_type->isStructTy());
1025  CHECK_EQ(struct_type->getStructNumElements(), 7u);
1026 
1027  const auto expected_elems = generated_struct_type->elements();
1028  const auto current_elems = struct_type->elements();
1029  for (size_t i = 0; i < expected_elems.size(); i++) {
1030  CHECK_EQ(expected_elems[i], current_elems[i]);
1031  }
1032  if (struct_type->isLiteral()) {
1033  return struct_type;
1034  }
1035 
1036  llvm::StringRef struct_name = struct_type->getStructName();
1037 #if LLVM_VERSION_MAJOR >= 12
1038  llvm::StructType* multi_linestring_type =
1039  struct_type->getTypeByName(cgen_state_->context_, struct_name);
1040 #else
1041  llvm::StructType* multi_linestring_type =
1042  module_for_lookup->getTypeByName(struct_name);
1043 #endif
1044  CHECK(multi_linestring_type);
1045 
1046  return multi_linestring_type;
1047  }
1048  return generated_struct_type;
1049 }
1050 
1052  const std::string& udf_func_name,
1053  size_t param_num,
1054  llvm::Value* multi_linestring_coords,
1055  llvm::Value* multi_linestring_coords_size,
1056  llvm::Value* linestring_sizes,
1057  llvm::Value* linestring_sizes_size,
1058  llvm::Value* compression,
1059  llvm::Value* input_srid,
1060  llvm::Value* output_srid,
1061  std::vector<llvm::Value*>& output_args) {
1063  CHECK(multi_linestring_coords);
1064  CHECK(multi_linestring_coords_size);
1065  CHECK(linestring_sizes);
1066  CHECK(linestring_sizes_size);
1067  CHECK(compression);
1068  CHECK(input_srid);
1069  CHECK(output_srid);
1070 
1071  auto multi_linestring_abstraction =
1072  createMultiLineStringStructType(udf_func_name, param_num);
1073  auto alloc_mem =
1074  cgen_state_->ir_builder_.CreateAlloca(multi_linestring_abstraction, nullptr);
1075 
1076  auto multi_linestring_coords_ptr = cgen_state_->ir_builder_.CreateStructGEP(
1077  multi_linestring_abstraction, alloc_mem, 0);
1078  cgen_state_->ir_builder_.CreateStore(multi_linestring_coords,
1079  multi_linestring_coords_ptr);
1080 
1081  auto multi_linestring_coords_size_ptr = cgen_state_->ir_builder_.CreateStructGEP(
1082  multi_linestring_abstraction, alloc_mem, 1);
1083  cgen_state_->ir_builder_.CreateStore(multi_linestring_coords_size,
1084  multi_linestring_coords_size_ptr);
1085 
1086  auto linestring_sizes_ptr = cgen_state_->ir_builder_.CreateStructGEP(
1087  multi_linestring_abstraction, alloc_mem, 2);
1088  const auto linestring_sizes_ptr_ty =
1089  llvm::dyn_cast<llvm::PointerType>(linestring_sizes_ptr->getType());
1090  CHECK(linestring_sizes_ptr_ty);
1091  cgen_state_->ir_builder_.CreateStore(
1092  cgen_state_->ir_builder_.CreateBitCast(
1093  linestring_sizes, linestring_sizes_ptr_ty->getPointerElementType()),
1094  linestring_sizes_ptr);
1095 
1096  auto linestring_sizes_size_ptr = cgen_state_->ir_builder_.CreateStructGEP(
1097  multi_linestring_abstraction, alloc_mem, 3);
1098  cgen_state_->ir_builder_.CreateStore(linestring_sizes_size, linestring_sizes_size_ptr);
1099 
1100  auto multi_linestring_compression_ptr = cgen_state_->ir_builder_.CreateStructGEP(
1101  multi_linestring_abstraction, alloc_mem, 4);
1102  cgen_state_->ir_builder_.CreateStore(compression, multi_linestring_compression_ptr);
1103 
1104  auto input_srid_ptr = cgen_state_->ir_builder_.CreateStructGEP(
1105  multi_linestring_abstraction, alloc_mem, 5);
1106  cgen_state_->ir_builder_.CreateStore(input_srid, input_srid_ptr);
1107 
1108  auto output_srid_ptr = cgen_state_->ir_builder_.CreateStructGEP(
1109  multi_linestring_abstraction, alloc_mem, 6);
1110  cgen_state_->ir_builder_.CreateStore(output_srid, output_srid_ptr);
1111 
1112  output_args.push_back(alloc_mem);
1113 }
1114 
1115 llvm::StructType* CodeGenerator::createPolygonStructType(const std::string& udf_func_name,
1116  size_t param_num) {
1117  llvm::Module* module_for_lookup = cgen_state_->module_;
1118  llvm::Function* udf_func = module_for_lookup->getFunction(udf_func_name);
1119 
1120  llvm::StructType* generated_struct_type =
1121  llvm::StructType::get(cgen_state_->context_,
1122  {llvm::Type::getInt8PtrTy(cgen_state_->context_),
1123  llvm::Type::getInt32Ty(cgen_state_->context_),
1124  llvm::Type::getInt8PtrTy(cgen_state_->context_),
1125  llvm::Type::getInt32Ty(cgen_state_->context_),
1126  llvm::Type::getInt32Ty(cgen_state_->context_),
1127  llvm::Type::getInt32Ty(cgen_state_->context_),
1128  llvm::Type::getInt32Ty(cgen_state_->context_)},
1129  false);
1130 
1131  if (udf_func) {
1132  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
1133  CHECK(param_num < udf_func_type->getNumParams());
1134  llvm::Type* param_pointer_type = udf_func_type->getParamType(param_num);
1135  CHECK(param_pointer_type->isPointerTy());
1136  llvm::Type* param_type = param_pointer_type->getPointerElementType();
1137  CHECK(param_type->isStructTy());
1138  llvm::StructType* struct_type = llvm::cast<llvm::StructType>(param_type);
1139 
1140  CHECK(struct_type->isStructTy());
1141  CHECK_EQ(struct_type->getStructNumElements(), 7u);
1142 
1143  const auto expected_elems = generated_struct_type->elements();
1144  const auto current_elems = struct_type->elements();
1145  for (size_t i = 0; i < expected_elems.size(); i++) {
1146  CHECK_EQ(expected_elems[i], current_elems[i]);
1147  }
1148  if (struct_type->isLiteral()) {
1149  return struct_type;
1150  }
1151 
1152  llvm::StringRef struct_name = struct_type->getStructName();
1153 
1154 #if LLVM_VERSION_MAJOR >= 12
1155  llvm::StructType* polygon_type =
1156  struct_type->getTypeByName(cgen_state_->context_, struct_name);
1157 #else
1158  llvm::StructType* polygon_type = module_for_lookup->getTypeByName(struct_name);
1159 #endif
1160  CHECK(polygon_type);
1161 
1162  return polygon_type;
1163  }
1164  return generated_struct_type;
1165 }
1166 
1167 void CodeGenerator::codegenGeoPolygonArgs(const std::string& udf_func_name,
1168  size_t param_num,
1169  llvm::Value* polygon_buf,
1170  llvm::Value* polygon_size,
1171  llvm::Value* ring_sizes_buf,
1172  llvm::Value* num_rings,
1173  llvm::Value* compression,
1174  llvm::Value* input_srid,
1175  llvm::Value* output_srid,
1176  std::vector<llvm::Value*>& output_args) {
1178  CHECK(polygon_buf);
1179  CHECK(polygon_size);
1180  CHECK(ring_sizes_buf);
1181  CHECK(num_rings);
1182  CHECK(compression);
1183  CHECK(input_srid);
1184  CHECK(output_srid);
1185 
1186  auto& builder = cgen_state_->ir_builder_;
1187 
1188  auto polygon_abstraction = createPolygonStructType(udf_func_name, param_num);
1189  auto alloc_mem = builder.CreateAlloca(polygon_abstraction, nullptr);
1190 
1191  const auto polygon_buf_ptr = builder.CreateStructGEP(polygon_abstraction, alloc_mem, 0);
1192  builder.CreateStore(polygon_buf, polygon_buf_ptr);
1193 
1194  const auto polygon_size_ptr =
1195  builder.CreateStructGEP(polygon_abstraction, alloc_mem, 1);
1196  builder.CreateStore(polygon_size, polygon_size_ptr);
1197 
1198  const auto ring_sizes_buf_ptr =
1199  builder.CreateStructGEP(polygon_abstraction, alloc_mem, 2);
1200  const auto ring_sizes_ptr_ty =
1201  llvm::dyn_cast<llvm::PointerType>(ring_sizes_buf_ptr->getType());
1202  CHECK(ring_sizes_ptr_ty);
1203  builder.CreateStore(
1204  builder.CreateBitCast(ring_sizes_buf, ring_sizes_ptr_ty->getPointerElementType()),
1205  ring_sizes_buf_ptr);
1206 
1207  const auto ring_size_ptr = builder.CreateStructGEP(polygon_abstraction, alloc_mem, 3);
1208  builder.CreateStore(num_rings, ring_size_ptr);
1209 
1210  const auto polygon_compression_ptr =
1211  builder.CreateStructGEP(polygon_abstraction, alloc_mem, 4);
1212  builder.CreateStore(compression, polygon_compression_ptr);
1213 
1214  const auto input_srid_ptr = builder.CreateStructGEP(polygon_abstraction, alloc_mem, 5);
1215  builder.CreateStore(input_srid, input_srid_ptr);
1216 
1217  const auto output_srid_ptr = builder.CreateStructGEP(polygon_abstraction, alloc_mem, 6);
1218  builder.CreateStore(output_srid, output_srid_ptr);
1219 
1220  output_args.push_back(alloc_mem);
1221 }
1222 
1224  const std::string& udf_func_name,
1225  size_t param_num) {
1226  llvm::Function* udf_func = cgen_state_->module_->getFunction(udf_func_name);
1227 
1228  llvm::StructType* generated_struct_type =
1229  llvm::StructType::get(cgen_state_->context_,
1230  {llvm::Type::getInt8PtrTy(cgen_state_->context_),
1231  llvm::Type::getInt32Ty(cgen_state_->context_),
1232  llvm::Type::getInt8PtrTy(cgen_state_->context_),
1233  llvm::Type::getInt32Ty(cgen_state_->context_),
1234  llvm::Type::getInt8PtrTy(cgen_state_->context_),
1235  llvm::Type::getInt32Ty(cgen_state_->context_),
1236  llvm::Type::getInt32Ty(cgen_state_->context_),
1237  llvm::Type::getInt32Ty(cgen_state_->context_),
1238  llvm::Type::getInt32Ty(cgen_state_->context_)},
1239  false);
1240 
1241  if (udf_func) {
1242  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
1243  CHECK(param_num < udf_func_type->getNumParams());
1244  llvm::Type* param_pointer_type = udf_func_type->getParamType(param_num);
1245  CHECK(param_pointer_type->isPointerTy());
1246  llvm::Type* param_type = param_pointer_type->getPointerElementType();
1247  CHECK(param_type->isStructTy());
1248  llvm::StructType* struct_type = llvm::cast<llvm::StructType>(param_type);
1249  CHECK(struct_type->isStructTy());
1250  CHECK_EQ(struct_type->getStructNumElements(), 9u);
1251  const auto expected_elems = generated_struct_type->elements();
1252  const auto current_elems = struct_type->elements();
1253  for (size_t i = 0; i < expected_elems.size(); i++) {
1254  CHECK_EQ(expected_elems[i], current_elems[i]);
1255  }
1256  if (struct_type->isLiteral()) {
1257  return struct_type;
1258  }
1259  llvm::StringRef struct_name = struct_type->getStructName();
1260 
1261 #if LLVM_VERSION_MAJOR >= 12
1262  llvm::StructType* polygon_type =
1263  struct_type->getTypeByName(cgen_state_->context_, struct_name);
1264 #else
1265  llvm::StructType* polygon_type = cgen_state_->module_->getTypeByName(struct_name);
1266 #endif
1267  CHECK(polygon_type);
1268 
1269  return polygon_type;
1270  }
1271  return generated_struct_type;
1272 }
1273 
1274 void CodeGenerator::codegenGeoMultiPolygonArgs(const std::string& udf_func_name,
1275  size_t param_num,
1276  llvm::Value* polygon_coords,
1277  llvm::Value* polygon_coords_size,
1278  llvm::Value* ring_sizes_buf,
1279  llvm::Value* ring_sizes,
1280  llvm::Value* polygon_bounds,
1281  llvm::Value* polygon_bounds_sizes,
1282  llvm::Value* compression,
1283  llvm::Value* input_srid,
1284  llvm::Value* output_srid,
1285  std::vector<llvm::Value*>& output_args) {
1287  CHECK(polygon_coords);
1288  CHECK(polygon_coords_size);
1289  CHECK(ring_sizes_buf);
1290  CHECK(ring_sizes);
1291  CHECK(polygon_bounds);
1292  CHECK(polygon_bounds_sizes);
1293  CHECK(compression);
1294  CHECK(input_srid);
1295  CHECK(output_srid);
1296 
1297  auto& builder = cgen_state_->ir_builder_;
1298 
1299  auto multi_polygon_abstraction = createMultiPolygonStructType(udf_func_name, param_num);
1300  auto alloc_mem = builder.CreateAlloca(multi_polygon_abstraction, nullptr);
1301 
1302  const auto polygon_coords_ptr =
1303  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 0);
1304  builder.CreateStore(polygon_coords, polygon_coords_ptr);
1305 
1306  const auto polygon_coords_size_ptr =
1307  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 1);
1308  builder.CreateStore(polygon_coords_size, polygon_coords_size_ptr);
1309 
1310  const auto ring_sizes_buf_ptr =
1311  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 2);
1312  const auto ring_sizes_ptr_ty =
1313  llvm::dyn_cast<llvm::PointerType>(ring_sizes_buf_ptr->getType());
1314  CHECK(ring_sizes_ptr_ty);
1315  builder.CreateStore(
1316  builder.CreateBitCast(ring_sizes_buf, ring_sizes_ptr_ty->getPointerElementType()),
1317  ring_sizes_buf_ptr);
1318 
1319  const auto ring_sizes_ptr =
1320  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 3);
1321  builder.CreateStore(ring_sizes, ring_sizes_ptr);
1322 
1323  const auto polygon_bounds_buf_ptr =
1324  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 4);
1325  const auto bounds_ptr_ty =
1326  llvm::dyn_cast<llvm::PointerType>(polygon_bounds_buf_ptr->getType());
1327  CHECK(bounds_ptr_ty);
1328  builder.CreateStore(
1329  builder.CreateBitCast(polygon_bounds, bounds_ptr_ty->getPointerElementType()),
1330  polygon_bounds_buf_ptr);
1331 
1332  const auto polygon_bounds_sizes_ptr =
1333  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 5);
1334  builder.CreateStore(polygon_bounds_sizes, polygon_bounds_sizes_ptr);
1335 
1336  const auto polygon_compression_ptr =
1337  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 6);
1338  builder.CreateStore(compression, polygon_compression_ptr);
1339 
1340  const auto input_srid_ptr =
1341  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 7);
1342  builder.CreateStore(input_srid, input_srid_ptr);
1343 
1344  const auto output_srid_ptr =
1345  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 8);
1346  builder.CreateStore(output_srid, output_srid_ptr);
1347 
1348  output_args.push_back(alloc_mem);
1349 }
1350 
1351 // Generate CAST operations for arguments in `orig_arg_lvs` to the types required by
1352 // `ext_func_sig`.
1354  const Analyzer::FunctionOper* function_oper,
1355  const ExtensionFunction* ext_func_sig,
1356  const std::vector<llvm::Value*>& orig_arg_lvs,
1357  const std::vector<size_t>& orig_arg_lvs_index,
1358  const std::unordered_map<llvm::Value*, llvm::Value*>& const_arr_size,
1359  const CompilationOptions& co) {
1361  CHECK(ext_func_sig);
1362  const auto& ext_func_args = ext_func_sig->getInputArgs();
1363  CHECK_LE(function_oper->getArity(), ext_func_args.size());
1364  const auto func_ti = function_oper->get_type_info();
1365  std::vector<llvm::Value*> args;
1366  /*
1367  i: argument in RA for the function operand
1368  j: extra offset in ext_func_args
1369  k: origin_arg_lvs counter, equal to orig_arg_lvs_index[i]
1370  ij: ext_func_args counter, equal to i + j
1371  dj: offset when UDF implementation first argument corresponds to return value
1372  */
1373  for (size_t i = 0, j = 0, dj = (func_ti.is_buffer() ? 1 : 0);
1374  i < function_oper->getArity();
1375  ++i) {
1376  size_t k = orig_arg_lvs_index[i];
1377  size_t ij = i + j;
1378  const auto arg = function_oper->getArg(i);
1379  const auto ext_func_arg = ext_func_args[ij];
1380  const auto& arg_ti = arg->get_type_info();
1381  llvm::Value* arg_lv{nullptr};
1382  if (arg_ti.is_bytes()) {
1383  CHECK(ext_func_arg == ExtArgumentType::TextEncodingNone)
1384  << ::toString(ext_func_arg);
1385  const auto ptr_lv = orig_arg_lvs[k + 1];
1386  const auto len_lv = orig_arg_lvs[k + 2];
1387  auto& builder = cgen_state_->ir_builder_;
1388  auto string_buf_arg = builder.CreatePointerCast(
1389  ptr_lv, llvm::Type::getInt8PtrTy(cgen_state_->context_));
1390  auto string_size_arg =
1391  builder.CreateZExt(len_lv, get_int_type(64, cgen_state_->context_));
1392  auto padding = ll_int<int8_t>(0, cgen_state_->context_);
1393  codegenBufferArgs(ext_func_sig->getName(),
1394  ij + dj,
1395  string_buf_arg,
1396  string_size_arg,
1397  padding,
1398  args);
1399  } else if (arg_ti.is_text_encoding_dict()) {
1400  CHECK(ext_func_arg == ExtArgumentType::TextEncodingDict)
1401  << ::toString(ext_func_arg);
1402  arg_lv = orig_arg_lvs[k];
1403  args.push_back(arg_lv);
1404  } else if (arg_ti.is_array()) {
1405  bool const_arr = (const_arr_size.count(orig_arg_lvs[k]) > 0);
1406  const auto elem_ti = arg_ti.get_elem_type();
1407  // TODO: switch to fast fixlen variants
1408  const auto ptr_lv = (const_arr)
1409  ? orig_arg_lvs[k]
1411  "array_buff",
1412  llvm::Type::getInt8PtrTy(cgen_state_->context_),
1413  {orig_arg_lvs[k], posArg(arg)});
1414  const auto len_lv =
1415  (const_arr) ? const_arr_size.at(orig_arg_lvs[k])
1417  "array_size",
1419  {orig_arg_lvs[k],
1420  posArg(arg),
1421  cgen_state_->llInt(log2_bytes(elem_ti.get_logical_size()))});
1422 
1423  if (is_ext_arg_type_pointer(ext_func_arg)) {
1424  args.push_back(castArrayPointer(ptr_lv, elem_ti));
1425  args.push_back(cgen_state_->ir_builder_.CreateZExt(
1426  len_lv, get_int_type(64, cgen_state_->context_)));
1427  j++;
1428  } else if (is_ext_arg_type_array(ext_func_arg)) {
1429  auto array_buf_arg = castArrayPointer(ptr_lv, elem_ti);
1430  auto& builder = cgen_state_->ir_builder_;
1431  auto array_size_arg =
1432  builder.CreateZExt(len_lv, get_int_type(64, cgen_state_->context_));
1433  llvm::Value* array_null_arg = nullptr;
1434  if (auto gep = llvm::dyn_cast<llvm::GetElementPtrInst>(ptr_lv)) {
1435  CHECK(gep->getSourceElementType()->isArrayTy());
1436  // gep has the form
1437  // %17 = getelementptr [9 x i32], [9 x i32]* %7, i32 0
1438  // and was created by passing a const array to the UDF function:
1439  // select array_append({11, 22, 33}, 4);
1440  array_null_arg = ll_bool(false, cgen_state_->context_);
1441  } else {
1442  array_null_arg =
1443  cgen_state_->emitExternalCall("array_is_null",
1445  {orig_arg_lvs[k], posArg(arg)});
1446  }
1447  codegenBufferArgs(ext_func_sig->getName(),
1448  ij + dj,
1449  array_buf_arg,
1450  array_size_arg,
1451  array_null_arg,
1452  args);
1453  } else {
1454  UNREACHABLE();
1455  }
1456 
1457  } else if (arg_ti.is_geometry()) {
1458  auto geo_expr_arg = dynamic_cast<const Analyzer::GeoExpr*>(arg);
1459  if (geo_expr_arg) {
1460  auto ptr_lv = cgen_state_->ir_builder_.CreateBitCast(
1461  orig_arg_lvs[k], llvm::Type::getInt8PtrTy(cgen_state_->context_));
1462  args.push_back(ptr_lv);
1463  // TODO: remove when we normalize extension functions geo sizes to int32
1464  auto size_lv = cgen_state_->ir_builder_.CreateSExt(
1465  orig_arg_lvs[k + 1], llvm::Type::getInt64Ty(cgen_state_->context_));
1466  args.push_back(size_lv);
1467  j++;
1468  continue;
1469  }
1470  // Coords
1471  bool const_arr = (const_arr_size.count(orig_arg_lvs[k]) > 0);
1472  // NOTE(adb): We're generating code to handle the TINYINT array only -- the actual
1473  // geo encoding (or lack thereof) does not matter here
1474  const auto elem_ti = SQLTypeInfo(SQLTypes::kARRAY,
1475  0,
1476  0,
1477  false,
1479  0,
1481  .get_elem_type();
1482  llvm::Value* ptr_lv;
1483  llvm::Value* len_lv;
1484  int32_t fixlen = -1;
1485  if (arg_ti.get_type() == kPOINT) {
1486  const auto col_var = dynamic_cast<const Analyzer::ColumnVar*>(arg);
1487  if (col_var) {
1488  const auto coords_cd = executor()->getPhysicalColumnDescriptor(col_var, 1);
1489  if (coords_cd && coords_cd->columnType.get_type() == kARRAY) {
1490  fixlen = coords_cd->columnType.get_size();
1491  }
1492  }
1493  }
1494  if (fixlen > 0) {
1495  ptr_lv =
1496  cgen_state_->emitExternalCall("fast_fixlen_array_buff",
1497  llvm::Type::getInt8PtrTy(cgen_state_->context_),
1498  {orig_arg_lvs[k], posArg(arg)});
1499  len_lv = cgen_state_->llInt(int32_t(fixlen));
1500  } else {
1501  // TODO: remove const_arr and related code if it's not needed
1502  ptr_lv = (const_arr) ? orig_arg_lvs[k]
1504  "array_buff",
1505  llvm::Type::getInt8PtrTy(cgen_state_->context_),
1506  {orig_arg_lvs[k], posArg(arg)});
1507  len_lv = (const_arr)
1508  ? const_arr_size.at(orig_arg_lvs[k])
1510  "array_size",
1512  {orig_arg_lvs[k],
1513  posArg(arg),
1514  cgen_state_->llInt(log2_bytes(elem_ti.get_logical_size()))});
1515  }
1516 
1517  if (is_ext_arg_type_geo(ext_func_arg)) {
1518  if (arg_ti.get_type() == kPOINT || arg_ti.get_type() == kLINESTRING ||
1519  arg_ti.get_type() == kMULTIPOINT) {
1520  auto array_buf_arg = castArrayPointer(ptr_lv, elem_ti);
1521  auto compression_val = codegenCompression(arg_ti);
1522  auto input_srid_val = cgen_state_->llInt(arg_ti.get_input_srid());
1523  auto output_srid_val = cgen_state_->llInt(arg_ti.get_output_srid());
1524 
1525  if (arg_ti.get_type() == kPOINT) {
1526  CHECK_EQ(k, ij);
1527  codegenGeoPointArgs(ext_func_sig->getName(),
1528  ij + dj,
1529  array_buf_arg,
1530  len_lv,
1531  compression_val,
1532  input_srid_val,
1533  output_srid_val,
1534  args);
1535  } else if (arg_ti.get_type() == kMULTIPOINT) {
1536  CHECK_EQ(k, ij);
1537  codegenGeoMultiPointArgs(ext_func_sig->getName(),
1538  ij + dj,
1539  array_buf_arg,
1540  len_lv,
1541  compression_val,
1542  input_srid_val,
1543  output_srid_val,
1544  args);
1545  } else {
1546  CHECK_EQ(k, ij);
1547  codegenGeoLineStringArgs(ext_func_sig->getName(),
1548  ij + dj,
1549  array_buf_arg,
1550  len_lv,
1551  compression_val,
1552  input_srid_val,
1553  output_srid_val,
1554  args);
1555  }
1556  }
1557  } else {
1558  CHECK(ext_func_arg == ExtArgumentType::PInt8);
1559  args.push_back(castArrayPointer(ptr_lv, elem_ti));
1560  args.push_back(cgen_state_->ir_builder_.CreateZExt(
1561  len_lv, get_int_type(64, cgen_state_->context_)));
1562  j++;
1563  }
1564 
1565  switch (arg_ti.get_type()) {
1566  case kPOINT:
1567  case kMULTIPOINT:
1568  case kLINESTRING:
1569  break;
1570  case kMULTILINESTRING: {
1571  if (ext_func_arg == ExtArgumentType::GeoMultiLineString) {
1572  auto multi_linestring_coords = castArrayPointer(ptr_lv, elem_ti);
1573  auto compression_val = codegenCompression(arg_ti);
1574  auto input_srid_val = cgen_state_->llInt(arg_ti.get_input_srid());
1575  auto output_srid_val = cgen_state_->llInt(arg_ti.get_output_srid());
1576 
1577  auto [linestring_sizes, linestring_sizes_size] =
1578  codegenArrayBuff(orig_arg_lvs[k + 1],
1579  posArg(arg),
1581  /*cast_and_extend=*/false);
1582  CHECK_EQ(k, ij);
1583  codegenGeoMultiLineStringArgs(ext_func_sig->getName(),
1584  ij + dj,
1585  multi_linestring_coords,
1586  len_lv,
1587  linestring_sizes,
1588  linestring_sizes_size,
1589  compression_val,
1590  input_srid_val,
1591  output_srid_val,
1592  args);
1593  } else {
1594  CHECK(ext_func_arg == ExtArgumentType::PInt8);
1595  // Linestring Sizes
1596  auto const_arr = const_arr_size.count(orig_arg_lvs[k + 1]) > 0;
1597  auto [linestring_sizes, linestring_sizes_size] =
1598  (const_arr) ? std::make_pair(orig_arg_lvs[k + 1],
1599  const_arr_size.at(orig_arg_lvs[k + 1]))
1600  : codegenArrayBuff(orig_arg_lvs[k + 1],
1601  posArg(arg),
1603  /*cast_and_extend=*/true);
1604  args.push_back(linestring_sizes);
1605  args.push_back(linestring_sizes_size);
1606  j += 2;
1607  }
1608  break;
1609  }
1610  case kPOLYGON: {
1611  if (ext_func_arg == ExtArgumentType::GeoPolygon) {
1612  auto array_buf_arg = castArrayPointer(ptr_lv, elem_ti);
1613  auto compression_val = codegenCompression(arg_ti);
1614  auto input_srid_val = cgen_state_->llInt(arg_ti.get_input_srid());
1615  auto output_srid_val = cgen_state_->llInt(arg_ti.get_output_srid());
1616 
1617  auto [ring_size_buff, ring_size] =
1618  codegenArrayBuff(orig_arg_lvs[k + 1],
1619  posArg(arg),
1621  /*cast_and_extend=*/false);
1622  CHECK_EQ(k, ij);
1623  codegenGeoPolygonArgs(ext_func_sig->getName(),
1624  ij + dj,
1625  array_buf_arg,
1626  len_lv,
1627  ring_size_buff,
1628  ring_size,
1629  compression_val,
1630  input_srid_val,
1631  output_srid_val,
1632  args);
1633  } else {
1634  CHECK(ext_func_arg == ExtArgumentType::PInt8);
1635  // Ring Sizes
1636  auto const_arr = const_arr_size.count(orig_arg_lvs[k + 1]) > 0;
1637  auto [ring_size_buff, ring_size] =
1638  (const_arr) ? std::make_pair(orig_arg_lvs[k + 1],
1639  const_arr_size.at(orig_arg_lvs[k + 1]))
1640  : codegenArrayBuff(orig_arg_lvs[k + 1],
1641  posArg(arg),
1643  /*cast_and_extend=*/true);
1644  args.push_back(ring_size_buff);
1645  args.push_back(ring_size);
1646  j += 2;
1647  }
1648  break;
1649  }
1650  case kMULTIPOLYGON: {
1651  if (ext_func_arg == ExtArgumentType::GeoMultiPolygon) {
1652  auto array_buf_arg = castArrayPointer(ptr_lv, elem_ti);
1653  auto compression_val = codegenCompression(arg_ti);
1654  auto input_srid_val = cgen_state_->llInt(arg_ti.get_input_srid());
1655  auto output_srid_val = cgen_state_->llInt(arg_ti.get_output_srid());
1656 
1657  auto [ring_size_buff, ring_size] =
1658  codegenArrayBuff(orig_arg_lvs[k + 1],
1659  posArg(arg),
1661  /*cast_and_extend=*/false);
1662 
1663  auto [poly_bounds_buff, poly_bounds_size] =
1664  codegenArrayBuff(orig_arg_lvs[k + 2],
1665  posArg(arg),
1667  /*cast_and_extend=*/false);
1668  CHECK_EQ(k, ij);
1669  codegenGeoMultiPolygonArgs(ext_func_sig->getName(),
1670  ij + dj,
1671  array_buf_arg,
1672  len_lv,
1673  ring_size_buff,
1674  ring_size,
1675  poly_bounds_buff,
1676  poly_bounds_size,
1677  compression_val,
1678  input_srid_val,
1679  output_srid_val,
1680  args);
1681  } else {
1682  CHECK(ext_func_arg == ExtArgumentType::PInt8);
1683  // Ring Sizes
1684  {
1685  auto const_arr = const_arr_size.count(orig_arg_lvs[k + 1]) > 0;
1686  auto [ring_size_buff, ring_size] =
1687  (const_arr) ? std::make_pair(orig_arg_lvs[k + 1],
1688  const_arr_size.at(orig_arg_lvs[k + 1]))
1689  : codegenArrayBuff(orig_arg_lvs[k + 1],
1690  posArg(arg),
1692  /*cast_and_extend=*/true);
1693 
1694  args.push_back(ring_size_buff);
1695  args.push_back(ring_size);
1696  }
1697  // Poly Rings
1698  {
1699  auto const_arr = const_arr_size.count(orig_arg_lvs[k + 2]) > 0;
1700  auto [poly_bounds_buff, poly_bounds_size] =
1701  (const_arr)
1702  ? std::make_pair(orig_arg_lvs[k + 2],
1703  const_arr_size.at(orig_arg_lvs[k + 2]))
1704  : codegenArrayBuff(
1705  orig_arg_lvs[k + 2], posArg(arg), SQLTypes::kINT, true);
1706 
1707  args.push_back(poly_bounds_buff);
1708  args.push_back(poly_bounds_size);
1709  }
1710  j += 4;
1711  }
1712  break;
1713  }
1714  default:
1715  CHECK(false);
1716  }
1717  } else {
1718  CHECK(is_ext_arg_type_scalar(ext_func_arg));
1719  const auto arg_target_ti = ext_arg_type_to_type_info(ext_func_arg);
1720  if (arg_ti.get_type() != arg_target_ti.get_type()) {
1721  arg_lv = codegenCast(orig_arg_lvs[k], arg_ti, arg_target_ti, false, co);
1722  } else {
1723  arg_lv = orig_arg_lvs[k];
1724  }
1725  CHECK_EQ(arg_lv->getType(),
1726  ext_arg_type_to_llvm_type(ext_func_arg, cgen_state_->context_));
1727  args.push_back(arg_lv);
1728  }
1729  }
1730  return args;
1731 }
1732 
1733 llvm::Value* CodeGenerator::castArrayPointer(llvm::Value* ptr,
1734  const SQLTypeInfo& elem_ti) {
1736  if (elem_ti.get_type() == kFLOAT) {
1737  return cgen_state_->ir_builder_.CreatePointerCast(
1738  ptr, llvm::Type::getFloatPtrTy(cgen_state_->context_));
1739  }
1740  if (elem_ti.get_type() == kDOUBLE) {
1741  return cgen_state_->ir_builder_.CreatePointerCast(
1742  ptr, llvm::Type::getDoublePtrTy(cgen_state_->context_));
1743  }
1744  CHECK(elem_ti.is_integer() || elem_ti.is_boolean() ||
1745  (elem_ti.is_string() && elem_ti.get_compression() == kENCODING_DICT));
1746  switch (elem_ti.get_size()) {
1747  case 1:
1748  return cgen_state_->ir_builder_.CreatePointerCast(
1749  ptr, llvm::Type::getInt8PtrTy(cgen_state_->context_));
1750  case 2:
1751  return cgen_state_->ir_builder_.CreatePointerCast(
1752  ptr, llvm::Type::getInt16PtrTy(cgen_state_->context_));
1753  case 4:
1754  return cgen_state_->ir_builder_.CreatePointerCast(
1755  ptr, llvm::Type::getInt32PtrTy(cgen_state_->context_));
1756  case 8:
1757  return cgen_state_->ir_builder_.CreatePointerCast(
1758  ptr, llvm::Type::getInt64PtrTy(cgen_state_->context_));
1759  default:
1760  CHECK(false);
1761  }
1762  return nullptr;
1763 }
1764 
1765 // Reflects struct StringView defined in Shared/Datum.h
1767  auto* const string_view_type =
1768  llvm::StructType::get(cgen_state_->context_,
1769  {llvm::Type::getInt8PtrTy(cgen_state_->context_),
1770  llvm::Type::getInt64Ty(cgen_state_->context_)});
1771  string_view_type->setName("StringView");
1772  return string_view_type;
1773 }
llvm::StructType * createLineStringStructType(const std::string &udf_func_name, size_t param_num)
void codegenGeoMultiPolygonArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *polygon_coords, llvm::Value *polygon_coords_size, llvm::Value *ring_sizes_buf, llvm::Value *ring_sizes, llvm::Value *polygon_bounds, llvm::Value *polygon_bounds_sizes, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
#define CHECK_EQ(x, y)
Definition: Logger.h:297
HOST DEVICE int get_size() const
Definition: sqltypes.h:390
bool is_ext_arg_type_scalar(const ExtArgumentType ext_arg_type)
llvm::BasicBlock * args_notnull_bb
size_t getArity() const
Definition: Analyzer.h:2404
SQLTypes
Definition: sqltypes.h:53
std::unique_ptr< llvm::Module > udf_gpu_module
CgenState * cgen_state_
const ExtArgumentType getRet() const
void codegenGeoPolygonArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *polygon_buf, llvm::Value *polygon_size, llvm::Value *ring_sizes_buf, llvm::Value *num_rings, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
llvm::StructType * createMultiPointStructType(const std::string &udf_func_name, size_t param_num)
#define LOG(tag)
Definition: Logger.h:283
std::vector< llvm::Value * > codegenFunctionOperCastArgs(const Analyzer::FunctionOper *, const ExtensionFunction *, const std::vector< llvm::Value * > &, const std::vector< size_t > &, const std::unordered_map< llvm::Value *, llvm::Value * > &, const CompilationOptions &)
llvm::StructType * createMultiLineStringStructType(const std::string &udf_func_name, size_t param_num)
llvm::Value * codegenFunctionOperNullArg(const Analyzer::FunctionOper *, const std::vector< llvm::Value * > &)
llvm::IRBuilder ir_builder_
Definition: CgenState.h:375
llvm::Value * posArg(const Analyzer::Expr *) const
Definition: ColumnIR.cpp:582
llvm::Value * castArrayPointer(llvm::Value *ptr, const SQLTypeInfo &elem_ti)
#define UNREACHABLE()
Definition: Logger.h:333
#define CHECK_GE(x, y)
Definition: Logger.h:302
Definition: sqldefs.h:48
llvm::Type * get_fp_type(const int width, llvm::LLVMContext &context)
llvm::StructType * createPointStructType(const std::string &udf_func_name, size_t param_num)
bool call_requires_custom_type_handling(const Analyzer::FunctionOper *function_oper)
const std::string getName(bool keep_suffix=true) const
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:380
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
bool ext_func_call_requires_nullcheck(const Analyzer::FunctionOper *function_oper)
SQLTypeInfo get_sql_type_from_llvm_type(const llvm::Type *ll_type)
llvm::StructType * get_buffer_struct_type(CgenState *cgen_state, const std::string &ext_func_name, size_t param_num, llvm::Type *elem_type)
std::vector< FunctionOperValue > ext_call_cache_
Definition: CgenState.h:381
void codegenBufferArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *buffer_buf, llvm::Value *buffer_size, llvm::Value *buffer_is_null, std::vector< llvm::Value * > &output_args)
llvm::Function * row_func_
Definition: CgenState.h:365
RUNTIME_EXPORT void register_buffer_with_executor_rsm(int64_t exec, int8_t *buffer)
std::pair< llvm::Value *, llvm::Value * > codegenArrayBuff(llvm::Value *chunk, llvm::Value *row_pos, SQLTypes array_type, bool cast_and_extend)
llvm::Module * module_
Definition: CgenState.h:364
Supported runtime functions management and retrieval.
llvm::LLVMContext & context_
Definition: CgenState.h:373
llvm::Function * current_func_
Definition: CgenState.h:367
std::tuple< ArgNullcheckBBs, llvm::Value * > beginArgsNullcheck(const Analyzer::FunctionOper *function_oper, const std::vector< llvm::Value * > &orig_arg_lvs)
llvm::Value * emitExternalCall(const std::string &fname, llvm::Type *ret_type, const std::vector< llvm::Value * > args, const std::vector< llvm::Attribute::AttrKind > &fnattrs={}, const bool has_struct_return=false)
Definition: CgenState.cpp:396
llvm::Value * get_arg_by_name(llvm::Function *func, const std::string &name)
Definition: Execute.h:166
bool is_integer() const
Definition: sqltypes.h:578
llvm::ConstantInt * inlineIntNull(const SQLTypeInfo &)
Definition: CgenState.cpp:64
bool is_ext_arg_type_geo(const ExtArgumentType ext_arg_type)
void codegenGeoMultiPointArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *multi_point_buf, llvm::Value *multi_point_size, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
llvm::Value * codegenFunctionOper(const Analyzer::FunctionOper *, const CompilationOptions &)
llvm::Type * get_llvm_type_from_sql_array_type(const SQLTypeInfo ti, llvm::LLVMContext &ctx)
std::string toString(const ExecutorDeviceType &device_type)
bool is_boolean() const
Definition: sqltypes.h:583
llvm::BasicBlock * args_null_bb
#define AUTOMATIC_IR_METADATA(CGENSTATE)
llvm::Type * ext_arg_type_to_llvm_type(const ExtArgumentType ext_arg_type, llvm::LLVMContext &ctx)
void codegenGeoMultiLineStringArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *multi_linestring_coords, llvm::Value *multi_linestring_size, llvm::Value *linestring_sizes, llvm::Value *linestring_sizes_size, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
Argument type based extension function binding.
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:83
llvm::Value * emitCall(const std::string &fname, const std::vector< llvm::Value * > &args)
Definition: CgenState.cpp:216
bool is_buffer() const
Definition: sqltypes.h:608
ExecutorDeviceType device_type
void codegenGeoPointArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *point_buf, llvm::Value *point_size, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
#define RUNTIME_EXPORT
std::vector< llvm::Value * > codegen(const Analyzer::Expr *, const bool fetch_columns, const CompilationOptions &)
Definition: IRCodegen.cpp:30
#define CHECK_LE(x, y)
Definition: Logger.h:300
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:388
std::string serialize_llvm_object(const T *llvm_obj)
bool isLocalAlloc() const
Definition: Analyzer.h:2662
llvm::StructType * createPolygonStructType(const std::string &udf_func_name, size_t param_num)
const Analyzer::Expr * getArg(const size_t i) const
Definition: Analyzer.h:2406
const Expr * get_operand() const
Definition: Analyzer.h:380
llvm::Value * endArgsNullcheck(const ArgNullcheckBBs &, llvm::Value *, llvm::Value *, const Analyzer::FunctionOper *)
const std::vector< ExtArgumentType > & getInputArgs() const
llvm::StructType * createStringViewStructType()
std::unique_ptr< llvm::Module > udf_cpu_module
HOST DEVICE int get_comp_param() const
Definition: sqltypes.h:389
llvm::Value * codegenFunctionOperWithCustomTypeHandling(const Analyzer::FunctionOperWithCustomTypeHandling *, const CompilationOptions &)
bool is_bytes() const
Definition: sqltypes.h:599
llvm::ConstantInt * llInt(const T v) const
Definition: CgenState.h:240
#define CHECK(condition)
Definition: Logger.h:289
llvm::Value * codegenIsNullNumber(llvm::Value *, const SQLTypeInfo &)
Definition: LogicalIR.cpp:412
uint64_t exp_to_scale(const unsigned exp)
llvm::ConstantInt * ll_bool(const bool v, llvm::LLVMContext &context)
llvm::Value * codegenCompression(const SQLTypeInfo &type_info)
llvm::Value * codegenCast(const Analyzer::UOper *, const CompilationOptions &)
Definition: CastIR.cpp:21
uint32_t log2_bytes(const uint32_t bytes)
Definition: Execute.h:176
Definition: sqltypes.h:60
bool is_string() const
Definition: sqltypes.h:576
std::string getName() const
Definition: Analyzer.h:2402
void codegenGeoLineStringArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *line_string_buf, llvm::Value *line_string_size, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
bool is_ext_arg_type_pointer(const ExtArgumentType ext_arg_type)
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:957
bool is_decimal() const
Definition: sqltypes.h:579
int get_physical_coord_cols() const
Definition: sqltypes.h:430
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
llvm::ConstantFP * inlineFpNull(const SQLTypeInfo &)
Definition: CgenState.cpp:103
Executor * executor() const
llvm::StructType * createMultiPolygonStructType(const std::string &udf_func_name, size_t param_num)