OmniSciDB  cde582ebc3
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsBinding.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 #include <algorithm>
19 #include "ExternalExecutor.h"
20 
21 // A rather crude function binding logic based on the types of the arguments.
22 // We want it to be possible to write specialized versions of functions to be
23 // exposed as SQL extensions. This is important especially for performance
24 // reasons, since double operations can be significantly slower than float. We
25 // compute a score for each candidate signature based on conversions required to
26 // from the function arguments as specified in the SQL query to the versions in
27 // ExtensionFunctions.hpp.
28 
29 /*
30  New implementation for binding a SQL function operator to the
31  optimal candidate within in all available extension functions.
32  */
33 
34 namespace {
35 
37  switch (ext_arg_column_type) {
39  return ExtArgumentType::Int8;
51  return ExtArgumentType::Bool;
56  default:
57  UNREACHABLE();
58  }
59  return ExtArgumentType{};
60 }
61 
63  const ExtArgumentType ext_arg_column_list_type) {
64  switch (ext_arg_column_list_type) {
66  return ExtArgumentType::Int8;
78  return ExtArgumentType::Bool;
81  default:
82  UNREACHABLE();
83  }
84  return ExtArgumentType{};
85 }
86 
88  switch (ext_arg_array_type) {
90  return ExtArgumentType::Int8;
102  return ExtArgumentType::Bool;
103  default:
104  UNREACHABLE();
105  }
106  return ExtArgumentType{};
107 }
108 
109 static int match_numeric_argument(const SQLTypeInfo& arg_type_info,
110  const bool is_arg_literal,
111  const ExtArgumentType& sig_ext_arg_type,
112  int32_t& penalty_score) {
113  const auto arg_type = arg_type_info.get_type();
114  CHECK(arg_type == kBOOLEAN || arg_type == kTINYINT || arg_type == kSMALLINT ||
115  arg_type == kINT || arg_type == kBIGINT || arg_type == kFLOAT ||
116  arg_type == kDOUBLE || arg_type == kDECIMAL || arg_type == kNUMERIC);
117  // Todo (todd): Add support for timestamp, date, and time types
118  const auto sig_type_info = ext_arg_type_to_type_info(sig_ext_arg_type);
119  const auto sig_type = sig_type_info.get_type();
120 
121  // If we can't legally auto-cast to sig_type, abort
122  if (!arg_type_info.is_numeric_scalar_auto_castable(sig_type_info)) {
123  return -1;
124  }
125 
126  // We now compare a measure of the scale of the sig_type with the
127  // arg_type, which provides a basis for scoring the match between
128  // the two. Note that get_numeric_scalar_scale for the most part
129  // returns the logical byte width of the type, with a few caveats
130  // for decimals and timestamps described in more depth in comments
131  // in the function itself. Also even though for example float and
132  // int types return 4 (as in 4 bytes), and double and bigint types
133  // return 8, a fp32 type cannot express every 32-bit integer (even
134  // if it can cover a larger absolute range), and an fp64 type
135  // likewise cannot express every 64-bit integer. With the aim to
136  // minimize the precision loss from casting (always precise) integer
137  // value to (imprecise) floating point value, in the case of integer
138  // inputs, we'll penalize wider floating point argument types least
139  // by a specific scale transformation (see the implementation
140  // below). For instance, casting tinyint to fp64 is prefered over
141  // casting it to fp32 to minimize precision loss.
142  const bool is_integer_to_fp_cast = (arg_type == kTINYINT || arg_type == kSMALLINT ||
143  arg_type == kINT || arg_type == kBIGINT) &&
144  (sig_type == kFLOAT || sig_type == kDOUBLE);
145 
146  const auto arg_type_relative_scale = arg_type_info.get_numeric_scalar_scale();
147  CHECK_GE(arg_type_relative_scale, 1);
148  CHECK_LE(arg_type_relative_scale, 8);
149  auto sig_type_relative_scale = sig_type_info.get_numeric_scalar_scale();
150  CHECK_GE(sig_type_relative_scale, 1);
151  CHECK_LE(sig_type_relative_scale, 8);
152 
153  if (is_integer_to_fp_cast) {
154  // transform fp scale: 4 becomes 16, 8 remains 8
155  sig_type_relative_scale = (3 - (sig_type_relative_scale >> 2)) << 3;
156  }
157 
158  // We do not allow auto-casting to types with less scale/precision
159  // within the same type family.
160  CHECK_GE(sig_type_relative_scale, arg_type_relative_scale);
161 
162  // Calculate the ratio of the sig_type by the arg_type, per the above check will be >= 1
163  const auto sig_type_scale_gain_ratio =
164  sig_type_relative_scale / arg_type_relative_scale;
165  CHECK_GE(sig_type_scale_gain_ratio, 1);
166 
167  // Following the old bespoke scoring logic this function replaces, we heavily penalize
168  // any casts that move ints to floats/doubles for the precision-loss reasons above
169  // Arguably all integers in the tinyint and smallint can be fully specified with both
170  // float and double types, but we treat them the same as int and bigint types here.
171  const int32_t type_family_cast_penalty_score = is_integer_to_fp_cast ? 1001000 : 1000;
172 
173  int32_t scale_cast_penalty_score;
174 
175  // The following logic is new. Basically there are strong reasons to
176  // prefer the promotion of constant literals to the most precise type possible, as
177  // rather than the type being inherent in the data - that is a column or columns where
178  // a user specified a type (and with any expressions on those columns following our
179  // standard sql casting logic), literal types are given to us by Calcite and do not
180  // necessarily convey any semantic intent (i.e. 10 will be an int, but 10.0 a decimal)
181  // Hence it is better to promote these types to the most precise sig_type available,
182  // while at the same time keeping column expressions as close as possible to the input
183  // types (mainly for performance, we have many float versions of various functions
184  // to allow for greater performance when the underlying data is not of double precision,
185  // and hence there is little benefit of the extra cost of computing double precision
186  // operators on this data)
187  if (is_arg_literal) {
188  scale_cast_penalty_score =
189  (8000 / arg_type_relative_scale) - (1000 * sig_type_scale_gain_ratio);
190  } else {
191  scale_cast_penalty_score = (1000 * sig_type_scale_gain_ratio);
192  }
193 
194  const auto cast_penalty_score =
195  type_family_cast_penalty_score + scale_cast_penalty_score;
196  CHECK_GT(cast_penalty_score, 0);
197  penalty_score += cast_penalty_score;
198  return 1;
199 }
200 
201 static int match_arguments(const SQLTypeInfo& arg_type,
202  const bool is_arg_literal,
203  int sig_pos,
204  const std::vector<ExtArgumentType>& sig_types,
205  int& penalty_score) {
206  /*
207  Returns non-negative integer `offset` if `arg_type` and
208  `sig_types[sig_pos:sig_pos + offset]` match.
209 
210  The `offset` value can be interpreted as the number of extension
211  function arguments that is consumed by the given `arg_type`. For
212  instance, for scalar types the offset is always 1, for array
213  types the offset is 2: one argument for array pointer value and
214  one argument for the array size value, etc.
215 
216  Returns -1 when the types of an argument and the corresponding
217  extension function argument(s) mismatch, or when downcasting would
218  be effective.
219 
220  In case of non-negative `offset` result, the function updates
221  penalty_score argument as follows:
222 
223  add 1000 if arg_type is non-scalar, otherwise:
224  add 1000 * sizeof(sig_type) / sizeof(arg_type)
225  add 1000000 if type kinds differ (integer vs double, for instance)
226 
227  */
228  int max_pos = sig_types.size() - 1;
229  if (sig_pos > max_pos) {
230  return -1;
231  }
232  auto sig_type = sig_types[sig_pos];
233  switch (arg_type.get_type()) {
234  case kBOOLEAN:
235  case kTINYINT:
236  case kSMALLINT:
237  case kINT:
238  case kBIGINT:
239  case kFLOAT:
240  case kDOUBLE:
241  case kDECIMAL:
242  case kNUMERIC:
243  return match_numeric_argument(arg_type, is_arg_literal, sig_type, penalty_score);
244  case kPOINT:
245  case kLINESTRING:
246  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
247  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
248  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble) &&
249  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
250  penalty_score += 1000;
251  return 2;
252  } else if (sig_type == ExtArgumentType::GeoPoint ||
253  sig_type == ExtArgumentType::GeoLineString) {
254  penalty_score += 1000;
255  return 1;
256  }
257  return -1;
258  case kARRAY:
259  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
260  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
261  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble ||
262  sig_type == ExtArgumentType::PBool) &&
263  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
264  penalty_score += 1000;
265  return 2;
266  } else if (is_ext_arg_type_array(sig_type)) {
267  // array arguments must match exactly
268  CHECK(arg_type.is_array());
269  const auto sig_type_ti =
271  if (arg_type.get_elem_type() == kBOOLEAN && sig_type_ti.get_type() == kTINYINT) {
272  /* Boolean array has the same low-level structure as Int8 array. */
273  penalty_score += 1000;
274  return 1;
275  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
276  penalty_score += 1000;
277  return 1;
278  } else {
279  return -1;
280  }
281  }
282  break;
283  case kPOLYGON:
284  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
285  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
286  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
287  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
288  penalty_score += 1000;
289  return 4;
290  } else if (sig_type == ExtArgumentType::GeoPolygon) {
291  penalty_score += 1000;
292  return 1;
293  }
294  break;
295  case kMULTIPOLYGON:
296  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 5 < max_pos &&
297  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
298  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
299  sig_types[sig_pos + 3] == ExtArgumentType::Int64 &&
300  sig_types[sig_pos + 4] == ExtArgumentType::PInt32 &&
301  sig_types[sig_pos + 5] == ExtArgumentType::Int64) {
302  penalty_score += 1000;
303  return 6;
304  } else if (sig_type == ExtArgumentType::GeoMultiPolygon) {
305  penalty_score += 1000;
306  return 1;
307  }
308  break;
309  case kNULLT: // NULL maps to a pointer and size argument
310  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
311  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
312  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble ||
313  sig_type == ExtArgumentType::PBool) &&
314  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
315  penalty_score += 1000;
316  return 2;
317  }
318  break;
319  case kCOLUMN:
320  if (is_ext_arg_type_column(sig_type)) {
321  // column arguments must match exactly
322  const auto sig_type_ti =
324  if (arg_type.get_elem_type() == kBOOLEAN && sig_type_ti.get_type() == kTINYINT) {
325  /* Boolean column has the same low-level structure as Int8 column. */
326  penalty_score += 1000;
327  return 1;
328  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
329  penalty_score += 1000;
330  return 1;
331  } else {
332  return -1;
333  }
334  }
335  break;
336  case kCOLUMN_LIST:
337  if (is_ext_arg_type_column_list(sig_type)) {
338  // column_list arguments must match exactly
339  const auto sig_type_ti =
341  if (arg_type.get_elem_type() == kBOOLEAN && sig_type_ti.get_type() == kTINYINT) {
342  /* Boolean column_list has the same low-level structure as Int8 column_list. */
343  penalty_score += 10000;
344  return 1;
345  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
346  penalty_score += 10000;
347  return 1;
348  } else {
349  return -1;
350  }
351  }
352  break;
353  case kVARCHAR:
354  if (sig_type != ExtArgumentType::TextEncodingNone) {
355  return -1;
356  }
357  switch (arg_type.get_compression()) {
358  case kENCODING_NONE:
359  penalty_score += 1000;
360  return 1;
361  case kENCODING_DICT:
362  return -1;
363  // Todo (todd): Evaluate when and where we can tranlate to dictionary-encoded
364  default:
365  UNREACHABLE();
366  }
367  case kTEXT:
368  if (sig_type != ExtArgumentType::TextEncodingNone) {
369  return -1;
370  }
371  switch (arg_type.get_compression()) {
372  case kENCODING_NONE:
373  penalty_score += 1000;
374  return 1;
375  case kENCODING_DICT:
376  return -1;
377  default:
378  UNREACHABLE();
379  }
380  case kTIMESTAMP:
381  if (arg_type.is_timestamp()) {
382  if (arg_type.get_precision() != 9) {
383  return -1;
384  }
385  penalty_score += 1000;
386  return 1;
387  }
388  break;
389  /* Not implemented types:
390  kCHAR
391  kTIME
392  kDATE
393  kINTERVAL_DAY_TIME
394  kINTERVAL_YEAR_MONTH
395  kGEOMETRY
396  kGEOGRAPHY
397  kEVAL_CONTEXT_TYPE
398  kVOID
399  kCURSOR
400  */
401  default:
402  throw std::runtime_error(std::string(__FILE__) + "#" + std::to_string(__LINE__) +
403  ": support for " + arg_type.get_type_name() +
404  "(type=" + std::to_string(arg_type.get_type()) + ")" +
405  +" not implemented: \n pos=" + std::to_string(sig_pos) +
406  " max_pos=" + std::to_string(max_pos) + "\n sig_types=(" +
407  ExtensionFunctionsWhitelist::toString(sig_types) + ")");
408  }
409  return -1;
410 }
411 
412 bool is_valid_identifier(std::string str) {
413  if (!str.size()) {
414  return false;
415  }
416 
417  if (!(std::isalpha(str[0]) || str[0] == '_')) {
418  return false;
419  }
420 
421  for (size_t i = 1; i < str.size(); i++) {
422  if (!(std::isalnum(str[i]) || str[i] == '_')) {
423  return false;
424  }
425  }
426 
427  return true;
428 }
429 
430 } // namespace
431 
432 template <typename T>
433 std::tuple<T, std::vector<SQLTypeInfo>> bind_function(
434  std::string name,
435  Analyzer::ExpressionPtrVector func_args, // function args from sql query
436  const std::vector<T>& ext_funcs, // list of functions registered
437  const std::string processor) {
438  /* worker function
439 
440  Template type T must implement the following methods:
441 
442  std::vector<ExtArgumentType> getInputArgs()
443  */
444  /*
445  Return extension function/table function that has the following
446  properties
447 
448  1. each argument type in `arg_types` matches with extension
449  function argument types.
450 
451  For scalar types, the matching means that the types are either
452  equal or the argument type is smaller than the corresponding
453  the extension function argument type. This ensures that no
454  information is lost when casting of argument values is
455  required.
456 
457  For array and geo types, the matching means that the argument
458  type matches exactly with a group of extension function
459  argument types. See `match_arguments`.
460 
461  2. has minimal penalty score among all implementations of the
462  extension function with given `name`, see `get_penalty_score`
463  for the definition of penalty score.
464 
465  It is assumed that function_oper and extension functions in
466  ext_funcs have the same name.
467  */
468  if (!is_valid_identifier(name)) {
469  throw NativeExecutionError(
470  "Cannot bind function with invalid UDF/UDTF function name: " + name);
471  }
472 
473  int minimal_score = std::numeric_limits<int>::max();
474  int index = -1;
475  int optimal = -1;
476  int optimal_variant = -1;
477 
478  std::vector<SQLTypeInfo> type_infos_input;
479  std::vector<bool> args_are_constants;
480  for (auto atype : func_args) {
481  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
482  if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
483  SQLTypeInfo type_info = atype->get_type_info();
484  if (atype->get_type_info().get_type() == kTEXT) {
485  auto ti = generate_column_type(type_info.get_type(), // subtype
486  type_info.get_compression(), // compression
487  type_info.get_comp_param()); // comp_param
488  type_infos_input.push_back(ti);
489  args_are_constants.push_back(false);
490  } else {
491  auto ti = generate_column_type(type_info.get_type());
492  type_infos_input.push_back(ti);
493  args_are_constants.push_back(true);
494  }
495  continue;
496  }
497  }
498  type_infos_input.push_back(atype->get_type_info());
499  if (dynamic_cast<const Analyzer::Constant*>(atype.get())) {
500  args_are_constants.push_back(true);
501  } else {
502  args_are_constants.push_back(false);
503  }
504  }
505  CHECK_EQ(type_infos_input.size(), args_are_constants.size());
506 
507  if (type_infos_input.size() == 0 && ext_funcs.size() > 0) {
508  CHECK_EQ(ext_funcs.size(), static_cast<size_t>(1));
509  CHECK_EQ(ext_funcs[0].getInputArgs().size(), static_cast<size_t>(0));
510  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
511  CHECK(ext_funcs[0].hasNonUserSpecifiedOutputSize());
512  }
513  std::vector<SQLTypeInfo> empty_type_info_variant(0);
514  return {ext_funcs[0], empty_type_info_variant};
515  }
516 
517  // clang-format off
518  /*
519  Table functions may have arguments such as ColumnList that collect
520  neighboring columns with the same data type into a single object.
521  Here we compute all possible combinations of mapping a subset of
522  columns into columns sets. For example, if the types of function
523  arguments are (as given in func_args argument)
524 
525  (Column<int>, Column<int>, Column<int>, int)
526 
527  then the computed variants will be
528 
529  (Column<int>, Column<int>, Column<int>, int)
530  (Column<int>, Column<int>, ColumnList[1]<int>, int)
531  (Column<int>, ColumnList[1]<int>, Column<int>, int)
532  (Column<int>, ColumnList[2]<int>, int)
533  (ColumnList[1]<int>, Column<int>, Column<int>, int)
534  (ColumnList[1]<int>, Column<int>, ColumnList[1]<int>, int)
535  (ColumnList[2]<int>, Column<int>, int)
536  (ColumnList[3]<int>, int)
537 
538  where the integers in [..] indicate the number of collected
539  columns. In the SQLTypeInfo instance, this number is stored in the
540  SQLTypeInfo dimension attribute.
541 
542  As an example, let us consider a SQL query containing the
543  following expression calling a UDTF foo:
544 
545  table(foo(cursor(select a, b, c from tableofints), 1))
546 
547  Here follows a list of table functions and the corresponding
548  optimal argument type variants that are computed for the given
549  query expression:
550 
551  UDTF: foo(ColumnList<int>, RowMultiplier) -> Column<int>
552  (ColumnList[3]<int>, int) # a, b, c are all collected to column_list
553 
554  UDTF: foo(Column<int>, ColumnList<int>, RowMultiplier) -> Column<int>
555  (Column<int>, ColumnList[2]<int>, int) # b and c are collected to column_list
556 
557  UDTF: foo(Column<int>, Column<int>, Column<int>, RowMultiplier) -> Column<int>
558  (Column<int>, Column<int>, Column<int>, int)
559  */
560  // clang-format on
561  std::vector<std::vector<SQLTypeInfo>> type_infos_variants;
562  for (auto ti : type_infos_input) {
563  if (type_infos_variants.begin() == type_infos_variants.end()) {
564  type_infos_variants.push_back({ti});
565  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
566  if (ti.is_column()) {
567  auto mti = generate_column_list_type(ti.get_subtype());
568  mti.set_dimension(1);
569  type_infos_variants.push_back({mti});
570  }
571  }
572  continue;
573  }
574  std::vector<std::vector<SQLTypeInfo>> new_type_infos_variants;
575  for (auto& type_infos : type_infos_variants) {
576  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
577  if (ti.is_column()) {
578  auto new_type_infos = type_infos; // makes a copy
579  const auto& last = type_infos.back();
580  if (last.is_column_list() && last.get_subtype() == ti.get_subtype()) {
581  // last column_list consumes column argument if types match
582  new_type_infos.back().set_dimension(last.get_dimension() + 1);
583  } else {
584  // add column as column_list argument
585  auto mti = generate_column_list_type(ti.get_subtype());
586  mti.set_dimension(1);
587  new_type_infos.push_back(mti);
588  }
589  new_type_infos_variants.push_back(new_type_infos);
590  }
591  }
592  type_infos.push_back(ti);
593  }
594  type_infos_variants.insert(type_infos_variants.end(),
595  new_type_infos_variants.begin(),
596  new_type_infos_variants.end());
597  }
598 
599  // Find extension function that gives the best match on the set of
600  // argument type variants:
601  for (auto ext_func : ext_funcs) {
602  index++;
603 
604  auto ext_func_args = ext_func.getInputArgs();
605  int index_variant = -1;
606  for (const auto& type_infos : type_infos_variants) {
607  index_variant++;
608  int penalty_score = 0;
609  int pos = 0;
610  int original_input_idx = 0;
611  CHECK_LE(type_infos.size(), args_are_constants.size());
612  // for (size_t ti_idx = 0; ti_idx != type_infos.size(); ++ti_idx) {
613  for (const auto& ti : type_infos) {
614  int offset = match_arguments(ti,
615  args_are_constants[original_input_idx],
616  pos,
617  ext_func_args,
618  penalty_score);
619  if (offset < 0) {
620  // atype does not match with ext_func argument
621  pos = -1;
622  break;
623  }
624  if (ti.get_type() == kCOLUMN_LIST) {
625  original_input_idx += ti.get_dimension();
626  } else {
627  original_input_idx++;
628  }
629  pos += offset;
630  }
631 
632  if ((size_t)pos == ext_func_args.size()) {
633  CHECK_EQ(args_are_constants.size(), original_input_idx);
634  // prefer smaller return types
635  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
636  if (penalty_score < minimal_score) {
637  optimal = index;
638  minimal_score = penalty_score;
639  optimal_variant = index_variant;
640  }
641  }
642  }
643  }
644 
645  if (optimal == -1) {
646  /* no extension function found that argument types would match
647  with types in `arg_types` */
648  auto sarg_types = ExtensionFunctionsWhitelist::toString(type_infos_input);
649  std::string message;
650  if (!ext_funcs.size()) {
651  message = "Function " + name + "(" + sarg_types + ") not supported.";
652  throw ExtensionFunctionBindingError(message);
653  } else {
654  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
655  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
656  " UDTF implementation.";
657  } else if constexpr (std::is_same_v<T, ExtensionFunction>) {
658  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
659  " UDF implementation.";
660  } else {
661  LOG(FATAL) << "bind_function: unknown extension function type "
662  << typeid(T).name();
663  }
664  message += "\n Existing extension function implementations:";
665  for (const auto& ext_func : ext_funcs) {
666  // Do not show functions missing the sizer argument
667  if constexpr (std::is_same_v<T, table_functions::TableFunction>)
668  if (ext_func.useDefaultSizer())
669  continue;
670  message += "\n " + ext_func.toStringSQL();
671  }
672  }
673  throw ExtensionFunctionBindingError(message);
674  }
675 
676  // Functions with "_default_" suffix only exist for calcite
677  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
678  if (ext_funcs[optimal].hasUserSpecifiedOutputSizeMultiplier() &&
679  ext_funcs[optimal].useDefaultSizer()) {
680  std::string name = ext_funcs[optimal].getName();
681  name.erase(name.find(DEFAULT_ROW_MULTIPLIER_SUFFIX),
683  for (size_t i = 0; i < ext_funcs.size(); i++) {
684  if (ext_funcs[i].getName() == name) {
685  optimal = i;
686  std::vector<SQLTypeInfo> type_info = type_infos_variants[optimal_variant];
687  size_t sizer = ext_funcs[optimal].getOutputRowSizeParameter();
688  type_info.insert(type_info.begin() + sizer - 1, SQLTypeInfo(kINT, true));
689  return {ext_funcs[optimal], type_info};
690  }
691  }
692  UNREACHABLE();
693  }
694  }
695 
696  return {ext_funcs[optimal], type_infos_variants[optimal_variant]};
697 }
698 
699 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
702  const std::vector<table_functions::TableFunction>& table_funcs,
703  const bool is_gpu) {
704  std::string processor = (is_gpu ? "GPU" : "CPU");
705  return bind_function<table_functions::TableFunction>(
706  name, input_args, table_funcs, processor);
707 }
708 
710  Analyzer::ExpressionPtrVector func_args) {
711  // used in RelAlgTranslator.cpp, first try GPU UDFs, then fall back
712  // to CPU UDFs.
713  bool is_gpu = true;
714  std::string processor = "GPU";
715  auto ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
716  if (!ext_funcs.size()) {
717  is_gpu = false;
718  processor = "CPU";
719  ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
720  }
721  try {
722  return std::get<0>(
723  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
724  } catch (ExtensionFunctionBindingError& e) {
725  if (is_gpu) {
726  is_gpu = false;
727  processor = "GPU|CPU";
728  ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
729  return std::get<0>(
730  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
731  } else {
732  throw;
733  }
734  }
735 }
736 
739  const bool is_gpu) {
740  // used below
741  std::vector<ExtensionFunction> ext_funcs =
743  std::string processor = (is_gpu ? "GPU" : "CPU");
744  return std::get<0>(
745  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
746 }
747 
749  const bool is_gpu) {
750  // used in ExtensionsIR.cpp
751  auto name = function_oper->getName();
752  Analyzer::ExpressionPtrVector func_args = {};
753  for (size_t i = 0; i < function_oper->getArity(); ++i) {
754  func_args.push_back(function_oper->getOwnArg(i));
755  }
756  return bind_function(name, func_args, is_gpu);
757 }
758 
759 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
762  const bool is_gpu) {
763  // used in RelAlgExecutor.cpp
764  std::vector<table_functions::TableFunction> table_funcs =
766  return bind_table_function(name, input_args, table_funcs, is_gpu);
767 }
#define CHECK_EQ(x, y)
Definition: Logger.h:230
ExtArgumentType get_array_arg_elem_type(const ExtArgumentType ext_arg_array_type)
static std::vector< TableFunction > get_table_funcs()
size_t getArity() const
Definition: Analyzer.h:2169
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
bool is_timestamp() const
Definition: sqltypes.h:895
bool is_ext_arg_type_column(const ExtArgumentType ext_arg_type)
#define LOG(tag)
Definition: Logger.h:216
auto generate_column_type(const SQLTypes subtype)
Definition: sqltypes.h:1124
static int match_arguments(const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
#define UNREACHABLE()
Definition: Logger.h:266
#define CHECK_GE(x, y)
Definition: Logger.h:235
ExtArgumentType get_column_list_arg_elem_type(const ExtArgumentType ext_arg_column_list_type)
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:329
#define CHECK_GT(x, y)
Definition: Logger.h:234
std::string to_string(char const *&&v)
bool is_ext_arg_type_column_list(const ExtArgumentType ext_arg_type)
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:2176
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
bool is_numeric_scalar_auto_castable(const SQLTypeInfo &new_type_info) const
returns true if the sql_type can be cast to the type specified by new_type_info with no loss of preci...
Definition: sqltypes.h:652
auto generate_column_list_type(const SQLTypes subtype)
Definition: sqltypes.h:1138
ExtArgumentType get_column_arg_elem_type(const ExtArgumentType ext_arg_column_type)
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
Argument type based extension function binding.
int get_precision() const
Definition: sqltypes.h:332
static int match_numeric_argument(const SQLTypeInfo &arg_type_info, const bool is_arg_literal, const ExtArgumentType &sig_ext_arg_type, int32_t &penalty_score)
Definition: sqltypes.h:52
#define CHECK_LE(x, y)
Definition: Logger.h:233
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:337
std::string get_type_name() const
Definition: sqltypes.h:443
int32_t get_numeric_scalar_scale() const
returns integer between 1 and 8 indicating what is roughly equivalent to the logical byte size of a s...
Definition: sqltypes.h:713
const std::tuple< table_functions::TableFunction, std::vector< SQLTypeInfo > > bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
HOST DEVICE int get_comp_param() const
Definition: sqltypes.h:338
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:222
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:188
Definition: sqltypes.h:45
std::string getName() const
Definition: Analyzer.h:2167
string name
Definition: setup.in.py:72
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:865
bool is_array() const
Definition: sqltypes.h:518
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)