OmniSciDB  fe05a0c208
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ExtensionFunctionsWhitelist.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <boost/algorithm/string/join.hpp>
20 #include <iostream>
21 
25 #include "Shared/StringTransform.h"
26 
27 // Get the list of all type specializations for the given function name.
28 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get(
29  const std::string& name) {
30  const auto it = functions_.find(to_upper(name));
31  if (it == functions_.end()) {
32  return nullptr;
33  }
34  return &it->second;
35 }
36 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get_udf(
37  const std::string& name) {
38  const auto it = udf_functions_.find(to_upper(name));
39  if (it == udf_functions_.end()) {
40  return nullptr;
41  }
42  return &it->second;
43 }
44 
45 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
46  const std::string& name,
47  const bool is_gpu) {
48  std::vector<ExtensionFunction> ext_funcs = {};
49  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
50  const auto uname = to_upper(name);
51  for (auto funcs : collections) {
52  const auto it = funcs->find(uname);
53  if (it == funcs->end()) {
54  continue;
55  }
56  auto ext_func_sigs = it->second;
57  std::copy_if(ext_func_sigs.begin(),
58  ext_func_sigs.end(),
59  std::back_inserter(ext_funcs),
60  [is_gpu](auto sig) { return (is_gpu ? sig.isGPU() : sig.isCPU()); });
61  }
62  return ext_funcs;
63 }
64 
65 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
66  const std::string& name,
67  size_t arity) {
68  std::vector<ExtensionFunction> ext_funcs = {};
69  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
70  const auto uname = to_upper(name);
71  for (auto funcs : collections) {
72  const auto it = funcs->find(uname);
73  if (it == funcs->end()) {
74  continue;
75  }
76  auto ext_func_sigs = it->second;
77  std::copy_if(ext_func_sigs.begin(),
78  ext_func_sigs.end(),
79  std::back_inserter(ext_funcs),
80  [arity](auto sig) { return arity == sig.getArgs().size(); });
81  }
82  return ext_funcs;
83 }
84 
85 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
86  const std::string& name,
87  size_t arity,
88  const SQLTypeInfo& rtype) {
89  std::vector<ExtensionFunction> ext_funcs = {};
90  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
91  const auto uname = to_upper(name);
92  for (auto funcs : collections) {
93  const auto it = funcs->find(uname);
94  if (it == funcs->end()) {
95  continue;
96  }
97  auto ext_func_sigs = it->second;
98  std::copy_if(ext_func_sigs.begin(),
99  ext_func_sigs.end(),
100  std::back_inserter(ext_funcs),
101  [arity, rtype](auto sig) {
102  // Ideally, arity should be equal to the number of
103  // sig arguments but there seems to be many cases
104  // where some sig arguments will be represented
105  // with multiple arguments, for instance, array
106  // argument is translated to data pointer and array
107  // size arguments.
108  if (arity > sig.getArgs().size()) {
109  return false;
110  }
111  auto rt = rtype.get_type();
112  auto st = ext_arg_type_to_type_info(sig.getRet()).get_type();
113  return (st == rt || (st == kTINYINT && rt == kBOOLEAN));
114  });
115  }
116  return ext_funcs;
117 }
118 
119 namespace {
120 
121 // Returns the LLVM name for `type`.
122 std::string serialize_type(const ExtArgumentType type, bool byval = true) {
123  switch (type) {
125  return "i8"; // clang converts bool to i8
127  return "i8";
129  return "i16";
131  return "i32";
133  return "i64";
135  return "float";
137  return "double";
139  return "void";
141  return "i8*";
143  return "i16*";
145  return "i32*";
147  return "i64*";
149  return "float*";
151  return "double*";
153  return "i1*";
155  return "{i8*, i64, i8}*";
157  return "{i16*, i64, i8}*";
159  return "{i32*, i64, i8}*";
161  return "{i64*, i64, i8}*";
163  return "{float*, i64, i8}*";
165  return "{double*, i64, i8}*";
167  return "{i1*, i64, i8}*";
169  return "geo_point";
171  return "geo_linestring";
173  return "geo_polygon";
175  return "geo_multi_polygon";
177  return "cursor";
179  return (byval ? "{i8*, i64}" : "i8*");
181  return (byval ? "{i16*, i64}" : "i8*");
183  return (byval ? "{i32*, i64}" : "i8*");
185  return (byval ? "{i64*, i64}" : "i8*");
187  return (byval ? "{float*, i64}" : "i8*");
189  return (byval ? "{double*, i64}" : "i8*");
191  return (byval ? "{i1*, i64}" : "i8*");
193  return "text_encoding_node";
195  return "text_encoding_dict8";
197  return "text_encoding_dict16";
199  return "text_encoding_dict32";
201  return "column_list_int8";
203  return "column_list_int16";
205  return "column_list_int32";
207  return "column_list_int64";
209  return "column_list_float";
211  return "column_list_double";
213  return "column_list_bool";
214  default:
215  CHECK(false);
216  }
217  CHECK(false);
218  return "";
219 }
220 
221 std::string drop_suffix(const std::string& str) {
222  const auto idx = str.find("__");
223  if (idx == std::string::npos) {
224  return str;
225  }
226  CHECK_GT(idx, std::string::size_type(0));
227  return str.substr(0, idx);
228 }
229 
230 } // namespace
231 
233  /* This function is mostly used for scalar types.
234  For non-scalar types, NULL is returned as a placeholder.
235  */
236 
237  switch (ext_arg_type) {
239  return SQLTypeInfo(kBOOLEAN, false);
241  return SQLTypeInfo(kTINYINT, false);
243  return SQLTypeInfo(kSMALLINT, false);
245  return SQLTypeInfo(kINT, false);
247  return SQLTypeInfo(kBIGINT, false);
249  return SQLTypeInfo(kFLOAT, false);
251  return SQLTypeInfo(kDOUBLE, false);
257  return generate_array_type(kINT);
261  return generate_array_type(kFLOAT);
271  return generate_column_type(kINT);
281  return SQLTypeInfo(kTEXT, false, kENCODING_NONE);
285  return SQLTypeInfo(kTEXT, false, kENCODING_DICT);
291  return generate_column_type(kINT);
300  default:
301  LOG(FATAL) << "ExtArgumentType `" << serialize_type(ext_arg_type)
302  << "` cannot be converted to SQLTypeInfo.";
303  }
304  return SQLTypeInfo(kNULLT, false);
305 }
306 
308  const std::vector<ExtensionFunction>& ext_funcs,
309  std::string tab) {
310  std::string r = "";
311  for (auto sig : ext_funcs) {
312  r += tab + sig.toString() + "\n";
313  }
314  return r;
315 }
316 
318  const std::vector<SQLTypeInfo>& arg_types) {
319  std::string r = "";
320  for (auto sig = arg_types.begin(); sig != arg_types.end();) {
321  r += sig->get_type_name();
322  sig++;
323  if (sig != arg_types.end()) {
324  r += ", ";
325  }
326  }
327  return r;
328 }
329 
331  const std::vector<ExtArgumentType>& sig_types) {
332  std::string r = "";
333  for (auto t = sig_types.begin(); t != sig_types.end();) {
334  r += serialize_type(*t);
335  t++;
336  if (t != sig_types.end()) {
337  r += ", ";
338  }
339  }
340  return r;
341 }
342 
344  const std::vector<ExtArgumentType>& sig_types) {
345  std::string r = "";
346  for (auto t = sig_types.begin(); t != sig_types.end();) {
348  t++;
349  if (t != sig_types.end()) {
350  r += ", ";
351  }
352  }
353  return r;
354 }
355 
357  return serialize_type(sig_type);
358 }
359 
361  switch (sig_type) {
363  return "TINYINT";
365  return "SMALLINT";
367  return "INTEGER";
369  return "BIGINT";
371  return "FLOAT";
373  return "DOUBLE";
375  return "BOOLEAN";
377  return "TINYINT[]";
379  return "SMALLINT[]";
381  return "INT[]";
383  return "BIGINT[]";
385  return "FLOAT[]";
387  return "DOUBLE[]";
389  return "BOOLEAN[]";
391  return "ARRAY<TINYINT>";
393  return "ARRAY<SMALLINT>";
395  return "ARRAY<INT>";
397  return "ARRAY<BIGINT>";
399  return "ARRAY<FLOAT>";
401  return "ARRAY<DOUBLE>";
403  return "ARRAY<BOOLEAN>";
405  return "COLUMN<TINYINT>";
407  return "COLUMN<SMALLINT>";
409  return "COLUMN<INT>";
411  return "COLUMN<BIGINT>";
413  return "COLUMN<FLOAT>";
415  return "COLUMN<DOUBLE>";
417  return "COLUMN<BOOLEAN>";
419  return "CURSOR";
421  return "POINT";
423  return "LINESTRING";
425  return "POLYGON";
427  return "MULTIPOLYGON";
429  return "VOID";
431  return "TEXT ENCODING NONE";
433  return "TEXT ENCODING DICT(8)";
435  return "TEXT ENCODING DICT(16)";
437  return "TEXT ENCODING DICT(32)";
439  return "COLUMNLIST<TINYINT>";
441  return "COLUMNLIST<SMALLINT>";
443  return "COLUMNLIST<INT>";
445  return "COLUMNLIST<BIGINT>";
447  return "COLUMNLIST<FLOAT>";
449  return "COLUMNLIST<DOUBLE>";
451  return "COLUMNLIST<BOOLEAN>";
452  default:
453  UNREACHABLE();
454  }
455  return "";
456 }
457 
458 const std::string ExtensionFunction::getName(bool keep_suffix) const {
459  return (keep_suffix ? name_ : drop_suffix(name_));
460 }
461 
462 std::string ExtensionFunction::toString() const {
463  return getName() + "(" + ExtensionFunctionsWhitelist::toString(args_) + ") -> " +
465 }
466 
467 std::string ExtensionFunction::toStringSQL() const {
468  return getName(/* keep_suffix = */ false) + "(" +
471 }
472 
473 // Converts the extension function signatures to their LLVM representation.
475  const std::unordered_set<std::string>& udf_decls,
476  const bool is_gpu) {
477  std::vector<std::string> declarations;
478  for (const auto& kv : functions_) {
479  const auto& signatures = kv.second;
480  CHECK(!signatures.empty());
481  for (const auto& signature : kv.second) {
482  // If there is a udf function declaration matching an extension function signature
483  // do not emit a duplicate declaration.
484  if (!udf_decls.empty() && udf_decls.find(signature.getName()) != udf_decls.end()) {
485  continue;
486  }
487 
488  std::string decl_prefix;
489  std::vector<std::string> arg_strs;
490 
491  if (is_ext_arg_type_array(signature.getRet())) {
492  decl_prefix = "declare void @" + signature.getName();
493  arg_strs.emplace_back(serialize_type(signature.getRet()));
494  } else {
495  decl_prefix =
496  "declare " + serialize_type(signature.getRet()) + " @" + signature.getName();
497  }
498  for (const auto arg : signature.getArgs()) {
499  arg_strs.push_back(serialize_type(arg));
500  }
501  declarations.push_back(decl_prefix + "(" + boost::algorithm::join(arg_strs, ", ") +
502  ");");
503  }
504  }
505 
507  if (kv.second.isRuntime() || kv.second.useDefaultSizer()) {
508  // Runtime UDTFs are defined in LLVM/NVVM IR module
509  // UDTFs using default sizer share LLVM IR
510  continue;
511  }
512  if (!((is_gpu && kv.second.isGPU()) || (!is_gpu && kv.second.isCPU()))) {
513  continue;
514  }
515  std::string decl_prefix{"declare " + serialize_type(ExtArgumentType::Int32) + " @" +
516  kv.first};
517  std::vector<std::string> arg_strs;
518  for (const auto arg : kv.second.getArgs(/* ensure_column = */ true)) {
519  arg_strs.push_back(serialize_type(arg, /* byval= */ kv.second.isRuntime()));
520  }
521  declarations.push_back(decl_prefix + "(" + boost::algorithm::join(arg_strs, ", ") +
522  ");");
523  }
524  return declarations;
525 }
526 
527 namespace {
528 
530  if (type_name == "bool" || type_name == "i1") {
531  return ExtArgumentType::Bool;
532  }
533  if (type_name == "i8") {
534  return ExtArgumentType::Int8;
535  }
536  if (type_name == "i16") {
537  return ExtArgumentType::Int16;
538  }
539  if (type_name == "i32") {
540  return ExtArgumentType::Int32;
541  }
542  if (type_name == "i64") {
543  return ExtArgumentType::Int64;
544  }
545  if (type_name == "float") {
546  return ExtArgumentType::Float;
547  }
548  if (type_name == "double") {
550  }
551  if (type_name == "void") {
552  return ExtArgumentType::Void;
553  }
554  if (type_name == "i8*") {
555  return ExtArgumentType::PInt8;
556  }
557  if (type_name == "i16*") {
559  }
560  if (type_name == "i32*") {
562  }
563  if (type_name == "i64*") {
565  }
566  if (type_name == "float*") {
568  }
569  if (type_name == "double*") {
571  }
572  if (type_name == "i1*" || type_name == "bool*") {
573  return ExtArgumentType::PBool;
574  }
575  if (type_name == "{i8*, i64, i8}*") {
577  }
578  if (type_name == "{i16*, i64, i8}*") {
580  }
581  if (type_name == "{i32*, i64, i8}*") {
583  }
584  if (type_name == "{i64*, i64, i8}*") {
586  }
587  if (type_name == "{float*, i64, i8}*") {
589  }
590  if (type_name == "{double*, i64, i8}*") {
592  }
593  if (type_name == "{i1*, i64, i8}*" || type_name == "{bool*, i64, i8}*") {
595  }
596  if (type_name == "geo_point") {
598  }
599  if (type_name == "geo_linestring") {
601  }
602  if (type_name == "geo_polygon") {
604  }
605  if (type_name == "geo_multi_polygon") {
607  }
608  if (type_name == "cursor") {
610  }
611  if (type_name == "{i8*, i64}") {
613  }
614  if (type_name == "{i16*, i64}") {
616  }
617  if (type_name == "{i32*, i64}") {
619  }
620  if (type_name == "{i64*, i64}") {
622  }
623  if (type_name == "{float*, i64}") {
625  }
626  if (type_name == "{double*, i64}") {
628  }
629  if (type_name == "{i1*, i64}" || type_name == "{bool*, i64}") {
631  }
632  if (type_name == "text_encoding_none") {
634  }
635  if (type_name == "text_encoding_dict8") {
637  }
638  if (type_name == "text_encoding_dict16") {
640  }
641  if (type_name == "text_encoding_dict32") {
643  }
644  if (type_name == "column_list_int8") {
646  }
647  if (type_name == "column_list_int16") {
649  }
650  if (type_name == "column_list_int32") {
652  }
653  if (type_name == "column_list_int64") {
655  }
656  if (type_name == "column_list_float") {
658  }
659  if (type_name == "column_list_double") {
661  }
662  if (type_name == "column_list_bool") {
664  }
665  CHECK(false);
666  return ExtArgumentType::Int16;
667 }
668 
669 } // namespace
670 
671 using SignatureMap = std::unordered_map<std::string, std::vector<ExtensionFunction>>;
672 
674  const std::string& json_func_sigs) {
675  rapidjson::Document func_sigs;
676  func_sigs.Parse(json_func_sigs.c_str());
677  CHECK(func_sigs.IsArray());
678  for (auto func_sigs_it = func_sigs.Begin(); func_sigs_it != func_sigs.End();
679  ++func_sigs_it) {
680  CHECK(func_sigs_it->IsObject());
681  const auto name = json_str(field(*func_sigs_it, "name"));
682  const auto ret = deserialize_type(json_str(field(*func_sigs_it, "ret")));
683  std::vector<ExtArgumentType> args;
684  const auto& args_serialized = field(*func_sigs_it, "args");
685  CHECK(args_serialized.IsArray());
686  for (auto args_serialized_it = args_serialized.Begin();
687  args_serialized_it != args_serialized.End();
688  ++args_serialized_it) {
689  args.push_back(deserialize_type(json_str(*args_serialized_it)));
690  }
691  signatures[to_upper(drop_suffix(name))].emplace_back(name, args, ret);
692  }
693 }
694 
695 // Calcite loads the available extensions from `ExtensionFunctions.ast`, adds
696 // them to its operator table and shares the list with the execution layer in
697 // JSON format. Build an in-memory representation of that list here so that it
698 // can be used by getLLVMDeclarations(), when the LLVM IR codegen asks for it.
699 void ExtensionFunctionsWhitelist::add(const std::string& json_func_sigs) {
700  // Valid json_func_sigs example:
701  // [
702  // {
703  // "name":"sum",
704  // "ret":"i32",
705  // "args":[
706  // "i32",
707  // "i32"
708  // ]
709  // }
710  // ]
711 
712  addCommon(functions_, json_func_sigs);
713 }
714 
715 void ExtensionFunctionsWhitelist::addUdfs(const std::string& json_func_sigs) {
716  if (!json_func_sigs.empty()) {
717  addCommon(udf_functions_, json_func_sigs);
718  }
719 }
720 
722  rt_udf_functions_.clear();
723 }
724 
725 void ExtensionFunctionsWhitelist::addRTUdfs(const std::string& json_func_sigs) {
726  if (!json_func_sigs.empty()) {
727  addCommon(rt_udf_functions_, json_func_sigs);
728  }
729 }
730 
731 std::unordered_map<std::string, std::vector<ExtensionFunction>>
733 
734 std::unordered_map<std::string, std::vector<ExtensionFunction>>
736 
737 std::unordered_map<std::string, std::vector<ExtensionFunction>>
739 
740 std::string toString(const ExtArgumentType& sig_type) {
741  return ExtensionFunctionsWhitelist::toString(sig_type);
742 }
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name, const bool is_gpu)
static void addUdfs(const std::string &json_func_sigs)
std::string toString(const ExtArgumentType &sig_type)
static std::unordered_map< std::string, std::vector< ExtensionFunction > > udf_functions_
static std::vector< ExtensionFunction > * get(const std::string &name)
#define LOG(tag)
Definition: Logger.h:194
auto generate_column_type(const SQLTypes subtype)
Definition: sqltypes.h:981
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:44
static std::unordered_map< std::string, std::vector< ExtensionFunction > > rt_udf_functions_
string name
Definition: setup.in.py:72
tuple r
Definition: test_fsi.py:16
std::string join(T const &container, std::string const &delim)
const std::vector< ExtArgumentType > args_
static void add(const std::string &json_func_sigs)
#define UNREACHABLE()
Definition: Logger.h:247
std::unordered_map< std::string, std::vector< ExtensionFunction >> SignatureMap
const std::string getName(bool keep_suffix=true) const
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:314
std::string toStringSQL() const
#define CHECK_GT(x, y)
Definition: Logger.h:215
static std::unordered_map< std::string, std::vector< ExtensionFunction > > functions_
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
static std::vector< ExtensionFunction > * get_udf(const std::string &name)
ExtArgumentType deserialize_type(const std::string &type_name)
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
std::string serialize_type(const ExtArgumentType type, bool byval=true)
static void addCommon(std::unordered_map< std::string, std::vector< ExtensionFunction >> &sigs, const std::string &json_func_sigs)
std::string to_upper(const std::string &str)
auto generate_array_type(const SQLTypes subtype)
Definition: sqltypes.h:975
Definition: sqltypes.h:51
const ExtArgumentType ret_
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:203
char * t
static std::unordered_map< std::string, TableFunction > functions_
Definition: sqltypes.h:44
constexpr auto type_name() noexcept
static std::vector< std::string > getLLVMDeclarations(const std::unordered_set< std::string > &udf_decls, const bool is_gpu=false)
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
static std::string toStringSQL(const std::vector< ExtArgumentType > &sig_types)
static void addRTUdfs(const std::string &json_func_sigs)