OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsBinding.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 #include <algorithm>
19 #include "ExternalExecutor.h"
20 
21 // A rather crude function binding logic based on the types of the arguments.
22 // We want it to be possible to write specialized versions of functions to be
23 // exposed as SQL extensions. This is important especially for performance
24 // reasons, since double operations can be significantly slower than float. We
25 // compute a score for each candidate signature based on conversions required to
26 // from the function arguments as specified in the SQL query to the versions in
27 // ExtensionFunctions.hpp.
28 
29 /*
30  New implementation for binding a SQL function operator to the
31  optimal candidate within in all available extension functions.
32  */
33 
34 namespace {
35 
37  switch (ext_arg_column_type) {
39  return ExtArgumentType::Int8;
51  return ExtArgumentType::Bool;
86  default:
87  UNREACHABLE();
88  }
89  return ExtArgumentType{};
90 }
91 
93  const ExtArgumentType ext_arg_column_list_type) {
94  switch (ext_arg_column_list_type) {
96  return ExtArgumentType::Int8;
100  return ExtArgumentType::Int32;
102  return ExtArgumentType::Int64;
104  return ExtArgumentType::Float;
108  return ExtArgumentType::Bool;
141  default:
142  UNREACHABLE();
143  }
144  return ExtArgumentType{};
145 }
146 
148  switch (ext_arg_array_type) {
150  return ExtArgumentType::Int8;
152  return ExtArgumentType::Int16;
154  return ExtArgumentType::Int32;
156  return ExtArgumentType::Int64;
158  return ExtArgumentType::Float;
162  return ExtArgumentType::Bool;
167  default:
168  UNREACHABLE();
169  }
170  return ExtArgumentType{};
171 }
172 
173 static int match_numeric_argument(const SQLTypeInfo& arg_type_info,
174  const bool is_arg_literal,
175  const ExtArgumentType& sig_ext_arg_type,
176  int32_t& penalty_score) {
177  const auto arg_type = arg_type_info.get_type();
178  CHECK(arg_type == kBOOLEAN || arg_type == kTINYINT || arg_type == kSMALLINT ||
179  arg_type == kINT || arg_type == kBIGINT || arg_type == kFLOAT ||
180  arg_type == kDOUBLE || arg_type == kDECIMAL || arg_type == kNUMERIC);
181  // Todo (todd): Add support for timestamp, date, and time types
182  const auto sig_type_info = ext_arg_type_to_type_info(sig_ext_arg_type);
183  const auto sig_type = sig_type_info.get_type();
184 
185  // If we can't legally auto-cast to sig_type, abort
186  if (!arg_type_info.is_numeric_scalar_auto_castable(sig_type_info)) {
187  return -1;
188  }
189 
190  // We now compare a measure of the scale of the sig_type with the
191  // arg_type, which provides a basis for scoring the match between
192  // the two. Note that get_numeric_scalar_scale for the most part
193  // returns the logical byte width of the type, with a few caveats
194  // for decimals and timestamps described in more depth in comments
195  // in the function itself. Also even though for example float and
196  // int types return 4 (as in 4 bytes), and double and bigint types
197  // return 8, a fp32 type cannot express every 32-bit integer (even
198  // if it can cover a larger absolute range), and an fp64 type
199  // likewise cannot express every 64-bit integer. With the aim to
200  // minimize the precision loss from casting (always precise) integer
201  // value to (imprecise) floating point value, in the case of integer
202  // inputs, we'll penalize wider floating point argument types least
203  // by a specific scale transformation (see the implementation
204  // below). For instance, casting tinyint to fp64 is prefered over
205  // casting it to fp32 to minimize precision loss.
206  const bool is_integer_to_fp_cast = (arg_type == kTINYINT || arg_type == kSMALLINT ||
207  arg_type == kINT || arg_type == kBIGINT) &&
208  (sig_type == kFLOAT || sig_type == kDOUBLE);
209 
210  const auto arg_type_relative_scale = arg_type_info.get_numeric_scalar_scale();
211  CHECK_GE(arg_type_relative_scale, 1);
212  CHECK_LE(arg_type_relative_scale, 8);
213  auto sig_type_relative_scale = sig_type_info.get_numeric_scalar_scale();
214  CHECK_GE(sig_type_relative_scale, 1);
215  CHECK_LE(sig_type_relative_scale, 8);
216 
217  if (is_integer_to_fp_cast) {
218  // transform fp scale: 4 becomes 16, 8 remains 8
219  sig_type_relative_scale = (3 - (sig_type_relative_scale >> 2)) << 3;
220  }
221 
222  // We do not allow auto-casting to types with less scale/precision
223  // within the same type family.
224  CHECK_GE(sig_type_relative_scale, arg_type_relative_scale);
225 
226  // Calculate the ratio of the sig_type by the arg_type, per the above check will be >= 1
227  const auto sig_type_scale_gain_ratio =
228  sig_type_relative_scale / arg_type_relative_scale;
229  CHECK_GE(sig_type_scale_gain_ratio, 1);
230 
231  // Following the old bespoke scoring logic this function replaces, we heavily penalize
232  // any casts that move ints to floats/doubles for the precision-loss reasons above
233  // Arguably all integers in the tinyint and smallint can be fully specified with both
234  // float and double types, but we treat them the same as int and bigint types here.
235  const int32_t type_family_cast_penalty_score = is_integer_to_fp_cast ? 1001000 : 1000;
236 
237  int32_t scale_cast_penalty_score;
238 
239  // The following logic is new. Basically there are strong reasons to
240  // prefer the promotion of constant literals to the most precise type possible, as
241  // rather than the type being inherent in the data - that is a column or columns where
242  // a user specified a type (and with any expressions on those columns following our
243  // standard sql casting logic), literal types are given to us by Calcite and do not
244  // necessarily convey any semantic intent (i.e. 10 will be an int, but 10.0 a decimal)
245  // Hence it is better to promote these types to the most precise sig_type available,
246  // while at the same time keeping column expressions as close as possible to the input
247  // types (mainly for performance, we have many float versions of various functions
248  // to allow for greater performance when the underlying data is not of double precision,
249  // and hence there is little benefit of the extra cost of computing double precision
250  // operators on this data)
251  if (is_arg_literal) {
252  scale_cast_penalty_score =
253  (8000 / arg_type_relative_scale) - (1000 * sig_type_scale_gain_ratio);
254  } else {
255  scale_cast_penalty_score = (1000 * sig_type_scale_gain_ratio);
256  }
257 
258  const auto cast_penalty_score =
259  type_family_cast_penalty_score + scale_cast_penalty_score;
260  CHECK_GT(cast_penalty_score, 0);
261  penalty_score += cast_penalty_score;
262  return 1;
263 }
264 
265 static int match_arguments(const SQLTypeInfo& arg_type,
266  const bool is_arg_literal,
267  int sig_pos,
268  const std::vector<ExtArgumentType>& sig_types,
269  int& penalty_score) {
270  /*
271  Returns non-negative integer `offset` if `arg_type` and
272  `sig_types[sig_pos:sig_pos + offset]` match.
273 
274  The `offset` value can be interpreted as the number of extension
275  function arguments that is consumed by the given `arg_type`. For
276  instance, for scalar types the offset is always 1, for array
277  types the offset is 2: one argument for array pointer value and
278  one argument for the array size value, etc.
279 
280  Returns -1 when the types of an argument and the corresponding
281  extension function argument(s) mismatch, or when downcasting would
282  be effective.
283 
284  In case of non-negative `offset` result, the function updates
285  penalty_score argument as follows:
286 
287  add 1000 if arg_type is non-scalar, otherwise:
288  add 1000 * sizeof(sig_type) / sizeof(arg_type)
289  add 1000000 if type kinds differ (integer vs double, for instance)
290 
291  */
292  int max_pos = sig_types.size() - 1;
293  if (sig_pos > max_pos) {
294  return -1;
295  }
296  auto sig_type = sig_types[sig_pos];
297  switch (arg_type.get_type()) {
298  case kBOOLEAN:
299  case kTINYINT:
300  case kSMALLINT:
301  case kINT:
302  case kBIGINT:
303  case kFLOAT:
304  case kDOUBLE:
305  case kDECIMAL:
306  case kNUMERIC:
307  return match_numeric_argument(arg_type, is_arg_literal, sig_type, penalty_score);
308  case kPOINT:
309  case kMULTIPOINT:
310  case kLINESTRING:
311  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
312  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
313  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble) &&
314  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
315  penalty_score += 1000;
316  return 2;
317  } else if ((sig_type == ExtArgumentType::GeoPoint &&
318  arg_type.get_type() == kPOINT) ||
319  (sig_type == ExtArgumentType::GeoMultiPoint &&
320  arg_type.get_type() == kMULTIPOINT) ||
321  (sig_type == ExtArgumentType::GeoLineString &&
322  arg_type.get_type() == kLINESTRING)) {
323  penalty_score += 1000;
324  return 1;
325  }
326  return -1;
327  case kMULTILINESTRING:
328  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
329  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
330  sig_types[sig_pos + 2] == ExtArgumentType::PInt8 &&
331  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
332  penalty_score += 1000;
333  return 4;
334  } else if (sig_type == ExtArgumentType::GeoMultiLineString) {
335  penalty_score += 1000;
336  return 1;
337  }
338  break;
339  case kARRAY:
340  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
341  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
342  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble ||
343  sig_type == ExtArgumentType::PBool) &&
344  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
345  penalty_score += 1000;
346  return 2;
347  } else if (is_ext_arg_type_array(sig_type)) {
348  // array arguments must match exactly
349  CHECK(arg_type.is_array());
350  const auto sig_type_ti =
352  if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
353  sig_type_ti.get_type() == kTINYINT) {
354  /* Boolean array has the same low-level structure as Int8 array. */
355  penalty_score += 1000;
356  return 1;
357  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
358  penalty_score += 1000;
359  return 1;
360  } else {
361  return -1;
362  }
363  }
364  break;
365  case kPOLYGON:
366  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
367  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
368  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
369  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
370  penalty_score += 1000;
371  return 4;
372  } else if (sig_type == ExtArgumentType::GeoPolygon) {
373  penalty_score += 1000;
374  return 1;
375  }
376  break;
377  case kMULTIPOLYGON:
378  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 5 < max_pos &&
379  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
380  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
381  sig_types[sig_pos + 3] == ExtArgumentType::Int64 &&
382  sig_types[sig_pos + 4] == ExtArgumentType::PInt32 &&
383  sig_types[sig_pos + 5] == ExtArgumentType::Int64) {
384  penalty_score += 1000;
385  return 6;
386  } else if (sig_type == ExtArgumentType::GeoMultiPolygon) {
387  penalty_score += 1000;
388  return 1;
389  }
390  break;
391  case kNULLT: // NULL maps to a pointer and size argument
392  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
393  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
394  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble ||
395  sig_type == ExtArgumentType::PBool) &&
396  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
397  penalty_score += 1000;
398  return 2;
399  }
400  break;
401  case kCOLUMN:
402  if (is_ext_arg_type_column(sig_type)) {
403  // column arguments must match exactly
404  const auto sig_type_ti =
406  if (arg_type.get_elem_type().get_type() == kARRAY &&
407  sig_type_ti.get_type() == kARRAY) {
408  if (arg_type.get_elem_type().get_elem_type().get_type() ==
409  sig_type_ti.get_elem_type().get_type()) {
410  penalty_score += 1000;
411  return 1;
412  } else {
413  return -1;
414  }
415  } else if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
416  sig_type_ti.get_type() == kTINYINT) {
417  /* Boolean column has the same low-level structure as Int8 column. */
418  penalty_score += 1000;
419  return 1;
420  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
421  penalty_score += 1000;
422  return 1;
423  } else {
424  return -1;
425  }
426  }
427  break;
428  case kCOLUMN_LIST:
429  if (is_ext_arg_type_column_list(sig_type)) {
430  // column_list arguments must match exactly
431  const auto sig_type_ti =
433  if (arg_type.get_elem_type().get_type() == kARRAY &&
434  sig_type_ti.get_type() == kARRAY) {
435  if (arg_type.get_elem_type().get_elem_type().get_type() ==
436  sig_type_ti.get_elem_type().get_type()) {
437  penalty_score += 1000;
438  return 1;
439  } else {
440  return -1;
441  }
442  } else if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
443  sig_type_ti.get_type() == kTINYINT) {
444  /* Boolean column_list has the same low-level structure as Int8 column_list. */
445  penalty_score += 10000;
446  return 1;
447  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
448  penalty_score += 10000;
449  return 1;
450  } else {
451  return -1;
452  }
453  }
454  break;
455  case kVARCHAR:
456  if (sig_type != ExtArgumentType::TextEncodingNone) {
457  return -1;
458  }
459  switch (arg_type.get_compression()) {
460  case kENCODING_NONE:
461  penalty_score += 1000;
462  return 1;
463  case kENCODING_DICT:
464  return -1;
465  // Todo (todd): Evaluate when and where we can tranlate to dictionary-encoded
466  default:
467  UNREACHABLE();
468  }
469  case kTEXT:
470  switch (arg_type.get_compression()) {
471  case kENCODING_NONE:
472  if (sig_type == ExtArgumentType::TextEncodingNone) {
473  penalty_score += 1000;
474  return 1;
475  }
476  return -1;
477  case kENCODING_DICT:
478  if (sig_type == ExtArgumentType::TextEncodingDict) {
479  penalty_score += 1000;
480  return 1;
481  }
482  return -1;
483  default:
484  UNREACHABLE();
485  }
486  case kTIMESTAMP:
487  if (sig_type == ExtArgumentType::Timestamp) {
488  penalty_score += 1000;
489  return 1;
490  }
491  break;
492  case kINTERVAL_DAY_TIME:
493  if (sig_type == ExtArgumentType::DayTimeInterval) {
494  penalty_score += 1000;
495  return 1;
496  }
497  break;
498 
500  if (sig_type == ExtArgumentType::YearMonthTimeInterval) {
501  penalty_score += 1000;
502  return 1;
503  }
504  break;
505 
506  /* Not implemented types:
507  kCHAR
508  kTIME
509  kDATE
510  kGEOMETRY
511  kGEOGRAPHY
512  kEVAL_CONTEXT_TYPE
513  kVOID
514  kCURSOR
515  */
516  default:
517  throw std::runtime_error(std::string(__FILE__) + "#" + std::to_string(__LINE__) +
518  ": support for " + arg_type.get_type_name() +
519  "(type=" + std::to_string(arg_type.get_type()) + ")" +
520  +" not implemented: \n pos=" + std::to_string(sig_pos) +
521  " max_pos=" + std::to_string(max_pos) + "\n sig_types=(" +
522  ExtensionFunctionsWhitelist::toString(sig_types) + ")");
523  }
524  return -1;
525 }
526 
527 bool is_valid_identifier(std::string str) {
528  if (!str.size()) {
529  return false;
530  }
531 
532  if (!(std::isalpha(str[0]) || str[0] == '_')) {
533  return false;
534  }
535 
536  for (size_t i = 1; i < str.size(); i++) {
537  if (!(std::isalnum(str[i]) || str[i] == '_')) {
538  return false;
539  }
540  }
541 
542  return true;
543 }
544 
545 } // namespace
546 
547 template <typename T>
548 std::tuple<T, std::vector<SQLTypeInfo>> bind_function(
549  std::string name,
550  Analyzer::ExpressionPtrVector func_args, // function args from sql query
551  const std::vector<T>& ext_funcs, // list of functions registered
552  const std::string processor) {
553  /* worker function
554 
555  Template type T must implement the following methods:
556 
557  std::vector<ExtArgumentType> getInputArgs()
558  */
559  /*
560  Return extension function/table function that has the following
561  properties
562 
563  1. each argument type in `arg_types` matches with extension
564  function argument types.
565 
566  For scalar types, the matching means that the types are either
567  equal or the argument type is smaller than the corresponding
568  the extension function argument type. This ensures that no
569  information is lost when casting of argument values is
570  required.
571 
572  For array and geo types, the matching means that the argument
573  type matches exactly with a group of extension function
574  argument types. See `match_arguments`.
575 
576  2. has minimal penalty score among all implementations of the
577  extension function with given `name`, see `get_penalty_score`
578  for the definition of penalty score.
579 
580  It is assumed that function_oper and extension functions in
581  ext_funcs have the same name.
582  */
583  if (!is_valid_identifier(name)) {
584  throw NativeExecutionError(
585  "Cannot bind function with invalid UDF/UDTF function name: " + name);
586  }
587 
588  std::vector<SQLTypeInfo> type_infos_input;
589  std::vector<bool> args_are_constants;
590  for (auto atype : func_args) {
591  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
592  if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
593  SQLTypeInfo type_info = atype->get_type_info();
594  auto ti = generate_column_type(type_info);
595  if (ti.get_subtype() == kNULLT) {
596  throw std::runtime_error(std::string(__FILE__) + "#" +
597  std::to_string(__LINE__) +
598  ": column support for type info " +
599  type_info.to_string() + " is not implemented");
600  }
601  ti.setUsesFlatBuffer(type_info.supportsFlatBuffer());
602  type_infos_input.push_back(ti);
603  args_are_constants.push_back(type_info.get_type() != kTEXT);
604  continue;
605  }
606  }
607  type_infos_input.push_back(atype->get_type_info());
608  if (dynamic_cast<const Analyzer::Constant*>(atype.get())) {
609  args_are_constants.push_back(true);
610  } else {
611  args_are_constants.push_back(false);
612  }
613  }
614  CHECK_EQ(type_infos_input.size(), args_are_constants.size());
615 
616  if (type_infos_input.size() == 0 && ext_funcs.size() > 0) {
617  CHECK_EQ(ext_funcs.size(), static_cast<size_t>(1));
618  CHECK_EQ(ext_funcs[0].getInputArgs().size(), static_cast<size_t>(0));
619  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
620  CHECK(ext_funcs[0].hasNonUserSpecifiedOutputSize());
621  }
622  std::vector<SQLTypeInfo> empty_type_info_variant(0);
623  return {ext_funcs[0], empty_type_info_variant};
624  }
625 
626  int minimal_score = std::numeric_limits<int>::max();
627  int index = -1;
628  int optimal = -1;
629  int optimal_variant = -1;
630  std::vector<std::vector<SQLTypeInfo>> type_infos_variants;
631 
632  // clang-format off
633  /*
634  Table functions may have arguments such as ColumnList that collect
635  neighboring columns with the same data type into a single object. In
636  general, the binding of UDTFs with ColumnLists might be ambiguous depending
637  on the order of the arguments:
638 
639  foo(ColumnList<T>, ColumnList<T>) -> Column<T>, T=[int]
640  bar(ColumnList<T>, Column<T>, ColumnList<T>) -> Column<T>, T=[int]
641 
642  Here both declarations above are ambiguous as the first ColumnList can
643  consume as many columns as possible, leaving a single column for each one
644  one of the remaining types. Or it can consume one argument, leaving the bulk
645  to the last ColumnList. Nevertheless, not all ColumnList declarations result
646  in an ambiguity signature. The example below shows an example of a function
647  that takes two ColumnLists of different types which has an exact match.
648 
649  baz(ColumnList<P>, ColumnList<T>) -> Column<T>, T=[int], Z=[float]
650 
651  To match a list of SQL arguments with an extension function, HeavyDB uses a
652  greedy algorithm that resolves the issue binding ambiguity as explained
653  below. As an example, let us consider a SQL query containing the following
654  expression calling a UDTF `bar` defined above:
655 
656  table(bar(select a, b, c, d, e from tableofints), 1)
657 
658  The algorithm will generate the following type variant, where the integer
659  value in [..] indicate the number of collected columns. This number is later
660  stored in the SQLTypeInfo dimension attribute.
661 
662  bar(ColumnList<T>[3], Column<T>, ColumnList<T>[1])
663 
664  */
665 
666  // clang-format on
667 
668  // Find extension function that gives the best match on the set of
669  // argument type variants
670  for (const auto& ext_func : ext_funcs) {
671  index++;
672 
673  const auto& ext_func_args = ext_func.getInputArgs();
674 
675  int penalty_score = 0;
676  int pos = 0;
677  int original_input_idx = 0;
678  type_infos_variants.emplace_back();
679 
680  for (size_t i = 0; i < type_infos_input.size(); i++) {
681  const SQLTypeInfo& ti = type_infos_input[i];
682 
683  if ((size_t)pos >= ext_func_args.size()) {
684  pos = -1;
685  break;
686  } else if (is_ext_arg_type_column_list(ext_func_args[pos])) {
687  SQLTypeInfo ti_col_list = generate_column_list_type(ti);
688  int offset = match_arguments(ti_col_list,
689  args_are_constants[original_input_idx],
690  pos,
691  ext_func_args,
692  penalty_score);
693  if (offset < 0) {
694  pos = -1;
695  break;
696  }
697 
698  // if offset > 0, greedly iterate over the rest of input args
699  // to consume columns with the same type as "ti"
700  int j = i;
701  size_t args_left = ext_func_args.size() - pos - 1;
702  while ((type_infos_input.size() - j > args_left) and
703  (ti_col_list.has_same_itemtype(type_infos_input[j]))) {
704  j++;
705  }
706  // push_back a ColumnList with dimension equals to the number of columns
707  // consumed above
708  ti_col_list.set_dimension(j - i);
709  type_infos_variants.back().push_back(ti_col_list);
710  // Move the "i" pointer to the last argument consumed
711  i = j - 1;
712  original_input_idx = j;
713  pos += offset;
714  } else {
715  int offset = match_arguments(ti,
716  args_are_constants[original_input_idx],
717  pos,
718  ext_func_args,
719  penalty_score);
720 
721  if (offset > 0) {
722  type_infos_variants.back().push_back(ti);
723  original_input_idx += 1;
724  pos += offset;
725  } else {
726  pos = -1;
727  break;
728  }
729  }
730  }
731 
732  if ((size_t)pos == ext_func_args.size()) {
733  CHECK_EQ(args_are_constants.size(), original_input_idx);
734  // prefer smaller return types
735  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
736  if (penalty_score < minimal_score) {
737  optimal = index;
738  minimal_score = penalty_score;
739  optimal_variant = type_infos_variants.size() - 1;
740  }
741  }
742  }
743 
744  if (optimal == -1) {
745  /* no extension function found that argument types would match
746  with types in `arg_types` */
747  auto sarg_types = ExtensionFunctionsWhitelist::toString(type_infos_input);
748  std::string message;
749  if (!ext_funcs.size()) {
750  message = "Function " + name + "(" + sarg_types + ") not supported.";
751  throw ExtensionFunctionBindingError(message);
752  } else {
753  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
754  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
755  " UDTF implementation.";
756  } else if constexpr (std::is_same_v<T, ExtensionFunction>) {
757  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
758  " UDF implementation.";
759  } else {
760  LOG(FATAL) << "bind_function: unknown extension function type "
761  << typeid(T).name();
762  }
763  message += "\n Existing extension function implementations:";
764  for (const auto& ext_func : ext_funcs) {
765  // Do not show functions missing the sizer argument
766  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
767  if (ext_func.useDefaultSizer()) {
768  continue;
769  }
770  }
771  message += "\n " + ext_func.toStringSQL();
772  }
773  }
774  throw ExtensionFunctionBindingError(message);
775  }
776 
777  // Functions with "_default_" suffix only exist for calcite
778  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
779  if (ext_funcs[optimal].hasUserSpecifiedOutputSizeMultiplier() &&
780  ext_funcs[optimal].useDefaultSizer()) {
781  std::string name = ext_funcs[optimal].getName();
782  name.erase(name.find(DEFAULT_ROW_MULTIPLIER_SUFFIX),
784  for (size_t i = 0; i < ext_funcs.size(); i++) {
785  if (ext_funcs[i].getName() == name) {
786  optimal = i;
787  std::vector<SQLTypeInfo> type_info = type_infos_variants[optimal_variant];
788  size_t sizer = ext_funcs[optimal].getOutputRowSizeParameter();
789  type_info.insert(type_info.begin() + sizer - 1, SQLTypeInfo(kINT, true));
790  return {ext_funcs[optimal], type_info};
791  }
792  }
793  UNREACHABLE();
794  }
795  }
796 
797  return {ext_funcs[optimal], type_infos_variants[optimal_variant]};
798 }
799 
800 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
803  const std::vector<table_functions::TableFunction>& table_funcs,
804  const bool is_gpu) {
805  std::string processor = (is_gpu ? "GPU" : "CPU");
806  return bind_function<table_functions::TableFunction>(
807  name, input_args, table_funcs, processor);
808 }
809 
811  Analyzer::ExpressionPtrVector func_args) {
812  // used in RelAlgTranslator.cpp, first try GPU UDFs, then fall back
813  // to CPU UDFs.
814  bool is_gpu = true;
815  std::string processor = "GPU";
816  auto ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
817  if (!ext_funcs.size()) {
818  is_gpu = false;
819  processor = "CPU";
820  ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
821  }
822  try {
823  return std::get<0>(
824  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
825  } catch (ExtensionFunctionBindingError& e) {
826  if (is_gpu) {
827  is_gpu = false;
828  processor = "GPU|CPU";
829  ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
830  return std::get<0>(
831  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
832  } else {
833  throw;
834  }
835  }
836 }
837 
840  const bool is_gpu) {
841  // used below
842  std::vector<ExtensionFunction> ext_funcs =
844  std::string processor = (is_gpu ? "GPU" : "CPU");
845  return std::get<0>(
846  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
847 }
848 
850  const bool is_gpu) {
851  // used in ExtensionsIR.cpp
852  auto name = function_oper->getName();
853  Analyzer::ExpressionPtrVector func_args = {};
854  for (size_t i = 0; i < function_oper->getArity(); ++i) {
855  func_args.push_back(function_oper->getOwnArg(i));
856  }
857  return bind_function(name, func_args, is_gpu);
858 }
859 
860 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
863  const bool is_gpu) {
864  // used in RelAlgExecutor.cpp
865  std::vector<table_functions::TableFunction> table_funcs =
867  return bind_table_function(name, input_args, table_funcs, is_gpu);
868 }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
ExtArgumentType get_array_arg_elem_type(const ExtArgumentType ext_arg_array_type)
static std::vector< TableFunction > get_table_funcs()
size_t getArity() const
Definition: Analyzer.h:2615
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
bool is_ext_arg_type_column(const ExtArgumentType ext_arg_type)
#define LOG(tag)
Definition: Logger.h:285
static int match_arguments(const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
#define UNREACHABLE()
Definition: Logger.h:338
#define CHECK_GE(x, y)
Definition: Logger.h:306
ExtArgumentType get_column_list_arg_elem_type(const ExtArgumentType ext_arg_column_list_type)
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:391
#define CHECK_GT(x, y)
Definition: Logger.h:305
std::string to_string(char const *&&v)
std::string to_string() const
Definition: sqltypes.h:526
bool is_ext_arg_type_column_list(const ExtArgumentType ext_arg_type)
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:2622
bool has_same_itemtype(const SQLTypeInfo &other) const
Definition: sqltypes.h:662
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
bool supportsFlatBuffer() const
Definition: sqltypes.h:1086
bool is_numeric_scalar_auto_castable(const SQLTypeInfo &new_type_info) const
returns true if the sql_type can be cast to the type specified by new_type_info with no loss of preci...
Definition: sqltypes.h:761
ExtArgumentType get_column_arg_elem_type(const ExtArgumentType ext_arg_column_type)
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
Argument type based extension function binding.
static int match_numeric_argument(const SQLTypeInfo &arg_type_info, const bool is_arg_literal, const ExtArgumentType &sig_ext_arg_type, int32_t &penalty_score)
auto generate_column_type(const SQLTypeInfo &elem_ti)
Definition: sqltypes.h:1627
Definition: sqltypes.h:79
#define CHECK_LE(x, y)
Definition: Logger.h:304
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:399
auto generate_column_list_type(const SQLTypeInfo &elem_ti)
Definition: sqltypes.h:1699
void set_dimension(int d)
Definition: sqltypes.h:470
std::string get_type_name() const
Definition: sqltypes.h:482
int32_t get_numeric_scalar_scale() const
returns integer between 1 and 8 indicating what is roughly equivalent to the logical byte size of a s...
Definition: sqltypes.h:822
const std::tuple< table_functions::TableFunction, std::vector< SQLTypeInfo > > bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:291
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:186
Definition: sqltypes.h:72
std::string getName() const
Definition: Analyzer.h:2613
string name
Definition: setup.in.py:72
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:975
bool is_array() const
Definition: sqltypes.h:583
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)