OmniSciDB  a987f07e93
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsWhitelist.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <boost/algorithm/string/join.hpp>
20 #include <iostream>
21 
25 #include "Shared/StringTransform.h"
26 
27 // Get the list of all type specializations for the given function name.
28 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get(
29  const std::string& name) {
30  const auto it = functions_.find(to_upper(name));
31  if (it == functions_.end()) {
32  return nullptr;
33  }
34  return &it->second;
35 }
36 
37 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get_udf(
38  const std::string& name) {
39  const auto it = udf_functions_.find(to_upper(name));
40  if (it == udf_functions_.end()) {
41  return nullptr;
42  }
43  return &it->second;
44 }
45 
46 // Get the list of all udfs
47 std::unordered_set<std::string> ExtensionFunctionsWhitelist::get_udfs_name(
48  const bool is_runtime) {
49  std::unordered_set<std::string> names;
50  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
51  for (auto funcs : collections) {
52  for (auto& pair : *funcs) {
53  ExtensionFunction udf = pair.second.at(0);
54  if (udf.isRuntime() == is_runtime) {
55  names.insert(udf.getName(/* keep_suffix */ false));
56  }
57  }
58  }
59  return names;
60 }
61 
62 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
63  const std::string& name) {
64  std::vector<ExtensionFunction> ext_funcs = {};
65  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
66  const auto uname = to_upper(name);
67  for (auto funcs : collections) {
68  const auto it = funcs->find(uname);
69  if (it == funcs->end()) {
70  continue;
71  }
72  auto ext_func_sigs = it->second;
73  std::copy(ext_func_sigs.begin(), ext_func_sigs.end(), std::back_inserter(ext_funcs));
74  }
75  return ext_funcs;
76 }
77 
78 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
79  const std::string& name,
80  const bool is_gpu) {
81  std::vector<ExtensionFunction> ext_funcs = {};
82  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
83  const auto uname = to_upper(name);
84  for (auto funcs : collections) {
85  const auto it = funcs->find(uname);
86  if (it == funcs->end()) {
87  continue;
88  }
89  auto ext_func_sigs = it->second;
90  std::copy_if(ext_func_sigs.begin(),
91  ext_func_sigs.end(),
92  std::back_inserter(ext_funcs),
93  [is_gpu](auto sig) { return (is_gpu ? sig.isGPU() : sig.isCPU()); });
94  }
95  return ext_funcs;
96 }
97 
98 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
99  const std::string& name,
100  size_t arity) {
101  std::vector<ExtensionFunction> ext_funcs = {};
102  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
103  const auto uname = to_upper(name);
104  for (auto funcs : collections) {
105  const auto it = funcs->find(uname);
106  if (it == funcs->end()) {
107  continue;
108  }
109  auto ext_func_sigs = it->second;
110  std::copy_if(ext_func_sigs.begin(),
111  ext_func_sigs.end(),
112  std::back_inserter(ext_funcs),
113  [arity](auto sig) { return arity == sig.getInputArgs().size(); });
114  }
115  return ext_funcs;
116 }
117 
118 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
119  const std::string& name,
120  size_t arity,
121  const SQLTypeInfo& rtype) {
122  std::vector<ExtensionFunction> ext_funcs = {};
123  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
124  const auto uname = to_upper(name);
125  for (auto funcs : collections) {
126  const auto it = funcs->find(uname);
127  if (it == funcs->end()) {
128  continue;
129  }
130  auto ext_func_sigs = it->second;
131  std::copy_if(ext_func_sigs.begin(),
132  ext_func_sigs.end(),
133  std::back_inserter(ext_funcs),
134  [arity, rtype](auto sig) {
135  // Ideally, arity should be equal to the number of
136  // sig arguments but there seems to be many cases
137  // where some sig arguments will be represented
138  // with multiple arguments, for instance, array
139  // argument is translated to data pointer and array
140  // size arguments.
141  if (arity > sig.getInputArgs().size()) {
142  return false;
143  }
144  auto rt = rtype.get_type();
145  auto st = ext_arg_type_to_type_info(sig.getRet()).get_type();
146  return (st == rt || (st == kTINYINT && rt == kBOOLEAN));
147  });
148  }
149  return ext_funcs;
150 }
151 
152 namespace {
153 
154 // Returns the LLVM name for `type`.
156  bool byval = true,
157  bool declare = false) {
158  switch (type) {
160  return "i8"; // clang converts bool to i8
162  return "i8";
164  return "i16";
166  return "i32";
168  return "i64";
170  return "float";
172  return "double";
174  return "void";
176  return "i8*";
178  return "i16*";
180  return "i32*";
182  return "i64*";
184  return "float*";
186  return "double*";
188  return "i1*";
190  return (declare ? "{i8*, i64, i8}*" : "Array<i8>");
192  return (declare ? "{i16*, i64, i8}*" : "Array<i16>");
194  return (declare ? "{i32*, i64, i8}*" : "Array<i32>");
196  return (declare ? "{i64*, i64, i8}*" : "Array<i64>");
198  return (declare ? "{float*, i64, i8}*" : "Array<float>");
200  return (declare ? "{double*, i64, i8}*" : "Array<double>");
202  return (declare ? "{i1*, i64, i8}*" : "Array<i1>");
204  return (declare ? "{i32*, i64, i8}*" : "Array<TextEncodingDict>");
206  return "geo_point";
208  return "geo_multi_point";
210  return "geo_linestring";
212  return "geo_multi_linestring";
214  return "geo_polygon";
216  return "geo_multi_polygon";
218  return "cursor";
220  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "Column<i8>");
222  return (declare ? (byval ? "{i16*, i64}" : "i8*") : "Column<i16>");
224  return (declare ? (byval ? "{i32*, i64}" : "i8*") : "Column<i32>");
226  return (declare ? (byval ? "{i64*, i64}" : "i8*") : "Column<i64>");
228  return (declare ? (byval ? "{float*, i64}" : "i8*") : "Column<float>");
230  return (declare ? (byval ? "{double*, i64}" : "i8*") : "Column<double>");
232  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "Column<bool>");
234  return (declare ? (byval ? "{i32*, i64}" : "i8*") : "Column<TextEncodingDict>");
236  return (declare ? (byval ? "{i64*, i64}" : "i8*") : "Column<Timestamp>");
238  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "TextEncodingNone");
240  return (declare ? "{ i32 }" : "TextEncodingDict");
242  return (declare ? "{ i64 }" : "Timestamp");
244  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i8>");
246  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i16>");
248  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i32>");
250  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i64>");
252  return (declare ? "{i8**, i64, i64}*" : "ColumnList<float>");
254  return (declare ? "{i8**, i64, i64}*" : "ColumnList<double>");
256  return (declare ? "{i8**, i64, i64}*" : "ColumnList<bool>");
258  return (declare ? "{i8**, i64, i64}*" : "ColumnList<TextEncodingDict>");
260  return (declare ? "{i8*, i64}*" : "Column<Array<i8>>");
262  return (declare ? "{i8*, i64}*" : "Column<Array<i16>>");
264  return (declare ? "{i8*, i64}*" : "Column<Array<i32>>");
266  return (declare ? "{i8*, i64}*" : "Column<Array<i64>>");
268  return (declare ? "{i8*, i64}*" : "Column<Array<float>>");
270  return (declare ? "{i8*, i64}*" : "Column<Array<double>>");
272  return (declare ? "{i8*, i64}*" : "Column<Array<bool>>");
274  return (declare ? "{i8*, i64}" : "Column<Array<TextEncodingDict>>");
276  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i8>");
278  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i16>");
280  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i32>");
282  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i64>");
284  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<float>");
286  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<double>");
288  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<bool>");
290  return (declare ? "{i8**, i64, i64}" : "ColumnList<Array<TextEncodingDict>>");
292  return (declare ? "{ i64 }" : "DayTimeInterval");
294  return (declare ? "{ i64 }" : "YearMonthTimeInterval");
295  default:
296  CHECK(false);
297  }
298  CHECK(false);
299  return "";
300 }
301 
302 std::string drop_suffix(const std::string& str) {
303  const auto idx = str.find("__");
304  if (idx == std::string::npos) {
305  return str;
306  }
307  CHECK_GT(idx, std::string::size_type(0));
308  return str.substr(0, idx);
309 }
310 
311 } // namespace
312 
314  SQLTypes type = kNULLT;
315  int d = 0;
316  int s = 0;
317  bool n = false;
319  int p = 0;
320  SQLTypes subtype = kNULLT;
321 
322 #define EXTARGTYPECASE(EXTARGTYPE, ELEMTYPE, ENCODING, ARRAYENCODING) \
323  case ExtArgumentType::EXTARGTYPE: \
324  type = ELEMTYPE; \
325  c = kENCODING_##ENCODING; \
326  break; \
327  case ExtArgumentType::Array##EXTARGTYPE: \
328  type = kARRAY; \
329  c = kENCODING_##ENCODING; \
330  subtype = ELEMTYPE; \
331  break; \
332  case ExtArgumentType::Column##EXTARGTYPE: \
333  type = kCOLUMN; \
334  c = kENCODING_##ENCODING; \
335  subtype = ELEMTYPE; \
336  break; \
337  case ExtArgumentType::ColumnList##EXTARGTYPE: \
338  type = kCOLUMN_LIST; \
339  c = kENCODING_##ENCODING; \
340  subtype = ELEMTYPE; \
341  break; \
342  case ExtArgumentType::ColumnArray##EXTARGTYPE: \
343  type = kCOLUMN; \
344  subtype = ELEMTYPE; \
345  c = kENCODING_##ARRAYENCODING; \
346  break; \
347  case ExtArgumentType::ColumnListArray##EXTARGTYPE: \
348  type = kCOLUMN_LIST; \
349  subtype = ELEMTYPE; \
350  c = kENCODING_##ARRAYENCODING; \
351  break;
352 
353  switch (ext_arg_type) {
354  EXTARGTYPECASE(Bool, kBOOLEAN, NONE, ARRAY);
355  EXTARGTYPECASE(Int8, kTINYINT, NONE, ARRAY);
357  EXTARGTYPECASE(Int32, kINT, NONE, ARRAY);
358  EXTARGTYPECASE(Int64, kBIGINT, NONE, ARRAY);
359  EXTARGTYPECASE(Float, kFLOAT, NONE, ARRAY);
360  EXTARGTYPECASE(Double, kDOUBLE, NONE, ARRAY);
362  EXTARGTYPECASE(TextEncodingDict, kTEXT, DICT, ARRAY_DICT);
363  // TODO: EXTARGTYPECASE(Timestamp, kTIMESTAMP, NONE, ARRAY);
365  type = kTIMESTAMP;
366  c = kENCODING_NONE;
367  d = 9;
368  break;
370  type = kCOLUMN;
371  subtype = kTIMESTAMP;
372  c = kENCODING_NONE;
373  d = 9;
374  break;
376  type = kINTERVAL_DAY_TIME;
377  break;
379  type = kINTERVAL_YEAR_MONTH;
380  break;
381  default:
382  LOG(FATAL) << "ExtArgumentType `" << serialize_type(ext_arg_type)
383  << "` cannot be converted to SQLTypes.";
384  UNREACHABLE();
385  }
386  return SQLTypeInfo(type, d, s, n, c, p, subtype);
387 }
388 
390  const std::vector<ExtensionFunction>& ext_funcs,
391  std::string tab) {
392  std::string r = "";
393  for (auto sig : ext_funcs) {
394  r += tab + sig.toString() + "\n";
395  }
396  return r;
397 }
398 
400  const std::vector<SQLTypeInfo>& arg_types) {
401  std::string r = "";
402  for (auto sig = arg_types.begin(); sig != arg_types.end();) {
403  r += sig->get_type_name();
404  sig++;
405  if (sig != arg_types.end()) {
406  r += ", ";
407  }
408  }
409  return r;
410 }
411 
413  const std::vector<ExtArgumentType>& sig_types) {
414  std::string r = "";
415  for (auto t = sig_types.begin(); t != sig_types.end();) {
416  r += serialize_type(*t, /* byval */ false, /* declare */ false);
417  t++;
418  if (t != sig_types.end()) {
419  r += ", ";
420  }
421  }
422  return r;
423 }
424 
426  const std::vector<ExtArgumentType>& sig_types) {
427  std::string r = "";
428  for (auto t = sig_types.begin(); t != sig_types.end();) {
430  t++;
431  if (t != sig_types.end()) {
432  r += ", ";
433  }
434  }
435  return r;
436 }
437 
439  return serialize_type(sig_type, /* byval */ false, /* declare */ false);
440 }
441 
443  switch (sig_type) {
445  return "TINYINT";
447  return "SMALLINT";
449  return "INTEGER";
451  return "BIGINT";
453  return "FLOAT";
455  return "DOUBLE";
457  return "BOOLEAN";
459  return "TINYINT[]";
461  return "SMALLINT[]";
463  return "INT[]";
465  return "BIGINT[]";
467  return "FLOAT[]";
469  return "DOUBLE[]";
471  return "BOOLEAN[]";
473  return "ARRAY<TINYINT>";
475  return "ARRAY<SMALLINT>";
477  return "ARRAY<INT>";
479  return "ARRAY<BIGINT>";
481  return "ARRAY<FLOAT>";
483  return "ARRAY<DOUBLE>";
485  return "ARRAY<BOOLEAN>";
487  return "ARRAY<TEXT ENCODING DICT>";
489  return "COLUMN<TINYINT>";
491  return "COLUMN<SMALLINT>";
493  return "COLUMN<INT>";
495  return "COLUMN<BIGINT>";
497  return "COLUMN<FLOAT>";
499  return "COLUMN<DOUBLE>";
501  return "COLUMN<BOOLEAN>";
503  return "COLUMN<TEXT ENCODING DICT>";
505  return "COLUMN<TIMESTAMP(9)>";
507  return "CURSOR";
509  return "POINT";
511  return "MULTIPOINT";
513  return "LINESTRING";
515  return "MULTILINESTRING";
517  return "POLYGON";
519  return "MULTIPOLYGON";
521  return "VOID";
523  return "TEXT ENCODING NONE";
525  return "TEXT ENCODING DICT";
527  return "TIMESTAMP(9)";
529  return "COLUMNLIST<TINYINT>";
531  return "COLUMNLIST<SMALLINT>";
533  return "COLUMNLIST<INT>";
535  return "COLUMNLIST<BIGINT>";
537  return "COLUMNLIST<FLOAT>";
539  return "COLUMNLIST<DOUBLE>";
541  return "COLUMNLIST<BOOLEAN>";
543  return "COLUMNLIST<TEXT ENCODING DICT>";
545  return "COLUMN<ARRAY<TINYINT>>";
547  return "COLUMN<ARRAY<SMALLINT>>";
549  return "COLUMN<ARRAY<INT>>";
551  return "COLUMN<ARRAY<BIGINT>>";
553  return "COLUMN<ARRAY<FLOAT>>";
555  return "COLUMN<ARRAY<DOUBLE>>";
557  return "COLUMN<ARRAY<BOOLEAN>>";
559  return "COLUMN<ARRAY<TEXT ENCODING DICT>>";
561  return "COLUMNLIST<ARRAY<TINYINT>>";
563  return "COLUMNLIST<ARRAY<SMALLINT>>";
565  return "COLUMNLIST<ARRAY<INT>>";
567  return "COLUMNLIST<ARRAY<BIGINT>>";
569  return "COLUMNLIST<ARRAY<FLOAT>>";
571  return "COLUMNLIST<ARRAY<DOUBLE>>";
573  return "COLUMNLIST<ARRAY<BOOLEAN>>";
575  return "COLUMNLIST<ARRAY<TEXT ENCODING DICT>>";
577  return "DAY TIME INTERVAL";
579  return "YEAR MONTH INTERVAL";
580  default:
581  UNREACHABLE();
582  }
583  return "";
584 }
585 
587  auto func_annotations = annotations_.back();
588  auto mgr_annotation = func_annotations.find("uses_manager");
589  bool uses_manager = mgr_annotation != func_annotations.end() &&
590  boost::algorithm::to_lower_copy(mgr_annotation->second) == "true";
591  return uses_manager;
592 }
593 
594 const std::string ExtensionFunction::getName(bool keep_suffix) const {
595  return (keep_suffix ? name_ : drop_suffix(name_));
596 }
597 
598 std::string ExtensionFunction::toString() const {
599  return getName() + "(" + ExtensionFunctionsWhitelist::toString(args_) + ") -> " +
601 }
602 
603 std::string ExtensionFunction::toStringSQL() const {
604  return getName(/* keep_suffix = */ false) + "(" +
607 }
608 
609 std::string ExtensionFunction::toSignature() const {
610  return "(" + ExtensionFunctionsWhitelist::toString(args_) + ") -> " +
612 }
613 
614 // Converts the extension function signatures to their LLVM representation.
616  const std::unordered_set<std::string>& udf_decls,
617  const bool is_gpu) {
618  std::vector<std::string> declarations;
619  for (const auto& kv : functions_) {
620  const std::vector<ExtensionFunction>& ext_funcs = kv.second;
621  CHECK(!ext_funcs.empty());
622  for (const auto& ext_func : ext_funcs) {
623  // If there is a udf function declaration matching an extension function signature
624  // do not emit a duplicate declaration.
625  if (!udf_decls.empty() && udf_decls.find(ext_func.getName()) != udf_decls.end()) {
626  continue;
627  }
628 
629  std::string decl_prefix;
630  std::vector<std::string> arg_strs;
631 
632  if (is_ext_arg_type_array(ext_func.getRet())) {
633  decl_prefix = "declare void @" + ext_func.getName();
634  arg_strs.emplace_back(
635  serialize_type(ext_func.getRet(), /* byval */ true, /* declare */ true));
636  } else {
637  decl_prefix =
638  "declare " +
639  serialize_type(ext_func.getRet(), /* byval */ true, /* declare */ true) +
640  " @" + ext_func.getName();
641  }
642 
643  // if the extension function uses a Row Function Manager, append "i8*" as the first
644  // arg
645  if (ext_func.usesManager()) {
646  arg_strs.emplace_back("i8*");
647  }
648 
649  for (const auto arg : ext_func.getInputArgs()) {
650  arg_strs.emplace_back(serialize_type(arg, /* byval */ false, /* declare */ true));
651  }
652  declarations.emplace_back(decl_prefix + "(" +
653  boost::algorithm::join(arg_strs, ", ") + ");");
654  }
655  }
656 
658  if (kv.second.isRuntime() || kv.second.useDefaultSizer()) {
659  // Runtime UDTFs are defined in LLVM/NVVM IR module
660  // UDTFs using default sizer share LLVM IR
661  continue;
662  }
663  if (!((is_gpu && kv.second.isGPU()) || (!is_gpu && kv.second.isCPU()))) {
664  continue;
665  }
666  std::string decl_prefix{
667  "declare " +
668  serialize_type(ExtArgumentType::Int32, /* byval */ true, /* declare */ true) +
669  " @" + kv.first};
670  std::vector<std::string> arg_strs;
671  for (const auto arg : kv.second.getArgs(/* ensure_column = */ true)) {
672  arg_strs.push_back(
673  serialize_type(arg, /* byval= */ kv.second.isRuntime(), /* declare= */ true));
674  }
675  declarations.push_back(decl_prefix + "(" + boost::algorithm::join(arg_strs, ", ") +
676  ");");
677  }
678  return declarations;
679 }
680 
681 namespace {
682 
684  if (type_name == "bool" || type_name == "i1") {
685  return ExtArgumentType::Bool;
686  }
687  if (type_name == "i8") {
688  return ExtArgumentType::Int8;
689  }
690  if (type_name == "i16") {
691  return ExtArgumentType::Int16;
692  }
693  if (type_name == "i32") {
694  return ExtArgumentType::Int32;
695  }
696  if (type_name == "i64") {
697  return ExtArgumentType::Int64;
698  }
699  if (type_name == "float") {
700  return ExtArgumentType::Float;
701  }
702  if (type_name == "double") {
704  }
705  if (type_name == "void") {
706  return ExtArgumentType::Void;
707  }
708  if (type_name == "i8*") {
709  return ExtArgumentType::PInt8;
710  }
711  if (type_name == "i16*") {
713  }
714  if (type_name == "i32*") {
716  }
717  if (type_name == "i64*") {
719  }
720  if (type_name == "float*") {
722  }
723  if (type_name == "double*") {
725  }
726  if (type_name == "i1*" || type_name == "bool*") {
727  return ExtArgumentType::PBool;
728  }
729  if (type_name == "Array<i8>") {
731  }
732  if (type_name == "Array<i16>") {
734  }
735  if (type_name == "Array<i32>") {
737  }
738  if (type_name == "Array<i64>") {
740  }
741  if (type_name == "Array<float>") {
743  }
744  if (type_name == "Array<double>") {
746  }
747  if (type_name == "Array<bool>" || type_name == "Array<i1>") {
749  }
750  if (type_name == "Array<TextEncodingDict>") {
752  }
753  if (type_name == "geo_point") {
755  }
756  if (type_name == "geo_multi_point") {
758  }
759  if (type_name == "geo_linestring") {
761  }
762  if (type_name == "geo_multi_linestring") {
764  }
765  if (type_name == "geo_polygon") {
767  }
768  if (type_name == "geo_multi_polygon") {
770  }
771  if (type_name == "cursor") {
773  }
774  if (type_name == "Column<i8>") {
776  }
777  if (type_name == "Column<i16>") {
779  }
780  if (type_name == "Column<i32>") {
782  }
783  if (type_name == "Column<i64>") {
785  }
786  if (type_name == "Column<float>") {
788  }
789  if (type_name == "Column<double>") {
791  }
792  if (type_name == "Column<bool>") {
794  }
795  if (type_name == "Column<TextEncodingDict>") {
797  }
798  if (type_name == "Column<Timestamp>") {
800  }
801  if (type_name == "TextEncodingNone") {
803  }
804  if (type_name == "TextEncodingDict") {
806  }
807  if (type_name == "timestamp") {
809  }
810  if (type_name == "ColumnList<i8>") {
812  }
813  if (type_name == "ColumnList<i16>") {
815  }
816  if (type_name == "ColumnList<i32>") {
818  }
819  if (type_name == "ColumnList<i64>") {
821  }
822  if (type_name == "ColumnList<float>") {
824  }
825  if (type_name == "ColumnList<double>") {
827  }
828  if (type_name == "ColumnList<bool>") {
830  }
831  if (type_name == "ColumnList<TextEncodingDict>") {
833  }
834  if (type_name == "Column<Array<i8>>") {
836  }
837  if (type_name == "Column<Array<i16>>") {
839  }
840  if (type_name == "Column<Array<i32>>") {
842  }
843  if (type_name == "Column<Array<i64>>") {
845  }
846  if (type_name == "Column<Array<float>>") {
848  }
849  if (type_name == "Column<Array<double>>") {
851  }
852  if (type_name == "Column<Array<bool>>") {
854  }
855  if (type_name == "Column<Array<TextEncodingDict>>") {
857  }
858  if (type_name == "ColumnList<Array<i8>>") {
860  }
861  if (type_name == "ColumnList<Array<i16>>") {
863  }
864  if (type_name == "ColumnList<Array<i32>>") {
866  }
867  if (type_name == "ColumnList<Array<i64>>") {
869  }
870  if (type_name == "ColumnList<Array<float>>") {
872  }
873  if (type_name == "ColumnList<Array<double>>") {
875  }
876  if (type_name == "ColumnList<Array<bool>>") {
878  }
879  if (type_name == "ColumnList<Array<TextEncodingDict>>") {
881  }
882  if (type_name == "DayTimeInterval") {
884  }
885  if (type_name == "YearMonthTimeInterval") {
887  }
888  CHECK(false);
889  return ExtArgumentType::Int16;
890 }
891 
892 } // namespace
893 
894 using SignatureMap = std::unordered_map<std::string, std::vector<ExtensionFunction>>;
895 
897  const std::string& json_func_sigs,
898  const bool is_runtime) {
899  rapidjson::Document func_sigs;
900  func_sigs.Parse(json_func_sigs.c_str());
901  CHECK(func_sigs.IsArray());
902  for (auto func_sigs_it = func_sigs.Begin(); func_sigs_it != func_sigs.End();
903  ++func_sigs_it) {
904  CHECK(func_sigs_it->IsObject());
905  const auto name = json_str(field(*func_sigs_it, "name"));
906  const auto ret = deserialize_type(json_str(field(*func_sigs_it, "ret")));
907  std::vector<ExtArgumentType> args;
908  const auto& args_serialized = field(*func_sigs_it, "args");
909  CHECK(args_serialized.IsArray());
910  for (auto args_serialized_it = args_serialized.Begin();
911  args_serialized_it != args_serialized.End();
912  ++args_serialized_it) {
913  args.push_back(deserialize_type(json_str(*args_serialized_it)));
914  }
915 
916  std::vector<std::map<std::string, std::string>> annotations;
917  const auto& anns = field(*func_sigs_it, "annotations");
918  CHECK(anns.IsArray());
919  static const std::map<std::string, std::string> map_empty = {};
920  for (auto obj = anns.Begin(); obj != anns.End(); ++obj) {
921  CHECK(obj->IsObject());
922  if (obj->ObjectEmpty()) {
923  annotations.push_back(map_empty);
924  } else {
925  std::map<std::string, std::string> m;
926  for (auto kv = obj->MemberBegin(); kv != obj->MemberEnd(); ++kv) {
927  m[kv->name.GetString()] = kv->value.GetString();
928  }
929  annotations.push_back(m);
930  }
931  }
932  signatures[to_upper(drop_suffix(name))].emplace_back(
933  name, args, ret, annotations, is_runtime);
934  }
935 }
936 
937 // Calcite loads the available extensions from `ExtensionFunctions.ast`, adds
938 // them to its operator table and shares the list with the execution layer in
939 // JSON format. Build an in-memory representation of that list here so that it
940 // can be used by getLLVMDeclarations(), when the LLVM IR codegen asks for it.
941 void ExtensionFunctionsWhitelist::add(const std::string& json_func_sigs) {
942  // Valid json_func_sigs example:
943  // [
944  // {
945  // "name":"sum",
946  // "ret":"i32",
947  // "args":[
948  // "i32",
949  // "i32"
950  // ]
951  // }
952  // ]
953 
954  addCommon(functions_, json_func_sigs, /* is_runtime */ false);
955 }
956 
957 void ExtensionFunctionsWhitelist::addUdfs(const std::string& json_func_sigs) {
958  if (!json_func_sigs.empty()) {
959  addCommon(udf_functions_, json_func_sigs, /* is_runtime */ false);
960  }
961 }
962 
964  rt_udf_functions_.clear();
965 }
966 
967 void ExtensionFunctionsWhitelist::addRTUdfs(const std::string& json_func_sigs) {
968  if (!json_func_sigs.empty()) {
969  addCommon(rt_udf_functions_, json_func_sigs, /* is_runtime */ true);
970  }
971 }
972 
973 std::unordered_map<std::string, std::vector<ExtensionFunction>>
975 
976 std::unordered_map<std::string, std::vector<ExtensionFunction>>
978 
979 std::unordered_map<std::string, std::vector<ExtensionFunction>>
981 
982 std::string toString(const ExtArgumentType& sig_type) {
983  return ExtensionFunctionsWhitelist::toString(sig_type);
984 }
static void addUdfs(const std::string &json_func_sigs)
static void addCommon(std::unordered_map< std::string, std::vector< ExtensionFunction >> &sigs, const std::string &json_func_sigs, const bool is_runtime)
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
static std::unordered_map< std::string, std::vector< ExtensionFunction > > udf_functions_
SQLTypes
Definition: sqltypes.h:53
static std::vector< ExtensionFunction > * get(const std::string &name)
#define LOG(tag)
Definition: Logger.h:283
std::string toSignature() const
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:44
static std::unordered_map< std::string, std::vector< ExtensionFunction > > rt_udf_functions_
std::string join(T const &container, std::string const &delim)
const std::vector< ExtArgumentType > args_
static void add(const std::string &json_func_sigs)
#define UNREACHABLE()
Definition: Logger.h:333
std::unordered_map< std::string, std::vector< ExtensionFunction >> SignatureMap
const std::string getName(bool keep_suffix=true) const
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:380
std::string toStringSQL() const
#define CHECK_GT(x, y)
Definition: Logger.h:301
static std::unordered_map< std::string, std::vector< ExtensionFunction > > functions_
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
EncodingType
Definition: sqltypes.h:228
Supported runtime functions management and retrieval.
static std::vector< ExtensionFunction > * get_udf(const std::string &name)
#define EXTARGTYPECASE(EXTARGTYPE, ELEMTYPE, ENCODING, ARRAYENCODING)
DEVICE auto copy(ARGS &&...args)
Definition: gpu_enabled.h:51
ExtArgumentType deserialize_type(const std::string &type_name)
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
static std::unordered_set< std::string > get_udfs_name(const bool is_runtime)
Checked json field retrieval.
std::string toString(const ExecutorDeviceType &device_type)
Argument type based extension function binding.
std::string to_upper(const std::string &str)
Definition: sqltypes.h:67
std::string serialize_type(const ExtArgumentType type, bool byval=true, bool declare=false)
const std::vector< std::map< std::string, std::string > > annotations_
const ExtArgumentType ret_
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:289
static std::unordered_map< std::string, TableFunction > functions_
Definition: sqltypes.h:60
string name
Definition: setup.in.py:72
constexpr double n
Definition: Utm.h:38
constexpr auto type_name() noexcept
static std::vector< std::string > getLLVMDeclarations(const std::unordered_set< std::string > &udf_decls, const bool is_gpu=false)
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
static std::string toStringSQL(const std::vector< ExtArgumentType > &sig_types)
static void addRTUdfs(const std::string &json_func_sigs)