OmniSciDB  6686921089
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsBinding.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 #include <algorithm>
19 #include "ExternalExecutor.h"
20 
21 // A rather crude function binding logic based on the types of the arguments.
22 // We want it to be possible to write specialized versions of functions to be
23 // exposed as SQL extensions. This is important especially for performance
24 // reasons, since double operations can be significantly slower than float. We
25 // compute a score for each candidate signature based on conversions required to
26 // from the function arguments as specified in the SQL query to the versions in
27 // ExtensionFunctions.hpp.
28 
29 /*
30  New implementation for binding a SQL function operator to the
31  optimal candidate within in all available extension functions.
32  */
33 
34 namespace {
35 
37  switch (ext_arg_column_type) {
39  return ExtArgumentType::Int8;
51  return ExtArgumentType::Bool;
54  default:
55  UNREACHABLE();
56  }
57  return ExtArgumentType{};
58 }
59 
61  const ExtArgumentType ext_arg_column_list_type) {
62  switch (ext_arg_column_list_type) {
64  return ExtArgumentType::Int8;
76  return ExtArgumentType::Bool;
79  default:
80  UNREACHABLE();
81  }
82  return ExtArgumentType{};
83 }
84 
86  switch (ext_arg_array_type) {
88  return ExtArgumentType::Int8;
100  return ExtArgumentType::Bool;
101  default:
102  UNREACHABLE();
103  }
104  return ExtArgumentType{};
105 }
106 
107 static int match_numeric_argument(const SQLTypeInfo& arg_type_info,
108  const bool is_arg_literal,
109  const ExtArgumentType& sig_ext_arg_type,
110  int32_t& penalty_score) {
111  const auto arg_type = arg_type_info.get_type();
112  CHECK(arg_type == kBOOLEAN || arg_type == kTINYINT || arg_type == kSMALLINT ||
113  arg_type == kINT || arg_type == kBIGINT || arg_type == kFLOAT ||
114  arg_type == kDOUBLE || arg_type == kDECIMAL || arg_type == kNUMERIC);
115  // Todo (todd): Add support for timestamp, date, and time types
116  const auto sig_type_info = ext_arg_type_to_type_info(sig_ext_arg_type);
117  const auto sig_type = sig_type_info.get_type();
118 
119  // If we can't legally auto-cast to sig_type, abort
120  if (!arg_type_info.is_numeric_scalar_auto_castable(sig_type_info)) {
121  return -1;
122  }
123 
124  // We now compare a measure of the scale of the sig_type with the arg_type,
125  // which provides a basis for scoring the match between the two.
126  // Note that get_numeric_scalar_scale for the most part returns the logical
127  // byte width of the type, with a few caveats for decimals and timestamps
128  // described in more depth in comments in the function itself.
129  // Also even though for example float and int types return 4 (as in 4 bytes),
130  // and double and bigint types return 8, a fp32 type cannot express every 32-bit integer
131  // (even if it can cover a larger absolute range), and an fp64 type likewise cannot
132  // express every 64-bit integer. However for the purposes of extension function casting
133  // we consider floats and ints to be equal scale, and similarly doubles and bigints.
134  const auto arg_type_relative_scale = arg_type_info.get_numeric_scalar_scale();
135  CHECK_GE(arg_type_relative_scale, 1);
136  CHECK_LE(arg_type_relative_scale, 8);
137  const auto sig_type_relative_scale = sig_type_info.get_numeric_scalar_scale();
138  CHECK_GE(sig_type_relative_scale, 1);
139  CHECK_LE(sig_type_relative_scale, 8);
140 
141  // Current we do not allow auto-casting to types with less scale/precision, with the
142  // caveats around floating point and decimal types mentioned above
143  CHECK_GE(sig_type_relative_scale, arg_type_relative_scale);
144 
145  // Calculate the ratio of the sig_type by the arg_type, per the above check will be >= 1
146  const auto sig_type_scale_gain_ratio =
147  sig_type_relative_scale / arg_type_relative_scale;
148  CHECK_GE(sig_type_scale_gain_ratio, 1);
149 
150  const bool is_integer_to_fp_cast = (arg_type == kTINYINT || arg_type == kSMALLINT ||
151  arg_type == kINT || arg_type == kBIGINT) &&
152  (sig_type == kFLOAT || sig_type == kDOUBLE);
153 
154  // Following the old bespoke scoring logic this function replaces, we heavily penalize
155  // any casts that move ints to floats/doubles for the precision-loss reasons above
156  // Arguably all integers in the tinyint and smallint can be fully specified with both
157  // float and double types, but we treat them the same as int and bigint types here.
158  const int32_t type_family_cast_penalty_score = is_integer_to_fp_cast ? 1001000 : 1000;
159 
160  int32_t scale_cast_penalty_score;
161 
162  // The following logic is new. Basically there are strong reasons to
163  // prefer the promotion of constant literals to the most precise type possible, as
164  // rather than the type being inherent in the data - that is a column or columns where
165  // a user specified a type (and with any expressions on those columns following our
166  // standard sql casting logic), literal types are given to us by Calcite and do not
167  // necessarily convey any semantic intent (i.e. 10 will be an int, but 10.0 a decimal)
168  // Hence it is better to promote these types to the most precise sig_type available,
169  // while at the same time keeping column expressions as close as possible to the input
170  // types (mainly for performance, we have many float versions of various functions
171  // to allow for greater performance when the underlying data is not of double precision,
172  // and hence there is little benefit of the extra cost of computing double precision
173  // operators on this data)
174  if (is_arg_literal) {
175  scale_cast_penalty_score =
176  (8000 / arg_type_relative_scale) - (1000 * sig_type_scale_gain_ratio);
177  } else {
178  scale_cast_penalty_score = (1000 * sig_type_scale_gain_ratio);
179  }
180 
181  const auto cast_penalty_score =
182  type_family_cast_penalty_score + scale_cast_penalty_score;
183  CHECK_GT(cast_penalty_score, 0);
184  penalty_score += cast_penalty_score;
185  return 1;
186 }
187 
188 static int match_arguments(const SQLTypeInfo& arg_type,
189  const bool is_arg_literal,
190  int sig_pos,
191  const std::vector<ExtArgumentType>& sig_types,
192  int& penalty_score) {
193  /*
194  Returns non-negative integer `offset` if `arg_type` and
195  `sig_types[sig_pos:sig_pos + offset]` match.
196 
197  The `offset` value can be interpreted as the number of extension
198  function arguments that is consumed by the given `arg_type`. For
199  instance, for scalar types the offset is always 1, for array
200  types the offset is 2: one argument for array pointer value and
201  one argument for the array size value, etc.
202 
203  Returns -1 when the types of an argument and the corresponding
204  extension function argument(s) mismatch, or when downcasting would
205  be effective.
206 
207  In case of non-negative `offset` result, the function updates
208  penalty_score argument as follows:
209 
210  add 1000 if arg_type is non-scalar, otherwise:
211  add 1000 * sizeof(sig_type) / sizeof(arg_type)
212  add 1000000 if type kinds differ (integer vs double, for instance)
213 
214  */
215  int max_pos = sig_types.size() - 1;
216  if (sig_pos > max_pos) {
217  return -1;
218  }
219  auto sig_type = sig_types[sig_pos];
220  switch (arg_type.get_type()) {
221  case kBOOLEAN:
222  case kTINYINT:
223  case kSMALLINT:
224  case kINT:
225  case kBIGINT:
226  case kFLOAT:
227  case kDOUBLE:
228  case kDECIMAL:
229  case kNUMERIC:
230  return match_numeric_argument(arg_type, is_arg_literal, sig_type, penalty_score);
231  case kPOINT:
232  case kLINESTRING:
233  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
234  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
235  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble) &&
236  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
237  penalty_score += 1000;
238  return 2;
239  } else if (sig_type == ExtArgumentType::GeoPoint ||
240  sig_type == ExtArgumentType::GeoLineString) {
241  penalty_score += 1000;
242  return 1;
243  }
244  return -1;
245  case kARRAY:
246  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
247  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
248  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble ||
249  sig_type == ExtArgumentType::PBool) &&
250  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
251  penalty_score += 1000;
252  return 2;
253  } else if (is_ext_arg_type_array(sig_type)) {
254  // array arguments must match exactly
255  CHECK(arg_type.is_array());
256  const auto sig_type_ti =
258  if (arg_type.get_elem_type() == kBOOLEAN && sig_type_ti.get_type() == kTINYINT) {
259  /* Boolean array has the same low-level structure as Int8 array. */
260  penalty_score += 1000;
261  return 1;
262  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
263  penalty_score += 1000;
264  return 1;
265  } else {
266  return -1;
267  }
268  }
269  break;
270  case kPOLYGON:
271  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
272  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
273  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
274  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
275  penalty_score += 1000;
276  return 4;
277  } else if (sig_type == ExtArgumentType::GeoPolygon) {
278  penalty_score += 1000;
279  return 1;
280  }
281  break;
282  case kMULTIPOLYGON:
283  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 5 < max_pos &&
284  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
285  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
286  sig_types[sig_pos + 3] == ExtArgumentType::Int64 &&
287  sig_types[sig_pos + 4] == ExtArgumentType::PInt32 &&
288  sig_types[sig_pos + 5] == ExtArgumentType::Int64) {
289  penalty_score += 1000;
290  return 6;
291  } else if (sig_type == ExtArgumentType::GeoMultiPolygon) {
292  penalty_score += 1000;
293  return 1;
294  }
295  break;
296  case kNULLT: // NULL maps to a pointer and size argument
297  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
298  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
299  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble ||
300  sig_type == ExtArgumentType::PBool) &&
301  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
302  penalty_score += 1000;
303  return 2;
304  }
305  break;
306  case kCOLUMN:
307  if (is_ext_arg_type_column(sig_type)) {
308  // column arguments must match exactly
309  const auto sig_type_ti =
311  if (arg_type.get_elem_type() == kBOOLEAN && sig_type_ti.get_type() == kTINYINT) {
312  /* Boolean column has the same low-level structure as Int8 column. */
313  penalty_score += 1000;
314  return 1;
315  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
316  penalty_score += 1000;
317  return 1;
318  } else {
319  return -1;
320  }
321  }
322  break;
323  case kCOLUMN_LIST:
324  if (is_ext_arg_type_column_list(sig_type)) {
325  // column_list arguments must match exactly
326  const auto sig_type_ti =
328  if (arg_type.get_elem_type() == kBOOLEAN && sig_type_ti.get_type() == kTINYINT) {
329  /* Boolean column_list has the same low-level structure as Int8 column_list. */
330  penalty_score += 10000;
331  return 1;
332  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
333  penalty_score += 10000;
334  return 1;
335  } else {
336  return -1;
337  }
338  }
339  break;
340  case kVARCHAR:
341  if (sig_type != ExtArgumentType::TextEncodingNone) {
342  return -1;
343  }
344  switch (arg_type.get_compression()) {
345  case kENCODING_NONE:
346  penalty_score += 1000;
347  return 1;
348  case kENCODING_DICT:
349  return -1;
350  // Todo (todd): Evaluate when and where we can tranlate to dictionary-encoded
351  default:
352  UNREACHABLE();
353  }
354  case kTEXT:
355  if (sig_type != ExtArgumentType::TextEncodingNone) {
356  return -1;
357  }
358  switch (arg_type.get_compression()) {
359  case kENCODING_NONE:
360  penalty_score += 1000;
361  return 1;
362  case kENCODING_DICT:
363  return -1;
364  default:
365  UNREACHABLE();
366  }
367  /* Not implemented types:
368  kCHAR
369  kTIME
370  kTIMESTAMP
371  kDATE
372  kINTERVAL_DAY_TIME
373  kINTERVAL_YEAR_MONTH
374  kGEOMETRY
375  kGEOGRAPHY
376  kEVAL_CONTEXT_TYPE
377  kVOID
378  kCURSOR
379  */
380  default:
381  throw std::runtime_error(std::string(__FILE__) + "#" + std::to_string(__LINE__) +
382  ": support for " + arg_type.get_type_name() +
383  "(type=" + std::to_string(arg_type.get_type()) + ")" +
384  +" not implemented: \n pos=" + std::to_string(sig_pos) +
385  " max_pos=" + std::to_string(max_pos) + "\n sig_types=(" +
386  ExtensionFunctionsWhitelist::toString(sig_types) + ")");
387  }
388  return -1;
389 }
390 
391 bool is_valid_identifier(std::string str) {
392  if (!str.size()) {
393  return false;
394  }
395 
396  if (!(std::isalpha(str[0]) || str[0] == '_')) {
397  return false;
398  }
399 
400  for (size_t i = 1; i < str.size(); i++) {
401  if (!(std::isalnum(str[i]) || str[i] == '_')) {
402  return false;
403  }
404  }
405 
406  return true;
407 }
408 
409 } // namespace
410 
411 template <typename T>
412 std::tuple<T, std::vector<SQLTypeInfo>> bind_function(
413  std::string name,
414  Analyzer::ExpressionPtrVector func_args, // function args from sql query
415  const std::vector<T>& ext_funcs, // list of functions registered
416  const std::string processor) {
417  /* worker function
418 
419  Template type T must implement the following methods:
420 
421  std::vector<ExtArgumentType> getInputArgs()
422  */
423  /*
424  Return extension function/table function that has the following
425  properties
426 
427  1. each argument type in `arg_types` matches with extension
428  function argument types.
429 
430  For scalar types, the matching means that the types are either
431  equal or the argument type is smaller than the corresponding
432  the extension function argument type. This ensures that no
433  information is lost when casting of argument values is
434  required.
435 
436  For array and geo types, the matching means that the argument
437  type matches exactly with a group of extension function
438  argument types. See `match_arguments`.
439 
440  2. has minimal penalty score among all implementations of the
441  extension function with given `name`, see `get_penalty_score`
442  for the definition of penalty score.
443 
444  It is assumed that function_oper and extension functions in
445  ext_funcs have the same name.
446  */
447  if (!is_valid_identifier(name)) {
448  throw NativeExecutionError(
449  "Cannot bind function with invalid UDF/UDTF function name: " + name);
450  }
451 
452  int minimal_score = std::numeric_limits<int>::max();
453  int index = -1;
454  int optimal = -1;
455  int optimal_variant = -1;
456 
457  std::vector<SQLTypeInfo> type_infos_input;
458  std::vector<bool> args_are_constants;
459  for (auto atype : func_args) {
460  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
461  if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
462  SQLTypeInfo type_info = atype->get_type_info();
463  if (atype->get_type_info().get_type() == kTEXT) {
464  auto ti = generate_column_type(type_info.get_type(), // subtype
465  type_info.get_compression(), // compression
466  type_info.get_comp_param()); // comp_param
467  type_infos_input.push_back(ti);
468  args_are_constants.push_back(false);
469  } else {
470  auto ti = generate_column_type(type_info.get_type());
471  type_infos_input.push_back(ti);
472  args_are_constants.push_back(true);
473  }
474  continue;
475  }
476  }
477  type_infos_input.push_back(atype->get_type_info());
478  if (dynamic_cast<const Analyzer::Constant*>(atype.get())) {
479  args_are_constants.push_back(true);
480  } else {
481  args_are_constants.push_back(false);
482  }
483  }
484  CHECK_EQ(type_infos_input.size(), args_are_constants.size());
485 
486  if (type_infos_input.size() == 0 && ext_funcs.size() > 0) {
487  CHECK_EQ(ext_funcs.size(), static_cast<size_t>(1));
488  CHECK_EQ(ext_funcs[0].getInputArgs().size(), static_cast<size_t>(0));
489  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
490  CHECK(ext_funcs[0].hasNonUserSpecifiedOutputSize());
491  }
492  std::vector<SQLTypeInfo> empty_type_info_variant(0);
493  return {ext_funcs[0], empty_type_info_variant};
494  }
495 
496  // clang-format off
497  /*
498  Table functions may have arguments such as ColumnList that collect
499  neighboring columns with the same data type into a single object.
500  Here we compute all possible combinations of mapping a subset of
501  columns into columns sets. For example, if the types of function
502  arguments are (as given in func_args argument)
503 
504  (Column<int>, Column<int>, Column<int>, int)
505 
506  then the computed variants will be
507 
508  (Column<int>, Column<int>, Column<int>, int)
509  (Column<int>, Column<int>, ColumnList[1]<int>, int)
510  (Column<int>, ColumnList[1]<int>, Column<int>, int)
511  (Column<int>, ColumnList[2]<int>, int)
512  (ColumnList[1]<int>, Column<int>, Column<int>, int)
513  (ColumnList[1]<int>, Column<int>, ColumnList[1]<int>, int)
514  (ColumnList[2]<int>, Column<int>, int)
515  (ColumnList[3]<int>, int)
516 
517  where the integers in [..] indicate the number of collected
518  columns. In the SQLTypeInfo instance, this number is stored in the
519  SQLTypeInfo dimension attribute.
520 
521  As an example, let us consider a SQL query containing the
522  following expression calling a UDTF foo:
523 
524  table(foo(cursor(select a, b, c from tableofints), 1))
525 
526  Here follows a list of table functions and the corresponding
527  optimal argument type variants that are computed for the given
528  query expression:
529 
530  UDTF: foo(ColumnList<int>, RowMultiplier) -> Column<int>
531  (ColumnList[3]<int>, int) # a, b, c are all collected to column_list
532 
533  UDTF: foo(Column<int>, ColumnList<int>, RowMultiplier) -> Column<int>
534  (Column<int>, ColumnList[2]<int>, int) # b and c are collected to column_list
535 
536  UDTF: foo(Column<int>, Column<int>, Column<int>, RowMultiplier) -> Column<int>
537  (Column<int>, Column<int>, Column<int>, int)
538  */
539  // clang-format on
540  std::vector<std::vector<SQLTypeInfo>> type_infos_variants;
541  for (auto ti : type_infos_input) {
542  if (type_infos_variants.begin() == type_infos_variants.end()) {
543  type_infos_variants.push_back({ti});
544  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
545  if (ti.is_column()) {
546  auto mti = generate_column_list_type(ti.get_subtype());
547  mti.set_dimension(1);
548  type_infos_variants.push_back({mti});
549  }
550  }
551  continue;
552  }
553  std::vector<std::vector<SQLTypeInfo>> new_type_infos_variants;
554  for (auto& type_infos : type_infos_variants) {
555  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
556  if (ti.is_column()) {
557  auto new_type_infos = type_infos; // makes a copy
558  const auto& last = type_infos.back();
559  if (last.is_column_list() && last.get_subtype() == ti.get_subtype()) {
560  // last column_list consumes column argument if types match
561  new_type_infos.back().set_dimension(last.get_dimension() + 1);
562  } else {
563  // add column as column_list argument
564  auto mti = generate_column_list_type(ti.get_subtype());
565  mti.set_dimension(1);
566  new_type_infos.push_back(mti);
567  }
568  new_type_infos_variants.push_back(new_type_infos);
569  }
570  }
571  type_infos.push_back(ti);
572  }
573  type_infos_variants.insert(type_infos_variants.end(),
574  new_type_infos_variants.begin(),
575  new_type_infos_variants.end());
576  }
577 
578  // Find extension function that gives the best match on the set of
579  // argument type variants:
580  for (auto ext_func : ext_funcs) {
581  index++;
582 
583  auto ext_func_args = ext_func.getInputArgs();
584  int index_variant = -1;
585  for (const auto& type_infos : type_infos_variants) {
586  index_variant++;
587  int penalty_score = 0;
588  int pos = 0;
589  int original_input_idx = 0;
590  CHECK_LE(type_infos.size(), args_are_constants.size());
591  // for (size_t ti_idx = 0; ti_idx != type_infos.size(); ++ti_idx) {
592  for (const auto& ti : type_infos) {
593  int offset = match_arguments(ti,
594  args_are_constants[original_input_idx],
595  pos,
596  ext_func_args,
597  penalty_score);
598  if (offset < 0) {
599  // atype does not match with ext_func argument
600  pos = -1;
601  break;
602  }
603  if (ti.get_type() == kCOLUMN_LIST) {
604  original_input_idx += ti.get_dimension();
605  } else {
606  original_input_idx++;
607  }
608  pos += offset;
609  }
610 
611  if ((size_t)pos == ext_func_args.size()) {
612  CHECK_EQ(args_are_constants.size(), original_input_idx);
613  // prefer smaller return types
614  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
615  if (penalty_score < minimal_score) {
616  optimal = index;
617  minimal_score = penalty_score;
618  optimal_variant = index_variant;
619  }
620  }
621  }
622  }
623 
624  if (optimal == -1) {
625  /* no extension function found that argument types would match
626  with types in `arg_types` */
627  auto sarg_types = ExtensionFunctionsWhitelist::toString(type_infos_input);
628  std::string message;
629  if (!ext_funcs.size()) {
630  message = "Function " + name + "(" + sarg_types + ") not supported.";
631  throw ExtensionFunctionBindingError(message);
632  } else {
633  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
634  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
635  " UDTF implementation.";
636  } else if constexpr (std::is_same_v<T, ExtensionFunction>) {
637  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
638  " UDF implementation.";
639  } else {
640  LOG(FATAL) << "bind_function: unknown extension function type "
641  << typeid(T).name();
642  }
643  message += "\n Existing extension function implementations:";
644  for (const auto& ext_func : ext_funcs) {
645  // Do not show functions missing the sizer argument
646  if constexpr (std::is_same_v<T, table_functions::TableFunction>)
647  if (ext_func.useDefaultSizer())
648  continue;
649  message += "\n " + ext_func.toStringSQL();
650  }
651  }
652  throw ExtensionFunctionBindingError(message);
653  }
654 
655  // Functions with "_default_" suffix only exist for calcite
656  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
657  if (ext_funcs[optimal].hasUserSpecifiedOutputSizeMultiplier() &&
658  ext_funcs[optimal].useDefaultSizer()) {
659  std::string name = ext_funcs[optimal].getName();
660  name.erase(name.find(DEFAULT_ROW_MULTIPLIER_SUFFIX),
662  for (size_t i = 0; i < ext_funcs.size(); i++) {
663  if (ext_funcs[i].getName() == name) {
664  optimal = i;
665  std::vector<SQLTypeInfo> type_info = type_infos_variants[optimal_variant];
666  size_t sizer = ext_funcs[optimal].getOutputRowSizeParameter();
667  type_info.insert(type_info.begin() + sizer - 1, SQLTypeInfo(kINT, true));
668  return {ext_funcs[optimal], type_info};
669  }
670  }
671  UNREACHABLE();
672  }
673  }
674 
675  return {ext_funcs[optimal], type_infos_variants[optimal_variant]};
676 }
677 
678 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
681  const std::vector<table_functions::TableFunction>& table_funcs,
682  const bool is_gpu) {
683  std::string processor = (is_gpu ? "GPU" : "CPU");
684  return bind_function<table_functions::TableFunction>(
685  name, input_args, table_funcs, processor);
686 }
687 
689  Analyzer::ExpressionPtrVector func_args) {
690  // used in RelAlgTranslator.cpp, first try GPU UDFs, then fall back
691  // to CPU UDFs.
692  bool is_gpu = true;
693  std::string processor = "GPU";
694  auto ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
695  if (!ext_funcs.size()) {
696  is_gpu = false;
697  processor = "CPU";
698  ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
699  }
700  try {
701  return std::get<0>(
702  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
703  } catch (ExtensionFunctionBindingError& e) {
704  if (is_gpu) {
705  is_gpu = false;
706  processor = "GPU|CPU";
707  ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
708  return std::get<0>(
709  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
710  } else {
711  throw;
712  }
713  }
714 }
715 
718  const bool is_gpu) {
719  // used below
720  std::vector<ExtensionFunction> ext_funcs =
722  std::string processor = (is_gpu ? "GPU" : "CPU");
723  return std::get<0>(
724  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
725 }
726 
728  const bool is_gpu) {
729  // used in ExtensionsIR.cpp
730  auto name = function_oper->getName();
731  Analyzer::ExpressionPtrVector func_args = {};
732  for (size_t i = 0; i < function_oper->getArity(); ++i) {
733  func_args.push_back(function_oper->getOwnArg(i));
734  }
735  return bind_function(name, func_args, is_gpu);
736 }
737 
738 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
741  const bool is_gpu) {
742  // used in RelAlgExecutor.cpp
743  std::vector<table_functions::TableFunction> table_funcs =
745  return bind_table_function(name, input_args, table_funcs, is_gpu);
746 }
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name, const bool is_gpu)
#define CHECK_EQ(x, y)
Definition: Logger.h:217
ExtArgumentType get_array_arg_elem_type(const ExtArgumentType ext_arg_array_type)
size_t getArity() const
Definition: Analyzer.h:1515
bool is_ext_arg_type_column(const ExtArgumentType ext_arg_type)
#define LOG(tag)
Definition: Logger.h:203
auto generate_column_type(const SQLTypes subtype)
Definition: sqltypes.h:1119
string name
Definition: setup.in.py:72
static int match_arguments(const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
#define UNREACHABLE()
Definition: Logger.h:253
#define CHECK_GE(x, y)
Definition: Logger.h:222
ExtArgumentType get_column_list_arg_elem_type(const ExtArgumentType ext_arg_column_list_type)
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:329
#define CHECK_GT(x, y)
Definition: Logger.h:221
std::string to_string(char const *&&v)
bool is_ext_arg_type_column_list(const ExtArgumentType ext_arg_type)
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:1522
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
bool is_numeric_scalar_auto_castable(const SQLTypeInfo &new_type_info) const
returns true if the sql_type can be cast to the type specified by new_type_info with no loss of preci...
Definition: sqltypes.h:637
auto generate_column_list_type(const SQLTypes subtype)
Definition: sqltypes.h:1133
ExtArgumentType get_column_arg_elem_type(const ExtArgumentType ext_arg_column_type)
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
static int match_numeric_argument(const SQLTypeInfo &arg_type_info, const bool is_arg_literal, const ExtArgumentType &sig_ext_arg_type, int32_t &penalty_score)
Definition: sqltypes.h:52
#define CHECK_LE(x, y)
Definition: Logger.h:220
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:337
std::string get_type_name() const
Definition: sqltypes.h:432
int32_t get_numeric_scalar_scale() const
returns integer between 1 and 8 indicating what is roughly equivalent to the logical byte size of a s...
Definition: sqltypes.h:698
const std::tuple< table_functions::TableFunction, std::vector< SQLTypeInfo > > bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
HOST DEVICE int get_comp_param() const
Definition: sqltypes.h:338
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:209
static std::vector< TableFunction > get_table_funcs(const std::string &name, const bool is_gpu)
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:181
Definition: sqltypes.h:45
std::string getName() const
Definition: Analyzer.h:1513
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:850
bool is_array() const
Definition: sqltypes.h:517
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)