OmniSciDB  cde582ebc3
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsBinding.cpp File Reference
#include "ExtensionFunctionsBinding.h"
#include <algorithm>
#include "ExternalExecutor.h"
+ Include dependency graph for ExtensionFunctionsBinding.cpp:

Go to the source code of this file.

Namespaces

 anonymous_namespace{ExtensionFunctionsBinding.cpp}
 

Functions

ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_column_arg_elem_type (const ExtArgumentType ext_arg_column_type)
 
ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_column_list_arg_elem_type (const ExtArgumentType ext_arg_column_list_type)
 
ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_array_arg_elem_type (const ExtArgumentType ext_arg_array_type)
 
static int anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_numeric_argument (const SQLTypeInfo &arg_type_info, const bool is_arg_literal, const ExtArgumentType &sig_ext_arg_type, int32_t &penalty_score)
 
static int anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_arguments (const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
 
bool anonymous_namespace{ExtensionFunctionsBinding.cpp}::is_valid_identifier (std::string str)
 
template<typename T >
std::tuple< T, std::vector
< SQLTypeInfo > > 
bind_function (std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
 
const std::tuple
< table_functions::TableFunction,
std::vector< SQLTypeInfo > > 
bind_table_function (std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
 
ExtensionFunction bind_function (std::string name, Analyzer::ExpressionPtrVector func_args)
 
ExtensionFunction bind_function (std::string name, Analyzer::ExpressionPtrVector func_args, const bool is_gpu)
 
ExtensionFunction bind_function (const Analyzer::FunctionOper *function_oper, const bool is_gpu)
 
const std::tuple
< table_functions::TableFunction,
std::vector< SQLTypeInfo > > 
bind_table_function (std::string name, Analyzer::ExpressionPtrVector input_args, const bool is_gpu)
 

Function Documentation

template<typename T >
std::tuple<T, std::vector<SQLTypeInfo> > bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args,
const std::vector< T > &  ext_funcs,
const std::string  processor 
)

Definition at line 433 of file ExtensionFunctionsBinding.cpp.

References CHECK, CHECK_EQ, CHECK_LE, DEFAULT_ROW_MULTIPLIER_SUFFIX, ext_arg_type_to_type_info(), logger::FATAL, generate_column_list_type(), generate_column_type(), SQLTypeInfo::get_comp_param(), SQLTypeInfo::get_compression(), SQLTypeInfo::get_type(), anonymous_namespace{ExtensionFunctionsBinding.cpp}::is_valid_identifier(), kCOLUMN_LIST, kINT, kTEXT, LOG, anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_arguments(), setup::name, heavydb.dtypes::T, ExtensionFunctionsWhitelist::toString(), and UNREACHABLE.

Referenced by bind_function(), CodeGenerator::codegenFunctionOper(), and RelAlgTranslator::translateFunction().

437  {
438  /* worker function
439 
440  Template type T must implement the following methods:
441 
442  std::vector<ExtArgumentType> getInputArgs()
443  */
444  /*
445  Return extension function/table function that has the following
446  properties
447 
448  1. each argument type in `arg_types` matches with extension
449  function argument types.
450 
451  For scalar types, the matching means that the types are either
452  equal or the argument type is smaller than the corresponding
453  the extension function argument type. This ensures that no
454  information is lost when casting of argument values is
455  required.
456 
457  For array and geo types, the matching means that the argument
458  type matches exactly with a group of extension function
459  argument types. See `match_arguments`.
460 
461  2. has minimal penalty score among all implementations of the
462  extension function with given `name`, see `get_penalty_score`
463  for the definition of penalty score.
464 
465  It is assumed that function_oper and extension functions in
466  ext_funcs have the same name.
467  */
468  if (!is_valid_identifier(name)) {
469  throw NativeExecutionError(
470  "Cannot bind function with invalid UDF/UDTF function name: " + name);
471  }
472 
473  int minimal_score = std::numeric_limits<int>::max();
474  int index = -1;
475  int optimal = -1;
476  int optimal_variant = -1;
477 
478  std::vector<SQLTypeInfo> type_infos_input;
479  std::vector<bool> args_are_constants;
480  for (auto atype : func_args) {
481  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
482  if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
483  SQLTypeInfo type_info = atype->get_type_info();
484  if (atype->get_type_info().get_type() == kTEXT) {
485  auto ti = generate_column_type(type_info.get_type(), // subtype
486  type_info.get_compression(), // compression
487  type_info.get_comp_param()); // comp_param
488  type_infos_input.push_back(ti);
489  args_are_constants.push_back(false);
490  } else {
491  auto ti = generate_column_type(type_info.get_type());
492  type_infos_input.push_back(ti);
493  args_are_constants.push_back(true);
494  }
495  continue;
496  }
497  }
498  type_infos_input.push_back(atype->get_type_info());
499  if (dynamic_cast<const Analyzer::Constant*>(atype.get())) {
500  args_are_constants.push_back(true);
501  } else {
502  args_are_constants.push_back(false);
503  }
504  }
505  CHECK_EQ(type_infos_input.size(), args_are_constants.size());
506 
507  if (type_infos_input.size() == 0 && ext_funcs.size() > 0) {
508  CHECK_EQ(ext_funcs.size(), static_cast<size_t>(1));
509  CHECK_EQ(ext_funcs[0].getInputArgs().size(), static_cast<size_t>(0));
510  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
511  CHECK(ext_funcs[0].hasNonUserSpecifiedOutputSize());
512  }
513  std::vector<SQLTypeInfo> empty_type_info_variant(0);
514  return {ext_funcs[0], empty_type_info_variant};
515  }
516 
517  // clang-format off
518  /*
519  Table functions may have arguments such as ColumnList that collect
520  neighboring columns with the same data type into a single object.
521  Here we compute all possible combinations of mapping a subset of
522  columns into columns sets. For example, if the types of function
523  arguments are (as given in func_args argument)
524 
525  (Column<int>, Column<int>, Column<int>, int)
526 
527  then the computed variants will be
528 
529  (Column<int>, Column<int>, Column<int>, int)
530  (Column<int>, Column<int>, ColumnList[1]<int>, int)
531  (Column<int>, ColumnList[1]<int>, Column<int>, int)
532  (Column<int>, ColumnList[2]<int>, int)
533  (ColumnList[1]<int>, Column<int>, Column<int>, int)
534  (ColumnList[1]<int>, Column<int>, ColumnList[1]<int>, int)
535  (ColumnList[2]<int>, Column<int>, int)
536  (ColumnList[3]<int>, int)
537 
538  where the integers in [..] indicate the number of collected
539  columns. In the SQLTypeInfo instance, this number is stored in the
540  SQLTypeInfo dimension attribute.
541 
542  As an example, let us consider a SQL query containing the
543  following expression calling a UDTF foo:
544 
545  table(foo(cursor(select a, b, c from tableofints), 1))
546 
547  Here follows a list of table functions and the corresponding
548  optimal argument type variants that are computed for the given
549  query expression:
550 
551  UDTF: foo(ColumnList<int>, RowMultiplier) -> Column<int>
552  (ColumnList[3]<int>, int) # a, b, c are all collected to column_list
553 
554  UDTF: foo(Column<int>, ColumnList<int>, RowMultiplier) -> Column<int>
555  (Column<int>, ColumnList[2]<int>, int) # b and c are collected to column_list
556 
557  UDTF: foo(Column<int>, Column<int>, Column<int>, RowMultiplier) -> Column<int>
558  (Column<int>, Column<int>, Column<int>, int)
559  */
560  // clang-format on
561  std::vector<std::vector<SQLTypeInfo>> type_infos_variants;
562  for (auto ti : type_infos_input) {
563  if (type_infos_variants.begin() == type_infos_variants.end()) {
564  type_infos_variants.push_back({ti});
565  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
566  if (ti.is_column()) {
567  auto mti = generate_column_list_type(ti.get_subtype());
568  mti.set_dimension(1);
569  type_infos_variants.push_back({mti});
570  }
571  }
572  continue;
573  }
574  std::vector<std::vector<SQLTypeInfo>> new_type_infos_variants;
575  for (auto& type_infos : type_infos_variants) {
576  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
577  if (ti.is_column()) {
578  auto new_type_infos = type_infos; // makes a copy
579  const auto& last = type_infos.back();
580  if (last.is_column_list() && last.get_subtype() == ti.get_subtype()) {
581  // last column_list consumes column argument if types match
582  new_type_infos.back().set_dimension(last.get_dimension() + 1);
583  } else {
584  // add column as column_list argument
585  auto mti = generate_column_list_type(ti.get_subtype());
586  mti.set_dimension(1);
587  new_type_infos.push_back(mti);
588  }
589  new_type_infos_variants.push_back(new_type_infos);
590  }
591  }
592  type_infos.push_back(ti);
593  }
594  type_infos_variants.insert(type_infos_variants.end(),
595  new_type_infos_variants.begin(),
596  new_type_infos_variants.end());
597  }
598 
599  // Find extension function that gives the best match on the set of
600  // argument type variants:
601  for (auto ext_func : ext_funcs) {
602  index++;
603 
604  auto ext_func_args = ext_func.getInputArgs();
605  int index_variant = -1;
606  for (const auto& type_infos : type_infos_variants) {
607  index_variant++;
608  int penalty_score = 0;
609  int pos = 0;
610  int original_input_idx = 0;
611  CHECK_LE(type_infos.size(), args_are_constants.size());
612  // for (size_t ti_idx = 0; ti_idx != type_infos.size(); ++ti_idx) {
613  for (const auto& ti : type_infos) {
614  int offset = match_arguments(ti,
615  args_are_constants[original_input_idx],
616  pos,
617  ext_func_args,
618  penalty_score);
619  if (offset < 0) {
620  // atype does not match with ext_func argument
621  pos = -1;
622  break;
623  }
624  if (ti.get_type() == kCOLUMN_LIST) {
625  original_input_idx += ti.get_dimension();
626  } else {
627  original_input_idx++;
628  }
629  pos += offset;
630  }
631 
632  if ((size_t)pos == ext_func_args.size()) {
633  CHECK_EQ(args_are_constants.size(), original_input_idx);
634  // prefer smaller return types
635  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
636  if (penalty_score < minimal_score) {
637  optimal = index;
638  minimal_score = penalty_score;
639  optimal_variant = index_variant;
640  }
641  }
642  }
643  }
644 
645  if (optimal == -1) {
646  /* no extension function found that argument types would match
647  with types in `arg_types` */
648  auto sarg_types = ExtensionFunctionsWhitelist::toString(type_infos_input);
649  std::string message;
650  if (!ext_funcs.size()) {
651  message = "Function " + name + "(" + sarg_types + ") not supported.";
652  throw ExtensionFunctionBindingError(message);
653  } else {
654  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
655  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
656  " UDTF implementation.";
657  } else if constexpr (std::is_same_v<T, ExtensionFunction>) {
658  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
659  " UDF implementation.";
660  } else {
661  LOG(FATAL) << "bind_function: unknown extension function type "
662  << typeid(T).name();
663  }
664  message += "\n Existing extension function implementations:";
665  for (const auto& ext_func : ext_funcs) {
666  // Do not show functions missing the sizer argument
667  if constexpr (std::is_same_v<T, table_functions::TableFunction>)
668  if (ext_func.useDefaultSizer())
669  continue;
670  message += "\n " + ext_func.toStringSQL();
671  }
672  }
673  throw ExtensionFunctionBindingError(message);
674  }
675 
676  // Functions with "_default_" suffix only exist for calcite
677  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
678  if (ext_funcs[optimal].hasUserSpecifiedOutputSizeMultiplier() &&
679  ext_funcs[optimal].useDefaultSizer()) {
680  std::string name = ext_funcs[optimal].getName();
681  name.erase(name.find(DEFAULT_ROW_MULTIPLIER_SUFFIX),
683  for (size_t i = 0; i < ext_funcs.size(); i++) {
684  if (ext_funcs[i].getName() == name) {
685  optimal = i;
686  std::vector<SQLTypeInfo> type_info = type_infos_variants[optimal_variant];
687  size_t sizer = ext_funcs[optimal].getOutputRowSizeParameter();
688  type_info.insert(type_info.begin() + sizer - 1, SQLTypeInfo(kINT, true));
689  return {ext_funcs[optimal], type_info};
690  }
691  }
692  UNREACHABLE();
693  }
694  }
695 
696  return {ext_funcs[optimal], type_infos_variants[optimal_variant]};
697 }
#define CHECK_EQ(x, y)
Definition: Logger.h:230
#define LOG(tag)
Definition: Logger.h:216
auto generate_column_type(const SQLTypes subtype)
Definition: sqltypes.h:1124
static int match_arguments(const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
#define UNREACHABLE()
Definition: Logger.h:266
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:329
auto generate_column_list_type(const SQLTypes subtype)
Definition: sqltypes.h:1138
Definition: sqltypes.h:52
#define CHECK_LE(x, y)
Definition: Logger.h:233
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:337
HOST DEVICE int get_comp_param() const
Definition: sqltypes.h:338
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:222
Definition: sqltypes.h:45
string name
Definition: setup.in.py:72
constexpr double n
Definition: Utm.h:38
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

ExtensionFunction bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args 
)

Definition at line 709 of file ExtensionFunctionsBinding.cpp.

References ExtensionFunctionsWhitelist::get_ext_funcs(), and setup::name.

710  {
711  // used in RelAlgTranslator.cpp, first try GPU UDFs, then fall back
712  // to CPU UDFs.
713  bool is_gpu = true;
714  std::string processor = "GPU";
715  auto ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
716  if (!ext_funcs.size()) {
717  is_gpu = false;
718  processor = "CPU";
720  }
721  try {
722  return std::get<0>(
723  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
724  } catch (ExtensionFunctionBindingError& e) {
725  if (is_gpu) {
726  is_gpu = false;
727  processor = "GPU|CPU";
729  return std::get<0>(
730  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
731  } else {
732  throw;
733  }
734  }
735 }
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
string name
Definition: setup.in.py:72

+ Here is the call graph for this function:

ExtensionFunction bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args,
const bool  is_gpu 
)

Definition at line 737 of file ExtensionFunctionsBinding.cpp.

References ExtensionFunctionsWhitelist::get_ext_funcs(), and setup::name.

739  {
740  // used below
741  std::vector<ExtensionFunction> ext_funcs =
743  std::string processor = (is_gpu ? "GPU" : "CPU");
744  return std::get<0>(
745  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
746 }
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
string name
Definition: setup.in.py:72

+ Here is the call graph for this function:

ExtensionFunction bind_function ( const Analyzer::FunctionOper function_oper,
const bool  is_gpu 
)

Definition at line 748 of file ExtensionFunctionsBinding.cpp.

References bind_function(), Analyzer::FunctionOper::getArity(), Analyzer::FunctionOper::getName(), Analyzer::FunctionOper::getOwnArg(), and setup::name.

749  {
750  // used in ExtensionsIR.cpp
751  auto name = function_oper->getName();
752  Analyzer::ExpressionPtrVector func_args = {};
753  for (size_t i = 0; i < function_oper->getArity(); ++i) {
754  func_args.push_back(function_oper->getOwnArg(i));
755  }
756  return bind_function(name, func_args, is_gpu);
757 }
size_t getArity() const
Definition: Analyzer.h:2169
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:2176
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:188
std::string getName() const
Definition: Analyzer.h:2167
string name
Definition: setup.in.py:72

+ Here is the call graph for this function:

const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo> > bind_table_function ( std::string  name,
Analyzer::ExpressionPtrVector  input_args,
const std::vector< table_functions::TableFunction > &  table_funcs,
const bool  is_gpu 
)

Definition at line 700 of file ExtensionFunctionsBinding.cpp.

References setup::name.

Referenced by bind_table_function(), and RelAlgExecutor::createTableFunctionWorkUnit().

703  {
704  std::string processor = (is_gpu ? "GPU" : "CPU");
705  return bind_function<table_functions::TableFunction>(
706  name, input_args, table_funcs, processor);
707 }
string name
Definition: setup.in.py:72

+ Here is the caller graph for this function:

const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo> > bind_table_function ( std::string  name,
Analyzer::ExpressionPtrVector  input_args,
const bool  is_gpu 
)

Definition at line 760 of file ExtensionFunctionsBinding.cpp.

References bind_table_function(), and table_functions::TableFunctionsFactory::get_table_funcs().

762  {
763  // used in RelAlgExecutor.cpp
764  std::vector<table_functions::TableFunction> table_funcs =
766  return bind_table_function(name, input_args, table_funcs, is_gpu);
767 }
static std::vector< TableFunction > get_table_funcs()
const std::tuple< table_functions::TableFunction, std::vector< SQLTypeInfo > > bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
string name
Definition: setup.in.py:72

+ Here is the call graph for this function: