OmniSciDB  c1a53651b2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsBinding.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 #include <algorithm>
19 #include "ExternalExecutor.h"
20 
21 // A rather crude function binding logic based on the types of the arguments.
22 // We want it to be possible to write specialized versions of functions to be
23 // exposed as SQL extensions. This is important especially for performance
24 // reasons, since double operations can be significantly slower than float. We
25 // compute a score for each candidate signature based on conversions required to
26 // from the function arguments as specified in the SQL query to the versions in
27 // ExtensionFunctions.hpp.
28 
29 /*
30  New implementation for binding a SQL function operator to the
31  optimal candidate within in all available extension functions.
32  */
33 
34 namespace {
35 
37  switch (ext_arg_column_type) {
39  return ExtArgumentType::Int8;
51  return ExtArgumentType::Bool;
72  default:
73  UNREACHABLE();
74  }
75  return ExtArgumentType{};
76 }
77 
79  const ExtArgumentType ext_arg_column_list_type) {
80  switch (ext_arg_column_list_type) {
82  return ExtArgumentType::Int8;
94  return ExtArgumentType::Bool;
113  default:
114  UNREACHABLE();
115  }
116  return ExtArgumentType{};
117 }
118 
120  switch (ext_arg_array_type) {
122  return ExtArgumentType::Int8;
124  return ExtArgumentType::Int16;
126  return ExtArgumentType::Int32;
128  return ExtArgumentType::Int64;
130  return ExtArgumentType::Float;
134  return ExtArgumentType::Bool;
137  default:
138  UNREACHABLE();
139  }
140  return ExtArgumentType{};
141 }
142 
143 static int match_numeric_argument(const SQLTypeInfo& arg_type_info,
144  const bool is_arg_literal,
145  const ExtArgumentType& sig_ext_arg_type,
146  int32_t& penalty_score) {
147  const auto arg_type = arg_type_info.get_type();
148  CHECK(arg_type == kBOOLEAN || arg_type == kTINYINT || arg_type == kSMALLINT ||
149  arg_type == kINT || arg_type == kBIGINT || arg_type == kFLOAT ||
150  arg_type == kDOUBLE || arg_type == kDECIMAL || arg_type == kNUMERIC);
151  // Todo (todd): Add support for timestamp, date, and time types
152  const auto sig_type_info = ext_arg_type_to_type_info(sig_ext_arg_type);
153  const auto sig_type = sig_type_info.get_type();
154 
155  // If we can't legally auto-cast to sig_type, abort
156  if (!arg_type_info.is_numeric_scalar_auto_castable(sig_type_info)) {
157  return -1;
158  }
159 
160  // We now compare a measure of the scale of the sig_type with the
161  // arg_type, which provides a basis for scoring the match between
162  // the two. Note that get_numeric_scalar_scale for the most part
163  // returns the logical byte width of the type, with a few caveats
164  // for decimals and timestamps described in more depth in comments
165  // in the function itself. Also even though for example float and
166  // int types return 4 (as in 4 bytes), and double and bigint types
167  // return 8, a fp32 type cannot express every 32-bit integer (even
168  // if it can cover a larger absolute range), and an fp64 type
169  // likewise cannot express every 64-bit integer. With the aim to
170  // minimize the precision loss from casting (always precise) integer
171  // value to (imprecise) floating point value, in the case of integer
172  // inputs, we'll penalize wider floating point argument types least
173  // by a specific scale transformation (see the implementation
174  // below). For instance, casting tinyint to fp64 is prefered over
175  // casting it to fp32 to minimize precision loss.
176  const bool is_integer_to_fp_cast = (arg_type == kTINYINT || arg_type == kSMALLINT ||
177  arg_type == kINT || arg_type == kBIGINT) &&
178  (sig_type == kFLOAT || sig_type == kDOUBLE);
179 
180  const auto arg_type_relative_scale = arg_type_info.get_numeric_scalar_scale();
181  CHECK_GE(arg_type_relative_scale, 1);
182  CHECK_LE(arg_type_relative_scale, 8);
183  auto sig_type_relative_scale = sig_type_info.get_numeric_scalar_scale();
184  CHECK_GE(sig_type_relative_scale, 1);
185  CHECK_LE(sig_type_relative_scale, 8);
186 
187  if (is_integer_to_fp_cast) {
188  // transform fp scale: 4 becomes 16, 8 remains 8
189  sig_type_relative_scale = (3 - (sig_type_relative_scale >> 2)) << 3;
190  }
191 
192  // We do not allow auto-casting to types with less scale/precision
193  // within the same type family.
194  CHECK_GE(sig_type_relative_scale, arg_type_relative_scale);
195 
196  // Calculate the ratio of the sig_type by the arg_type, per the above check will be >= 1
197  const auto sig_type_scale_gain_ratio =
198  sig_type_relative_scale / arg_type_relative_scale;
199  CHECK_GE(sig_type_scale_gain_ratio, 1);
200 
201  // Following the old bespoke scoring logic this function replaces, we heavily penalize
202  // any casts that move ints to floats/doubles for the precision-loss reasons above
203  // Arguably all integers in the tinyint and smallint can be fully specified with both
204  // float and double types, but we treat them the same as int and bigint types here.
205  const int32_t type_family_cast_penalty_score = is_integer_to_fp_cast ? 1001000 : 1000;
206 
207  int32_t scale_cast_penalty_score;
208 
209  // The following logic is new. Basically there are strong reasons to
210  // prefer the promotion of constant literals to the most precise type possible, as
211  // rather than the type being inherent in the data - that is a column or columns where
212  // a user specified a type (and with any expressions on those columns following our
213  // standard sql casting logic), literal types are given to us by Calcite and do not
214  // necessarily convey any semantic intent (i.e. 10 will be an int, but 10.0 a decimal)
215  // Hence it is better to promote these types to the most precise sig_type available,
216  // while at the same time keeping column expressions as close as possible to the input
217  // types (mainly for performance, we have many float versions of various functions
218  // to allow for greater performance when the underlying data is not of double precision,
219  // and hence there is little benefit of the extra cost of computing double precision
220  // operators on this data)
221  if (is_arg_literal) {
222  scale_cast_penalty_score =
223  (8000 / arg_type_relative_scale) - (1000 * sig_type_scale_gain_ratio);
224  } else {
225  scale_cast_penalty_score = (1000 * sig_type_scale_gain_ratio);
226  }
227 
228  const auto cast_penalty_score =
229  type_family_cast_penalty_score + scale_cast_penalty_score;
230  CHECK_GT(cast_penalty_score, 0);
231  penalty_score += cast_penalty_score;
232  return 1;
233 }
234 
235 static int match_arguments(const SQLTypeInfo& arg_type,
236  const bool is_arg_literal,
237  int sig_pos,
238  const std::vector<ExtArgumentType>& sig_types,
239  int& penalty_score) {
240  /*
241  Returns non-negative integer `offset` if `arg_type` and
242  `sig_types[sig_pos:sig_pos + offset]` match.
243 
244  The `offset` value can be interpreted as the number of extension
245  function arguments that is consumed by the given `arg_type`. For
246  instance, for scalar types the offset is always 1, for array
247  types the offset is 2: one argument for array pointer value and
248  one argument for the array size value, etc.
249 
250  Returns -1 when the types of an argument and the corresponding
251  extension function argument(s) mismatch, or when downcasting would
252  be effective.
253 
254  In case of non-negative `offset` result, the function updates
255  penalty_score argument as follows:
256 
257  add 1000 if arg_type is non-scalar, otherwise:
258  add 1000 * sizeof(sig_type) / sizeof(arg_type)
259  add 1000000 if type kinds differ (integer vs double, for instance)
260 
261  */
262  int max_pos = sig_types.size() - 1;
263  if (sig_pos > max_pos) {
264  return -1;
265  }
266  auto sig_type = sig_types[sig_pos];
267  switch (arg_type.get_type()) {
268  case kBOOLEAN:
269  case kTINYINT:
270  case kSMALLINT:
271  case kINT:
272  case kBIGINT:
273  case kFLOAT:
274  case kDOUBLE:
275  case kDECIMAL:
276  case kNUMERIC:
277  return match_numeric_argument(arg_type, is_arg_literal, sig_type, penalty_score);
278  case kPOINT:
279  case kMULTIPOINT:
280  case kLINESTRING:
281  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
282  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
283  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble) &&
284  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
285  penalty_score += 1000;
286  return 2;
287  } else if (sig_type == ExtArgumentType::GeoPoint ||
288  sig_type == ExtArgumentType::GeoMultiPoint ||
289  sig_type == ExtArgumentType::GeoLineString) {
290  penalty_score += 1000;
291  return 1;
292  }
293  return -1;
294  case kMULTILINESTRING:
295  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
296  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
297  sig_types[sig_pos + 2] == ExtArgumentType::PInt8 &&
298  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
299  penalty_score += 1000;
300  return 4;
301  } else if (sig_type == ExtArgumentType::GeoMultiLineString) {
302  penalty_score += 1000;
303  return 1;
304  }
305  break;
306  case kARRAY:
307  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
308  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
309  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble ||
310  sig_type == ExtArgumentType::PBool) &&
311  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
312  penalty_score += 1000;
313  return 2;
314  } else if (is_ext_arg_type_array(sig_type)) {
315  // array arguments must match exactly
316  CHECK(arg_type.is_array());
317  const auto sig_type_ti =
319  if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
320  sig_type_ti.get_type() == kTINYINT) {
321  /* Boolean array has the same low-level structure as Int8 array. */
322  penalty_score += 1000;
323  return 1;
324  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
325  penalty_score += 1000;
326  return 1;
327  } else {
328  return -1;
329  }
330  }
331  break;
332  case kPOLYGON:
333  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
334  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
335  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
336  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
337  penalty_score += 1000;
338  return 4;
339  } else if (sig_type == ExtArgumentType::GeoPolygon) {
340  penalty_score += 1000;
341  return 1;
342  }
343  break;
344  case kMULTIPOLYGON:
345  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 5 < max_pos &&
346  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
347  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
348  sig_types[sig_pos + 3] == ExtArgumentType::Int64 &&
349  sig_types[sig_pos + 4] == ExtArgumentType::PInt32 &&
350  sig_types[sig_pos + 5] == ExtArgumentType::Int64) {
351  penalty_score += 1000;
352  return 6;
353  } else if (sig_type == ExtArgumentType::GeoMultiPolygon) {
354  penalty_score += 1000;
355  return 1;
356  }
357  break;
358  case kNULLT: // NULL maps to a pointer and size argument
359  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
360  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
361  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble ||
362  sig_type == ExtArgumentType::PBool) &&
363  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
364  penalty_score += 1000;
365  return 2;
366  }
367  break;
368  case kCOLUMN:
369  if (is_ext_arg_type_column(sig_type)) {
370  // column arguments must match exactly
371  const auto sig_type_ti =
373  if (arg_type.get_elem_type().get_type() == kARRAY &&
374  sig_type_ti.get_type() == kARRAY) {
375  if (arg_type.get_elem_type().get_elem_type().get_type() ==
376  sig_type_ti.get_elem_type().get_type()) {
377  penalty_score += 1000;
378  return 1;
379  } else {
380  return -1;
381  }
382  } else if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
383  sig_type_ti.get_type() == kTINYINT) {
384  /* Boolean column has the same low-level structure as Int8 column. */
385  penalty_score += 1000;
386  return 1;
387  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
388  penalty_score += 1000;
389  return 1;
390  } else {
391  return -1;
392  }
393  }
394  break;
395  case kCOLUMN_LIST:
396  if (is_ext_arg_type_column_list(sig_type)) {
397  // column_list arguments must match exactly
398  const auto sig_type_ti =
400  if (arg_type.get_elem_type().get_type() == kARRAY &&
401  sig_type_ti.get_type() == kARRAY) {
402  if (arg_type.get_elem_type().get_elem_type().get_type() ==
403  sig_type_ti.get_elem_type().get_type()) {
404  penalty_score += 1000;
405  return 1;
406  } else {
407  return -1;
408  }
409  } else if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
410  sig_type_ti.get_type() == kTINYINT) {
411  /* Boolean column_list has the same low-level structure as Int8 column_list. */
412  penalty_score += 10000;
413  return 1;
414  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
415  penalty_score += 10000;
416  return 1;
417  } else {
418  return -1;
419  }
420  }
421  break;
422  case kVARCHAR:
423  if (sig_type != ExtArgumentType::TextEncodingNone) {
424  return -1;
425  }
426  switch (arg_type.get_compression()) {
427  case kENCODING_NONE:
428  penalty_score += 1000;
429  return 1;
430  case kENCODING_DICT:
431  return -1;
432  // Todo (todd): Evaluate when and where we can tranlate to dictionary-encoded
433  default:
434  UNREACHABLE();
435  }
436  case kTEXT:
437  switch (arg_type.get_compression()) {
438  case kENCODING_NONE:
439  if (sig_type == ExtArgumentType::TextEncodingNone) {
440  penalty_score += 1000;
441  return 1;
442  }
443  return -1;
444  case kENCODING_DICT:
445  if (sig_type == ExtArgumentType::TextEncodingDict) {
446  penalty_score += 1000;
447  return 1;
448  }
449  return -1;
450  default:
451  UNREACHABLE();
452  }
453  case kTIMESTAMP:
454  if (sig_type == ExtArgumentType::Timestamp) {
455  penalty_score += 1000;
456  return 1;
457  }
458  break;
459  case kINTERVAL_DAY_TIME:
460  if (sig_type == ExtArgumentType::DayTimeInterval) {
461  penalty_score += 1000;
462  return 1;
463  }
464  break;
465 
467  if (sig_type == ExtArgumentType::YearMonthTimeInterval) {
468  penalty_score += 1000;
469  return 1;
470  }
471  break;
472 
473  /* Not implemented types:
474  kCHAR
475  kTIME
476  kDATE
477  kGEOMETRY
478  kGEOGRAPHY
479  kEVAL_CONTEXT_TYPE
480  kVOID
481  kCURSOR
482  */
483  default:
484  throw std::runtime_error(std::string(__FILE__) + "#" + std::to_string(__LINE__) +
485  ": support for " + arg_type.get_type_name() +
486  "(type=" + std::to_string(arg_type.get_type()) + ")" +
487  +" not implemented: \n pos=" + std::to_string(sig_pos) +
488  " max_pos=" + std::to_string(max_pos) + "\n sig_types=(" +
489  ExtensionFunctionsWhitelist::toString(sig_types) + ")");
490  }
491  return -1;
492 }
493 
494 bool is_valid_identifier(std::string str) {
495  if (!str.size()) {
496  return false;
497  }
498 
499  if (!(std::isalpha(str[0]) || str[0] == '_')) {
500  return false;
501  }
502 
503  for (size_t i = 1; i < str.size(); i++) {
504  if (!(std::isalnum(str[i]) || str[i] == '_')) {
505  return false;
506  }
507  }
508 
509  return true;
510 }
511 
512 } // namespace
513 
514 template <typename T>
515 std::tuple<T, std::vector<SQLTypeInfo>> bind_function(
516  std::string name,
517  Analyzer::ExpressionPtrVector func_args, // function args from sql query
518  const std::vector<T>& ext_funcs, // list of functions registered
519  const std::string processor) {
520  /* worker function
521 
522  Template type T must implement the following methods:
523 
524  std::vector<ExtArgumentType> getInputArgs()
525  */
526  /*
527  Return extension function/table function that has the following
528  properties
529 
530  1. each argument type in `arg_types` matches with extension
531  function argument types.
532 
533  For scalar types, the matching means that the types are either
534  equal or the argument type is smaller than the corresponding
535  the extension function argument type. This ensures that no
536  information is lost when casting of argument values is
537  required.
538 
539  For array and geo types, the matching means that the argument
540  type matches exactly with a group of extension function
541  argument types. See `match_arguments`.
542 
543  2. has minimal penalty score among all implementations of the
544  extension function with given `name`, see `get_penalty_score`
545  for the definition of penalty score.
546 
547  It is assumed that function_oper and extension functions in
548  ext_funcs have the same name.
549  */
550  if (!is_valid_identifier(name)) {
551  throw NativeExecutionError(
552  "Cannot bind function with invalid UDF/UDTF function name: " + name);
553  }
554 
555  int minimal_score = std::numeric_limits<int>::max();
556  int index = -1;
557  int optimal = -1;
558  int optimal_variant = -1;
559 
560  std::vector<SQLTypeInfo> type_infos_input;
561  std::vector<bool> args_are_constants;
562  for (auto atype : func_args) {
563  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
564  if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
565  SQLTypeInfo type_info = atype->get_type_info();
566  auto ti = generate_column_type(type_info);
567  if (ti.get_subtype() == kNULLT) {
568  throw std::runtime_error(std::string(__FILE__) + "#" +
569  std::to_string(__LINE__) +
570  ": column support for type info " +
571  type_info.to_string() + " is not implemented");
572  }
573  type_infos_input.push_back(ti);
574  args_are_constants.push_back(type_info.get_type() != kTEXT);
575  continue;
576  }
577  }
578  type_infos_input.push_back(atype->get_type_info());
579  if (dynamic_cast<const Analyzer::Constant*>(atype.get())) {
580  args_are_constants.push_back(true);
581  } else {
582  args_are_constants.push_back(false);
583  }
584  }
585  CHECK_EQ(type_infos_input.size(), args_are_constants.size());
586 
587  if (type_infos_input.size() == 0 && ext_funcs.size() > 0) {
588  CHECK_EQ(ext_funcs.size(), static_cast<size_t>(1));
589  CHECK_EQ(ext_funcs[0].getInputArgs().size(), static_cast<size_t>(0));
590  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
591  CHECK(ext_funcs[0].hasNonUserSpecifiedOutputSize());
592  }
593  std::vector<SQLTypeInfo> empty_type_info_variant(0);
594  return {ext_funcs[0], empty_type_info_variant};
595  }
596 
597  // clang-format off
598  /*
599  Table functions may have arguments such as ColumnList that collect
600  neighboring columns with the same data type into a single object.
601  Here we compute all possible combinations of mapping a subset of
602  columns into columns sets. For example, if the types of function
603  arguments are (as given in func_args argument)
604 
605  (Column<int>, Column<int>, Column<int>, int)
606 
607  then the computed variants will be
608 
609  (Column<int>, Column<int>, Column<int>, int)
610  (Column<int>, Column<int>, ColumnList[1]<int>, int)
611  (Column<int>, ColumnList[1]<int>, Column<int>, int)
612  (Column<int>, ColumnList[2]<int>, int)
613  (ColumnList[1]<int>, Column<int>, Column<int>, int)
614  (ColumnList[1]<int>, Column<int>, ColumnList[1]<int>, int)
615  (ColumnList[2]<int>, Column<int>, int)
616  (ColumnList[3]<int>, int)
617 
618  where the integers in [..] indicate the number of collected
619  columns. In the SQLTypeInfo instance, this number is stored in the
620  SQLTypeInfo dimension attribute.
621 
622  As an example, let us consider a SQL query containing the
623  following expression calling a UDTF foo:
624 
625  table(foo(cursor(select a, b, c from tableofints), 1))
626 
627  Here follows a list of table functions and the corresponding
628  optimal argument type variants that are computed for the given
629  query expression:
630 
631  UDTF: foo(ColumnList<int>, RowMultiplier) -> Column<int>
632  (ColumnList[3]<int>, int) # a, b, c are all collected to column_list
633 
634  UDTF: foo(Column<int>, ColumnList<int>, RowMultiplier) -> Column<int>
635  (Column<int>, ColumnList[2]<int>, int) # b and c are collected to column_list
636 
637  UDTF: foo(Column<int>, Column<int>, Column<int>, RowMultiplier) -> Column<int>
638  (Column<int>, Column<int>, Column<int>, int)
639  */
640  // clang-format on
641 
642  // We first check if any of the matched extension functions
643  // in the ext_funcs list allow for column lists, as if they do not,
644  // we do not need to account for possible input permutations with
645  // ColumnList arguments, which currently can be slow time
646  // to match with the extension functions if the number of arguments
647  // is high
648 
649  // Todo: Develop faster matching algorithm to avoid such performance
650  // hits when ColumnLists are allowed
651 
652  bool ext_funcs_allow_column_lists{false};
653  for (const auto& ext_func : ext_funcs) {
654  auto ext_func_args = ext_func.getInputArgs();
655  for (const auto& arg : ext_func_args) {
656  if (is_ext_arg_type_column_list(arg)) {
657  ext_funcs_allow_column_lists = true;
658  break;
659  }
660  }
661  if (ext_funcs_allow_column_lists) {
662  break;
663  }
664  }
665 
666  std::vector<std::vector<SQLTypeInfo>> type_infos_variants;
667  if (ext_funcs_allow_column_lists) {
668  for (const auto& ti : type_infos_input) {
669  if (type_infos_variants.begin() == type_infos_variants.end()) {
670  type_infos_variants.push_back({ti});
671  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
672  if (ti.is_column()) {
673  auto mti = generate_column_list_type(ti);
674  if (mti.get_subtype() == kNULLT) {
675  continue; // skip unsupported element type.
676  }
677  mti.set_dimension(1);
678  type_infos_variants.push_back({mti});
679  }
680  }
681  continue;
682  }
683  std::vector<std::vector<SQLTypeInfo>> new_type_infos_variants;
684  for (auto& type_infos : type_infos_variants) {
685  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
686  if (ti.is_column()) {
687  auto new_type_infos = type_infos; // makes a copy
688  const auto& last = type_infos.back();
689  if (last.is_column_list() && last.has_same_itemtype(ti)) {
690  // last column_list consumes column argument if item types match
691  new_type_infos.back().set_dimension(last.get_dimension() + 1);
692  } else {
693  // add column as column_list argument
694  auto mti = generate_column_list_type(ti);
695  if (mti.get_subtype() == kNULLT) {
696  // skip unsupported element type
697  type_infos.push_back(ti);
698  continue;
699  }
700  mti.set_dimension(1);
701  new_type_infos.push_back(mti);
702  }
703  new_type_infos_variants.push_back(new_type_infos);
704  }
705  }
706  type_infos.push_back(ti);
707  }
708  type_infos_variants.insert(type_infos_variants.end(),
709  new_type_infos_variants.begin(),
710  new_type_infos_variants.end());
711  }
712  } else {
713  type_infos_variants.emplace_back(type_infos_input);
714  }
715 
716  // Find extension function that gives the best match on the set of
717  // argument type variants:
718  for (const auto& ext_func : ext_funcs) {
719  index++;
720 
721  const auto& ext_func_args = ext_func.getInputArgs();
722  int index_variant = -1;
723  for (const auto& type_infos : type_infos_variants) {
724  index_variant++;
725  int penalty_score = 0;
726  int pos = 0;
727  int original_input_idx = 0;
728  CHECK_LE(type_infos.size(), args_are_constants.size());
729  for (const auto& ti : type_infos) {
730  int offset = match_arguments(ti,
731  args_are_constants[original_input_idx],
732  pos,
733  ext_func_args,
734  penalty_score);
735  if (offset < 0) {
736  // atype does not match with ext_func argument
737  pos = -1;
738  break;
739  }
740  if (ti.get_type() == kCOLUMN_LIST) {
741  original_input_idx += ti.get_dimension();
742  } else {
743  original_input_idx++;
744  }
745  pos += offset;
746  }
747 
748  if ((size_t)pos == ext_func_args.size()) {
749  CHECK_EQ(args_are_constants.size(), original_input_idx);
750  // prefer smaller return types
751  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
752  if (penalty_score < minimal_score) {
753  optimal = index;
754  minimal_score = penalty_score;
755  optimal_variant = index_variant;
756  }
757  }
758  }
759  }
760 
761  if (optimal == -1) {
762  /* no extension function found that argument types would match
763  with types in `arg_types` */
764  auto sarg_types = ExtensionFunctionsWhitelist::toString(type_infos_input);
765  std::string message;
766  if (!ext_funcs.size()) {
767  message = "Function " + name + "(" + sarg_types + ") not supported.";
768  throw ExtensionFunctionBindingError(message);
769  } else {
770  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
771  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
772  " UDTF implementation.";
773  } else if constexpr (std::is_same_v<T, ExtensionFunction>) {
774  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
775  " UDF implementation.";
776  } else {
777  LOG(FATAL) << "bind_function: unknown extension function type "
778  << typeid(T).name();
779  }
780  message += "\n Existing extension function implementations:";
781  for (const auto& ext_func : ext_funcs) {
782  // Do not show functions missing the sizer argument
783  if constexpr (std::is_same_v<T, table_functions::TableFunction>)
784  if (ext_func.useDefaultSizer())
785  continue;
786  message += "\n " + ext_func.toStringSQL();
787  }
788  }
789  throw ExtensionFunctionBindingError(message);
790  }
791 
792  // Functions with "_default_" suffix only exist for calcite
793  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
794  if (ext_funcs[optimal].hasUserSpecifiedOutputSizeMultiplier() &&
795  ext_funcs[optimal].useDefaultSizer()) {
796  std::string name = ext_funcs[optimal].getName();
797  name.erase(name.find(DEFAULT_ROW_MULTIPLIER_SUFFIX),
799  for (size_t i = 0; i < ext_funcs.size(); i++) {
800  if (ext_funcs[i].getName() == name) {
801  optimal = i;
802  std::vector<SQLTypeInfo> type_info = type_infos_variants[optimal_variant];
803  size_t sizer = ext_funcs[optimal].getOutputRowSizeParameter();
804  type_info.insert(type_info.begin() + sizer - 1, SQLTypeInfo(kINT, true));
805  return {ext_funcs[optimal], type_info};
806  }
807  }
808  UNREACHABLE();
809  }
810  }
811 
812  return {ext_funcs[optimal], type_infos_variants[optimal_variant]};
813 }
814 
815 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
818  const std::vector<table_functions::TableFunction>& table_funcs,
819  const bool is_gpu) {
820  std::string processor = (is_gpu ? "GPU" : "CPU");
821  return bind_function<table_functions::TableFunction>(
822  name, input_args, table_funcs, processor);
823 }
824 
826  Analyzer::ExpressionPtrVector func_args) {
827  // used in RelAlgTranslator.cpp, first try GPU UDFs, then fall back
828  // to CPU UDFs.
829  bool is_gpu = true;
830  std::string processor = "GPU";
831  auto ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
832  if (!ext_funcs.size()) {
833  is_gpu = false;
834  processor = "CPU";
835  ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
836  }
837  try {
838  return std::get<0>(
839  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
840  } catch (ExtensionFunctionBindingError& e) {
841  if (is_gpu) {
842  is_gpu = false;
843  processor = "GPU|CPU";
844  ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
845  return std::get<0>(
846  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
847  } else {
848  throw;
849  }
850  }
851 }
852 
855  const bool is_gpu) {
856  // used below
857  std::vector<ExtensionFunction> ext_funcs =
859  std::string processor = (is_gpu ? "GPU" : "CPU");
860  return std::get<0>(
861  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
862 }
863 
865  const bool is_gpu) {
866  // used in ExtensionsIR.cpp
867  auto name = function_oper->getName();
868  Analyzer::ExpressionPtrVector func_args = {};
869  for (size_t i = 0; i < function_oper->getArity(); ++i) {
870  func_args.push_back(function_oper->getOwnArg(i));
871  }
872  return bind_function(name, func_args, is_gpu);
873 }
874 
875 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
878  const bool is_gpu) {
879  // used in RelAlgExecutor.cpp
880  std::vector<table_functions::TableFunction> table_funcs =
882  return bind_table_function(name, input_args, table_funcs, is_gpu);
883 }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
ExtArgumentType get_array_arg_elem_type(const ExtArgumentType ext_arg_array_type)
static std::vector< TableFunction > get_table_funcs()
size_t getArity() const
Definition: Analyzer.h:2408
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
bool is_ext_arg_type_column(const ExtArgumentType ext_arg_type)
#define LOG(tag)
Definition: Logger.h:285
static int match_arguments(const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
#define UNREACHABLE()
Definition: Logger.h:337
#define CHECK_GE(x, y)
Definition: Logger.h:306
ExtArgumentType get_column_list_arg_elem_type(const ExtArgumentType ext_arg_column_list_type)
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:381
#define CHECK_GT(x, y)
Definition: Logger.h:305
std::string to_string(char const *&&v)
std::string to_string() const
Definition: sqltypes.h:547
bool is_ext_arg_type_column_list(const ExtArgumentType ext_arg_type)
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:2415
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
bool is_numeric_scalar_auto_castable(const SQLTypeInfo &new_type_info) const
returns true if the sql_type can be cast to the type specified by new_type_info with no loss of preci...
Definition: sqltypes.h:749
ExtArgumentType get_column_arg_elem_type(const ExtArgumentType ext_arg_column_type)
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
Argument type based extension function binding.
static int match_numeric_argument(const SQLTypeInfo &arg_type_info, const bool is_arg_literal, const ExtArgumentType &sig_ext_arg_type, int32_t &penalty_score)
auto generate_column_type(const SQLTypeInfo &elem_ti)
Definition: sqltypes.h:1387
Definition: sqltypes.h:69
#define CHECK_LE(x, y)
Definition: Logger.h:304
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:389
auto generate_column_list_type(const SQLTypeInfo &elem_ti)
Definition: sqltypes.h:1445
std::string get_type_name() const
Definition: sqltypes.h:507
int32_t get_numeric_scalar_scale() const
returns integer between 1 and 8 indicating what is roughly equivalent to the logical byte size of a s...
Definition: sqltypes.h:810
const std::tuple< table_functions::TableFunction, std::vector< SQLTypeInfo > > bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:291
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:186
Definition: sqltypes.h:62
std::string getName() const
Definition: Analyzer.h:2406
string name
Definition: setup.in.py:72
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:963
bool is_array() const
Definition: sqltypes.h:588
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)