OmniSciDB  a987f07e93
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsBinding.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 #include <algorithm>
19 #include "ExternalExecutor.h"
20 
21 // A rather crude function binding logic based on the types of the arguments.
22 // We want it to be possible to write specialized versions of functions to be
23 // exposed as SQL extensions. This is important especially for performance
24 // reasons, since double operations can be significantly slower than float. We
25 // compute a score for each candidate signature based on conversions required to
26 // from the function arguments as specified in the SQL query to the versions in
27 // ExtensionFunctions.hpp.
28 
29 /*
30  New implementation for binding a SQL function operator to the
31  optimal candidate within in all available extension functions.
32  */
33 
34 namespace {
35 
37  switch (ext_arg_column_type) {
39  return ExtArgumentType::Int8;
51  return ExtArgumentType::Bool;
72  default:
73  UNREACHABLE();
74  }
75  return ExtArgumentType{};
76 }
77 
79  const ExtArgumentType ext_arg_column_list_type) {
80  switch (ext_arg_column_list_type) {
82  return ExtArgumentType::Int8;
94  return ExtArgumentType::Bool;
113  default:
114  UNREACHABLE();
115  }
116  return ExtArgumentType{};
117 }
118 
120  switch (ext_arg_array_type) {
122  return ExtArgumentType::Int8;
124  return ExtArgumentType::Int16;
126  return ExtArgumentType::Int32;
128  return ExtArgumentType::Int64;
130  return ExtArgumentType::Float;
134  return ExtArgumentType::Bool;
137  default:
138  UNREACHABLE();
139  }
140  return ExtArgumentType{};
141 }
142 
143 static int match_numeric_argument(const SQLTypeInfo& arg_type_info,
144  const bool is_arg_literal,
145  const ExtArgumentType& sig_ext_arg_type,
146  int32_t& penalty_score) {
147  const auto arg_type = arg_type_info.get_type();
148  CHECK(arg_type == kBOOLEAN || arg_type == kTINYINT || arg_type == kSMALLINT ||
149  arg_type == kINT || arg_type == kBIGINT || arg_type == kFLOAT ||
150  arg_type == kDOUBLE || arg_type == kDECIMAL || arg_type == kNUMERIC);
151  // Todo (todd): Add support for timestamp, date, and time types
152  const auto sig_type_info = ext_arg_type_to_type_info(sig_ext_arg_type);
153  const auto sig_type = sig_type_info.get_type();
154 
155  // If we can't legally auto-cast to sig_type, abort
156  if (!arg_type_info.is_numeric_scalar_auto_castable(sig_type_info)) {
157  return -1;
158  }
159 
160  // We now compare a measure of the scale of the sig_type with the
161  // arg_type, which provides a basis for scoring the match between
162  // the two. Note that get_numeric_scalar_scale for the most part
163  // returns the logical byte width of the type, with a few caveats
164  // for decimals and timestamps described in more depth in comments
165  // in the function itself. Also even though for example float and
166  // int types return 4 (as in 4 bytes), and double and bigint types
167  // return 8, a fp32 type cannot express every 32-bit integer (even
168  // if it can cover a larger absolute range), and an fp64 type
169  // likewise cannot express every 64-bit integer. With the aim to
170  // minimize the precision loss from casting (always precise) integer
171  // value to (imprecise) floating point value, in the case of integer
172  // inputs, we'll penalize wider floating point argument types least
173  // by a specific scale transformation (see the implementation
174  // below). For instance, casting tinyint to fp64 is prefered over
175  // casting it to fp32 to minimize precision loss.
176  const bool is_integer_to_fp_cast = (arg_type == kTINYINT || arg_type == kSMALLINT ||
177  arg_type == kINT || arg_type == kBIGINT) &&
178  (sig_type == kFLOAT || sig_type == kDOUBLE);
179 
180  const auto arg_type_relative_scale = arg_type_info.get_numeric_scalar_scale();
181  CHECK_GE(arg_type_relative_scale, 1);
182  CHECK_LE(arg_type_relative_scale, 8);
183  auto sig_type_relative_scale = sig_type_info.get_numeric_scalar_scale();
184  CHECK_GE(sig_type_relative_scale, 1);
185  CHECK_LE(sig_type_relative_scale, 8);
186 
187  if (is_integer_to_fp_cast) {
188  // transform fp scale: 4 becomes 16, 8 remains 8
189  sig_type_relative_scale = (3 - (sig_type_relative_scale >> 2)) << 3;
190  }
191 
192  // We do not allow auto-casting to types with less scale/precision
193  // within the same type family.
194  CHECK_GE(sig_type_relative_scale, arg_type_relative_scale);
195 
196  // Calculate the ratio of the sig_type by the arg_type, per the above check will be >= 1
197  const auto sig_type_scale_gain_ratio =
198  sig_type_relative_scale / arg_type_relative_scale;
199  CHECK_GE(sig_type_scale_gain_ratio, 1);
200 
201  // Following the old bespoke scoring logic this function replaces, we heavily penalize
202  // any casts that move ints to floats/doubles for the precision-loss reasons above
203  // Arguably all integers in the tinyint and smallint can be fully specified with both
204  // float and double types, but we treat them the same as int and bigint types here.
205  const int32_t type_family_cast_penalty_score = is_integer_to_fp_cast ? 1001000 : 1000;
206 
207  int32_t scale_cast_penalty_score;
208 
209  // The following logic is new. Basically there are strong reasons to
210  // prefer the promotion of constant literals to the most precise type possible, as
211  // rather than the type being inherent in the data - that is a column or columns where
212  // a user specified a type (and with any expressions on those columns following our
213  // standard sql casting logic), literal types are given to us by Calcite and do not
214  // necessarily convey any semantic intent (i.e. 10 will be an int, but 10.0 a decimal)
215  // Hence it is better to promote these types to the most precise sig_type available,
216  // while at the same time keeping column expressions as close as possible to the input
217  // types (mainly for performance, we have many float versions of various functions
218  // to allow for greater performance when the underlying data is not of double precision,
219  // and hence there is little benefit of the extra cost of computing double precision
220  // operators on this data)
221  if (is_arg_literal) {
222  scale_cast_penalty_score =
223  (8000 / arg_type_relative_scale) - (1000 * sig_type_scale_gain_ratio);
224  } else {
225  scale_cast_penalty_score = (1000 * sig_type_scale_gain_ratio);
226  }
227 
228  const auto cast_penalty_score =
229  type_family_cast_penalty_score + scale_cast_penalty_score;
230  CHECK_GT(cast_penalty_score, 0);
231  penalty_score += cast_penalty_score;
232  return 1;
233 }
234 
235 static int match_arguments(const SQLTypeInfo& arg_type,
236  const bool is_arg_literal,
237  int sig_pos,
238  const std::vector<ExtArgumentType>& sig_types,
239  int& penalty_score) {
240  /*
241  Returns non-negative integer `offset` if `arg_type` and
242  `sig_types[sig_pos:sig_pos + offset]` match.
243 
244  The `offset` value can be interpreted as the number of extension
245  function arguments that is consumed by the given `arg_type`. For
246  instance, for scalar types the offset is always 1, for array
247  types the offset is 2: one argument for array pointer value and
248  one argument for the array size value, etc.
249 
250  Returns -1 when the types of an argument and the corresponding
251  extension function argument(s) mismatch, or when downcasting would
252  be effective.
253 
254  In case of non-negative `offset` result, the function updates
255  penalty_score argument as follows:
256 
257  add 1000 if arg_type is non-scalar, otherwise:
258  add 1000 * sizeof(sig_type) / sizeof(arg_type)
259  add 1000000 if type kinds differ (integer vs double, for instance)
260 
261  */
262  int max_pos = sig_types.size() - 1;
263  if (sig_pos > max_pos) {
264  return -1;
265  }
266  auto sig_type = sig_types[sig_pos];
267  switch (arg_type.get_type()) {
268  case kBOOLEAN:
269  case kTINYINT:
270  case kSMALLINT:
271  case kINT:
272  case kBIGINT:
273  case kFLOAT:
274  case kDOUBLE:
275  case kDECIMAL:
276  case kNUMERIC:
277  return match_numeric_argument(arg_type, is_arg_literal, sig_type, penalty_score);
278  case kPOINT:
279  case kMULTIPOINT:
280  case kLINESTRING:
281  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
282  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
283  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble) &&
284  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
285  penalty_score += 1000;
286  return 2;
287  } else if (sig_type == ExtArgumentType::GeoPoint ||
288  sig_type == ExtArgumentType::GeoMultiPoint ||
289  sig_type == ExtArgumentType::GeoLineString) {
290  penalty_score += 1000;
291  return 1;
292  }
293  return -1;
294  case kMULTILINESTRING:
295  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
296  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
297  sig_types[sig_pos + 2] == ExtArgumentType::PInt8 &&
298  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
299  penalty_score += 1000;
300  return 4;
301  } else if (sig_type == ExtArgumentType::GeoMultiLineString) {
302  penalty_score += 1000;
303  return 1;
304  }
305  break;
306  case kARRAY:
307  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
308  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
309  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble ||
310  sig_type == ExtArgumentType::PBool) &&
311  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
312  penalty_score += 1000;
313  return 2;
314  } else if (is_ext_arg_type_array(sig_type)) {
315  // array arguments must match exactly
316  CHECK(arg_type.is_array());
317  const auto sig_type_ti =
319  if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
320  sig_type_ti.get_type() == kTINYINT) {
321  /* Boolean array has the same low-level structure as Int8 array. */
322  penalty_score += 1000;
323  return 1;
324  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
325  penalty_score += 1000;
326  return 1;
327  } else {
328  return -1;
329  }
330  }
331  break;
332  case kPOLYGON:
333  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
334  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
335  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
336  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
337  penalty_score += 1000;
338  return 4;
339  } else if (sig_type == ExtArgumentType::GeoPolygon) {
340  penalty_score += 1000;
341  return 1;
342  }
343  break;
344  case kMULTIPOLYGON:
345  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 5 < max_pos &&
346  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
347  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
348  sig_types[sig_pos + 3] == ExtArgumentType::Int64 &&
349  sig_types[sig_pos + 4] == ExtArgumentType::PInt32 &&
350  sig_types[sig_pos + 5] == ExtArgumentType::Int64) {
351  penalty_score += 1000;
352  return 6;
353  } else if (sig_type == ExtArgumentType::GeoMultiPolygon) {
354  penalty_score += 1000;
355  return 1;
356  }
357  break;
358  case kNULLT: // NULL maps to a pointer and size argument
359  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
360  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
361  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble ||
362  sig_type == ExtArgumentType::PBool) &&
363  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
364  penalty_score += 1000;
365  return 2;
366  }
367  break;
368  case kCOLUMN:
369  if (is_ext_arg_type_column(sig_type)) {
370  // column arguments must match exactly
371  const auto sig_type_ti =
373  if (arg_type.get_elem_type().get_type() == kARRAY &&
374  sig_type_ti.get_type() == kARRAY) {
375  if (arg_type.get_elem_type().get_elem_type().get_type() ==
376  sig_type_ti.get_elem_type().get_type()) {
377  penalty_score += 1000;
378  return 1;
379  } else {
380  return -1;
381  }
382  } else if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
383  sig_type_ti.get_type() == kTINYINT) {
384  /* Boolean column has the same low-level structure as Int8 column. */
385  penalty_score += 1000;
386  return 1;
387  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
388  penalty_score += 1000;
389  return 1;
390  } else {
391  return -1;
392  }
393  }
394  break;
395  case kCOLUMN_LIST:
396  if (is_ext_arg_type_column_list(sig_type)) {
397  // column_list arguments must match exactly
398  const auto sig_type_ti =
400  if (arg_type.get_elem_type().get_type() == kARRAY &&
401  sig_type_ti.get_type() == kARRAY) {
402  if (arg_type.get_elem_type().get_elem_type().get_type() ==
403  sig_type_ti.get_elem_type().get_type()) {
404  penalty_score += 1000;
405  return 1;
406  } else {
407  return -1;
408  }
409  } else if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
410  sig_type_ti.get_type() == kTINYINT) {
411  /* Boolean column_list has the same low-level structure as Int8 column_list. */
412  penalty_score += 10000;
413  return 1;
414  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
415  penalty_score += 10000;
416  return 1;
417  } else {
418  return -1;
419  }
420  }
421  break;
422  case kVARCHAR:
423  if (sig_type != ExtArgumentType::TextEncodingNone) {
424  return -1;
425  }
426  switch (arg_type.get_compression()) {
427  case kENCODING_NONE:
428  penalty_score += 1000;
429  return 1;
430  case kENCODING_DICT:
431  return -1;
432  // Todo (todd): Evaluate when and where we can tranlate to dictionary-encoded
433  default:
434  UNREACHABLE();
435  }
436  case kTEXT:
437  switch (arg_type.get_compression()) {
438  case kENCODING_NONE:
439  if (sig_type == ExtArgumentType::TextEncodingNone) {
440  penalty_score += 1000;
441  return 1;
442  }
443  return -1;
444  case kENCODING_DICT:
445  if (sig_type == ExtArgumentType::TextEncodingDict) {
446  penalty_score += 1000;
447  return 1;
448  }
449  return -1;
450  default:
451  UNREACHABLE();
452  }
453  case kTIMESTAMP:
454  if (sig_type == ExtArgumentType::Timestamp) {
455  penalty_score += 1000;
456  return 1;
457  }
458  break;
459  case kINTERVAL_DAY_TIME:
460  if (sig_type == ExtArgumentType::DayTimeInterval) {
461  penalty_score += 1000;
462  return 1;
463  }
464  break;
465 
467  if (sig_type == ExtArgumentType::YearMonthTimeInterval) {
468  penalty_score += 1000;
469  return 1;
470  }
471  break;
472 
473  /* Not implemented types:
474  kCHAR
475  kTIME
476  kDATE
477  kGEOMETRY
478  kGEOGRAPHY
479  kEVAL_CONTEXT_TYPE
480  kVOID
481  kCURSOR
482  */
483  default:
484  throw std::runtime_error(std::string(__FILE__) + "#" + std::to_string(__LINE__) +
485  ": support for " + arg_type.get_type_name() +
486  "(type=" + std::to_string(arg_type.get_type()) + ")" +
487  +" not implemented: \n pos=" + std::to_string(sig_pos) +
488  " max_pos=" + std::to_string(max_pos) + "\n sig_types=(" +
489  ExtensionFunctionsWhitelist::toString(sig_types) + ")");
490  }
491  return -1;
492 }
493 
494 bool is_valid_identifier(std::string str) {
495  if (!str.size()) {
496  return false;
497  }
498 
499  if (!(std::isalpha(str[0]) || str[0] == '_')) {
500  return false;
501  }
502 
503  for (size_t i = 1; i < str.size(); i++) {
504  if (!(std::isalnum(str[i]) || str[i] == '_')) {
505  return false;
506  }
507  }
508 
509  return true;
510 }
511 
512 } // namespace
513 
514 template <typename T>
515 std::tuple<T, std::vector<SQLTypeInfo>> bind_function(
516  std::string name,
517  Analyzer::ExpressionPtrVector func_args, // function args from sql query
518  const std::vector<T>& ext_funcs, // list of functions registered
519  const std::string processor) {
520  /* worker function
521 
522  Template type T must implement the following methods:
523 
524  std::vector<ExtArgumentType> getInputArgs()
525  */
526  /*
527  Return extension function/table function that has the following
528  properties
529 
530  1. each argument type in `arg_types` matches with extension
531  function argument types.
532 
533  For scalar types, the matching means that the types are either
534  equal or the argument type is smaller than the corresponding
535  the extension function argument type. This ensures that no
536  information is lost when casting of argument values is
537  required.
538 
539  For array and geo types, the matching means that the argument
540  type matches exactly with a group of extension function
541  argument types. See `match_arguments`.
542 
543  2. has minimal penalty score among all implementations of the
544  extension function with given `name`, see `get_penalty_score`
545  for the definition of penalty score.
546 
547  It is assumed that function_oper and extension functions in
548  ext_funcs have the same name.
549  */
550  if (!is_valid_identifier(name)) {
551  throw NativeExecutionError(
552  "Cannot bind function with invalid UDF/UDTF function name: " + name);
553  }
554 
555  int minimal_score = std::numeric_limits<int>::max();
556  int index = -1;
557  int optimal = -1;
558  int optimal_variant = -1;
559 
560  std::vector<SQLTypeInfo> type_infos_input;
561  std::vector<bool> args_are_constants;
562  for (auto atype : func_args) {
563  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
564  if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
565  SQLTypeInfo type_info = atype->get_type_info();
566  auto ti = generate_column_type(type_info);
567  if (ti.get_subtype() == kNULLT) {
568  throw std::runtime_error(std::string(__FILE__) + "#" +
569  std::to_string(__LINE__) +
570  ": column support for type info " +
571  type_info.to_string() + " is not implemented");
572  }
573  type_infos_input.push_back(ti);
574  args_are_constants.push_back(type_info.get_type() != kTEXT);
575  continue;
576  }
577  }
578  type_infos_input.push_back(atype->get_type_info());
579  if (dynamic_cast<const Analyzer::Constant*>(atype.get())) {
580  args_are_constants.push_back(true);
581  } else {
582  args_are_constants.push_back(false);
583  }
584  }
585  CHECK_EQ(type_infos_input.size(), args_are_constants.size());
586 
587  if (type_infos_input.size() == 0 && ext_funcs.size() > 0) {
588  CHECK_EQ(ext_funcs.size(), static_cast<size_t>(1));
589  CHECK_EQ(ext_funcs[0].getInputArgs().size(), static_cast<size_t>(0));
590  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
591  CHECK(ext_funcs[0].hasNonUserSpecifiedOutputSize());
592  }
593  std::vector<SQLTypeInfo> empty_type_info_variant(0);
594  return {ext_funcs[0], empty_type_info_variant};
595  }
596 
597  // clang-format off
598  /*
599  Table functions may have arguments such as ColumnList that collect
600  neighboring columns with the same data type into a single object.
601  Here we compute all possible combinations of mapping a subset of
602  columns into columns sets. For example, if the types of function
603  arguments are (as given in func_args argument)
604 
605  (Column<int>, Column<int>, Column<int>, int)
606 
607  then the computed variants will be
608 
609  (Column<int>, Column<int>, Column<int>, int)
610  (Column<int>, Column<int>, ColumnList[1]<int>, int)
611  (Column<int>, ColumnList[1]<int>, Column<int>, int)
612  (Column<int>, ColumnList[2]<int>, int)
613  (ColumnList[1]<int>, Column<int>, Column<int>, int)
614  (ColumnList[1]<int>, Column<int>, ColumnList[1]<int>, int)
615  (ColumnList[2]<int>, Column<int>, int)
616  (ColumnList[3]<int>, int)
617 
618  where the integers in [..] indicate the number of collected
619  columns. In the SQLTypeInfo instance, this number is stored in the
620  SQLTypeInfo dimension attribute.
621 
622  As an example, let us consider a SQL query containing the
623  following expression calling a UDTF foo:
624 
625  table(foo(cursor(select a, b, c from tableofints), 1))
626 
627  Here follows a list of table functions and the corresponding
628  optimal argument type variants that are computed for the given
629  query expression:
630 
631  UDTF: foo(ColumnList<int>, RowMultiplier) -> Column<int>
632  (ColumnList[3]<int>, int) # a, b, c are all collected to column_list
633 
634  UDTF: foo(Column<int>, ColumnList<int>, RowMultiplier) -> Column<int>
635  (Column<int>, ColumnList[2]<int>, int) # b and c are collected to column_list
636 
637  UDTF: foo(Column<int>, Column<int>, Column<int>, RowMultiplier) -> Column<int>
638  (Column<int>, Column<int>, Column<int>, int)
639  */
640  // clang-format on
641  std::vector<std::vector<SQLTypeInfo>> type_infos_variants;
642  for (auto ti : type_infos_input) {
643  if (type_infos_variants.begin() == type_infos_variants.end()) {
644  type_infos_variants.push_back({ti});
645  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
646  if (ti.is_column()) {
647  auto mti = generate_column_list_type(ti);
648  if (mti.get_subtype() == kNULLT) {
649  continue; // skip unsupported element type.
650  }
651  mti.set_dimension(1);
652  type_infos_variants.push_back({mti});
653  }
654  }
655  continue;
656  }
657  std::vector<std::vector<SQLTypeInfo>> new_type_infos_variants;
658  for (auto& type_infos : type_infos_variants) {
659  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
660  if (ti.is_column()) {
661  auto new_type_infos = type_infos; // makes a copy
662  const auto& last = type_infos.back();
663  if (last.is_column_list() && last.has_same_itemtype(ti)) {
664  // last column_list consumes column argument if item types match
665  new_type_infos.back().set_dimension(last.get_dimension() + 1);
666  } else {
667  // add column as column_list argument
668  auto mti = generate_column_list_type(ti);
669  if (mti.get_subtype() == kNULLT) {
670  // skip unsupported element type
671  type_infos.push_back(ti);
672  continue;
673  }
674  mti.set_dimension(1);
675  new_type_infos.push_back(mti);
676  }
677  new_type_infos_variants.push_back(new_type_infos);
678  }
679  }
680  type_infos.push_back(ti);
681  }
682  type_infos_variants.insert(type_infos_variants.end(),
683  new_type_infos_variants.begin(),
684  new_type_infos_variants.end());
685  }
686 
687  // Find extension function that gives the best match on the set of
688  // argument type variants:
689  for (auto ext_func : ext_funcs) {
690  index++;
691 
692  auto ext_func_args = ext_func.getInputArgs();
693  int index_variant = -1;
694  for (const auto& type_infos : type_infos_variants) {
695  index_variant++;
696  int penalty_score = 0;
697  int pos = 0;
698  int original_input_idx = 0;
699  CHECK_LE(type_infos.size(), args_are_constants.size());
700  for (const auto& ti : type_infos) {
701  int offset = match_arguments(ti,
702  args_are_constants[original_input_idx],
703  pos,
704  ext_func_args,
705  penalty_score);
706  if (offset < 0) {
707  // atype does not match with ext_func argument
708  pos = -1;
709  break;
710  }
711  if (ti.get_type() == kCOLUMN_LIST) {
712  original_input_idx += ti.get_dimension();
713  } else {
714  original_input_idx++;
715  }
716  pos += offset;
717  }
718 
719  if ((size_t)pos == ext_func_args.size()) {
720  CHECK_EQ(args_are_constants.size(), original_input_idx);
721  // prefer smaller return types
722  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
723  if (penalty_score < minimal_score) {
724  optimal = index;
725  minimal_score = penalty_score;
726  optimal_variant = index_variant;
727  }
728  }
729  }
730  }
731 
732  if (optimal == -1) {
733  /* no extension function found that argument types would match
734  with types in `arg_types` */
735  auto sarg_types = ExtensionFunctionsWhitelist::toString(type_infos_input);
736  std::string message;
737  if (!ext_funcs.size()) {
738  message = "Function " + name + "(" + sarg_types + ") not supported.";
739  throw ExtensionFunctionBindingError(message);
740  } else {
741  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
742  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
743  " UDTF implementation.";
744  } else if constexpr (std::is_same_v<T, ExtensionFunction>) {
745  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
746  " UDF implementation.";
747  } else {
748  LOG(FATAL) << "bind_function: unknown extension function type "
749  << typeid(T).name();
750  }
751  message += "\n Existing extension function implementations:";
752  for (const auto& ext_func : ext_funcs) {
753  // Do not show functions missing the sizer argument
754  if constexpr (std::is_same_v<T, table_functions::TableFunction>)
755  if (ext_func.useDefaultSizer())
756  continue;
757  message += "\n " + ext_func.toStringSQL();
758  }
759  }
760  throw ExtensionFunctionBindingError(message);
761  }
762 
763  // Functions with "_default_" suffix only exist for calcite
764  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
765  if (ext_funcs[optimal].hasUserSpecifiedOutputSizeMultiplier() &&
766  ext_funcs[optimal].useDefaultSizer()) {
767  std::string name = ext_funcs[optimal].getName();
768  name.erase(name.find(DEFAULT_ROW_MULTIPLIER_SUFFIX),
770  for (size_t i = 0; i < ext_funcs.size(); i++) {
771  if (ext_funcs[i].getName() == name) {
772  optimal = i;
773  std::vector<SQLTypeInfo> type_info = type_infos_variants[optimal_variant];
774  size_t sizer = ext_funcs[optimal].getOutputRowSizeParameter();
775  type_info.insert(type_info.begin() + sizer - 1, SQLTypeInfo(kINT, true));
776  return {ext_funcs[optimal], type_info};
777  }
778  }
779  UNREACHABLE();
780  }
781  }
782 
783  return {ext_funcs[optimal], type_infos_variants[optimal_variant]};
784 }
785 
786 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
789  const std::vector<table_functions::TableFunction>& table_funcs,
790  const bool is_gpu) {
791  std::string processor = (is_gpu ? "GPU" : "CPU");
792  return bind_function<table_functions::TableFunction>(
793  name, input_args, table_funcs, processor);
794 }
795 
797  Analyzer::ExpressionPtrVector func_args) {
798  // used in RelAlgTranslator.cpp, first try GPU UDFs, then fall back
799  // to CPU UDFs.
800  bool is_gpu = true;
801  std::string processor = "GPU";
802  auto ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
803  if (!ext_funcs.size()) {
804  is_gpu = false;
805  processor = "CPU";
806  ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
807  }
808  try {
809  return std::get<0>(
810  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
811  } catch (ExtensionFunctionBindingError& e) {
812  if (is_gpu) {
813  is_gpu = false;
814  processor = "GPU|CPU";
815  ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
816  return std::get<0>(
817  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
818  } else {
819  throw;
820  }
821  }
822 }
823 
826  const bool is_gpu) {
827  // used below
828  std::vector<ExtensionFunction> ext_funcs =
830  std::string processor = (is_gpu ? "GPU" : "CPU");
831  return std::get<0>(
832  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
833 }
834 
836  const bool is_gpu) {
837  // used in ExtensionsIR.cpp
838  auto name = function_oper->getName();
839  Analyzer::ExpressionPtrVector func_args = {};
840  for (size_t i = 0; i < function_oper->getArity(); ++i) {
841  func_args.push_back(function_oper->getOwnArg(i));
842  }
843  return bind_function(name, func_args, is_gpu);
844 }
845 
846 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
849  const bool is_gpu) {
850  // used in RelAlgExecutor.cpp
851  std::vector<table_functions::TableFunction> table_funcs =
853  return bind_table_function(name, input_args, table_funcs, is_gpu);
854 }
#define CHECK_EQ(x, y)
Definition: Logger.h:297
ExtArgumentType get_array_arg_elem_type(const ExtArgumentType ext_arg_array_type)
static std::vector< TableFunction > get_table_funcs()
size_t getArity() const
Definition: Analyzer.h:2404
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
bool is_ext_arg_type_column(const ExtArgumentType ext_arg_type)
#define LOG(tag)
Definition: Logger.h:283
static int match_arguments(const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
#define UNREACHABLE()
Definition: Logger.h:333
#define CHECK_GE(x, y)
Definition: Logger.h:302
ExtArgumentType get_column_list_arg_elem_type(const ExtArgumentType ext_arg_column_list_type)
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:380
#define CHECK_GT(x, y)
Definition: Logger.h:301
std::string to_string(char const *&&v)
std::string to_string() const
Definition: sqltypes.h:544
bool is_ext_arg_type_column_list(const ExtArgumentType ext_arg_type)
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:2411
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
bool is_numeric_scalar_auto_castable(const SQLTypeInfo &new_type_info) const
returns true if the sql_type can be cast to the type specified by new_type_info with no loss of preci...
Definition: sqltypes.h:744
ExtArgumentType get_column_arg_elem_type(const ExtArgumentType ext_arg_column_type)
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
Argument type based extension function binding.
static int match_numeric_argument(const SQLTypeInfo &arg_type_info, const bool is_arg_literal, const ExtArgumentType &sig_ext_arg_type, int32_t &penalty_score)
auto generate_column_type(const SQLTypeInfo &elem_ti)
Definition: sqltypes.h:1361
Definition: sqltypes.h:67
#define CHECK_LE(x, y)
Definition: Logger.h:300
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:388
auto generate_column_list_type(const SQLTypeInfo &elem_ti)
Definition: sqltypes.h:1409
std::string get_type_name() const
Definition: sqltypes.h:504
int32_t get_numeric_scalar_scale() const
returns integer between 1 and 8 indicating what is roughly equivalent to the logical byte size of a s...
Definition: sqltypes.h:805
const std::tuple< table_functions::TableFunction, std::vector< SQLTypeInfo > > bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:289
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:190
Definition: sqltypes.h:60
std::string getName() const
Definition: Analyzer.h:2402
string name
Definition: setup.in.py:72
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:957
bool is_array() const
Definition: sqltypes.h:584
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)