OmniSciDB  c0231cc57d
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsBinding.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 #include <algorithm>
19 #include "ExternalExecutor.h"
20 
21 // A rather crude function binding logic based on the types of the arguments.
22 // We want it to be possible to write specialized versions of functions to be
23 // exposed as SQL extensions. This is important especially for performance
24 // reasons, since double operations can be significantly slower than float. We
25 // compute a score for each candidate signature based on conversions required to
26 // from the function arguments as specified in the SQL query to the versions in
27 // ExtensionFunctions.hpp.
28 
29 /*
30  New implementation for binding a SQL function operator to the
31  optimal candidate within in all available extension functions.
32  */
33 
34 namespace {
35 
37  switch (ext_arg_column_type) {
39  return ExtArgumentType::Int8;
51  return ExtArgumentType::Bool;
72  default:
73  UNREACHABLE();
74  }
75  return ExtArgumentType{};
76 }
77 
79  const ExtArgumentType ext_arg_column_list_type) {
80  switch (ext_arg_column_list_type) {
82  return ExtArgumentType::Int8;
94  return ExtArgumentType::Bool;
113  default:
114  UNREACHABLE();
115  }
116  return ExtArgumentType{};
117 }
118 
120  switch (ext_arg_array_type) {
122  return ExtArgumentType::Int8;
124  return ExtArgumentType::Int16;
126  return ExtArgumentType::Int32;
128  return ExtArgumentType::Int64;
130  return ExtArgumentType::Float;
134  return ExtArgumentType::Bool;
137  default:
138  UNREACHABLE();
139  }
140  return ExtArgumentType{};
141 }
142 
143 static int match_numeric_argument(const SQLTypeInfo& arg_type_info,
144  const bool is_arg_literal,
145  const ExtArgumentType& sig_ext_arg_type,
146  int32_t& penalty_score) {
147  const auto arg_type = arg_type_info.get_type();
148  CHECK(arg_type == kBOOLEAN || arg_type == kTINYINT || arg_type == kSMALLINT ||
149  arg_type == kINT || arg_type == kBIGINT || arg_type == kFLOAT ||
150  arg_type == kDOUBLE || arg_type == kDECIMAL || arg_type == kNUMERIC);
151  // Todo (todd): Add support for timestamp, date, and time types
152  const auto sig_type_info = ext_arg_type_to_type_info(sig_ext_arg_type);
153  const auto sig_type = sig_type_info.get_type();
154 
155  // If we can't legally auto-cast to sig_type, abort
156  if (!arg_type_info.is_numeric_scalar_auto_castable(sig_type_info)) {
157  return -1;
158  }
159 
160  // We now compare a measure of the scale of the sig_type with the
161  // arg_type, which provides a basis for scoring the match between
162  // the two. Note that get_numeric_scalar_scale for the most part
163  // returns the logical byte width of the type, with a few caveats
164  // for decimals and timestamps described in more depth in comments
165  // in the function itself. Also even though for example float and
166  // int types return 4 (as in 4 bytes), and double and bigint types
167  // return 8, a fp32 type cannot express every 32-bit integer (even
168  // if it can cover a larger absolute range), and an fp64 type
169  // likewise cannot express every 64-bit integer. With the aim to
170  // minimize the precision loss from casting (always precise) integer
171  // value to (imprecise) floating point value, in the case of integer
172  // inputs, we'll penalize wider floating point argument types least
173  // by a specific scale transformation (see the implementation
174  // below). For instance, casting tinyint to fp64 is prefered over
175  // casting it to fp32 to minimize precision loss.
176  const bool is_integer_to_fp_cast = (arg_type == kTINYINT || arg_type == kSMALLINT ||
177  arg_type == kINT || arg_type == kBIGINT) &&
178  (sig_type == kFLOAT || sig_type == kDOUBLE);
179 
180  const auto arg_type_relative_scale = arg_type_info.get_numeric_scalar_scale();
181  CHECK_GE(arg_type_relative_scale, 1);
182  CHECK_LE(arg_type_relative_scale, 8);
183  auto sig_type_relative_scale = sig_type_info.get_numeric_scalar_scale();
184  CHECK_GE(sig_type_relative_scale, 1);
185  CHECK_LE(sig_type_relative_scale, 8);
186 
187  if (is_integer_to_fp_cast) {
188  // transform fp scale: 4 becomes 16, 8 remains 8
189  sig_type_relative_scale = (3 - (sig_type_relative_scale >> 2)) << 3;
190  }
191 
192  // We do not allow auto-casting to types with less scale/precision
193  // within the same type family.
194  CHECK_GE(sig_type_relative_scale, arg_type_relative_scale);
195 
196  // Calculate the ratio of the sig_type by the arg_type, per the above check will be >= 1
197  const auto sig_type_scale_gain_ratio =
198  sig_type_relative_scale / arg_type_relative_scale;
199  CHECK_GE(sig_type_scale_gain_ratio, 1);
200 
201  // Following the old bespoke scoring logic this function replaces, we heavily penalize
202  // any casts that move ints to floats/doubles for the precision-loss reasons above
203  // Arguably all integers in the tinyint and smallint can be fully specified with both
204  // float and double types, but we treat them the same as int and bigint types here.
205  const int32_t type_family_cast_penalty_score = is_integer_to_fp_cast ? 1001000 : 1000;
206 
207  int32_t scale_cast_penalty_score;
208 
209  // The following logic is new. Basically there are strong reasons to
210  // prefer the promotion of constant literals to the most precise type possible, as
211  // rather than the type being inherent in the data - that is a column or columns where
212  // a user specified a type (and with any expressions on those columns following our
213  // standard sql casting logic), literal types are given to us by Calcite and do not
214  // necessarily convey any semantic intent (i.e. 10 will be an int, but 10.0 a decimal)
215  // Hence it is better to promote these types to the most precise sig_type available,
216  // while at the same time keeping column expressions as close as possible to the input
217  // types (mainly for performance, we have many float versions of various functions
218  // to allow for greater performance when the underlying data is not of double precision,
219  // and hence there is little benefit of the extra cost of computing double precision
220  // operators on this data)
221  if (is_arg_literal) {
222  scale_cast_penalty_score =
223  (8000 / arg_type_relative_scale) - (1000 * sig_type_scale_gain_ratio);
224  } else {
225  scale_cast_penalty_score = (1000 * sig_type_scale_gain_ratio);
226  }
227 
228  const auto cast_penalty_score =
229  type_family_cast_penalty_score + scale_cast_penalty_score;
230  CHECK_GT(cast_penalty_score, 0);
231  penalty_score += cast_penalty_score;
232  return 1;
233 }
234 
235 static int match_arguments(const SQLTypeInfo& arg_type,
236  const bool is_arg_literal,
237  int sig_pos,
238  const std::vector<ExtArgumentType>& sig_types,
239  int& penalty_score) {
240  /*
241  Returns non-negative integer `offset` if `arg_type` and
242  `sig_types[sig_pos:sig_pos + offset]` match.
243 
244  The `offset` value can be interpreted as the number of extension
245  function arguments that is consumed by the given `arg_type`. For
246  instance, for scalar types the offset is always 1, for array
247  types the offset is 2: one argument for array pointer value and
248  one argument for the array size value, etc.
249 
250  Returns -1 when the types of an argument and the corresponding
251  extension function argument(s) mismatch, or when downcasting would
252  be effective.
253 
254  In case of non-negative `offset` result, the function updates
255  penalty_score argument as follows:
256 
257  add 1000 if arg_type is non-scalar, otherwise:
258  add 1000 * sizeof(sig_type) / sizeof(arg_type)
259  add 1000000 if type kinds differ (integer vs double, for instance)
260 
261  */
262  int max_pos = sig_types.size() - 1;
263  if (sig_pos > max_pos) {
264  return -1;
265  }
266  auto sig_type = sig_types[sig_pos];
267  switch (arg_type.get_type()) {
268  case kBOOLEAN:
269  case kTINYINT:
270  case kSMALLINT:
271  case kINT:
272  case kBIGINT:
273  case kFLOAT:
274  case kDOUBLE:
275  case kDECIMAL:
276  case kNUMERIC:
277  return match_numeric_argument(arg_type, is_arg_literal, sig_type, penalty_score);
278  case kPOINT:
279  case kMULTIPOINT:
280  case kLINESTRING:
281  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
282  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
283  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble) &&
284  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
285  penalty_score += 1000;
286  return 2;
287  } else if (sig_type == ExtArgumentType::GeoPoint ||
288  sig_type == ExtArgumentType::GeoMultiPoint ||
289  sig_type == ExtArgumentType::GeoLineString) {
290  penalty_score += 1000;
291  return 1;
292  }
293  return -1;
294  case kMULTILINESTRING:
295  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
296  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
297  sig_types[sig_pos + 2] == ExtArgumentType::PInt8 &&
298  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
299  penalty_score += 1000;
300  return 4;
301  } else if (sig_type == ExtArgumentType::GeoMultiLineString) {
302  penalty_score += 1000;
303  return 1;
304  }
305  break;
306  case kARRAY:
307  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
308  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
309  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble ||
310  sig_type == ExtArgumentType::PBool) &&
311  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
312  penalty_score += 1000;
313  return 2;
314  } else if (is_ext_arg_type_array(sig_type)) {
315  // array arguments must match exactly
316  CHECK(arg_type.is_array());
317  const auto sig_type_ti =
319  if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
320  sig_type_ti.get_type() == kTINYINT) {
321  /* Boolean array has the same low-level structure as Int8 array. */
322  penalty_score += 1000;
323  return 1;
324  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
325  penalty_score += 1000;
326  return 1;
327  } else {
328  return -1;
329  }
330  }
331  break;
332  case kPOLYGON:
333  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
334  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
335  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
336  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
337  penalty_score += 1000;
338  return 4;
339  } else if (sig_type == ExtArgumentType::GeoPolygon) {
340  penalty_score += 1000;
341  return 1;
342  }
343  break;
344  case kMULTIPOLYGON:
345  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 5 < max_pos &&
346  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
347  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
348  sig_types[sig_pos + 3] == ExtArgumentType::Int64 &&
349  sig_types[sig_pos + 4] == ExtArgumentType::PInt32 &&
350  sig_types[sig_pos + 5] == ExtArgumentType::Int64) {
351  penalty_score += 1000;
352  return 6;
353  } else if (sig_type == ExtArgumentType::GeoMultiPolygon) {
354  penalty_score += 1000;
355  return 1;
356  }
357  break;
358  case kNULLT: // NULL maps to a pointer and size argument
359  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
360  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
361  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble ||
362  sig_type == ExtArgumentType::PBool) &&
363  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
364  penalty_score += 1000;
365  return 2;
366  }
367  break;
368  case kCOLUMN:
369  if (is_ext_arg_type_column(sig_type)) {
370  // column arguments must match exactly
371  const auto sig_type_ti =
373  if (arg_type.get_elem_type().get_type() == kARRAY &&
374  sig_type_ti.get_type() == kARRAY) {
375  if (arg_type.get_elem_type().get_elem_type().get_type() ==
376  sig_type_ti.get_elem_type().get_type()) {
377  penalty_score += 1000;
378  return 1;
379  } else {
380  return -1;
381  }
382  } else if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
383  sig_type_ti.get_type() == kTINYINT) {
384  /* Boolean column has the same low-level structure as Int8 column. */
385  penalty_score += 1000;
386  return 1;
387  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
388  penalty_score += 1000;
389  return 1;
390  } else {
391  return -1;
392  }
393  }
394  break;
395  case kCOLUMN_LIST:
396  if (is_ext_arg_type_column_list(sig_type)) {
397  // column_list arguments must match exactly
398  const auto sig_type_ti =
400  if (arg_type.get_elem_type().get_type() == kARRAY &&
401  sig_type_ti.get_type() == kARRAY) {
402  if (arg_type.get_elem_type().get_elem_type().get_type() ==
403  sig_type_ti.get_elem_type().get_type()) {
404  penalty_score += 1000;
405  return 1;
406  } else {
407  return -1;
408  }
409  } else if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
410  sig_type_ti.get_type() == kTINYINT) {
411  /* Boolean column_list has the same low-level structure as Int8 column_list. */
412  penalty_score += 10000;
413  return 1;
414  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
415  penalty_score += 10000;
416  return 1;
417  } else {
418  return -1;
419  }
420  }
421  break;
422  case kVARCHAR:
423  if (sig_type != ExtArgumentType::TextEncodingNone) {
424  return -1;
425  }
426  switch (arg_type.get_compression()) {
427  case kENCODING_NONE:
428  penalty_score += 1000;
429  return 1;
430  case kENCODING_DICT:
431  return -1;
432  // Todo (todd): Evaluate when and where we can tranlate to dictionary-encoded
433  default:
434  UNREACHABLE();
435  }
436  case kTEXT:
437  if (sig_type != ExtArgumentType::TextEncodingNone) {
438  return -1;
439  }
440  switch (arg_type.get_compression()) {
441  case kENCODING_NONE:
442  penalty_score += 1000;
443  return 1;
444  case kENCODING_DICT:
445  return -1;
446  default:
447  UNREACHABLE();
448  }
449  case kTIMESTAMP:
450  if (sig_type == ExtArgumentType::Timestamp) {
451  if (arg_type.get_precision() != 9) {
452  return -1;
453  }
454  penalty_score += 1000;
455  return 1;
456  }
457  break;
458  /* Not implemented types:
459  kCHAR
460  kTIME
461  kDATE
462  kINTERVAL_DAY_TIME
463  kINTERVAL_YEAR_MONTH
464  kGEOMETRY
465  kGEOGRAPHY
466  kEVAL_CONTEXT_TYPE
467  kVOID
468  kCURSOR
469  */
470  default:
471  throw std::runtime_error(std::string(__FILE__) + "#" + std::to_string(__LINE__) +
472  ": support for " + arg_type.get_type_name() +
473  "(type=" + std::to_string(arg_type.get_type()) + ")" +
474  +" not implemented: \n pos=" + std::to_string(sig_pos) +
475  " max_pos=" + std::to_string(max_pos) + "\n sig_types=(" +
476  ExtensionFunctionsWhitelist::toString(sig_types) + ")");
477  }
478  return -1;
479 }
480 
481 bool is_valid_identifier(std::string str) {
482  if (!str.size()) {
483  return false;
484  }
485 
486  if (!(std::isalpha(str[0]) || str[0] == '_')) {
487  return false;
488  }
489 
490  for (size_t i = 1; i < str.size(); i++) {
491  if (!(std::isalnum(str[i]) || str[i] == '_')) {
492  return false;
493  }
494  }
495 
496  return true;
497 }
498 
499 } // namespace
500 
501 template <typename T>
502 std::tuple<T, std::vector<SQLTypeInfo>> bind_function(
503  std::string name,
504  Analyzer::ExpressionPtrVector func_args, // function args from sql query
505  const std::vector<T>& ext_funcs, // list of functions registered
506  const std::string processor) {
507  /* worker function
508 
509  Template type T must implement the following methods:
510 
511  std::vector<ExtArgumentType> getInputArgs()
512  */
513  /*
514  Return extension function/table function that has the following
515  properties
516 
517  1. each argument type in `arg_types` matches with extension
518  function argument types.
519 
520  For scalar types, the matching means that the types are either
521  equal or the argument type is smaller than the corresponding
522  the extension function argument type. This ensures that no
523  information is lost when casting of argument values is
524  required.
525 
526  For array and geo types, the matching means that the argument
527  type matches exactly with a group of extension function
528  argument types. See `match_arguments`.
529 
530  2. has minimal penalty score among all implementations of the
531  extension function with given `name`, see `get_penalty_score`
532  for the definition of penalty score.
533 
534  It is assumed that function_oper and extension functions in
535  ext_funcs have the same name.
536  */
537  if (!is_valid_identifier(name)) {
538  throw NativeExecutionError(
539  "Cannot bind function with invalid UDF/UDTF function name: " + name);
540  }
541 
542  int minimal_score = std::numeric_limits<int>::max();
543  int index = -1;
544  int optimal = -1;
545  int optimal_variant = -1;
546 
547  std::vector<SQLTypeInfo> type_infos_input;
548  std::vector<bool> args_are_constants;
549  for (auto atype : func_args) {
550  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
551  if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
552  SQLTypeInfo type_info = atype->get_type_info();
553  auto ti = generate_column_type(type_info);
554  if (ti.get_subtype() == kNULLT) {
555  throw std::runtime_error(std::string(__FILE__) + "#" +
556  std::to_string(__LINE__) +
557  ": column support for type info " +
558  type_info.to_string() + " is not implemented");
559  }
560  type_infos_input.push_back(ti);
561  args_are_constants.push_back(type_info.get_type() != kTEXT);
562  continue;
563  }
564  }
565  type_infos_input.push_back(atype->get_type_info());
566  if (dynamic_cast<const Analyzer::Constant*>(atype.get())) {
567  args_are_constants.push_back(true);
568  } else {
569  args_are_constants.push_back(false);
570  }
571  }
572  CHECK_EQ(type_infos_input.size(), args_are_constants.size());
573 
574  if (type_infos_input.size() == 0 && ext_funcs.size() > 0) {
575  CHECK_EQ(ext_funcs.size(), static_cast<size_t>(1));
576  CHECK_EQ(ext_funcs[0].getInputArgs().size(), static_cast<size_t>(0));
577  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
578  CHECK(ext_funcs[0].hasNonUserSpecifiedOutputSize());
579  }
580  std::vector<SQLTypeInfo> empty_type_info_variant(0);
581  return {ext_funcs[0], empty_type_info_variant};
582  }
583 
584  // clang-format off
585  /*
586  Table functions may have arguments such as ColumnList that collect
587  neighboring columns with the same data type into a single object.
588  Here we compute all possible combinations of mapping a subset of
589  columns into columns sets. For example, if the types of function
590  arguments are (as given in func_args argument)
591 
592  (Column<int>, Column<int>, Column<int>, int)
593 
594  then the computed variants will be
595 
596  (Column<int>, Column<int>, Column<int>, int)
597  (Column<int>, Column<int>, ColumnList[1]<int>, int)
598  (Column<int>, ColumnList[1]<int>, Column<int>, int)
599  (Column<int>, ColumnList[2]<int>, int)
600  (ColumnList[1]<int>, Column<int>, Column<int>, int)
601  (ColumnList[1]<int>, Column<int>, ColumnList[1]<int>, int)
602  (ColumnList[2]<int>, Column<int>, int)
603  (ColumnList[3]<int>, int)
604 
605  where the integers in [..] indicate the number of collected
606  columns. In the SQLTypeInfo instance, this number is stored in the
607  SQLTypeInfo dimension attribute.
608 
609  As an example, let us consider a SQL query containing the
610  following expression calling a UDTF foo:
611 
612  table(foo(cursor(select a, b, c from tableofints), 1))
613 
614  Here follows a list of table functions and the corresponding
615  optimal argument type variants that are computed for the given
616  query expression:
617 
618  UDTF: foo(ColumnList<int>, RowMultiplier) -> Column<int>
619  (ColumnList[3]<int>, int) # a, b, c are all collected to column_list
620 
621  UDTF: foo(Column<int>, ColumnList<int>, RowMultiplier) -> Column<int>
622  (Column<int>, ColumnList[2]<int>, int) # b and c are collected to column_list
623 
624  UDTF: foo(Column<int>, Column<int>, Column<int>, RowMultiplier) -> Column<int>
625  (Column<int>, Column<int>, Column<int>, int)
626  */
627  // clang-format on
628  std::vector<std::vector<SQLTypeInfo>> type_infos_variants;
629  for (auto ti : type_infos_input) {
630  if (type_infos_variants.begin() == type_infos_variants.end()) {
631  type_infos_variants.push_back({ti});
632  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
633  if (ti.is_column()) {
634  auto mti = generate_column_list_type(ti);
635  if (mti.get_subtype() == kNULLT) {
636  continue; // skip unsupported element type.
637  }
638  mti.set_dimension(1);
639  type_infos_variants.push_back({mti});
640  }
641  }
642  continue;
643  }
644  std::vector<std::vector<SQLTypeInfo>> new_type_infos_variants;
645  for (auto& type_infos : type_infos_variants) {
646  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
647  if (ti.is_column()) {
648  auto new_type_infos = type_infos; // makes a copy
649  const auto& last = type_infos.back();
650  if (last.is_column_list() && last.has_same_itemtype(ti)) {
651  // last column_list consumes column argument if item types match
652  new_type_infos.back().set_dimension(last.get_dimension() + 1);
653  } else {
654  // add column as column_list argument
655  auto mti = generate_column_list_type(ti);
656  if (mti.get_subtype() == kNULLT) {
657  // skip unsupported element type
658  type_infos.push_back(ti);
659  continue;
660  }
661  mti.set_dimension(1);
662  new_type_infos.push_back(mti);
663  }
664  new_type_infos_variants.push_back(new_type_infos);
665  }
666  }
667  type_infos.push_back(ti);
668  }
669  type_infos_variants.insert(type_infos_variants.end(),
670  new_type_infos_variants.begin(),
671  new_type_infos_variants.end());
672  }
673 
674  // Find extension function that gives the best match on the set of
675  // argument type variants:
676  for (auto ext_func : ext_funcs) {
677  index++;
678 
679  auto ext_func_args = ext_func.getInputArgs();
680  int index_variant = -1;
681  for (const auto& type_infos : type_infos_variants) {
682  index_variant++;
683  int penalty_score = 0;
684  int pos = 0;
685  int original_input_idx = 0;
686  CHECK_LE(type_infos.size(), args_are_constants.size());
687  for (const auto& ti : type_infos) {
688  int offset = match_arguments(ti,
689  args_are_constants[original_input_idx],
690  pos,
691  ext_func_args,
692  penalty_score);
693  if (offset < 0) {
694  // atype does not match with ext_func argument
695  pos = -1;
696  break;
697  }
698  if (ti.get_type() == kCOLUMN_LIST) {
699  original_input_idx += ti.get_dimension();
700  } else {
701  original_input_idx++;
702  }
703  pos += offset;
704  }
705 
706  if ((size_t)pos == ext_func_args.size()) {
707  CHECK_EQ(args_are_constants.size(), original_input_idx);
708  // prefer smaller return types
709  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
710  if (penalty_score < minimal_score) {
711  optimal = index;
712  minimal_score = penalty_score;
713  optimal_variant = index_variant;
714  }
715  }
716  }
717  }
718 
719  if (optimal == -1) {
720  /* no extension function found that argument types would match
721  with types in `arg_types` */
722  auto sarg_types = ExtensionFunctionsWhitelist::toString(type_infos_input);
723  std::string message;
724  if (!ext_funcs.size()) {
725  message = "Function " + name + "(" + sarg_types + ") not supported.";
726  throw ExtensionFunctionBindingError(message);
727  } else {
728  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
729  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
730  " UDTF implementation.";
731  } else if constexpr (std::is_same_v<T, ExtensionFunction>) {
732  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
733  " UDF implementation.";
734  } else {
735  LOG(FATAL) << "bind_function: unknown extension function type "
736  << typeid(T).name();
737  }
738  message += "\n Existing extension function implementations:";
739  for (const auto& ext_func : ext_funcs) {
740  // Do not show functions missing the sizer argument
741  if constexpr (std::is_same_v<T, table_functions::TableFunction>)
742  if (ext_func.useDefaultSizer())
743  continue;
744  message += "\n " + ext_func.toStringSQL();
745  }
746  }
747  throw ExtensionFunctionBindingError(message);
748  }
749 
750  // Functions with "_default_" suffix only exist for calcite
751  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
752  if (ext_funcs[optimal].hasUserSpecifiedOutputSizeMultiplier() &&
753  ext_funcs[optimal].useDefaultSizer()) {
754  std::string name = ext_funcs[optimal].getName();
755  name.erase(name.find(DEFAULT_ROW_MULTIPLIER_SUFFIX),
757  for (size_t i = 0; i < ext_funcs.size(); i++) {
758  if (ext_funcs[i].getName() == name) {
759  optimal = i;
760  std::vector<SQLTypeInfo> type_info = type_infos_variants[optimal_variant];
761  size_t sizer = ext_funcs[optimal].getOutputRowSizeParameter();
762  type_info.insert(type_info.begin() + sizer - 1, SQLTypeInfo(kINT, true));
763  return {ext_funcs[optimal], type_info};
764  }
765  }
766  UNREACHABLE();
767  }
768  }
769 
770  return {ext_funcs[optimal], type_infos_variants[optimal_variant]};
771 }
772 
773 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
776  const std::vector<table_functions::TableFunction>& table_funcs,
777  const bool is_gpu) {
778  std::string processor = (is_gpu ? "GPU" : "CPU");
779  return bind_function<table_functions::TableFunction>(
780  name, input_args, table_funcs, processor);
781 }
782 
784  Analyzer::ExpressionPtrVector func_args) {
785  // used in RelAlgTranslator.cpp, first try GPU UDFs, then fall back
786  // to CPU UDFs.
787  bool is_gpu = true;
788  std::string processor = "GPU";
789  auto ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
790  if (!ext_funcs.size()) {
791  is_gpu = false;
792  processor = "CPU";
793  ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
794  }
795  try {
796  return std::get<0>(
797  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
798  } catch (ExtensionFunctionBindingError& e) {
799  if (is_gpu) {
800  is_gpu = false;
801  processor = "GPU|CPU";
802  ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
803  return std::get<0>(
804  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
805  } else {
806  throw;
807  }
808  }
809 }
810 
813  const bool is_gpu) {
814  // used below
815  std::vector<ExtensionFunction> ext_funcs =
817  std::string processor = (is_gpu ? "GPU" : "CPU");
818  return std::get<0>(
819  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
820 }
821 
823  const bool is_gpu) {
824  // used in ExtensionsIR.cpp
825  auto name = function_oper->getName();
826  Analyzer::ExpressionPtrVector func_args = {};
827  for (size_t i = 0; i < function_oper->getArity(); ++i) {
828  func_args.push_back(function_oper->getOwnArg(i));
829  }
830  return bind_function(name, func_args, is_gpu);
831 }
832 
833 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
836  const bool is_gpu) {
837  // used in RelAlgExecutor.cpp
838  std::vector<table_functions::TableFunction> table_funcs =
840  return bind_table_function(name, input_args, table_funcs, is_gpu);
841 }
#define CHECK_EQ(x, y)
Definition: Logger.h:230
ExtArgumentType get_array_arg_elem_type(const ExtArgumentType ext_arg_array_type)
static std::vector< TableFunction > get_table_funcs()
size_t getArity() const
Definition: Analyzer.h:2260
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
bool is_ext_arg_type_column(const ExtArgumentType ext_arg_type)
#define LOG(tag)
Definition: Logger.h:216
static int match_arguments(const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
#define UNREACHABLE()
Definition: Logger.h:266
#define CHECK_GE(x, y)
Definition: Logger.h:235
ExtArgumentType get_column_list_arg_elem_type(const ExtArgumentType ext_arg_column_list_type)
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:404
#define CHECK_GT(x, y)
Definition: Logger.h:234
std::string to_string(char const *&&v)
std::string to_string() const
Definition: sqltypes.h:568
bool is_ext_arg_type_column_list(const ExtArgumentType ext_arg_type)
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:2267
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
bool is_numeric_scalar_auto_castable(const SQLTypeInfo &new_type_info) const
returns true if the sql_type can be cast to the type specified by new_type_info with no loss of preci...
Definition: sqltypes.h:768
ExtArgumentType get_column_arg_elem_type(const ExtArgumentType ext_arg_column_type)
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
Argument type based extension function binding.
int get_precision() const
Definition: sqltypes.h:407
static int match_numeric_argument(const SQLTypeInfo &arg_type_info, const bool is_arg_literal, const ExtArgumentType &sig_ext_arg_type, int32_t &penalty_score)
auto generate_column_type(const SQLTypeInfo &elem_ti)
Definition: sqltypes.h:1376
Definition: sqltypes.h:66
#define CHECK_LE(x, y)
Definition: Logger.h:233
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:412
auto generate_column_list_type(const SQLTypeInfo &elem_ti)
Definition: sqltypes.h:1424
std::string get_type_name() const
Definition: sqltypes.h:528
int32_t get_numeric_scalar_scale() const
returns integer between 1 and 8 indicating what is roughly equivalent to the logical byte size of a s...
Definition: sqltypes.h:829
const std::tuple< table_functions::TableFunction, std::vector< SQLTypeInfo > > bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:222
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:189
Definition: sqltypes.h:59
std::string getName() const
Definition: Analyzer.h:2258
string name
Definition: setup.in.py:72
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:981
bool is_array() const
Definition: sqltypes.h:608
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)