OmniSciDB  c0231cc57d
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsWhitelist.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <boost/algorithm/string/join.hpp>
20 #include <iostream>
21 
25 #include "Shared/StringTransform.h"
26 
27 // Get the list of all type specializations for the given function name.
28 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get(
29  const std::string& name) {
30  const auto it = functions_.find(to_upper(name));
31  if (it == functions_.end()) {
32  return nullptr;
33  }
34  return &it->second;
35 }
36 
37 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get_udf(
38  const std::string& name) {
39  const auto it = udf_functions_.find(to_upper(name));
40  if (it == udf_functions_.end()) {
41  return nullptr;
42  }
43  return &it->second;
44 }
45 
46 // Get the list of all udfs
47 std::unordered_set<std::string> ExtensionFunctionsWhitelist::get_udfs_name(
48  const bool is_runtime) {
49  std::unordered_set<std::string> names;
50  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
51  for (auto funcs : collections) {
52  for (auto& pair : *funcs) {
53  ExtensionFunction udf = pair.second.at(0);
54  if (udf.isRuntime() == is_runtime) {
55  names.insert(udf.getName(/* keep_suffix */ false));
56  }
57  }
58  }
59  return names;
60 }
61 
62 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
63  const std::string& name) {
64  std::vector<ExtensionFunction> ext_funcs = {};
65  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
66  const auto uname = to_upper(name);
67  for (auto funcs : collections) {
68  const auto it = funcs->find(uname);
69  if (it == funcs->end()) {
70  continue;
71  }
72  auto ext_func_sigs = it->second;
73  std::copy(ext_func_sigs.begin(), ext_func_sigs.end(), std::back_inserter(ext_funcs));
74  }
75  return ext_funcs;
76 }
77 
78 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
79  const std::string& name,
80  const bool is_gpu) {
81  std::vector<ExtensionFunction> ext_funcs = {};
82  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
83  const auto uname = to_upper(name);
84  for (auto funcs : collections) {
85  const auto it = funcs->find(uname);
86  if (it == funcs->end()) {
87  continue;
88  }
89  auto ext_func_sigs = it->second;
90  std::copy_if(ext_func_sigs.begin(),
91  ext_func_sigs.end(),
92  std::back_inserter(ext_funcs),
93  [is_gpu](auto sig) { return (is_gpu ? sig.isGPU() : sig.isCPU()); });
94  }
95  return ext_funcs;
96 }
97 
98 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
99  const std::string& name,
100  size_t arity) {
101  std::vector<ExtensionFunction> ext_funcs = {};
102  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
103  const auto uname = to_upper(name);
104  for (auto funcs : collections) {
105  const auto it = funcs->find(uname);
106  if (it == funcs->end()) {
107  continue;
108  }
109  auto ext_func_sigs = it->second;
110  std::copy_if(ext_func_sigs.begin(),
111  ext_func_sigs.end(),
112  std::back_inserter(ext_funcs),
113  [arity](auto sig) { return arity == sig.getInputArgs().size(); });
114  }
115  return ext_funcs;
116 }
117 
118 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
119  const std::string& name,
120  size_t arity,
121  const SQLTypeInfo& rtype) {
122  std::vector<ExtensionFunction> ext_funcs = {};
123  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
124  const auto uname = to_upper(name);
125  for (auto funcs : collections) {
126  const auto it = funcs->find(uname);
127  if (it == funcs->end()) {
128  continue;
129  }
130  auto ext_func_sigs = it->second;
131  std::copy_if(ext_func_sigs.begin(),
132  ext_func_sigs.end(),
133  std::back_inserter(ext_funcs),
134  [arity, rtype](auto sig) {
135  // Ideally, arity should be equal to the number of
136  // sig arguments but there seems to be many cases
137  // where some sig arguments will be represented
138  // with multiple arguments, for instance, array
139  // argument is translated to data pointer and array
140  // size arguments.
141  if (arity > sig.getInputArgs().size()) {
142  return false;
143  }
144  auto rt = rtype.get_type();
145  auto st = ext_arg_type_to_type_info(sig.getRet()).get_type();
146  return (st == rt || (st == kTINYINT && rt == kBOOLEAN));
147  });
148  }
149  return ext_funcs;
150 }
151 
152 namespace {
153 
154 // Returns the LLVM name for `type`.
156  bool byval = true,
157  bool declare = false) {
158  switch (type) {
160  return "i8"; // clang converts bool to i8
162  return "i8";
164  return "i16";
166  return "i32";
168  return "i64";
170  return "float";
172  return "double";
174  return "void";
176  return "i8*";
178  return "i16*";
180  return "i32*";
182  return "i64*";
184  return "float*";
186  return "double*";
188  return "i1*";
190  return (declare ? "{i8*, i64, i8}*" : "Array<i8>");
192  return (declare ? "{i16*, i64, i8}*" : "Array<i16>");
194  return (declare ? "{i32*, i64, i8}*" : "Array<i32>");
196  return (declare ? "{i64*, i64, i8}*" : "Array<i64>");
198  return (declare ? "{float*, i64, i8}*" : "Array<float>");
200  return (declare ? "{double*, i64, i8}*" : "Array<double>");
202  return (declare ? "{i1*, i64, i8}*" : "Array<i1>");
204  return (declare ? "{i32*, i64, i8}*" : "Array<TextEncodingDict>");
206  return "geo_point";
208  return "geo_multi_point";
210  return "geo_linestring";
212  return "geo_multi_linestring";
214  return "geo_polygon";
216  return "geo_multi_polygon";
218  return "cursor";
220  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "Column<i8>");
222  return (declare ? (byval ? "{i16*, i64}" : "i8*") : "Column<i16>");
224  return (declare ? (byval ? "{i32*, i64}" : "i8*") : "Column<i32>");
226  return (declare ? (byval ? "{i64*, i64}" : "i8*") : "Column<i64>");
228  return (declare ? (byval ? "{float*, i64}" : "i8*") : "Column<float>");
230  return (declare ? (byval ? "{double*, i64}" : "i8*") : "Column<double>");
232  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "Column<bool>");
234  return (declare ? (byval ? "{i32*, i64}" : "i8*") : "Column<TextEncodingDict>");
236  return (declare ? (byval ? "{i64*, i64}" : "i8*") : "Column<timestamp>");
238  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "TextEncodingNone");
240  return (declare ? "{i8*, i32}*" : "TextEncodingDict");
242  return (declare ? "{ i64 }" : "timestamp");
244  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i8>");
246  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i16>");
248  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i32>");
250  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i64>");
252  return (declare ? "{i8**, i64, i64}*" : "ColumnList<float>");
254  return (declare ? "{i8**, i64, i64}*" : "ColumnList<double>");
256  return (declare ? "{i8**, i64, i64}*" : "ColumnList<bool>");
258  return (declare ? "{i8**, i64, i64}*" : "ColumnList<TextEncodingDict>");
260  return (declare ? "{i8*, i64}*" : "Column<Array<i8>>");
262  return (declare ? "{i8*, i64}*" : "Column<Array<i16>>");
264  return (declare ? "{i8*, i64}*" : "Column<Array<i32>>");
266  return (declare ? "{i8*, i64}*" : "Column<Array<i64>>");
268  return (declare ? "{i8*, i64}*" : "Column<Array<float>>");
270  return (declare ? "{i8*, i64}*" : "Column<Array<double>>");
272  return (declare ? "{i8*, i64}*" : "Column<Array<bool>>");
274  return (declare ? "{i8*, i64}" : "Column<Array<TextEncodingDict>>");
276  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i8>");
278  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i16>");
280  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i32>");
282  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i64>");
284  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<float>");
286  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<double>");
288  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<bool>");
290  return (declare ? "{i8**, i64, i64}" : "ColumnList<Array<TextEncodingDict>>");
291  default:
292  CHECK(false);
293  }
294  CHECK(false);
295  return "";
296 }
297 
298 std::string drop_suffix(const std::string& str) {
299  const auto idx = str.find("__");
300  if (idx == std::string::npos) {
301  return str;
302  }
303  CHECK_GT(idx, std::string::size_type(0));
304  return str.substr(0, idx);
305 }
306 
307 } // namespace
308 
310  SQLTypes type = kNULLT;
311  int d = 0;
312  int s = 0;
313  bool n = false;
315  int p = 0;
316  SQLTypes subtype = kNULLT;
317 
318 #define EXTARGTYPECASE(EXTARGTYPE, ELEMTYPE, ENCODING, ARRAYENCODING) \
319  case ExtArgumentType::EXTARGTYPE: \
320  type = ELEMTYPE; \
321  c = kENCODING_##ENCODING; \
322  break; \
323  case ExtArgumentType::Array##EXTARGTYPE: \
324  type = kARRAY; \
325  c = kENCODING_##ENCODING; \
326  subtype = ELEMTYPE; \
327  break; \
328  case ExtArgumentType::Column##EXTARGTYPE: \
329  type = kCOLUMN; \
330  c = kENCODING_##ENCODING; \
331  subtype = ELEMTYPE; \
332  break; \
333  case ExtArgumentType::ColumnList##EXTARGTYPE: \
334  type = kCOLUMN_LIST; \
335  c = kENCODING_##ENCODING; \
336  subtype = ELEMTYPE; \
337  break; \
338  case ExtArgumentType::ColumnArray##EXTARGTYPE: \
339  type = kCOLUMN; \
340  subtype = ELEMTYPE; \
341  c = kENCODING_##ARRAYENCODING; \
342  break; \
343  case ExtArgumentType::ColumnListArray##EXTARGTYPE: \
344  type = kCOLUMN_LIST; \
345  subtype = ELEMTYPE; \
346  c = kENCODING_##ARRAYENCODING; \
347  break;
348 
349  switch (ext_arg_type) {
350  EXTARGTYPECASE(Bool, kBOOLEAN, NONE, ARRAY);
351  EXTARGTYPECASE(Int8, kTINYINT, NONE, ARRAY);
353  EXTARGTYPECASE(Int32, kINT, NONE, ARRAY);
354  EXTARGTYPECASE(Int64, kBIGINT, NONE, ARRAY);
355  EXTARGTYPECASE(Float, kFLOAT, NONE, ARRAY);
356  EXTARGTYPECASE(Double, kDOUBLE, NONE, ARRAY);
358  EXTARGTYPECASE(TextEncodingDict, kTEXT, DICT, ARRAY_DICT);
359  // TODO: EXTARGTYPECASE(Timestamp, kTIMESTAMP, NONE, ARRAY);
361  type = kTIMESTAMP;
362  c = kENCODING_NONE;
363  d = 9;
364  break;
366  type = kCOLUMN;
367  subtype = kTIMESTAMP;
368  c = kENCODING_NONE;
369  d = 9;
370  break;
371  default:
372  LOG(FATAL) << "ExtArgumentType `" << serialize_type(ext_arg_type)
373  << "` cannot be converted to SQLTypes.";
374  UNREACHABLE();
375  }
376  return SQLTypeInfo(type, d, s, n, c, p, subtype);
377 }
378 
380  const std::vector<ExtensionFunction>& ext_funcs,
381  std::string tab) {
382  std::string r = "";
383  for (auto sig : ext_funcs) {
384  r += tab + sig.toString() + "\n";
385  }
386  return r;
387 }
388 
390  const std::vector<SQLTypeInfo>& arg_types) {
391  std::string r = "";
392  for (auto sig = arg_types.begin(); sig != arg_types.end();) {
393  r += sig->get_type_name();
394  sig++;
395  if (sig != arg_types.end()) {
396  r += ", ";
397  }
398  }
399  return r;
400 }
401 
403  const std::vector<ExtArgumentType>& sig_types) {
404  std::string r = "";
405  for (auto t = sig_types.begin(); t != sig_types.end();) {
406  r += serialize_type(*t, /* byval */ false, /* declare */ false);
407  t++;
408  if (t != sig_types.end()) {
409  r += ", ";
410  }
411  }
412  return r;
413 }
414 
416  const std::vector<ExtArgumentType>& sig_types) {
417  std::string r = "";
418  for (auto t = sig_types.begin(); t != sig_types.end();) {
420  t++;
421  if (t != sig_types.end()) {
422  r += ", ";
423  }
424  }
425  return r;
426 }
427 
429  return serialize_type(sig_type, /* byval */ false, /* declare */ false);
430 }
431 
433  switch (sig_type) {
435  return "TINYINT";
437  return "SMALLINT";
439  return "INTEGER";
441  return "BIGINT";
443  return "FLOAT";
445  return "DOUBLE";
447  return "BOOLEAN";
449  return "TINYINT[]";
451  return "SMALLINT[]";
453  return "INT[]";
455  return "BIGINT[]";
457  return "FLOAT[]";
459  return "DOUBLE[]";
461  return "BOOLEAN[]";
463  return "ARRAY<TINYINT>";
465  return "ARRAY<SMALLINT>";
467  return "ARRAY<INT>";
469  return "ARRAY<BIGINT>";
471  return "ARRAY<FLOAT>";
473  return "ARRAY<DOUBLE>";
475  return "ARRAY<BOOLEAN>";
477  return "ARRAY<TEXT ENCODING DICT>";
479  return "COLUMN<TINYINT>";
481  return "COLUMN<SMALLINT>";
483  return "COLUMN<INT>";
485  return "COLUMN<BIGINT>";
487  return "COLUMN<FLOAT>";
489  return "COLUMN<DOUBLE>";
491  return "COLUMN<BOOLEAN>";
493  return "COLUMN<TEXT ENCODING DICT>";
495  return "COLUMN<TIMESTAMP(9)>";
497  return "CURSOR";
499  return "POINT";
501  return "MULTIPOINT";
503  return "LINESTRING";
505  return "MULTILINESTRING";
507  return "POLYGON";
509  return "MULTIPOLYGON";
511  return "VOID";
513  return "TEXT ENCODING NONE";
515  return "TEXT ENCODING DICT";
517  return "TIMESTAMP(9)";
519  return "COLUMNLIST<TINYINT>";
521  return "COLUMNLIST<SMALLINT>";
523  return "COLUMNLIST<INT>";
525  return "COLUMNLIST<BIGINT>";
527  return "COLUMNLIST<FLOAT>";
529  return "COLUMNLIST<DOUBLE>";
531  return "COLUMNLIST<BOOLEAN>";
533  return "COLUMNLIST<TEXT ENCODING DICT>";
535  return "COLUMN<ARRAY<TINYINT>>";
537  return "COLUMN<ARRAY<SMALLINT>>";
539  return "COLUMN<ARRAY<INT>>";
541  return "COLUMN<ARRAY<BIGINT>>";
543  return "COLUMN<ARRAY<FLOAT>>";
545  return "COLUMN<ARRAY<DOUBLE>>";
547  return "COLUMN<ARRAY<BOOLEAN>>";
549  return "COLUMN<ARRAY<TEXT ENCODING DICT>>";
551  return "COLUMNLIST<ARRAY<TINYINT>>";
553  return "COLUMNLIST<ARRAY<SMALLINT>>";
555  return "COLUMNLIST<ARRAY<INT>>";
557  return "COLUMNLIST<ARRAY<BIGINT>>";
559  return "COLUMNLIST<ARRAY<FLOAT>>";
561  return "COLUMNLIST<ARRAY<DOUBLE>>";
563  return "COLUMNLIST<ARRAY<BOOLEAN>>";
565  return "COLUMNLIST<ARRAY<TEXT ENCODING DICT>>";
566  default:
567  UNREACHABLE();
568  }
569  return "";
570 }
571 
572 const std::string ExtensionFunction::getName(bool keep_suffix) const {
573  return (keep_suffix ? name_ : drop_suffix(name_));
574 }
575 
576 std::string ExtensionFunction::toString() const {
577  return getName() + "(" + ExtensionFunctionsWhitelist::toString(args_) + ") -> " +
579 }
580 
581 std::string ExtensionFunction::toStringSQL() const {
582  return getName(/* keep_suffix = */ false) + "(" +
585 }
586 
587 std::string ExtensionFunction::toSignature() const {
588  return "(" + ExtensionFunctionsWhitelist::toString(args_) + ") -> " +
590 }
591 
592 // Converts the extension function signatures to their LLVM representation.
594  const std::unordered_set<std::string>& udf_decls,
595  const bool is_gpu) {
596  std::vector<std::string> declarations;
597  for (const auto& kv : functions_) {
598  const auto& signatures = kv.second;
599  CHECK(!signatures.empty());
600  for (const auto& signature : kv.second) {
601  // If there is a udf function declaration matching an extension function signature
602  // do not emit a duplicate declaration.
603  if (!udf_decls.empty() && udf_decls.find(signature.getName()) != udf_decls.end()) {
604  continue;
605  }
606 
607  std::string decl_prefix;
608  std::vector<std::string> arg_strs;
609 
610  if (is_ext_arg_type_array(signature.getRet())) {
611  decl_prefix = "declare void @" + signature.getName();
612  arg_strs.emplace_back(
613  serialize_type(signature.getRet(), /* byval */ true, /* declare */ true));
614  } else {
615  decl_prefix =
616  "declare " +
617  serialize_type(signature.getRet(), /* byval */ true, /* declare */ true) +
618  " @" + signature.getName();
619  }
620  for (const auto arg : signature.getInputArgs()) {
621  arg_strs.push_back(serialize_type(arg, /* byval */ false, /* declare */ true));
622  }
623  declarations.push_back(decl_prefix + "(" + boost::algorithm::join(arg_strs, ", ") +
624  ");");
625  }
626  }
627 
629  if (kv.second.isRuntime() || kv.second.useDefaultSizer()) {
630  // Runtime UDTFs are defined in LLVM/NVVM IR module
631  // UDTFs using default sizer share LLVM IR
632  continue;
633  }
634  if (!((is_gpu && kv.second.isGPU()) || (!is_gpu && kv.second.isCPU()))) {
635  continue;
636  }
637  std::string decl_prefix{
638  "declare " +
639  serialize_type(ExtArgumentType::Int32, /* byval */ true, /* declare */ true) +
640  " @" + kv.first};
641  std::vector<std::string> arg_strs;
642  for (const auto arg : kv.second.getArgs(/* ensure_column = */ true)) {
643  arg_strs.push_back(
644  serialize_type(arg, /* byval= */ kv.second.isRuntime(), /* declare= */ true));
645  }
646  declarations.push_back(decl_prefix + "(" + boost::algorithm::join(arg_strs, ", ") +
647  ");");
648  }
649  return declarations;
650 }
651 
652 namespace {
653 
655  if (type_name == "bool" || type_name == "i1") {
656  return ExtArgumentType::Bool;
657  }
658  if (type_name == "i8") {
659  return ExtArgumentType::Int8;
660  }
661  if (type_name == "i16") {
662  return ExtArgumentType::Int16;
663  }
664  if (type_name == "i32") {
665  return ExtArgumentType::Int32;
666  }
667  if (type_name == "i64") {
668  return ExtArgumentType::Int64;
669  }
670  if (type_name == "float") {
671  return ExtArgumentType::Float;
672  }
673  if (type_name == "double") {
675  }
676  if (type_name == "void") {
677  return ExtArgumentType::Void;
678  }
679  if (type_name == "i8*") {
680  return ExtArgumentType::PInt8;
681  }
682  if (type_name == "i16*") {
684  }
685  if (type_name == "i32*") {
687  }
688  if (type_name == "i64*") {
690  }
691  if (type_name == "float*") {
693  }
694  if (type_name == "double*") {
696  }
697  if (type_name == "i1*" || type_name == "bool*") {
698  return ExtArgumentType::PBool;
699  }
700  if (type_name == "Array<i8>") {
702  }
703  if (type_name == "Array<i16>") {
705  }
706  if (type_name == "Array<i32>") {
708  }
709  if (type_name == "Array<i64>") {
711  }
712  if (type_name == "Array<float>") {
714  }
715  if (type_name == "Array<double>") {
717  }
718  if (type_name == "Array<bool>" || type_name == "Array<i1>") {
720  }
721  if (type_name == "Array<TextEncodingDict>") {
723  }
724  if (type_name == "geo_point") {
726  }
727  if (type_name == "geo_multi_point") {
729  }
730  if (type_name == "geo_linestring") {
732  }
733  if (type_name == "geo_multi_linestring") {
735  }
736  if (type_name == "geo_polygon") {
738  }
739  if (type_name == "geo_multi_polygon") {
741  }
742  if (type_name == "cursor") {
744  }
745  if (type_name == "Column<i8>") {
747  }
748  if (type_name == "Column<i16>") {
750  }
751  if (type_name == "Column<i32>") {
753  }
754  if (type_name == "Column<i64>") {
756  }
757  if (type_name == "Column<float>") {
759  }
760  if (type_name == "Column<double>") {
762  }
763  if (type_name == "Column<bool>") {
765  }
766  if (type_name == "Column<TextEncodingDict>") {
768  }
769  if (type_name == "Column<timestamp>") {
771  }
772  if (type_name == "TextEncodingNone") {
774  }
775  if (type_name == "TextEncodingDict") {
777  }
778  if (type_name == "timestamp") {
780  }
781  if (type_name == "ColumnList<i8>") {
783  }
784  if (type_name == "ColumnList<i16>") {
786  }
787  if (type_name == "ColumnList<i32>") {
789  }
790  if (type_name == "ColumnList<i64>") {
792  }
793  if (type_name == "ColumnList<float>") {
795  }
796  if (type_name == "ColumnList<double>") {
798  }
799  if (type_name == "ColumnList<bool>") {
801  }
802  if (type_name == "ColumnList<TextEncodingDict>") {
804  }
805  if (type_name == "Column<Array<i8>>") {
807  }
808  if (type_name == "Column<Array<i16>>") {
810  }
811  if (type_name == "Column<Array<i32>>") {
813  }
814  if (type_name == "Column<Array<i64>>") {
816  }
817  if (type_name == "Column<Array<float>>") {
819  }
820  if (type_name == "Column<Array<double>>") {
822  }
823  if (type_name == "Column<Array<bool>>") {
825  }
826  if (type_name == "Column<Array<TextEncodingDict>>") {
828  }
829  if (type_name == "ColumnList<Array<i8>>") {
831  }
832  if (type_name == "ColumnList<Array<i16>>") {
834  }
835  if (type_name == "ColumnList<Array<i32>>") {
837  }
838  if (type_name == "ColumnList<Array<i64>>") {
840  }
841  if (type_name == "ColumnList<Array<float>>") {
843  }
844  if (type_name == "ColumnList<Array<double>>") {
846  }
847  if (type_name == "ColumnList<Array<bool>>") {
849  }
850  if (type_name == "ColumnList<Array<TextEncodingDict>>") {
852  }
853  CHECK(false);
854  return ExtArgumentType::Int16;
855 }
856 
857 } // namespace
858 
859 using SignatureMap = std::unordered_map<std::string, std::vector<ExtensionFunction>>;
860 
862  const std::string& json_func_sigs,
863  const bool is_runtime) {
864  rapidjson::Document func_sigs;
865  func_sigs.Parse(json_func_sigs.c_str());
866  CHECK(func_sigs.IsArray());
867  for (auto func_sigs_it = func_sigs.Begin(); func_sigs_it != func_sigs.End();
868  ++func_sigs_it) {
869  CHECK(func_sigs_it->IsObject());
870  const auto name = json_str(field(*func_sigs_it, "name"));
871  const auto ret = deserialize_type(json_str(field(*func_sigs_it, "ret")));
872  std::vector<ExtArgumentType> args;
873  const auto& args_serialized = field(*func_sigs_it, "args");
874  CHECK(args_serialized.IsArray());
875  for (auto args_serialized_it = args_serialized.Begin();
876  args_serialized_it != args_serialized.End();
877  ++args_serialized_it) {
878  args.push_back(deserialize_type(json_str(*args_serialized_it)));
879  }
880  signatures[to_upper(drop_suffix(name))].emplace_back(name, args, ret, is_runtime);
881  }
882 }
883 
884 // Calcite loads the available extensions from `ExtensionFunctions.ast`, adds
885 // them to its operator table and shares the list with the execution layer in
886 // JSON format. Build an in-memory representation of that list here so that it
887 // can be used by getLLVMDeclarations(), when the LLVM IR codegen asks for it.
888 void ExtensionFunctionsWhitelist::add(const std::string& json_func_sigs) {
889  // Valid json_func_sigs example:
890  // [
891  // {
892  // "name":"sum",
893  // "ret":"i32",
894  // "args":[
895  // "i32",
896  // "i32"
897  // ]
898  // }
899  // ]
900 
901  addCommon(functions_, json_func_sigs, /* is_runtime */ false);
902 }
903 
904 void ExtensionFunctionsWhitelist::addUdfs(const std::string& json_func_sigs) {
905  if (!json_func_sigs.empty()) {
906  addCommon(udf_functions_, json_func_sigs, /* is_runtime */ false);
907  }
908 }
909 
911  rt_udf_functions_.clear();
912 }
913 
914 void ExtensionFunctionsWhitelist::addRTUdfs(const std::string& json_func_sigs) {
915  if (!json_func_sigs.empty()) {
916  addCommon(rt_udf_functions_, json_func_sigs, /* is_runtime */ true);
917  }
918 }
919 
920 std::unordered_map<std::string, std::vector<ExtensionFunction>>
922 
923 std::unordered_map<std::string, std::vector<ExtensionFunction>>
925 
926 std::unordered_map<std::string, std::vector<ExtensionFunction>>
928 
929 std::string toString(const ExtArgumentType& sig_type) {
930  return ExtensionFunctionsWhitelist::toString(sig_type);
931 }
static void addUdfs(const std::string &json_func_sigs)
static void addCommon(std::unordered_map< std::string, std::vector< ExtensionFunction >> &sigs, const std::string &json_func_sigs, const bool is_runtime)
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
static std::unordered_map< std::string, std::vector< ExtensionFunction > > udf_functions_
SQLTypes
Definition: sqltypes.h:52
static std::vector< ExtensionFunction > * get(const std::string &name)
#define LOG(tag)
Definition: Logger.h:216
std::string toSignature() const
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:44
static std::unordered_map< std::string, std::vector< ExtensionFunction > > rt_udf_functions_
std::string join(T const &container, std::string const &delim)
const std::vector< ExtArgumentType > args_
static void add(const std::string &json_func_sigs)
#define UNREACHABLE()
Definition: Logger.h:266
std::unordered_map< std::string, std::vector< ExtensionFunction >> SignatureMap
std::string toString(const QueryDescriptionType &type)
Definition: Types.h:64
const std::string getName(bool keep_suffix=true) const
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:404
std::string toStringSQL() const
#define CHECK_GT(x, y)
Definition: Logger.h:234
static std::unordered_map< std::string, std::vector< ExtensionFunction > > functions_
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
EncodingType
Definition: sqltypes.h:253
Supported runtime functions management and retrieval.
static std::vector< ExtensionFunction > * get_udf(const std::string &name)
#define EXTARGTYPECASE(EXTARGTYPE, ELEMTYPE, ENCODING, ARRAYENCODING)
DEVICE auto copy(ARGS &&...args)
Definition: gpu_enabled.h:51
ExtArgumentType deserialize_type(const std::string &type_name)
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
static std::unordered_set< std::string > get_udfs_name(const bool is_runtime)
Checked json field retrieval.
Argument type based extension function binding.
std::string to_upper(const std::string &str)
Definition: sqltypes.h:66
std::string serialize_type(const ExtArgumentType type, bool byval=true, bool declare=false)
const ExtArgumentType ret_
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:222
static std::unordered_map< std::string, TableFunction > functions_
Definition: sqltypes.h:59
string name
Definition: setup.in.py:72
constexpr double n
Definition: Utm.h:38
constexpr auto type_name() noexcept
static std::vector< std::string > getLLVMDeclarations(const std::unordered_set< std::string > &udf_decls, const bool is_gpu=false)
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
static std::string toStringSQL(const std::vector< ExtArgumentType > &sig_types)
static void addRTUdfs(const std::string &json_func_sigs)