OmniSciDB  c1a53651b2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsWhitelist.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <boost/algorithm/string/join.hpp>
20 #include <iostream>
21 
25 #include "Shared/StringTransform.h"
26 
27 // Get the list of all type specializations for the given function name.
28 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get(
29  const std::string& name) {
30  const auto it = functions_.find(to_upper(name));
31  if (it == functions_.end()) {
32  return nullptr;
33  }
34  return &it->second;
35 }
36 
37 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get_udf(
38  const std::string& name) {
39  const auto it = udf_functions_.find(to_upper(name));
40  if (it == udf_functions_.end()) {
41  return nullptr;
42  }
43  return &it->second;
44 }
45 
46 // Get the list of all udfs
47 std::unordered_set<std::string> ExtensionFunctionsWhitelist::get_udfs_name(
48  const bool is_runtime) {
49  std::unordered_set<std::string> names;
50  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
51  for (auto funcs : collections) {
52  for (auto& pair : *funcs) {
53  ExtensionFunction udf = pair.second.at(0);
54  if (udf.isRuntime() == is_runtime) {
55  names.insert(udf.getName(/* keep_suffix */ false));
56  }
57  }
58  }
59  return names;
60 }
61 
62 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
63  const std::string& name) {
64  std::vector<ExtensionFunction> ext_funcs = {};
65  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
66  const auto uname = to_upper(name);
67  for (auto funcs : collections) {
68  const auto it = funcs->find(uname);
69  if (it == funcs->end()) {
70  continue;
71  }
72  auto ext_func_sigs = it->second;
73  std::copy(ext_func_sigs.begin(), ext_func_sigs.end(), std::back_inserter(ext_funcs));
74  }
75  return ext_funcs;
76 }
77 
78 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
79  const std::string& name,
80  const bool is_gpu) {
81  std::vector<ExtensionFunction> ext_funcs = {};
82  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
83  const auto uname = to_upper(name);
84  for (auto funcs : collections) {
85  const auto it = funcs->find(uname);
86  if (it == funcs->end()) {
87  continue;
88  }
89  auto ext_func_sigs = it->second;
90  std::copy_if(ext_func_sigs.begin(),
91  ext_func_sigs.end(),
92  std::back_inserter(ext_funcs),
93  [is_gpu](auto sig) { return (is_gpu ? sig.isGPU() : sig.isCPU()); });
94  }
95  return ext_funcs;
96 }
97 
98 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
99  const std::string& name,
100  size_t arity) {
101  std::vector<ExtensionFunction> ext_funcs = {};
102  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
103  const auto uname = to_upper(name);
104  for (auto funcs : collections) {
105  const auto it = funcs->find(uname);
106  if (it == funcs->end()) {
107  continue;
108  }
109  auto ext_func_sigs = it->second;
110  std::copy_if(ext_func_sigs.begin(),
111  ext_func_sigs.end(),
112  std::back_inserter(ext_funcs),
113  [arity](auto sig) { return arity == sig.getInputArgs().size(); });
114  }
115  return ext_funcs;
116 }
117 
118 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
119  const std::string& name,
120  size_t arity,
121  const SQLTypeInfo& rtype) {
122  std::vector<ExtensionFunction> ext_funcs = {};
123  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
124  const auto uname = to_upper(name);
125  for (auto funcs : collections) {
126  const auto it = funcs->find(uname);
127  if (it == funcs->end()) {
128  continue;
129  }
130  auto ext_func_sigs = it->second;
131  std::copy_if(ext_func_sigs.begin(),
132  ext_func_sigs.end(),
133  std::back_inserter(ext_funcs),
134  [arity, rtype](auto sig) {
135  // Ideally, arity should be equal to the number of
136  // sig arguments but there seems to be many cases
137  // where some sig arguments will be represented
138  // with multiple arguments, for instance, array
139  // argument is translated to data pointer and array
140  // size arguments.
141  if (arity > sig.getInputArgs().size()) {
142  return false;
143  }
144  auto rt = rtype.get_type();
145  auto st = ext_arg_type_to_type_info(sig.getRet()).get_type();
146  return (st == rt || (st == kTINYINT && rt == kBOOLEAN));
147  });
148  }
149  return ext_funcs;
150 }
151 
152 namespace {
153 
154 // Returns the LLVM name for `type`.
156  bool byval = true,
157  bool declare = false) {
158  switch (type) {
160  return "i8"; // clang converts bool to i8
162  return "i8";
164  return "i16";
166  return "i32";
168  return "i64";
170  return "float";
172  return "double";
174  return "void";
176  return "i8*";
178  return "i16*";
180  return "i32*";
182  return "i64*";
184  return "float*";
186  return "double*";
188  return "i1*";
190  return (declare ? "{i8*, i64, i8}*" : "Array<i8>");
192  return (declare ? "{i16*, i64, i8}*" : "Array<i16>");
194  return (declare ? "{i32*, i64, i8}*" : "Array<i32>");
196  return (declare ? "{i64*, i64, i8}*" : "Array<i64>");
198  return (declare ? "{float*, i64, i8}*" : "Array<float>");
200  return (declare ? "{double*, i64, i8}*" : "Array<double>");
202  return (declare ? "{i1*, i64, i8}*" : "Array<i1>");
204  return (declare ? "{i32*, i64, i8}*" : "Array<TextEncodingDict>");
206  return "geo_point";
208  return "geo_multi_point";
210  return "geo_linestring";
212  return "geo_multi_linestring";
214  return "geo_polygon";
216  return "geo_multi_polygon";
218  return "cursor";
220  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "Column<i8>");
222  return (declare ? (byval ? "{i16*, i64}" : "i8*") : "Column<i16>");
224  return (declare ? (byval ? "{i32*, i64}" : "i8*") : "Column<i32>");
226  return (declare ? (byval ? "{i64*, i64}" : "i8*") : "Column<i64>");
228  return (declare ? (byval ? "{float*, i64}" : "i8*") : "Column<float>");
230  return (declare ? (byval ? "{double*, i64}" : "i8*") : "Column<double>");
232  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "Column<bool>");
234  return (declare ? (byval ? "{i32*, i64}" : "i8*") : "Column<TextEncodingDict>");
236  return (declare ? (byval ? "{i64*, i64}" : "i8*") : "Column<Timestamp>");
238  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "TextEncodingNone");
240  return (declare ? "{ i32 }" : "TextEncodingDict");
242  return (declare ? "{ i64 }" : "Timestamp");
244  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i8>");
246  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i16>");
248  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i32>");
250  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i64>");
252  return (declare ? "{i8**, i64, i64}*" : "ColumnList<float>");
254  return (declare ? "{i8**, i64, i64}*" : "ColumnList<double>");
256  return (declare ? "{i8**, i64, i64}*" : "ColumnList<bool>");
258  return (declare ? "{i8**, i64, i64}*" : "ColumnList<TextEncodingDict>");
260  return (declare ? "{i8*, i64}*" : "Column<Array<i8>>");
262  return (declare ? "{i8*, i64}*" : "Column<Array<i16>>");
264  return (declare ? "{i8*, i64}*" : "Column<Array<i32>>");
266  return (declare ? "{i8*, i64}*" : "Column<Array<i64>>");
268  return (declare ? "{i8*, i64}*" : "Column<Array<float>>");
270  return (declare ? "{i8*, i64}*" : "Column<Array<double>>");
272  return (declare ? "{i8*, i64}*" : "Column<Array<bool>>");
274  return (declare ? "{i8*, i64}" : "Column<Array<TextEncodingDict>>");
276  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i8>");
278  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i16>");
280  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i32>");
282  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i64>");
284  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<float>");
286  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<double>");
288  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<bool>");
290  return (declare ? "{i8**, i64, i64}" : "ColumnList<Array<TextEncodingDict>>");
292  return (declare ? "{ i64 }" : "DayTimeInterval");
294  return (declare ? "{ i64 }" : "YearMonthTimeInterval");
295  default:
296  CHECK(false);
297  }
298  CHECK(false);
299  return "";
300 }
301 
302 std::string drop_suffix(const std::string& str) {
303  const auto idx = str.find("__");
304  if (idx == std::string::npos) {
305  return str;
306  }
307  CHECK_GT(idx, std::string::size_type(0));
308  return str.substr(0, idx);
309 }
310 
311 } // namespace
312 
314  SQLTypes type = kNULLT;
315  int d = 0;
316  int s = 0;
317  bool n = false;
319  int p = 0;
320  SQLTypes subtype = kNULLT;
321 
322 #define EXTARGTYPECASE(EXTARGTYPE, ELEMTYPE, ENCODING, ARRAYENCODING) \
323  case ExtArgumentType::EXTARGTYPE: \
324  type = ELEMTYPE; \
325  c = kENCODING_##ENCODING; \
326  break; \
327  case ExtArgumentType::Array##EXTARGTYPE: \
328  type = kARRAY; \
329  c = kENCODING_##ENCODING; \
330  subtype = ELEMTYPE; \
331  break; \
332  case ExtArgumentType::Column##EXTARGTYPE: \
333  type = kCOLUMN; \
334  c = kENCODING_##ENCODING; \
335  subtype = ELEMTYPE; \
336  break; \
337  case ExtArgumentType::ColumnList##EXTARGTYPE: \
338  type = kCOLUMN_LIST; \
339  c = kENCODING_##ENCODING; \
340  subtype = ELEMTYPE; \
341  break; \
342  case ExtArgumentType::ColumnArray##EXTARGTYPE: \
343  type = kCOLUMN; \
344  subtype = ELEMTYPE; \
345  c = kENCODING_##ARRAYENCODING; \
346  break; \
347  case ExtArgumentType::ColumnListArray##EXTARGTYPE: \
348  type = kCOLUMN_LIST; \
349  subtype = ELEMTYPE; \
350  c = kENCODING_##ARRAYENCODING; \
351  break;
352 
353  switch (ext_arg_type) {
354  EXTARGTYPECASE(Bool, kBOOLEAN, NONE, ARRAY);
355  EXTARGTYPECASE(Int8, kTINYINT, NONE, ARRAY);
357  EXTARGTYPECASE(Int32, kINT, NONE, ARRAY);
358  EXTARGTYPECASE(Int64, kBIGINT, NONE, ARRAY);
359  EXTARGTYPECASE(Float, kFLOAT, NONE, ARRAY);
360  EXTARGTYPECASE(Double, kDOUBLE, NONE, ARRAY);
362  EXTARGTYPECASE(TextEncodingDict, kTEXT, DICT, ARRAY_DICT);
363  // TODO: EXTARGTYPECASE(Timestamp, kTIMESTAMP, NONE, ARRAY);
365  type = kTIMESTAMP;
366  c = kENCODING_NONE;
367  d = 9;
368  break;
370  type = kCOLUMN;
371  subtype = kTIMESTAMP;
372  c = kENCODING_NONE;
373  d = 9;
374  break;
376  type = kINTERVAL_DAY_TIME;
377  break;
379  type = kINTERVAL_YEAR_MONTH;
380  break;
381  default:
382  LOG(FATAL) << "ExtArgumentType `" << serialize_type(ext_arg_type)
383  << "` cannot be converted to SQLTypes.";
384  UNREACHABLE();
385  }
386  return SQLTypeInfo(type, d, s, n, c, p, subtype);
387 }
388 
390  const std::vector<ExtensionFunction>& ext_funcs,
391  std::string tab) {
392  std::string r = "";
393  for (auto sig : ext_funcs) {
394  r += tab + sig.toString() + "\n";
395  }
396  return r;
397 }
398 
400  const std::vector<SQLTypeInfo>& arg_types) {
401  std::string r = "";
402  for (auto sig = arg_types.begin(); sig != arg_types.end();) {
403  r += sig->get_type_name();
404  sig++;
405  if (sig != arg_types.end()) {
406  r += ", ";
407  }
408  }
409  return r;
410 }
411 
413  const std::vector<ExtArgumentType>& sig_types) {
414  std::string r = "";
415  for (auto t = sig_types.begin(); t != sig_types.end();) {
416  r += serialize_type(*t, /* byval */ false, /* declare */ false);
417  t++;
418  if (t != sig_types.end()) {
419  r += ", ";
420  }
421  }
422  return r;
423 }
424 
426  const std::vector<ExtArgumentType>& sig_types) {
427  std::string r = "";
428  for (auto t = sig_types.begin(); t != sig_types.end();) {
430  t++;
431  if (t != sig_types.end()) {
432  r += ", ";
433  }
434  }
435  return r;
436 }
437 
439  return serialize_type(sig_type, /* byval */ false, /* declare */ false);
440 }
441 
443  switch (sig_type) {
445  return "TINYINT";
447  return "SMALLINT";
449  return "INTEGER";
451  return "BIGINT";
453  return "FLOAT";
455  return "DOUBLE";
457  return "BOOLEAN";
459  return "TINYINT[]";
461  return "SMALLINT[]";
463  return "INT[]";
465  return "BIGINT[]";
467  return "FLOAT[]";
469  return "DOUBLE[]";
471  return "BOOLEAN[]";
473  return "ARRAY<TINYINT>";
475  return "ARRAY<SMALLINT>";
477  return "ARRAY<INT>";
479  return "ARRAY<BIGINT>";
481  return "ARRAY<FLOAT>";
483  return "ARRAY<DOUBLE>";
485  return "ARRAY<BOOLEAN>";
487  return "ARRAY<TEXT ENCODING DICT>";
489  return "COLUMN<TINYINT>";
491  return "COLUMN<SMALLINT>";
493  return "COLUMN<INT>";
495  return "COLUMN<BIGINT>";
497  return "COLUMN<FLOAT>";
499  return "COLUMN<DOUBLE>";
501  return "COLUMN<BOOLEAN>";
503  return "COLUMN<TEXT ENCODING DICT>";
505  return "COLUMN<TIMESTAMP(9)>";
507  return "CURSOR";
509  return "POINT";
511  return "MULTIPOINT";
513  return "LINESTRING";
515  return "MULTILINESTRING";
517  return "POLYGON";
519  return "MULTIPOLYGON";
521  return "VOID";
523  return "TEXT ENCODING NONE";
525  return "TEXT ENCODING DICT";
527  return "TIMESTAMP(9)";
529  return "COLUMNLIST<TINYINT>";
531  return "COLUMNLIST<SMALLINT>";
533  return "COLUMNLIST<INT>";
535  return "COLUMNLIST<BIGINT>";
537  return "COLUMNLIST<FLOAT>";
539  return "COLUMNLIST<DOUBLE>";
541  return "COLUMNLIST<BOOLEAN>";
543  return "COLUMNLIST<TEXT ENCODING DICT>";
545  return "COLUMN<ARRAY<TINYINT>>";
547  return "COLUMN<ARRAY<SMALLINT>>";
549  return "COLUMN<ARRAY<INT>>";
551  return "COLUMN<ARRAY<BIGINT>>";
553  return "COLUMN<ARRAY<FLOAT>>";
555  return "COLUMN<ARRAY<DOUBLE>>";
557  return "COLUMN<ARRAY<BOOLEAN>>";
559  return "COLUMN<ARRAY<TEXT ENCODING DICT>>";
561  return "COLUMNLIST<ARRAY<TINYINT>>";
563  return "COLUMNLIST<ARRAY<SMALLINT>>";
565  return "COLUMNLIST<ARRAY<INT>>";
567  return "COLUMNLIST<ARRAY<BIGINT>>";
569  return "COLUMNLIST<ARRAY<FLOAT>>";
571  return "COLUMNLIST<ARRAY<DOUBLE>>";
573  return "COLUMNLIST<ARRAY<BOOLEAN>>";
575  return "COLUMNLIST<ARRAY<TEXT ENCODING DICT>>";
577  return "DAY TIME INTERVAL";
579  return "YEAR MONTH INTERVAL";
580  default:
581  UNREACHABLE();
582  }
583  return "";
584 }
585 
587  // if-else exists to keep compatibility with older versions of RBC
588  if (annotations_.empty()) {
589  return false;
590  } else {
591  auto func_annotations = annotations_.back();
592  auto mgr_annotation = func_annotations.find("uses_manager");
593  if (mgr_annotation != func_annotations.end()) {
594  return boost::algorithm::to_lower_copy(mgr_annotation->second) == "true";
595  }
596  return false;
597  }
598 }
599 
600 const std::string ExtensionFunction::getName(bool keep_suffix) const {
601  return (keep_suffix ? name_ : drop_suffix(name_));
602 }
603 
604 std::string ExtensionFunction::toString() const {
605  return getName() + "(" + ExtensionFunctionsWhitelist::toString(args_) + ") -> " +
607 }
608 
609 std::string ExtensionFunction::toStringSQL() const {
610  return getName(/* keep_suffix = */ false) + "(" +
613 }
614 
615 std::string ExtensionFunction::toSignature() const {
616  return "(" + ExtensionFunctionsWhitelist::toString(args_) + ") -> " +
618 }
619 
620 // Converts the extension function signatures to their LLVM representation.
622  const std::unordered_set<std::string>& udf_decls,
623  const bool is_gpu) {
624  std::vector<std::string> declarations;
625  for (const auto& kv : functions_) {
626  const std::vector<ExtensionFunction>& ext_funcs = kv.second;
627  CHECK(!ext_funcs.empty());
628  for (const auto& ext_func : ext_funcs) {
629  // If there is a udf function declaration matching an extension function signature
630  // do not emit a duplicate declaration.
631  if (!udf_decls.empty() && udf_decls.find(ext_func.getName()) != udf_decls.end()) {
632  continue;
633  }
634 
635  std::string decl_prefix;
636  std::vector<std::string> arg_strs;
637 
638  if (is_ext_arg_type_array(ext_func.getRet())) {
639  decl_prefix = "declare void @" + ext_func.getName();
640  arg_strs.emplace_back(
641  serialize_type(ext_func.getRet(), /* byval */ true, /* declare */ true));
642  } else {
643  decl_prefix =
644  "declare " +
645  serialize_type(ext_func.getRet(), /* byval */ true, /* declare */ true) +
646  " @" + ext_func.getName();
647  }
648 
649  // if the extension function uses a Row Function Manager, append "i8*" as the first
650  // arg
651  if (ext_func.usesManager()) {
652  arg_strs.emplace_back("i8*");
653  }
654 
655  for (const auto arg : ext_func.getInputArgs()) {
656  arg_strs.emplace_back(serialize_type(arg, /* byval */ false, /* declare */ true));
657  }
658  declarations.emplace_back(decl_prefix + "(" +
659  boost::algorithm::join(arg_strs, ", ") + ");");
660  }
661  }
662 
664  if (kv.second.isRuntime() || kv.second.useDefaultSizer()) {
665  // Runtime UDTFs are defined in LLVM/NVVM IR module
666  // UDTFs using default sizer share LLVM IR
667  continue;
668  }
669  if (!((is_gpu && kv.second.isGPU()) || (!is_gpu && kv.second.isCPU()))) {
670  continue;
671  }
672  std::string decl_prefix{
673  "declare " +
674  serialize_type(ExtArgumentType::Int32, /* byval */ true, /* declare */ true) +
675  " @" + kv.first};
676  std::vector<std::string> arg_strs;
677  for (const auto arg : kv.second.getArgs(/* ensure_column = */ true)) {
678  arg_strs.push_back(
679  serialize_type(arg, /* byval= */ kv.second.isRuntime(), /* declare= */ true));
680  }
681  declarations.push_back(decl_prefix + "(" + boost::algorithm::join(arg_strs, ", ") +
682  ");");
683  }
684  return declarations;
685 }
686 
687 namespace {
688 
690  if (type_name == "bool" || type_name == "i1") {
691  return ExtArgumentType::Bool;
692  }
693  if (type_name == "i8") {
694  return ExtArgumentType::Int8;
695  }
696  if (type_name == "i16") {
697  return ExtArgumentType::Int16;
698  }
699  if (type_name == "i32") {
700  return ExtArgumentType::Int32;
701  }
702  if (type_name == "i64") {
703  return ExtArgumentType::Int64;
704  }
705  if (type_name == "float") {
706  return ExtArgumentType::Float;
707  }
708  if (type_name == "double") {
710  }
711  if (type_name == "void") {
712  return ExtArgumentType::Void;
713  }
714  if (type_name == "i8*") {
715  return ExtArgumentType::PInt8;
716  }
717  if (type_name == "i16*") {
719  }
720  if (type_name == "i32*") {
722  }
723  if (type_name == "i64*") {
725  }
726  if (type_name == "float*") {
728  }
729  if (type_name == "double*") {
731  }
732  if (type_name == "i1*" || type_name == "bool*") {
733  return ExtArgumentType::PBool;
734  }
735  if (type_name == "Array<i8>") {
737  }
738  if (type_name == "Array<i16>") {
740  }
741  if (type_name == "Array<i32>") {
743  }
744  if (type_name == "Array<i64>") {
746  }
747  if (type_name == "Array<float>") {
749  }
750  if (type_name == "Array<double>") {
752  }
753  if (type_name == "Array<bool>" || type_name == "Array<i1>") {
755  }
756  if (type_name == "Array<TextEncodingDict>") {
758  }
759  if (type_name == "geo_point") {
761  }
762  if (type_name == "geo_multi_point") {
764  }
765  if (type_name == "geo_linestring") {
767  }
768  if (type_name == "geo_multi_linestring") {
770  }
771  if (type_name == "geo_polygon") {
773  }
774  if (type_name == "geo_multi_polygon") {
776  }
777  if (type_name == "cursor") {
779  }
780  if (type_name == "Column<i8>") {
782  }
783  if (type_name == "Column<i16>") {
785  }
786  if (type_name == "Column<i32>") {
788  }
789  if (type_name == "Column<i64>") {
791  }
792  if (type_name == "Column<float>") {
794  }
795  if (type_name == "Column<double>") {
797  }
798  if (type_name == "Column<bool>") {
800  }
801  if (type_name == "Column<TextEncodingDict>") {
803  }
804  if (type_name == "Column<Timestamp>") {
806  }
807  if (type_name == "TextEncodingNone") {
809  }
810  if (type_name == "TextEncodingDict") {
812  }
813  if (type_name == "timestamp") {
815  }
816  if (type_name == "ColumnList<i8>") {
818  }
819  if (type_name == "ColumnList<i16>") {
821  }
822  if (type_name == "ColumnList<i32>") {
824  }
825  if (type_name == "ColumnList<i64>") {
827  }
828  if (type_name == "ColumnList<float>") {
830  }
831  if (type_name == "ColumnList<double>") {
833  }
834  if (type_name == "ColumnList<bool>") {
836  }
837  if (type_name == "ColumnList<TextEncodingDict>") {
839  }
840  if (type_name == "Column<Array<i8>>") {
842  }
843  if (type_name == "Column<Array<i16>>") {
845  }
846  if (type_name == "Column<Array<i32>>") {
848  }
849  if (type_name == "Column<Array<i64>>") {
851  }
852  if (type_name == "Column<Array<float>>") {
854  }
855  if (type_name == "Column<Array<double>>") {
857  }
858  if (type_name == "Column<Array<bool>>") {
860  }
861  if (type_name == "Column<Array<TextEncodingDict>>") {
863  }
864  if (type_name == "ColumnList<Array<i8>>") {
866  }
867  if (type_name == "ColumnList<Array<i16>>") {
869  }
870  if (type_name == "ColumnList<Array<i32>>") {
872  }
873  if (type_name == "ColumnList<Array<i64>>") {
875  }
876  if (type_name == "ColumnList<Array<float>>") {
878  }
879  if (type_name == "ColumnList<Array<double>>") {
881  }
882  if (type_name == "ColumnList<Array<bool>>") {
884  }
885  if (type_name == "ColumnList<Array<TextEncodingDict>>") {
887  }
888  if (type_name == "DayTimeInterval") {
890  }
891  if (type_name == "YearMonthTimeInterval") {
893  }
894  CHECK(false);
895  return ExtArgumentType::Int16;
896 }
897 
898 } // namespace
899 
900 using SignatureMap = std::unordered_map<std::string, std::vector<ExtensionFunction>>;
901 
903  const std::string& json_func_sigs,
904  const bool is_runtime) {
905  rapidjson::Document func_sigs;
906  func_sigs.Parse(json_func_sigs.c_str());
907  CHECK(func_sigs.IsArray());
908  for (auto func_sigs_it = func_sigs.Begin(); func_sigs_it != func_sigs.End();
909  ++func_sigs_it) {
910  CHECK(func_sigs_it->IsObject());
911  const auto name = json_str(field(*func_sigs_it, "name"));
912  const auto ret = deserialize_type(json_str(field(*func_sigs_it, "ret")));
913  std::vector<ExtArgumentType> args;
914  const auto& args_serialized = field(*func_sigs_it, "args");
915  CHECK(args_serialized.IsArray());
916  for (auto args_serialized_it = args_serialized.Begin();
917  args_serialized_it != args_serialized.End();
918  ++args_serialized_it) {
919  args.push_back(deserialize_type(json_str(*args_serialized_it)));
920  }
921 
922  std::vector<std::map<std::string, std::string>> annotations;
923  const auto& anns = field(*func_sigs_it, "annotations");
924  CHECK(anns.IsArray());
925  static const std::map<std::string, std::string> map_empty = {};
926  for (auto obj = anns.Begin(); obj != anns.End(); ++obj) {
927  CHECK(obj->IsObject());
928  if (obj->ObjectEmpty()) {
929  annotations.push_back(map_empty);
930  } else {
931  std::map<std::string, std::string> m;
932  for (auto kv = obj->MemberBegin(); kv != obj->MemberEnd(); ++kv) {
933  m[kv->name.GetString()] = kv->value.GetString();
934  }
935  annotations.push_back(m);
936  }
937  }
938  signatures[to_upper(drop_suffix(name))].emplace_back(
939  name, args, ret, annotations, is_runtime);
940  }
941 }
942 
943 // Calcite loads the available extensions from `ExtensionFunctions.ast`, adds
944 // them to its operator table and shares the list with the execution layer in
945 // JSON format. Build an in-memory representation of that list here so that it
946 // can be used by getLLVMDeclarations(), when the LLVM IR codegen asks for it.
947 void ExtensionFunctionsWhitelist::add(const std::string& json_func_sigs) {
948  // Valid json_func_sigs example:
949  // [
950  // {
951  // "name":"sum",
952  // "ret":"i32",
953  // "args":[
954  // "i32",
955  // "i32"
956  // ]
957  // }
958  // ]
959 
960  addCommon(functions_, json_func_sigs, /* is_runtime */ false);
961 }
962 
963 void ExtensionFunctionsWhitelist::addUdfs(const std::string& json_func_sigs) {
964  if (!json_func_sigs.empty()) {
965  addCommon(udf_functions_, json_func_sigs, /* is_runtime */ false);
966  }
967 }
968 
970  rt_udf_functions_.clear();
971 }
972 
973 void ExtensionFunctionsWhitelist::addRTUdfs(const std::string& json_func_sigs) {
974  if (!json_func_sigs.empty()) {
975  addCommon(rt_udf_functions_, json_func_sigs, /* is_runtime */ true);
976  }
977 }
978 
979 std::unordered_map<std::string, std::vector<ExtensionFunction>>
981 
982 std::unordered_map<std::string, std::vector<ExtensionFunction>>
984 
985 std::unordered_map<std::string, std::vector<ExtensionFunction>>
987 
988 std::string toString(const ExtArgumentType& sig_type) {
989  return ExtensionFunctionsWhitelist::toString(sig_type);
990 }
static void addUdfs(const std::string &json_func_sigs)
static void addCommon(std::unordered_map< std::string, std::vector< ExtensionFunction >> &sigs, const std::string &json_func_sigs, const bool is_runtime)
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
static std::unordered_map< std::string, std::vector< ExtensionFunction > > udf_functions_
SQLTypes
Definition: sqltypes.h:55
static std::vector< ExtensionFunction > * get(const std::string &name)
#define LOG(tag)
Definition: Logger.h:285
std::string toSignature() const
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:44
static std::unordered_map< std::string, std::vector< ExtensionFunction > > rt_udf_functions_
std::string join(T const &container, std::string const &delim)
const std::vector< ExtArgumentType > args_
static void add(const std::string &json_func_sigs)
#define UNREACHABLE()
Definition: Logger.h:337
std::unordered_map< std::string, std::vector< ExtensionFunction >> SignatureMap
const std::string getName(bool keep_suffix=true) const
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:381
std::string toStringSQL() const
#define CHECK_GT(x, y)
Definition: Logger.h:305
static std::unordered_map< std::string, std::vector< ExtensionFunction > > functions_
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
EncodingType
Definition: sqltypes.h:230
Supported runtime functions management and retrieval.
static std::vector< ExtensionFunction > * get_udf(const std::string &name)
#define EXTARGTYPECASE(EXTARGTYPE, ELEMTYPE, ENCODING, ARRAYENCODING)
DEVICE auto copy(ARGS &&...args)
Definition: gpu_enabled.h:51
ExtArgumentType deserialize_type(const std::string &type_name)
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
static std::unordered_set< std::string > get_udfs_name(const bool is_runtime)
Checked json field retrieval.
std::string toString(const ExecutorDeviceType &device_type)
Argument type based extension function binding.
std::string to_upper(const std::string &str)
Definition: sqltypes.h:69
std::string serialize_type(const ExtArgumentType type, bool byval=true, bool declare=false)
const std::vector< std::map< std::string, std::string > > annotations_
const ExtArgumentType ret_
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:291
static std::unordered_map< std::string, TableFunction > functions_
Definition: sqltypes.h:62
string name
Definition: setup.in.py:72
constexpr double n
Definition: Utm.h:38
constexpr auto type_name() noexcept
static std::vector< std::string > getLLVMDeclarations(const std::unordered_set< std::string > &udf_decls, const bool is_gpu=false)
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
static std::string toStringSQL(const std::vector< ExtArgumentType > &sig_types)
static void addRTUdfs(const std::string &json_func_sigs)