OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsBinding.cpp File Reference
#include "ExtensionFunctionsBinding.h"
#include <algorithm>
#include "ExternalExecutor.h"
+ Include dependency graph for ExtensionFunctionsBinding.cpp:

Go to the source code of this file.

Namespaces

 anonymous_namespace{ExtensionFunctionsBinding.cpp}
 

Functions

ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_column_arg_elem_type (const ExtArgumentType ext_arg_column_type)
 
ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_column_list_arg_elem_type (const ExtArgumentType ext_arg_column_list_type)
 
ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_array_arg_elem_type (const ExtArgumentType ext_arg_array_type)
 
static int anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_numeric_argument (const SQLTypeInfo &arg_type_info, const bool is_arg_literal, const ExtArgumentType &sig_ext_arg_type, int32_t &penalty_score)
 
static int anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_arguments (const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
 
bool anonymous_namespace{ExtensionFunctionsBinding.cpp}::is_valid_identifier (std::string str)
 
template<typename T >
std::tuple< T, std::vector
< SQLTypeInfo > > 
bind_function (std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
 
const std::tuple
< table_functions::TableFunction,
std::vector< SQLTypeInfo > > 
bind_table_function (std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
 
ExtensionFunction bind_function (std::string name, Analyzer::ExpressionPtrVector func_args)
 
ExtensionFunction bind_function (std::string name, Analyzer::ExpressionPtrVector func_args, const bool is_gpu)
 
ExtensionFunction bind_function (const Analyzer::FunctionOper *function_oper, const bool is_gpu)
 
const std::tuple
< table_functions::TableFunction,
std::vector< SQLTypeInfo > > 
bind_table_function (std::string name, Analyzer::ExpressionPtrVector input_args, const bool is_gpu)
 

Function Documentation

template<typename T >
std::tuple<T, std::vector<SQLTypeInfo> > bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args,
const std::vector< T > &  ext_funcs,
const std::string  processor 
)

Definition at line 548 of file ExtensionFunctionsBinding.cpp.

References CHECK, CHECK_EQ, DEFAULT_ROW_MULTIPLIER_SUFFIX, ext_arg_type_to_type_info(), logger::FATAL, generate_column_list_type(), generate_column_type(), SQLTypeInfo::get_type(), SQLTypeInfo::has_same_itemtype(), is_ext_arg_type_column_list(), anonymous_namespace{ExtensionFunctionsBinding.cpp}::is_valid_identifier(), kINT, kNULLT, kTEXT, LOG, anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_arguments(), setup::name, SQLTypeInfo::set_dimension(), SQLTypeInfo::supportsFlatBuffer(), heavydb.dtypes::T, to_string(), SQLTypeInfo::to_string(), ExtensionFunctionsWhitelist::toString(), and UNREACHABLE.

Referenced by bind_function(), CodeGenerator::codegenFunctionOper(), and RelAlgTranslator::translateFunction().

552  {
553  /* worker function
554 
555  Template type T must implement the following methods:
556 
557  std::vector<ExtArgumentType> getInputArgs()
558  */
559  /*
560  Return extension function/table function that has the following
561  properties
562 
563  1. each argument type in `arg_types` matches with extension
564  function argument types.
565 
566  For scalar types, the matching means that the types are either
567  equal or the argument type is smaller than the corresponding
568  the extension function argument type. This ensures that no
569  information is lost when casting of argument values is
570  required.
571 
572  For array and geo types, the matching means that the argument
573  type matches exactly with a group of extension function
574  argument types. See `match_arguments`.
575 
576  2. has minimal penalty score among all implementations of the
577  extension function with given `name`, see `get_penalty_score`
578  for the definition of penalty score.
579 
580  It is assumed that function_oper and extension functions in
581  ext_funcs have the same name.
582  */
583  if (!is_valid_identifier(name)) {
584  throw NativeExecutionError(
585  "Cannot bind function with invalid UDF/UDTF function name: " + name);
586  }
587 
588  std::vector<SQLTypeInfo> type_infos_input;
589  std::vector<bool> args_are_constants;
590  for (auto atype : func_args) {
591  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
592  if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
593  SQLTypeInfo type_info = atype->get_type_info();
594  auto ti = generate_column_type(type_info);
595  if (ti.get_subtype() == kNULLT) {
596  throw std::runtime_error(std::string(__FILE__) + "#" +
597  std::to_string(__LINE__) +
598  ": column support for type info " +
599  type_info.to_string() + " is not implemented");
600  }
601  ti.setUsesFlatBuffer(type_info.supportsFlatBuffer());
602  type_infos_input.push_back(ti);
603  args_are_constants.push_back(type_info.get_type() != kTEXT);
604  continue;
605  }
606  }
607  type_infos_input.push_back(atype->get_type_info());
608  if (dynamic_cast<const Analyzer::Constant*>(atype.get())) {
609  args_are_constants.push_back(true);
610  } else {
611  args_are_constants.push_back(false);
612  }
613  }
614  CHECK_EQ(type_infos_input.size(), args_are_constants.size());
615 
616  if (type_infos_input.size() == 0 && ext_funcs.size() > 0) {
617  CHECK_EQ(ext_funcs.size(), static_cast<size_t>(1));
618  CHECK_EQ(ext_funcs[0].getInputArgs().size(), static_cast<size_t>(0));
619  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
620  CHECK(ext_funcs[0].hasNonUserSpecifiedOutputSize());
621  }
622  std::vector<SQLTypeInfo> empty_type_info_variant(0);
623  return {ext_funcs[0], empty_type_info_variant};
624  }
625 
626  int minimal_score = std::numeric_limits<int>::max();
627  int index = -1;
628  int optimal = -1;
629  int optimal_variant = -1;
630  std::vector<std::vector<SQLTypeInfo>> type_infos_variants;
631 
632  // clang-format off
633  /*
634  Table functions may have arguments such as ColumnList that collect
635  neighboring columns with the same data type into a single object. In
636  general, the binding of UDTFs with ColumnLists might be ambiguous depending
637  on the order of the arguments:
638 
639  foo(ColumnList<T>, ColumnList<T>) -> Column<T>, T=[int]
640  bar(ColumnList<T>, Column<T>, ColumnList<T>) -> Column<T>, T=[int]
641 
642  Here both declarations above are ambiguous as the first ColumnList can
643  consume as many columns as possible, leaving a single column for each one
644  one of the remaining types. Or it can consume one argument, leaving the bulk
645  to the last ColumnList. Nevertheless, not all ColumnList declarations result
646  in an ambiguity signature. The example below shows an example of a function
647  that takes two ColumnLists of different types which has an exact match.
648 
649  baz(ColumnList<P>, ColumnList<T>) -> Column<T>, T=[int], Z=[float]
650 
651  To match a list of SQL arguments with an extension function, HeavyDB uses a
652  greedy algorithm that resolves the issue binding ambiguity as explained
653  below. As an example, let us consider a SQL query containing the following
654  expression calling a UDTF `bar` defined above:
655 
656  table(bar(select a, b, c, d, e from tableofints), 1)
657 
658  The algorithm will generate the following type variant, where the integer
659  value in [..] indicate the number of collected columns. This number is later
660  stored in the SQLTypeInfo dimension attribute.
661 
662  bar(ColumnList<T>[3], Column<T>, ColumnList<T>[1])
663 
664  */
665 
666  // clang-format on
667 
668  // Find extension function that gives the best match on the set of
669  // argument type variants
670  for (const auto& ext_func : ext_funcs) {
671  index++;
672 
673  const auto& ext_func_args = ext_func.getInputArgs();
674 
675  int penalty_score = 0;
676  int pos = 0;
677  int original_input_idx = 0;
678  type_infos_variants.emplace_back();
679 
680  for (size_t i = 0; i < type_infos_input.size(); i++) {
681  const SQLTypeInfo& ti = type_infos_input[i];
682 
683  if ((size_t)pos >= ext_func_args.size()) {
684  pos = -1;
685  break;
686  } else if (is_ext_arg_type_column_list(ext_func_args[pos])) {
687  SQLTypeInfo ti_col_list = generate_column_list_type(ti);
688  int offset = match_arguments(ti_col_list,
689  args_are_constants[original_input_idx],
690  pos,
691  ext_func_args,
692  penalty_score);
693  if (offset < 0) {
694  pos = -1;
695  break;
696  }
697 
698  // if offset > 0, greedly iterate over the rest of input args
699  // to consume columns with the same type as "ti"
700  int j = i;
701  size_t args_left = ext_func_args.size() - pos - 1;
702  while ((type_infos_input.size() - j > args_left) and
703  (ti_col_list.has_same_itemtype(type_infos_input[j]))) {
704  j++;
705  }
706  // push_back a ColumnList with dimension equals to the number of columns
707  // consumed above
708  ti_col_list.set_dimension(j - i);
709  type_infos_variants.back().push_back(ti_col_list);
710  // Move the "i" pointer to the last argument consumed
711  i = j - 1;
712  original_input_idx = j;
713  pos += offset;
714  } else {
715  int offset = match_arguments(ti,
716  args_are_constants[original_input_idx],
717  pos,
718  ext_func_args,
719  penalty_score);
720 
721  if (offset > 0) {
722  type_infos_variants.back().push_back(ti);
723  original_input_idx += 1;
724  pos += offset;
725  } else {
726  pos = -1;
727  break;
728  }
729  }
730  }
731 
732  if ((size_t)pos == ext_func_args.size()) {
733  CHECK_EQ(args_are_constants.size(), original_input_idx);
734  // prefer smaller return types
735  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
736  if (penalty_score < minimal_score) {
737  optimal = index;
738  minimal_score = penalty_score;
739  optimal_variant = type_infos_variants.size() - 1;
740  }
741  }
742  }
743 
744  if (optimal == -1) {
745  /* no extension function found that argument types would match
746  with types in `arg_types` */
747  auto sarg_types = ExtensionFunctionsWhitelist::toString(type_infos_input);
748  std::string message;
749  if (!ext_funcs.size()) {
750  message = "Function " + name + "(" + sarg_types + ") not supported.";
751  throw ExtensionFunctionBindingError(message);
752  } else {
753  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
754  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
755  " UDTF implementation.";
756  } else if constexpr (std::is_same_v<T, ExtensionFunction>) {
757  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
758  " UDF implementation.";
759  } else {
760  LOG(FATAL) << "bind_function: unknown extension function type "
761  << typeid(T).name();
762  }
763  message += "\n Existing extension function implementations:";
764  for (const auto& ext_func : ext_funcs) {
765  // Do not show functions missing the sizer argument
766  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
767  if (ext_func.useDefaultSizer()) {
768  continue;
769  }
770  }
771  message += "\n " + ext_func.toStringSQL();
772  }
773  }
774  throw ExtensionFunctionBindingError(message);
775  }
776 
777  // Functions with "_default_" suffix only exist for calcite
778  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
779  if (ext_funcs[optimal].hasUserSpecifiedOutputSizeMultiplier() &&
780  ext_funcs[optimal].useDefaultSizer()) {
781  std::string name = ext_funcs[optimal].getName();
782  name.erase(name.find(DEFAULT_ROW_MULTIPLIER_SUFFIX),
784  for (size_t i = 0; i < ext_funcs.size(); i++) {
785  if (ext_funcs[i].getName() == name) {
786  optimal = i;
787  std::vector<SQLTypeInfo> type_info = type_infos_variants[optimal_variant];
788  size_t sizer = ext_funcs[optimal].getOutputRowSizeParameter();
789  type_info.insert(type_info.begin() + sizer - 1, SQLTypeInfo(kINT, true));
790  return {ext_funcs[optimal], type_info};
791  }
792  }
793  UNREACHABLE();
794  }
795  }
796 
797  return {ext_funcs[optimal], type_infos_variants[optimal_variant]};
798 }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
#define LOG(tag)
Definition: Logger.h:285
static int match_arguments(const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
#define UNREACHABLE()
Definition: Logger.h:338
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:391
std::string to_string(char const *&&v)
std::string to_string() const
Definition: sqltypes.h:526
bool is_ext_arg_type_column_list(const ExtArgumentType ext_arg_type)
bool has_same_itemtype(const SQLTypeInfo &other) const
Definition: sqltypes.h:662
bool supportsFlatBuffer() const
Definition: sqltypes.h:1086
auto generate_column_type(const SQLTypeInfo &elem_ti)
Definition: sqltypes.h:1627
Definition: sqltypes.h:79
auto generate_column_list_type(const SQLTypeInfo &elem_ti)
Definition: sqltypes.h:1699
void set_dimension(int d)
Definition: sqltypes.h:470
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:291
Definition: sqltypes.h:72
string name
Definition: setup.in.py:72
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

ExtensionFunction bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args 
)

Definition at line 810 of file ExtensionFunctionsBinding.cpp.

References ExtensionFunctionsWhitelist::get_ext_funcs(), and setup::name.

811  {
812  // used in RelAlgTranslator.cpp, first try GPU UDFs, then fall back
813  // to CPU UDFs.
814  bool is_gpu = true;
815  std::string processor = "GPU";
816  auto ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
817  if (!ext_funcs.size()) {
818  is_gpu = false;
819  processor = "CPU";
821  }
822  try {
823  return std::get<0>(
824  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
825  } catch (ExtensionFunctionBindingError& e) {
826  if (is_gpu) {
827  is_gpu = false;
828  processor = "GPU|CPU";
830  return std::get<0>(
831  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
832  } else {
833  throw;
834  }
835  }
836 }
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
string name
Definition: setup.in.py:72

+ Here is the call graph for this function:

ExtensionFunction bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args,
const bool  is_gpu 
)

Definition at line 838 of file ExtensionFunctionsBinding.cpp.

References ExtensionFunctionsWhitelist::get_ext_funcs(), and setup::name.

840  {
841  // used below
842  std::vector<ExtensionFunction> ext_funcs =
844  std::string processor = (is_gpu ? "GPU" : "CPU");
845  return std::get<0>(
846  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
847 }
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
string name
Definition: setup.in.py:72

+ Here is the call graph for this function:

ExtensionFunction bind_function ( const Analyzer::FunctionOper function_oper,
const bool  is_gpu 
)

Definition at line 849 of file ExtensionFunctionsBinding.cpp.

References bind_function(), Analyzer::FunctionOper::getArity(), Analyzer::FunctionOper::getName(), Analyzer::FunctionOper::getOwnArg(), and setup::name.

850  {
851  // used in ExtensionsIR.cpp
852  auto name = function_oper->getName();
853  Analyzer::ExpressionPtrVector func_args = {};
854  for (size_t i = 0; i < function_oper->getArity(); ++i) {
855  func_args.push_back(function_oper->getOwnArg(i));
856  }
857  return bind_function(name, func_args, is_gpu);
858 }
size_t getArity() const
Definition: Analyzer.h:2615
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:2622
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:186
std::string getName() const
Definition: Analyzer.h:2613
string name
Definition: setup.in.py:72

+ Here is the call graph for this function:

const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo> > bind_table_function ( std::string  name,
Analyzer::ExpressionPtrVector  input_args,
const std::vector< table_functions::TableFunction > &  table_funcs,
const bool  is_gpu 
)

Definition at line 801 of file ExtensionFunctionsBinding.cpp.

References setup::name.

Referenced by bind_table_function(), and RelAlgExecutor::createTableFunctionWorkUnit().

804  {
805  std::string processor = (is_gpu ? "GPU" : "CPU");
806  return bind_function<table_functions::TableFunction>(
807  name, input_args, table_funcs, processor);
808 }
string name
Definition: setup.in.py:72

+ Here is the caller graph for this function:

const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo> > bind_table_function ( std::string  name,
Analyzer::ExpressionPtrVector  input_args,
const bool  is_gpu 
)

Definition at line 861 of file ExtensionFunctionsBinding.cpp.

References bind_table_function(), and table_functions::TableFunctionsFactory::get_table_funcs().

863  {
864  // used in RelAlgExecutor.cpp
865  std::vector<table_functions::TableFunction> table_funcs =
867  return bind_table_function(name, input_args, table_funcs, is_gpu);
868 }
static std::vector< TableFunction > get_table_funcs()
const std::tuple< table_functions::TableFunction, std::vector< SQLTypeInfo > > bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
string name
Definition: setup.in.py:72

+ Here is the call graph for this function: