37 switch (ext_arg_column_type) {
80 switch (ext_arg_column_list_type) {
120 switch (ext_arg_array_type) {
144 const bool is_arg_literal,
146 int32_t& penalty_score) {
147 const auto arg_type = arg_type_info.
get_type();
153 const auto sig_type = sig_type_info.get_type();
176 const bool is_integer_to_fp_cast = (arg_type ==
kTINYINT || arg_type ==
kSMALLINT ||
181 CHECK_GE(arg_type_relative_scale, 1);
182 CHECK_LE(arg_type_relative_scale, 8);
183 auto sig_type_relative_scale = sig_type_info.get_numeric_scalar_scale();
184 CHECK_GE(sig_type_relative_scale, 1);
185 CHECK_LE(sig_type_relative_scale, 8);
187 if (is_integer_to_fp_cast) {
189 sig_type_relative_scale = (3 - (sig_type_relative_scale >> 2)) << 3;
194 CHECK_GE(sig_type_relative_scale, arg_type_relative_scale);
197 const auto sig_type_scale_gain_ratio =
198 sig_type_relative_scale / arg_type_relative_scale;
199 CHECK_GE(sig_type_scale_gain_ratio, 1);
205 const int32_t type_family_cast_penalty_score = is_integer_to_fp_cast ? 1001000 : 1000;
207 int32_t scale_cast_penalty_score;
221 if (is_arg_literal) {
222 scale_cast_penalty_score =
223 (8000 / arg_type_relative_scale) - (1000 * sig_type_scale_gain_ratio);
225 scale_cast_penalty_score = (1000 * sig_type_scale_gain_ratio);
228 const auto cast_penalty_score =
229 type_family_cast_penalty_score + scale_cast_penalty_score;
231 penalty_score += cast_penalty_score;
236 const bool is_arg_literal,
238 const std::vector<ExtArgumentType>& sig_types,
239 int& penalty_score) {
262 int max_pos = sig_types.size() - 1;
263 if (sig_pos > max_pos) {
266 auto sig_type = sig_types[sig_pos];
285 penalty_score += 1000;
290 penalty_score += 1000;
299 penalty_score += 1000;
302 penalty_score += 1000;
312 penalty_score += 1000;
317 const auto sig_type_ti =
320 sig_type_ti.get_type() ==
kTINYINT) {
322 penalty_score += 1000;
325 penalty_score += 1000;
337 penalty_score += 1000;
340 penalty_score += 1000;
351 penalty_score += 1000;
354 penalty_score += 1000;
364 penalty_score += 1000;
371 const auto sig_type_ti =
374 sig_type_ti.get_type() ==
kARRAY) {
376 sig_type_ti.get_elem_type().get_type()) {
377 penalty_score += 1000;
383 sig_type_ti.get_type() ==
kTINYINT) {
385 penalty_score += 1000;
388 penalty_score += 1000;
398 const auto sig_type_ti =
401 sig_type_ti.get_type() ==
kARRAY) {
403 sig_type_ti.get_elem_type().get_type()) {
404 penalty_score += 1000;
410 sig_type_ti.get_type() ==
kTINYINT) {
412 penalty_score += 10000;
415 penalty_score += 10000;
428 penalty_score += 1000;
440 penalty_score += 1000;
446 penalty_score += 1000;
455 penalty_score += 1000;
461 penalty_score += 1000;
468 penalty_score += 1000;
484 throw std::runtime_error(std::string(__FILE__) +
"#" +
std::to_string(__LINE__) +
499 if (!(std::isalpha(str[0]) || str[0] ==
'_')) {
503 for (
size_t i = 1; i < str.size(); i++) {
504 if (!(std::isalnum(str[i]) || str[i] ==
'_')) {
514 template <
typename T>
518 const std::vector<T>& ext_funcs,
519 const std::string processor) {
552 "Cannot bind function with invalid UDF/UDTF function name: " + name);
555 int minimal_score = std::numeric_limits<int>::max();
558 int optimal_variant = -1;
560 std::vector<SQLTypeInfo> type_infos_input;
561 std::vector<bool> args_are_constants;
562 for (
auto atype : func_args) {
563 if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
564 if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
567 if (ti.get_subtype() ==
kNULLT) {
568 throw std::runtime_error(std::string(__FILE__) +
"#" +
570 ": column support for type info " +
571 type_info.
to_string() +
" is not implemented");
573 type_infos_input.push_back(ti);
578 type_infos_input.push_back(atype->get_type_info());
579 if (dynamic_cast<const Analyzer::Constant*>(atype.get())) {
580 args_are_constants.push_back(
true);
582 args_are_constants.push_back(
false);
585 CHECK_EQ(type_infos_input.size(), args_are_constants.size());
587 if (type_infos_input.size() == 0 && ext_funcs.size() > 0) {
588 CHECK_EQ(ext_funcs.size(),
static_cast<size_t>(1));
589 CHECK_EQ(ext_funcs[0].getInputArgs().size(), static_cast<size_t>(0));
590 if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
591 CHECK(ext_funcs[0].hasNonUserSpecifiedOutputSize());
593 std::vector<SQLTypeInfo> empty_type_info_variant(0);
594 return {ext_funcs[0], empty_type_info_variant};
652 bool ext_funcs_allow_column_lists{
false};
653 for (
const auto& ext_func : ext_funcs) {
654 auto ext_func_args = ext_func.getInputArgs();
655 for (
const auto& arg : ext_func_args) {
657 ext_funcs_allow_column_lists =
true;
661 if (ext_funcs_allow_column_lists) {
666 std::vector<std::vector<SQLTypeInfo>> type_infos_variants;
667 if (ext_funcs_allow_column_lists) {
668 for (
const auto& ti : type_infos_input) {
669 if (type_infos_variants.begin() == type_infos_variants.end()) {
670 type_infos_variants.push_back({ti});
671 if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
672 if (ti.is_column()) {
674 if (mti.get_subtype() ==
kNULLT) {
677 mti.set_dimension(1);
678 type_infos_variants.push_back({mti});
683 std::vector<std::vector<SQLTypeInfo>> new_type_infos_variants;
684 for (
auto& type_infos : type_infos_variants) {
685 if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
686 if (ti.is_column()) {
687 auto new_type_infos = type_infos;
688 const auto& last = type_infos.back();
689 if (last.is_column_list() && last.has_same_itemtype(ti)) {
691 new_type_infos.back().set_dimension(last.get_dimension() + 1);
695 if (mti.get_subtype() ==
kNULLT) {
697 type_infos.push_back(ti);
700 mti.set_dimension(1);
701 new_type_infos.push_back(mti);
703 new_type_infos_variants.push_back(new_type_infos);
706 type_infos.push_back(ti);
708 type_infos_variants.insert(type_infos_variants.end(),
709 new_type_infos_variants.begin(),
710 new_type_infos_variants.end());
713 type_infos_variants.emplace_back(type_infos_input);
718 for (
const auto& ext_func : ext_funcs) {
721 const auto& ext_func_args = ext_func.getInputArgs();
722 int index_variant = -1;
723 for (
const auto& type_infos : type_infos_variants) {
725 int penalty_score = 0;
727 int original_input_idx = 0;
728 CHECK_LE(type_infos.size(), args_are_constants.size());
729 for (
const auto& ti : type_infos) {
731 args_are_constants[original_input_idx],
741 original_input_idx += ti.get_dimension();
743 original_input_idx++;
748 if ((
size_t)pos == ext_func_args.size()) {
749 CHECK_EQ(args_are_constants.size(), original_input_idx);
752 if (penalty_score < minimal_score) {
754 minimal_score = penalty_score;
755 optimal_variant = index_variant;
766 if (!ext_funcs.size()) {
767 message =
"Function " + name +
"(" + sarg_types +
") not supported.";
770 if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
771 message =
"Could not bind " + name +
"(" + sarg_types +
") to any " + processor +
772 " UDTF implementation.";
773 }
else if constexpr (std::is_same_v<T, ExtensionFunction>) {
774 message =
"Could not bind " + name +
"(" + sarg_types +
") to any " + processor +
775 " UDF implementation.";
777 LOG(
FATAL) <<
"bind_function: unknown extension function type "
780 message +=
"\n Existing extension function implementations:";
781 for (
const auto& ext_func : ext_funcs) {
783 if constexpr (std::is_same_v<T, table_functions::TableFunction>)
784 if (ext_func.useDefaultSizer())
786 message +=
"\n " + ext_func.toStringSQL();
793 if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
794 if (ext_funcs[optimal].hasUserSpecifiedOutputSizeMultiplier() &&
795 ext_funcs[optimal].useDefaultSizer()) {
796 std::string name = ext_funcs[optimal].getName();
799 for (
size_t i = 0; i < ext_funcs.size(); i++) {
800 if (ext_funcs[i].getName() ==
name) {
802 std::vector<SQLTypeInfo> type_info = type_infos_variants[optimal_variant];
803 size_t sizer = ext_funcs[optimal].getOutputRowSizeParameter();
804 type_info.insert(type_info.begin() + sizer - 1,
SQLTypeInfo(
kINT,
true));
805 return {ext_funcs[optimal], type_info};
812 return {ext_funcs[optimal], type_infos_variants[optimal_variant]};
815 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
818 const std::vector<table_functions::TableFunction>& table_funcs,
820 std::string processor = (is_gpu ?
"GPU" :
"CPU");
821 return bind_function<table_functions::TableFunction>(
822 name, input_args, table_funcs, processor);
830 std::string processor =
"GPU";
832 if (!ext_funcs.size()) {
839 bind_function<ExtensionFunction>(
name, func_args, ext_funcs, processor));
843 processor =
"GPU|CPU";
846 bind_function<ExtensionFunction>(
name, func_args, ext_funcs, processor));
857 std::vector<ExtensionFunction> ext_funcs =
859 std::string processor = (is_gpu ?
"GPU" :
"CPU");
861 bind_function<ExtensionFunction>(
name, func_args, ext_funcs, processor));
869 for (
size_t i = 0; i < function_oper->
getArity(); ++i) {
870 func_args.push_back(function_oper->
getOwnArg(i));
875 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
880 std::vector<table_functions::TableFunction> table_funcs =
ExtArgumentType get_array_arg_elem_type(const ExtArgumentType ext_arg_array_type)
static std::vector< TableFunction > get_table_funcs()
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
bool is_ext_arg_type_column(const ExtArgumentType ext_arg_type)
static int match_arguments(const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
ExtArgumentType get_column_list_arg_elem_type(const ExtArgumentType ext_arg_column_list_type)
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
HOST DEVICE SQLTypes get_type() const
std::string to_string() const
bool is_ext_arg_type_column_list(const ExtArgumentType ext_arg_type)
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
bool is_numeric_scalar_auto_castable(const SQLTypeInfo &new_type_info) const
returns true if the sql_type can be cast to the type specified by new_type_info with no loss of preci...
ExtArgumentType get_column_arg_elem_type(const ExtArgumentType ext_arg_column_type)
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
Argument type based extension function binding.
static int match_numeric_argument(const SQLTypeInfo &arg_type_info, const bool is_arg_literal, const ExtArgumentType &sig_ext_arg_type, int32_t &penalty_score)
auto generate_column_type(const SQLTypeInfo &elem_ti)
HOST DEVICE EncodingType get_compression() const
auto generate_column_list_type(const SQLTypeInfo &elem_ti)
std::string get_type_name() const
int32_t get_numeric_scalar_scale() const
returns integer between 1 and 8 indicating what is roughly equivalent to the logical byte size of a s...
const std::tuple< table_functions::TableFunction, std::vector< SQLTypeInfo > > bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
std::vector< ExpressionPtr > ExpressionPtrVector
bool is_valid_identifier(std::string str)
std::string getName() const
SQLTypeInfo get_elem_type() const
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)