OmniSciDB  c1a53651b2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsBinding.cpp File Reference
#include "ExtensionFunctionsBinding.h"
#include <algorithm>
#include "ExternalExecutor.h"
+ Include dependency graph for ExtensionFunctionsBinding.cpp:

Go to the source code of this file.

Namespaces

 anonymous_namespace{ExtensionFunctionsBinding.cpp}
 

Functions

ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_column_arg_elem_type (const ExtArgumentType ext_arg_column_type)
 
ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_column_list_arg_elem_type (const ExtArgumentType ext_arg_column_list_type)
 
ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_array_arg_elem_type (const ExtArgumentType ext_arg_array_type)
 
static int anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_numeric_argument (const SQLTypeInfo &arg_type_info, const bool is_arg_literal, const ExtArgumentType &sig_ext_arg_type, int32_t &penalty_score)
 
static int anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_arguments (const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
 
bool anonymous_namespace{ExtensionFunctionsBinding.cpp}::is_valid_identifier (std::string str)
 
template<typename T >
std::tuple< T, std::vector
< SQLTypeInfo > > 
bind_function (std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
 
const std::tuple
< table_functions::TableFunction,
std::vector< SQLTypeInfo > > 
bind_table_function (std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
 
ExtensionFunction bind_function (std::string name, Analyzer::ExpressionPtrVector func_args)
 
ExtensionFunction bind_function (std::string name, Analyzer::ExpressionPtrVector func_args, const bool is_gpu)
 
ExtensionFunction bind_function (const Analyzer::FunctionOper *function_oper, const bool is_gpu)
 
const std::tuple
< table_functions::TableFunction,
std::vector< SQLTypeInfo > > 
bind_table_function (std::string name, Analyzer::ExpressionPtrVector input_args, const bool is_gpu)
 

Function Documentation

template<typename T >
std::tuple<T, std::vector<SQLTypeInfo> > bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args,
const std::vector< T > &  ext_funcs,
const std::string  processor 
)

Definition at line 515 of file ExtensionFunctionsBinding.cpp.

References CHECK, CHECK_EQ, CHECK_LE, DEFAULT_ROW_MULTIPLIER_SUFFIX, ext_arg_type_to_type_info(), logger::FATAL, generate_column_list_type(), generate_column_type(), SQLTypeInfo::get_type(), is_ext_arg_type_column_list(), anonymous_namespace{ExtensionFunctionsBinding.cpp}::is_valid_identifier(), kCOLUMN_LIST, kINT, kNULLT, kTEXT, LOG, anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_arguments(), setup::name, heavydb.dtypes::T, to_string(), SQLTypeInfo::to_string(), ExtensionFunctionsWhitelist::toString(), and UNREACHABLE.

Referenced by bind_function(), CodeGenerator::codegenFunctionOper(), and RelAlgTranslator::translateFunction().

519  {
520  /* worker function
521 
522  Template type T must implement the following methods:
523 
524  std::vector<ExtArgumentType> getInputArgs()
525  */
526  /*
527  Return extension function/table function that has the following
528  properties
529 
530  1. each argument type in `arg_types` matches with extension
531  function argument types.
532 
533  For scalar types, the matching means that the types are either
534  equal or the argument type is smaller than the corresponding
535  the extension function argument type. This ensures that no
536  information is lost when casting of argument values is
537  required.
538 
539  For array and geo types, the matching means that the argument
540  type matches exactly with a group of extension function
541  argument types. See `match_arguments`.
542 
543  2. has minimal penalty score among all implementations of the
544  extension function with given `name`, see `get_penalty_score`
545  for the definition of penalty score.
546 
547  It is assumed that function_oper and extension functions in
548  ext_funcs have the same name.
549  */
550  if (!is_valid_identifier(name)) {
551  throw NativeExecutionError(
552  "Cannot bind function with invalid UDF/UDTF function name: " + name);
553  }
554 
555  int minimal_score = std::numeric_limits<int>::max();
556  int index = -1;
557  int optimal = -1;
558  int optimal_variant = -1;
559 
560  std::vector<SQLTypeInfo> type_infos_input;
561  std::vector<bool> args_are_constants;
562  for (auto atype : func_args) {
563  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
564  if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
565  SQLTypeInfo type_info = atype->get_type_info();
566  auto ti = generate_column_type(type_info);
567  if (ti.get_subtype() == kNULLT) {
568  throw std::runtime_error(std::string(__FILE__) + "#" +
569  std::to_string(__LINE__) +
570  ": column support for type info " +
571  type_info.to_string() + " is not implemented");
572  }
573  type_infos_input.push_back(ti);
574  args_are_constants.push_back(type_info.get_type() != kTEXT);
575  continue;
576  }
577  }
578  type_infos_input.push_back(atype->get_type_info());
579  if (dynamic_cast<const Analyzer::Constant*>(atype.get())) {
580  args_are_constants.push_back(true);
581  } else {
582  args_are_constants.push_back(false);
583  }
584  }
585  CHECK_EQ(type_infos_input.size(), args_are_constants.size());
586 
587  if (type_infos_input.size() == 0 && ext_funcs.size() > 0) {
588  CHECK_EQ(ext_funcs.size(), static_cast<size_t>(1));
589  CHECK_EQ(ext_funcs[0].getInputArgs().size(), static_cast<size_t>(0));
590  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
591  CHECK(ext_funcs[0].hasNonUserSpecifiedOutputSize());
592  }
593  std::vector<SQLTypeInfo> empty_type_info_variant(0);
594  return {ext_funcs[0], empty_type_info_variant};
595  }
596 
597  // clang-format off
598  /*
599  Table functions may have arguments such as ColumnList that collect
600  neighboring columns with the same data type into a single object.
601  Here we compute all possible combinations of mapping a subset of
602  columns into columns sets. For example, if the types of function
603  arguments are (as given in func_args argument)
604 
605  (Column<int>, Column<int>, Column<int>, int)
606 
607  then the computed variants will be
608 
609  (Column<int>, Column<int>, Column<int>, int)
610  (Column<int>, Column<int>, ColumnList[1]<int>, int)
611  (Column<int>, ColumnList[1]<int>, Column<int>, int)
612  (Column<int>, ColumnList[2]<int>, int)
613  (ColumnList[1]<int>, Column<int>, Column<int>, int)
614  (ColumnList[1]<int>, Column<int>, ColumnList[1]<int>, int)
615  (ColumnList[2]<int>, Column<int>, int)
616  (ColumnList[3]<int>, int)
617 
618  where the integers in [..] indicate the number of collected
619  columns. In the SQLTypeInfo instance, this number is stored in the
620  SQLTypeInfo dimension attribute.
621 
622  As an example, let us consider a SQL query containing the
623  following expression calling a UDTF foo:
624 
625  table(foo(cursor(select a, b, c from tableofints), 1))
626 
627  Here follows a list of table functions and the corresponding
628  optimal argument type variants that are computed for the given
629  query expression:
630 
631  UDTF: foo(ColumnList<int>, RowMultiplier) -> Column<int>
632  (ColumnList[3]<int>, int) # a, b, c are all collected to column_list
633 
634  UDTF: foo(Column<int>, ColumnList<int>, RowMultiplier) -> Column<int>
635  (Column<int>, ColumnList[2]<int>, int) # b and c are collected to column_list
636 
637  UDTF: foo(Column<int>, Column<int>, Column<int>, RowMultiplier) -> Column<int>
638  (Column<int>, Column<int>, Column<int>, int)
639  */
640  // clang-format on
641 
642  // We first check if any of the matched extension functions
643  // in the ext_funcs list allow for column lists, as if they do not,
644  // we do not need to account for possible input permutations with
645  // ColumnList arguments, which currently can be slow time
646  // to match with the extension functions if the number of arguments
647  // is high
648 
649  // Todo: Develop faster matching algorithm to avoid such performance
650  // hits when ColumnLists are allowed
651 
652  bool ext_funcs_allow_column_lists{false};
653  for (const auto& ext_func : ext_funcs) {
654  auto ext_func_args = ext_func.getInputArgs();
655  for (const auto& arg : ext_func_args) {
656  if (is_ext_arg_type_column_list(arg)) {
657  ext_funcs_allow_column_lists = true;
658  break;
659  }
660  }
661  if (ext_funcs_allow_column_lists) {
662  break;
663  }
664  }
665 
666  std::vector<std::vector<SQLTypeInfo>> type_infos_variants;
667  if (ext_funcs_allow_column_lists) {
668  for (const auto& ti : type_infos_input) {
669  if (type_infos_variants.begin() == type_infos_variants.end()) {
670  type_infos_variants.push_back({ti});
671  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
672  if (ti.is_column()) {
673  auto mti = generate_column_list_type(ti);
674  if (mti.get_subtype() == kNULLT) {
675  continue; // skip unsupported element type.
676  }
677  mti.set_dimension(1);
678  type_infos_variants.push_back({mti});
679  }
680  }
681  continue;
682  }
683  std::vector<std::vector<SQLTypeInfo>> new_type_infos_variants;
684  for (auto& type_infos : type_infos_variants) {
685  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
686  if (ti.is_column()) {
687  auto new_type_infos = type_infos; // makes a copy
688  const auto& last = type_infos.back();
689  if (last.is_column_list() && last.has_same_itemtype(ti)) {
690  // last column_list consumes column argument if item types match
691  new_type_infos.back().set_dimension(last.get_dimension() + 1);
692  } else {
693  // add column as column_list argument
694  auto mti = generate_column_list_type(ti);
695  if (mti.get_subtype() == kNULLT) {
696  // skip unsupported element type
697  type_infos.push_back(ti);
698  continue;
699  }
700  mti.set_dimension(1);
701  new_type_infos.push_back(mti);
702  }
703  new_type_infos_variants.push_back(new_type_infos);
704  }
705  }
706  type_infos.push_back(ti);
707  }
708  type_infos_variants.insert(type_infos_variants.end(),
709  new_type_infos_variants.begin(),
710  new_type_infos_variants.end());
711  }
712  } else {
713  type_infos_variants.emplace_back(type_infos_input);
714  }
715 
716  // Find extension function that gives the best match on the set of
717  // argument type variants:
718  for (const auto& ext_func : ext_funcs) {
719  index++;
720 
721  const auto& ext_func_args = ext_func.getInputArgs();
722  int index_variant = -1;
723  for (const auto& type_infos : type_infos_variants) {
724  index_variant++;
725  int penalty_score = 0;
726  int pos = 0;
727  int original_input_idx = 0;
728  CHECK_LE(type_infos.size(), args_are_constants.size());
729  for (const auto& ti : type_infos) {
730  int offset = match_arguments(ti,
731  args_are_constants[original_input_idx],
732  pos,
733  ext_func_args,
734  penalty_score);
735  if (offset < 0) {
736  // atype does not match with ext_func argument
737  pos = -1;
738  break;
739  }
740  if (ti.get_type() == kCOLUMN_LIST) {
741  original_input_idx += ti.get_dimension();
742  } else {
743  original_input_idx++;
744  }
745  pos += offset;
746  }
747 
748  if ((size_t)pos == ext_func_args.size()) {
749  CHECK_EQ(args_are_constants.size(), original_input_idx);
750  // prefer smaller return types
751  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
752  if (penalty_score < minimal_score) {
753  optimal = index;
754  minimal_score = penalty_score;
755  optimal_variant = index_variant;
756  }
757  }
758  }
759  }
760 
761  if (optimal == -1) {
762  /* no extension function found that argument types would match
763  with types in `arg_types` */
764  auto sarg_types = ExtensionFunctionsWhitelist::toString(type_infos_input);
765  std::string message;
766  if (!ext_funcs.size()) {
767  message = "Function " + name + "(" + sarg_types + ") not supported.";
768  throw ExtensionFunctionBindingError(message);
769  } else {
770  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
771  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
772  " UDTF implementation.";
773  } else if constexpr (std::is_same_v<T, ExtensionFunction>) {
774  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
775  " UDF implementation.";
776  } else {
777  LOG(FATAL) << "bind_function: unknown extension function type "
778  << typeid(T).name();
779  }
780  message += "\n Existing extension function implementations:";
781  for (const auto& ext_func : ext_funcs) {
782  // Do not show functions missing the sizer argument
783  if constexpr (std::is_same_v<T, table_functions::TableFunction>)
784  if (ext_func.useDefaultSizer())
785  continue;
786  message += "\n " + ext_func.toStringSQL();
787  }
788  }
789  throw ExtensionFunctionBindingError(message);
790  }
791 
792  // Functions with "_default_" suffix only exist for calcite
793  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
794  if (ext_funcs[optimal].hasUserSpecifiedOutputSizeMultiplier() &&
795  ext_funcs[optimal].useDefaultSizer()) {
796  std::string name = ext_funcs[optimal].getName();
797  name.erase(name.find(DEFAULT_ROW_MULTIPLIER_SUFFIX),
799  for (size_t i = 0; i < ext_funcs.size(); i++) {
800  if (ext_funcs[i].getName() == name) {
801  optimal = i;
802  std::vector<SQLTypeInfo> type_info = type_infos_variants[optimal_variant];
803  size_t sizer = ext_funcs[optimal].getOutputRowSizeParameter();
804  type_info.insert(type_info.begin() + sizer - 1, SQLTypeInfo(kINT, true));
805  return {ext_funcs[optimal], type_info};
806  }
807  }
808  UNREACHABLE();
809  }
810  }
811 
812  return {ext_funcs[optimal], type_infos_variants[optimal_variant]};
813 }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
#define LOG(tag)
Definition: Logger.h:285
static int match_arguments(const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
#define UNREACHABLE()
Definition: Logger.h:337
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:381
std::string to_string(char const *&&v)
std::string to_string() const
Definition: sqltypes.h:547
bool is_ext_arg_type_column_list(const ExtArgumentType ext_arg_type)
auto generate_column_type(const SQLTypeInfo &elem_ti)
Definition: sqltypes.h:1387
Definition: sqltypes.h:69
#define CHECK_LE(x, y)
Definition: Logger.h:304
auto generate_column_list_type(const SQLTypeInfo &elem_ti)
Definition: sqltypes.h:1445
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:291
Definition: sqltypes.h:62
string name
Definition: setup.in.py:72
constexpr double n
Definition: Utm.h:38
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

ExtensionFunction bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args 
)

Definition at line 825 of file ExtensionFunctionsBinding.cpp.

References ExtensionFunctionsWhitelist::get_ext_funcs(), and setup::name.

826  {
827  // used in RelAlgTranslator.cpp, first try GPU UDFs, then fall back
828  // to CPU UDFs.
829  bool is_gpu = true;
830  std::string processor = "GPU";
831  auto ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
832  if (!ext_funcs.size()) {
833  is_gpu = false;
834  processor = "CPU";
836  }
837  try {
838  return std::get<0>(
839  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
840  } catch (ExtensionFunctionBindingError& e) {
841  if (is_gpu) {
842  is_gpu = false;
843  processor = "GPU|CPU";
845  return std::get<0>(
846  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
847  } else {
848  throw;
849  }
850  }
851 }
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
string name
Definition: setup.in.py:72

+ Here is the call graph for this function:

ExtensionFunction bind_function ( std::string  name,
Analyzer::ExpressionPtrVector  func_args,
const bool  is_gpu 
)

Definition at line 853 of file ExtensionFunctionsBinding.cpp.

References ExtensionFunctionsWhitelist::get_ext_funcs(), and setup::name.

855  {
856  // used below
857  std::vector<ExtensionFunction> ext_funcs =
859  std::string processor = (is_gpu ? "GPU" : "CPU");
860  return std::get<0>(
861  bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor));
862 }
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
string name
Definition: setup.in.py:72

+ Here is the call graph for this function:

ExtensionFunction bind_function ( const Analyzer::FunctionOper function_oper,
const bool  is_gpu 
)

Definition at line 864 of file ExtensionFunctionsBinding.cpp.

References bind_function(), Analyzer::FunctionOper::getArity(), Analyzer::FunctionOper::getName(), Analyzer::FunctionOper::getOwnArg(), and setup::name.

865  {
866  // used in ExtensionsIR.cpp
867  auto name = function_oper->getName();
868  Analyzer::ExpressionPtrVector func_args = {};
869  for (size_t i = 0; i < function_oper->getArity(); ++i) {
870  func_args.push_back(function_oper->getOwnArg(i));
871  }
872  return bind_function(name, func_args, is_gpu);
873 }
size_t getArity() const
Definition: Analyzer.h:2408
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:2415
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:186
std::string getName() const
Definition: Analyzer.h:2406
string name
Definition: setup.in.py:72

+ Here is the call graph for this function:

const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo> > bind_table_function ( std::string  name,
Analyzer::ExpressionPtrVector  input_args,
const std::vector< table_functions::TableFunction > &  table_funcs,
const bool  is_gpu 
)

Definition at line 816 of file ExtensionFunctionsBinding.cpp.

References setup::name.

Referenced by bind_table_function(), and RelAlgExecutor::createTableFunctionWorkUnit().

819  {
820  std::string processor = (is_gpu ? "GPU" : "CPU");
821  return bind_function<table_functions::TableFunction>(
822  name, input_args, table_funcs, processor);
823 }
string name
Definition: setup.in.py:72

+ Here is the caller graph for this function:

const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo> > bind_table_function ( std::string  name,
Analyzer::ExpressionPtrVector  input_args,
const bool  is_gpu 
)

Definition at line 876 of file ExtensionFunctionsBinding.cpp.

References bind_table_function(), and table_functions::TableFunctionsFactory::get_table_funcs().

878  {
879  // used in RelAlgExecutor.cpp
880  std::vector<table_functions::TableFunction> table_funcs =
882  return bind_table_function(name, input_args, table_funcs, is_gpu);
883 }
static std::vector< TableFunction > get_table_funcs()
const std::tuple< table_functions::TableFunction, std::vector< SQLTypeInfo > > bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
string name
Definition: setup.in.py:72

+ Here is the call graph for this function: