OmniSciDB  6686921089
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsBinding.cpp File Reference
#include "ExtensionFunctionsBinding.h"
#include <algorithm>
#include "ExternalExecutor.h"
+ Include dependency graph for ExtensionFunctionsBinding.cpp:

Go to the source code of this file.

Namespaces

 anonymous_namespace{ExtensionFunctionsBinding.cpp}
 

Functions

ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_column_arg_elem_type (const ExtArgumentType ext_arg_column_type)
 
ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_column_list_arg_elem_type (const ExtArgumentType ext_arg_column_list_type)
 
ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_array_arg_elem_type (const ExtArgumentType ext_arg_array_type)
 
static int anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_numeric_argument (const SQLTypeInfo &arg_type_info, const bool is_arg_literal, const ExtArgumentType &sig_ext_arg_type, int32_t &penalty_score)
 
static int anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_arguments (const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
 
bool anonymous_namespace{ExtensionFunctionsBinding.cpp}::is_valid_identifier (std::string str)
 
template<typename T >
std::tuple< T, std::vector
< SQLTypeInfo > > 
bind_function (std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
 
const std::tuple
< table_functions::TableFunction,
std::vector< SQLTypeInfo > > 
bind_table_function (std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
 
ExtensionFunction bind_function (std::string name, Analyzer::ExpressionPtrVector func_args)
 
ExtensionFunction bind_function (std::string name, Analyzer::ExpressionPtrVector func_args, const bool is_gpu)
 
ExtensionFunction bind_function (const Analyzer::FunctionOper *function_oper, const bool is_gpu)
 
const std::tuple
< table_functions::TableFunction,
std::vector< SQLTypeInfo > > 
bind_table_function (std::string name, Analyzer::ExpressionPtrVector input_args, const bool is_gpu)
 

Function Documentation

template<typename T >
std::tuple<T, std::vector<SQLTypeInfo> > bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args,
const std::vector< T > &  ext_funcs,
const std::string  processor 
)

Definition at line 412 of file ExtensionFunctionsBinding.cpp.

References CHECK, CHECK_EQ, CHECK_LE, DEFAULT_ROW_MULTIPLIER_SUFFIX, ext_arg_type_to_type_info(), logger::FATAL, generate_column_list_type(), generate_column_type(), SQLTypeInfo::get_comp_param(), SQLTypeInfo::get_compression(), SQLTypeInfo::get_type(), i, anonymous_namespace{ExtensionFunctionsBinding.cpp}::is_valid_identifier(), kCOLUMN_LIST, kINT, kTEXT, LOG, anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_arguments(), setup::name, omnisci.dtypes::T, ExtensionFunctionsWhitelist::toString(), and UNREACHABLE.

Referenced by bind_function(), CodeGenerator::codegenFunctionOper(), and RelAlgTranslator::translateFunction().

416  {
417  /* worker function
418 
419  Template type T must implement the following methods:
420 
421  std::vector<ExtArgumentType> getInputArgs()
422  */
423  /*
424  Return extension function/table function that has the following
425  properties
426 
427  1. each argument type in `arg_types` matches with extension
428  function argument types.
429 
430  For scalar types, the matching means that the types are either
431  equal or the argument type is smaller than the corresponding
432  the extension function argument type. This ensures that no
433  information is lost when casting of argument values is
434  required.
435 
436  For array and geo types, the matching means that the argument
437  type matches exactly with a group of extension function
438  argument types. See `match_arguments`.
439 
440  2. has minimal penalty score among all implementations of the
441  extension function with given `name`, see `get_penalty_score`
442  for the definition of penalty score.
443 
444  It is assumed that function_oper and extension functions in
445  ext_funcs have the same name.
446  */
447  if (!is_valid_identifier(name)) {
448  throw NativeExecutionError(
449  "Cannot bind function with invalid UDF/UDTF function name: " + name);
450  }
451 
452  int minimal_score = std::numeric_limits<int>::max();
453  int index = -1;
454  int optimal = -1;
455  int optimal_variant = -1;
456 
457  std::vector<SQLTypeInfo> type_infos_input;
458  std::vector<bool> args_are_constants;
459  for (auto atype : func_args) {
460  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
461  if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
462  SQLTypeInfo type_info = atype->get_type_info();
463  if (atype->get_type_info().get_type() == kTEXT) {
464  auto ti = generate_column_type(type_info.get_type(), // subtype
465  type_info.get_compression(), // compression
466  type_info.get_comp_param()); // comp_param
467  type_infos_input.push_back(ti);
468  args_are_constants.push_back(false);
469  } else {
470  auto ti = generate_column_type(type_info.get_type());
471  type_infos_input.push_back(ti);
472  args_are_constants.push_back(true);
473  }
474  continue;
475  }
476  }
477  type_infos_input.push_back(atype->get_type_info());
478  if (dynamic_cast<const Analyzer::Constant*>(atype.get())) {
479  args_are_constants.push_back(true);
480  } else {
481  args_are_constants.push_back(false);
482  }
483  }
484  CHECK_EQ(type_infos_input.size(), args_are_constants.size());
485 
486  if (type_infos_input.size() == 0 && ext_funcs.size() > 0) {
487  CHECK_EQ(ext_funcs.size(), static_cast<size_t>(1));
488  CHECK_EQ(ext_funcs[0].getInputArgs().size(), static_cast<size_t>(0));
489  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
490  CHECK(ext_funcs[0].hasNonUserSpecifiedOutputSize());
491  }
492  std::vector<SQLTypeInfo> empty_type_info_variant(0);
493  return {ext_funcs[0], empty_type_info_variant};
494  }
495 
496  // clang-format off
497  /*
498  Table functions may have arguments such as ColumnList that collect
499  neighboring columns with the same data type into a single object.
500  Here we compute all possible combinations of mapping a subset of
501  columns into columns sets. For example, if the types of function
502  arguments are (as given in func_args argument)
503 
504  (Column<int>, Column<int>, Column<int>, int)
505 
506  then the computed variants will be
507 
508  (Column<int>, Column<int>, Column<int>, int)
509  (Column<int>, Column<int>, ColumnList[1]<int>, int)
510  (Column<int>, ColumnList[1]<int>, Column<int>, int)
511  (Column<int>, ColumnList[2]<int>, int)
512  (ColumnList[1]<int>, Column<int>, Column<int>, int)
513  (ColumnList[1]<int>, Column<int>, ColumnList[1]<int>, int)
514  (ColumnList[2]<int>, Column<int>, int)
515  (ColumnList[3]<int>, int)
516 
517  where the integers in [..] indicate the number of collected
518  columns. In the SQLTypeInfo instance, this number is stored in the
519  SQLTypeInfo dimension attribute.
520 
521  As an example, let us consider a SQL query containing the
522  following expression calling a UDTF foo:
523 
524  table(foo(cursor(select a, b, c from tableofints), 1))
525 
526  Here follows a list of table functions and the corresponding
527  optimal argument type variants that are computed for the given
528  query expression:
529 
530  UDTF: foo(ColumnList<int>, RowMultiplier) -> Column<int>
531  (ColumnList[3]<int>, int) # a, b, c are all collected to column_list
532 
533  UDTF: foo(Column<int>, ColumnList<int>, RowMultiplier) -> Column<int>
534  (Column<int>, ColumnList[2]<int>, int) # b and c are collected to column_list
535 
536  UDTF: foo(Column<int>, Column<int>, Column<int>, RowMultiplier) -> Column<int>
537  (Column<int>, Column<int>, Column<int>, int)
538  */
539  // clang-format on
540  std::vector<std::vector<SQLTypeInfo>> type_infos_variants;
541  for (auto ti : type_infos_input) {
542  if (type_infos_variants.begin() == type_infos_variants.end()) {
543  type_infos_variants.push_back({ti});
544  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
545  if (ti.is_column()) {
546  auto mti = generate_column_list_type(ti.get_subtype());
547  mti.set_dimension(1);
548  type_infos_variants.push_back({mti});
549  }
550  }
551  continue;
552  }
553  std::vector<std::vector<SQLTypeInfo>> new_type_infos_variants;
554  for (auto& type_infos : type_infos_variants) {
555  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
556  if (ti.is_column()) {
557  auto new_type_infos = type_infos; // makes a copy
558  const auto& last = type_infos.back();
559  if (last.is_column_list() && last.get_subtype() == ti.get_subtype()) {
560  // last column_list consumes column argument if types match
561  new_type_infos.back().set_dimension(last.get_dimension() + 1);
562  } else {
563  // add column as column_list argument
564  auto mti = generate_column_list_type(ti.get_subtype());
565  mti.set_dimension(1);
566  new_type_infos.push_back(mti);
567  }
568  new_type_infos_variants.push_back(new_type_infos);
569  }
570  }
571  type_infos.push_back(ti);
572  }
573  type_infos_variants.insert(type_infos_variants.end(),
574  new_type_infos_variants.begin(),
575  new_type_infos_variants.end());
576  }
577 
578  // Find extension function that gives the best match on the set of
579  // argument type variants:
580  for (auto ext_func : ext_funcs) {
581  index++;
582 
583  auto ext_func_args = ext_func.getInputArgs();
584  int index_variant = -1;
585  for (const auto& type_infos : type_infos_variants) {
586  index_variant++;
587  int penalty_score = 0;
588  int pos = 0;
589  int original_input_idx = 0;
590  CHECK_LE(type_infos.size(), args_are_constants.size());
591  // for (size_t ti_idx = 0; ti_idx != type_infos.size(); ++ti_idx) {
592  for (const auto& ti : type_infos) {
593  int offset = match_arguments(ti,
594  args_are_constants[original_input_idx],
595  pos,
596  ext_func_args,
597  penalty_score);
598  if (offset < 0) {
599  // atype does not match with ext_func argument
600  pos = -1;
601  break;
602  }
603  if (ti.get_type() == kCOLUMN_LIST) {
604  original_input_idx += ti.get_dimension();
605  } else {
606  original_input_idx++;
607  }
608  pos += offset;
609  }
610 
611  if ((size_t)pos == ext_func_args.size()) {
612  CHECK_EQ(args_are_constants.size(), original_input_idx);
613  // prefer smaller return types
614  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
615  if (penalty_score < minimal_score) {
616  optimal = index;
617  minimal_score = penalty_score;
618  optimal_variant = index_variant;
619  }
620  }
621  }
622  }
623 
624  if (optimal == -1) {
625  /* no extension function found that argument types would match
626  with types in `arg_types` */
627  auto sarg_types = ExtensionFunctionsWhitelist::toString(type_infos_input);
628  std::string message;
629  if (!ext_funcs.size()) {
630  message = "Function " + name + "(" + sarg_types + ") not supported.";
631  throw ExtensionFunctionBindingError(message);
632  } else {
633  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
634  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
635  " UDTF implementation.";
636  } else if constexpr (std::is_same_v<T, ExtensionFunction>) {
637  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
638  " UDF implementation.";
639  } else {
640  LOG(FATAL) << "bind_function: unknown extension function type "
641  << typeid(T).name();
642  }
643  message += "\n Existing extension function implementations:";
644  for (const auto& ext_func : ext_funcs) {
645  // Do not show functions missing the sizer argument
646  if constexpr (std::is_same_v<T, table_functions::TableFunction>)
647  if (ext_func.useDefaultSizer())
648  continue;
649  message += "\n " + ext_func.toStringSQL();
650  }
651  }
652  throw ExtensionFunctionBindingError(message);
653  }
654 
655  // Functions with "_default_" suffix only exist for calcite
656  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
657  if (ext_funcs[optimal].hasUserSpecifiedOutputSizeMultiplier() &&
658  ext_funcs[optimal].useDefaultSizer()) {
659  std::string name = ext_funcs[optimal].getName();
660  name.erase(name.find(DEFAULT_ROW_MULTIPLIER_SUFFIX),
662  for (size_t i = 0; i < ext_funcs.size(); i++) {
663  if (ext_funcs[i].getName() == name) {
664  optimal = i;
665  std::vector<SQLTypeInfo> type_info = type_infos_variants[optimal_variant];
666  size_t sizer = ext_funcs[optimal].getOutputRowSizeParameter();
667  type_info.insert(type_info.begin() + sizer - 1, SQLTypeInfo(kINT, true));
668  return {ext_funcs[optimal], type_info};
669  }
670  }
671  UNREACHABLE();
672  }
673  }
674 
675  return {ext_funcs[optimal], type_infos_variants[optimal_variant]};
676 }
#define CHECK_EQ(x, y)
Definition: Logger.h:217
#define LOG(tag)
Definition: Logger.h:203
auto generate_column_type(const SQLTypes subtype)
Definition: sqltypes.h:1119
string name
Definition: setup.in.py:72
static int match_arguments(const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
#define UNREACHABLE()
Definition: Logger.h:253
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:329
auto generate_column_list_type(const SQLTypes subtype)
Definition: sqltypes.h:1133
Definition: sqltypes.h:52
#define CHECK_LE(x, y)
Definition: Logger.h:220
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:337
HOST DEVICE int get_comp_param() const
Definition: sqltypes.h:338
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:209
Definition: sqltypes.h:45
constexpr double n
Definition: Utm.h:46
if(yyssp >=yyss+yystacksize-1)
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

ExtensionFunction bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args 
)

Definition at line 688 of file ExtensionFunctionsBinding.cpp.

References ExtensionFunctionsWhitelist::get_ext_funcs(), and setup::name.

689  {
690  // used in RelAlgTranslator.cpp, first try GPU UDFs, then fall back
691  // to CPU UDFs.
692  bool is_gpu = true;
693  std::string processor = "GPU";
694  auto ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
695  if (!ext_funcs.size()) {
696  is_gpu = false;
697  processor = "CPU";
699  }
700  try {
701  return std::get<0>(
702  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
703  } catch (ExtensionFunctionBindingError& e) {
704  if (is_gpu) {
705  is_gpu = false;
706  processor = "GPU|CPU";
708  return std::get<0>(
709  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
710  } else {
711  throw;
712  }
713  }
714 }
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name, const bool is_gpu)
string name
Definition: setup.in.py:72

+ Here is the call graph for this function:

ExtensionFunction bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args,
const bool  is_gpu 
)

Definition at line 716 of file ExtensionFunctionsBinding.cpp.

References ExtensionFunctionsWhitelist::get_ext_funcs(), and setup::name.

718  {
719  // used below
720  std::vector<ExtensionFunction> ext_funcs =
722  std::string processor = (is_gpu ? "GPU" : "CPU");
723  return std::get<0>(
724  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
725 }
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name, const bool is_gpu)
string name
Definition: setup.in.py:72

+ Here is the call graph for this function:

ExtensionFunction bind_function ( const Analyzer::FunctionOper function_oper,
const bool  is_gpu 
)

Definition at line 727 of file ExtensionFunctionsBinding.cpp.

References bind_function(), Analyzer::FunctionOper::getArity(), Analyzer::FunctionOper::getName(), Analyzer::FunctionOper::getOwnArg(), i, and setup::name.

728  {
729  // used in ExtensionsIR.cpp
730  auto name = function_oper->getName();
731  Analyzer::ExpressionPtrVector func_args = {};
732  for (size_t i = 0; i < function_oper->getArity(); ++i) {
733  func_args.push_back(function_oper->getOwnArg(i));
734  }
735  return bind_function(name, func_args, is_gpu);
736 }
size_t getArity() const
Definition: Analyzer.h:1515
string name
Definition: setup.in.py:72
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:1522
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:181
std::string getName() const
Definition: Analyzer.h:1513

+ Here is the call graph for this function:

const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo> > bind_table_function ( std::string  name,
Analyzer::ExpressionPtrVector  input_args,
const std::vector< table_functions::TableFunction > &  table_funcs,
const bool  is_gpu 
)

Definition at line 679 of file ExtensionFunctionsBinding.cpp.

References setup::name.

Referenced by bind_table_function(), and RelAlgExecutor::createTableFunctionWorkUnit().

682  {
683  std::string processor = (is_gpu ? "GPU" : "CPU");
684  return bind_function<table_functions::TableFunction>(
685  name, input_args, table_funcs, processor);
686 }
string name
Definition: setup.in.py:72

+ Here is the caller graph for this function:

const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo> > bind_table_function ( std::string  name,
Analyzer::ExpressionPtrVector  input_args,
const bool  is_gpu 
)

Definition at line 739 of file ExtensionFunctionsBinding.cpp.

References bind_table_function(), and table_functions::TableFunctionsFactory::get_table_funcs().

741  {
742  // used in RelAlgExecutor.cpp
743  std::vector<table_functions::TableFunction> table_funcs =
745  return bind_table_function(name, input_args, table_funcs, is_gpu);
746 }
string name
Definition: setup.in.py:72
const std::tuple< table_functions::TableFunction, std::vector< SQLTypeInfo > > bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
static std::vector< TableFunction > get_table_funcs(const std::string &name, const bool is_gpu)

+ Here is the call graph for this function: