OmniSciDB  8fa3bf436f
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ExtensionFunctionsBinding.cpp File Reference
#include "ExtensionFunctionsBinding.h"
#include <algorithm>
#include "ExternalExecutor.h"
+ Include dependency graph for ExtensionFunctionsBinding.cpp:

Go to the source code of this file.

Namespaces

 anonymous_namespace{ExtensionFunctionsBinding.cpp}
 

Functions

ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_column_arg_elem_type (const ExtArgumentType ext_arg_column_type)
 
ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_column_list_arg_elem_type (const ExtArgumentType ext_arg_column_list_type)
 
ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_array_arg_elem_type (const ExtArgumentType ext_arg_array_type)
 
static int anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_arguments (const SQLTypeInfo &arg_type, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
 
bool anonymous_namespace{ExtensionFunctionsBinding.cpp}::is_valid_identifier (std::string str)
 
template<typename T >
std::tuple< T, std::vector
< SQLTypeInfo > > 
bind_function (std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
 
const std::tuple
< table_functions::TableFunction,
std::vector< SQLTypeInfo > > 
bind_table_function (std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
 
ExtensionFunction bind_function (std::string name, Analyzer::ExpressionPtrVector func_args)
 
ExtensionFunction bind_function (std::string name, Analyzer::ExpressionPtrVector func_args, const bool is_gpu)
 
ExtensionFunction bind_function (const Analyzer::FunctionOper *function_oper, const bool is_gpu)
 
const std::tuple
< table_functions::TableFunction,
std::vector< SQLTypeInfo > > 
bind_table_function (std::string name, Analyzer::ExpressionPtrVector input_args, const bool is_gpu)
 

Function Documentation

template<typename T >
std::tuple<T, std::vector<SQLTypeInfo> > bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args,
const std::vector< T > &  ext_funcs,
const std::string  processor 
)

Definition at line 408 of file ExtensionFunctionsBinding.cpp.

References DEFAULT_ROW_MULTIPLIER_SUFFIX, ext_arg_type_to_type_info(), logger::FATAL, generate_column_list_type(), generate_column_type(), i, anonymous_namespace{ExtensionFunctionsBinding.cpp}::is_valid_identifier(), kINT, LOG, anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_arguments(), setup::name, generate_TableFunctionsFactory_init::sizer, omnisci.dtypes::T, ExtensionFunctionsWhitelist::toString(), and UNREACHABLE.

Referenced by bind_function(), CodeGenerator::codegenFunctionOper(), and RelAlgTranslator::translateFunction().

412  {
413  /* worker function
414 
415  Template type T must implement the following methods:
416 
417  std::vector<ExtArgumentType> getInputArgs()
418  */
419  /*
420  Return extension function/table function that has the following
421  properties
422 
423  1. each argument type in `arg_types` matches with extension
424  function argument types.
425 
426  For scalar types, the matching means that the types are either
427  equal or the argument type is smaller than the corresponding
428  the extension function argument type. This ensures that no
429  information is lost when casting of argument values is
430  required.
431 
432  For array and geo types, the matching means that the argument
433  type matches exactly with a group of extension function
434  argument types. See `match_arguments`.
435 
436  2. has minimal penalty score among all implementations of the
437  extension function with given `name`, see `get_penalty_score`
438  for the definition of penalty score.
439 
440  It is assumed that function_oper and extension functions in
441  ext_funcs have the same name.
442  */
443  if (!is_valid_identifier(name)) {
444  throw NativeExecutionError(
445  "Cannot bind function with invalid UDF/UDTF function name: " + name);
446  }
447 
448  int minimal_score = std::numeric_limits<int>::max();
449  int index = -1;
450  int optimal = -1;
451  int optimal_variant = -1;
452 
453  std::vector<SQLTypeInfo> type_infos_input;
454  for (auto atype : func_args) {
455  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
456  if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
457  auto ti = generate_column_type(atype->get_type_info().get_type());
458  type_infos_input.push_back(ti);
459  continue;
460  }
461  }
462  type_infos_input.push_back(atype->get_type_info());
463  }
464 
465  // clang-format off
466  /*
467  Table functions may have arguments such as ColumnList that collect
468  neighboring columns with the same data type into a single object.
469  Here we compute all possible combinations of mapping a subset of
470  columns into columns sets. For example, if the types of function
471  arguments are (as given in func_args argument)
472 
473  (Column<int>, Column<int>, Column<int>, int)
474 
475  then the computed variants will be
476 
477  (Column<int>, Column<int>, Column<int>, int)
478  (Column<int>, Column<int>, ColumnList[1]<int>, int)
479  (Column<int>, ColumnList[1]<int>, Column<int>, int)
480  (Column<int>, ColumnList[2]<int>, int)
481  (ColumnList[1]<int>, Column<int>, Column<int>, int)
482  (ColumnList[1]<int>, Column<int>, ColumnList[1]<int>, int)
483  (ColumnList[2]<int>, Column<int>, int)
484  (ColumnList[3]<int>, int)
485 
486  where the integers in [..] indicate the number of collected
487  columns. In the SQLTypeInfo instance, this number is stored in the
488  SQLTypeInfo dimension attribute.
489 
490  As an example, let us consider a SQL query containing the
491  following expression calling a UDTF foo:
492 
493  table(foo(cursor(select a, b, c from tableofints), 1))
494 
495  Here follows a list of table functions and the corresponding
496  optimal argument type variants that are computed for the given
497  query expression:
498 
499  UDTF: foo(ColumnList<int>, RowMultiplier) -> Column<int>
500  (ColumnList[3]<int>, int) # a, b, c are all collected to column_list
501 
502  UDTF: foo(Column<int>, ColumnList<int>, RowMultiplier) -> Column<int>
503  (Column<int>, ColumnList[2]<int>, int) # b and c are collected to column_list
504 
505  UDTF: foo(Column<int>, Column<int>, Column<int>, RowMultiplier) -> Column<int>
506  (Column<int>, Column<int>, Column<int>, int)
507  */
508  // clang-format on
509  std::vector<std::vector<SQLTypeInfo>> type_infos_variants;
510  for (auto ti : type_infos_input) {
511  if (type_infos_variants.begin() == type_infos_variants.end()) {
512  type_infos_variants.push_back({ti});
513  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
514  if (ti.is_column()) {
515  auto mti = generate_column_list_type(ti.get_subtype());
516  mti.set_dimension(1);
517  type_infos_variants.push_back({mti});
518  }
519  }
520  continue;
521  }
522  std::vector<std::vector<SQLTypeInfo>> new_type_infos_variants;
523  for (auto& type_infos : type_infos_variants) {
524  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
525  if (ti.is_column()) {
526  auto new_type_infos = type_infos; // makes a copy
527  const auto& last = type_infos.back();
528  if (last.is_column_list() && last.get_subtype() == ti.get_subtype()) {
529  // last column_list consumes column argument if types match
530  new_type_infos.back().set_dimension(last.get_dimension() + 1);
531  } else {
532  // add column as column_list argument
533  auto mti = generate_column_list_type(ti.get_subtype());
534  mti.set_dimension(1);
535  new_type_infos.push_back(mti);
536  }
537  new_type_infos_variants.push_back(new_type_infos);
538  }
539  }
540  type_infos.push_back(ti);
541  }
542  type_infos_variants.insert(type_infos_variants.end(),
543  new_type_infos_variants.begin(),
544  new_type_infos_variants.end());
545  }
546 
547  // Find extension function that gives the best match on the set of
548  // argument type variants:
549  for (auto ext_func : ext_funcs) {
550  index++;
551 
552  auto ext_func_args = ext_func.getInputArgs();
553  int index_variant = -1;
554  for (const auto& type_infos : type_infos_variants) {
555  index_variant++;
556  int penalty_score = 0;
557  int pos = 0;
558  for (const auto& ti : type_infos) {
559  int offset = match_arguments(ti, pos, ext_func_args, penalty_score);
560  if (offset < 0) {
561  // atype does not match with ext_func argument
562  pos = -1;
563  break;
564  }
565  pos += offset;
566  }
567 
568  if ((size_t)pos == ext_func_args.size()) {
569  // prefer smaller return types
570  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
571  if (penalty_score < minimal_score) {
572  optimal = index;
573  minimal_score = penalty_score;
574  optimal_variant = index_variant;
575  }
576  }
577  }
578  }
579 
580  if (optimal == -1) {
581  /* no extension function found that argument types would match
582  with types in `arg_types` */
583  auto sarg_types = ExtensionFunctionsWhitelist::toString(type_infos_input);
584  std::string message;
585  if (!ext_funcs.size()) {
586  message = "Function " + name + "(" + sarg_types + ") not supported.";
587  throw ExtensionFunctionBindingError(message);
588  } else {
589  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
590  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
591  " UDTF implementation.";
592  } else if constexpr (std::is_same_v<T, ExtensionFunction>) {
593  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
594  " UDF implementation.";
595  } else {
596  LOG(FATAL) << "bind_function: unknown extension function type "
597  << typeid(T).name();
598  }
599  message += "\n Existing extension function implementations:";
600  for (const auto& ext_func : ext_funcs) {
601  // Do not show functions missing the sizer argument
602  if constexpr (std::is_same_v<T, table_functions::TableFunction>)
603  if (ext_func.useDefaultSizer())
604  continue;
605  message += "\n " + ext_func.toStringSQL();
606  }
607  }
608  throw ExtensionFunctionBindingError(message);
609  }
610 
611  // Functions with "_default_" suffix only exist for calcite
612  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
613  if (ext_funcs[optimal].hasUserSpecifiedOutputSizeMultiplier() &&
614  ext_funcs[optimal].useDefaultSizer()) {
615  std::string name = ext_funcs[optimal].getName();
616  name.erase(name.find(DEFAULT_ROW_MULTIPLIER_SUFFIX),
618  for (size_t i = 0; i < ext_funcs.size(); i++) {
619  if (ext_funcs[i].getName() == name) {
620  optimal = i;
621  std::vector<SQLTypeInfo> type_info = type_infos_variants[optimal_variant];
622  size_t sizer = ext_funcs[optimal].getOutputRowSizeParameter();
623  type_info.insert(type_info.begin() + sizer - 1, SQLTypeInfo(kINT, true));
624  return {ext_funcs[optimal], type_info};
625  }
626  }
627  UNREACHABLE();
628  }
629  }
630 
631  return {ext_funcs[optimal], type_infos_variants[optimal_variant]};
632 }
#define LOG(tag)
Definition: Logger.h:194
auto generate_column_type(const SQLTypes subtype)
Definition: sqltypes.h:981
string name
Definition: setup.in.py:72
#define UNREACHABLE()
Definition: Logger.h:247
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
auto generate_column_list_type(const SQLTypes subtype)
Definition: sqltypes.h:987
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
static int match_arguments(const SQLTypeInfo &arg_type, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
Definition: sqltypes.h:44
if(yyssp >=yyss+yystacksize-1)
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

ExtensionFunction bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args 
)

Definition at line 644 of file ExtensionFunctionsBinding.cpp.

References ExtensionFunctionsWhitelist::get_ext_funcs(), and setup::name.

645  {
646  // used in RelAlgTranslator.cpp, first try GPU UDFs, then fall back
647  // to CPU UDFs.
648  bool is_gpu = true;
649  std::string processor = "GPU";
650  auto ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
651  if (!ext_funcs.size()) {
652  is_gpu = false;
653  processor = "CPU";
655  }
656  try {
657  return std::get<0>(
658  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
659  } catch (ExtensionFunctionBindingError& e) {
660  if (is_gpu) {
661  is_gpu = false;
662  processor = "GPU|CPU";
664  return std::get<0>(
665  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
666  } else {
667  throw;
668  }
669  }
670 }
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name, const bool is_gpu)
string name
Definition: setup.in.py:72

+ Here is the call graph for this function:

ExtensionFunction bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args,
const bool  is_gpu 
)

Definition at line 672 of file ExtensionFunctionsBinding.cpp.

References ExtensionFunctionsWhitelist::get_ext_funcs(), and setup::name.

674  {
675  // used below
676  std::vector<ExtensionFunction> ext_funcs =
678  std::string processor = (is_gpu ? "GPU" : "CPU");
679  return std::get<0>(
680  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
681 }
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name, const bool is_gpu)
string name
Definition: setup.in.py:72

+ Here is the call graph for this function:

ExtensionFunction bind_function ( const Analyzer::FunctionOper function_oper,
const bool  is_gpu 
)

Definition at line 683 of file ExtensionFunctionsBinding.cpp.

References bind_function(), Analyzer::FunctionOper::getArity(), Analyzer::FunctionOper::getName(), Analyzer::FunctionOper::getOwnArg(), i, and setup::name.

684  {
685  // used in ExtensionsIR.cpp
686  auto name = function_oper->getName();
687  Analyzer::ExpressionPtrVector func_args = {};
688  for (size_t i = 0; i < function_oper->getArity(); ++i) {
689  func_args.push_back(function_oper->getOwnArg(i));
690  }
691  return bind_function(name, func_args, is_gpu);
692 }
size_t getArity() const
Definition: Analyzer.h:1360
string name
Definition: setup.in.py:72
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:1367
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:182
std::string getName() const
Definition: Analyzer.h:1358

+ Here is the call graph for this function:

const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo> > bind_table_function ( std::string  name,
Analyzer::ExpressionPtrVector  input_args,
const std::vector< table_functions::TableFunction > &  table_funcs,
const bool  is_gpu 
)

Definition at line 635 of file ExtensionFunctionsBinding.cpp.

References setup::name.

Referenced by bind_table_function(), and RelAlgExecutor::createTableFunctionWorkUnit().

638  {
639  std::string processor = (is_gpu ? "GPU" : "CPU");
640  return bind_function<table_functions::TableFunction>(
641  name, input_args, table_funcs, processor);
642 }
string name
Definition: setup.in.py:72

+ Here is the caller graph for this function:

const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo> > bind_table_function ( std::string  name,
Analyzer::ExpressionPtrVector  input_args,
const bool  is_gpu 
)

Definition at line 695 of file ExtensionFunctionsBinding.cpp.

References bind_table_function(), and table_functions::TableFunctionsFactory::get_table_funcs().

697  {
698  // used in RelAlgExecutor.cpp
699  std::vector<table_functions::TableFunction> table_funcs =
701  return bind_table_function(name, input_args, table_funcs, is_gpu);
702 }
string name
Definition: setup.in.py:72
const std::tuple< table_functions::TableFunction, std::vector< SQLTypeInfo > > bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
static std::vector< TableFunction > get_table_funcs(const std::string &name, const bool is_gpu)

+ Here is the call graph for this function: