OmniSciDB  085a039ca4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsWhitelist.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <boost/algorithm/string/join.hpp>
20 #include <iostream>
21 
25 #include "Shared/StringTransform.h"
26 
27 // Get the list of all type specializations for the given function name.
28 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get(
29  const std::string& name) {
30  const auto it = functions_.find(to_upper(name));
31  if (it == functions_.end()) {
32  return nullptr;
33  }
34  return &it->second;
35 }
36 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get_udf(
37  const std::string& name) {
38  const auto it = udf_functions_.find(to_upper(name));
39  if (it == udf_functions_.end()) {
40  return nullptr;
41  }
42  return &it->second;
43 }
44 
45 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
46  const std::string& name,
47  const bool is_gpu) {
48  std::vector<ExtensionFunction> ext_funcs = {};
49  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
50  const auto uname = to_upper(name);
51  for (auto funcs : collections) {
52  const auto it = funcs->find(uname);
53  if (it == funcs->end()) {
54  continue;
55  }
56  auto ext_func_sigs = it->second;
57  std::copy_if(ext_func_sigs.begin(),
58  ext_func_sigs.end(),
59  std::back_inserter(ext_funcs),
60  [is_gpu](auto sig) { return (is_gpu ? sig.isGPU() : sig.isCPU()); });
61  }
62  return ext_funcs;
63 }
64 
65 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
66  const std::string& name,
67  size_t arity) {
68  std::vector<ExtensionFunction> ext_funcs = {};
69  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
70  const auto uname = to_upper(name);
71  for (auto funcs : collections) {
72  const auto it = funcs->find(uname);
73  if (it == funcs->end()) {
74  continue;
75  }
76  auto ext_func_sigs = it->second;
77  std::copy_if(ext_func_sigs.begin(),
78  ext_func_sigs.end(),
79  std::back_inserter(ext_funcs),
80  [arity](auto sig) { return arity == sig.getArgs().size(); });
81  }
82  return ext_funcs;
83 }
84 
85 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
86  const std::string& name,
87  size_t arity,
88  const SQLTypeInfo& rtype) {
89  std::vector<ExtensionFunction> ext_funcs = {};
90  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
91  const auto uname = to_upper(name);
92  for (auto funcs : collections) {
93  const auto it = funcs->find(uname);
94  if (it == funcs->end()) {
95  continue;
96  }
97  auto ext_func_sigs = it->second;
98  std::copy_if(ext_func_sigs.begin(),
99  ext_func_sigs.end(),
100  std::back_inserter(ext_funcs),
101  [arity, rtype](auto sig) {
102  // Ideally, arity should be equal to the number of
103  // sig arguments but there seems to be many cases
104  // where some sig arguments will be represented
105  // with multiple arguments, for instance, array
106  // argument is translated to data pointer and array
107  // size arguments.
108  if (arity > sig.getArgs().size()) {
109  return false;
110  }
111  auto rt = rtype.get_type();
112  auto st = ext_arg_type_to_type_info(sig.getRet()).get_type();
113  return (st == rt || (st == kTINYINT && rt == kBOOLEAN));
114  });
115  }
116  return ext_funcs;
117 }
118 
119 namespace {
120 
121 // Returns the LLVM name for `type`.
123  bool byval = true,
124  bool declare = false) {
125  switch (type) {
127  return "i8"; // clang converts bool to i8
129  return "i8";
131  return "i16";
133  return "i32";
135  return "i64";
137  return "float";
139  return "double";
141  return "void";
143  return "i8*";
145  return "i16*";
147  return "i32*";
149  return "i64*";
151  return "float*";
153  return "double*";
155  return "i1*";
157  return (declare ? "{i8*, i64, i8}*" : "Array<i8>");
159  return (declare ? "{i16*, i64, i8}*" : "Array<i16>");
161  return (declare ? "{i32*, i64, i8}*" : "Array<i32>");
163  return (declare ? "{i64*, i64, i8}*" : "Array<i64>");
165  return (declare ? "{float*, i64, i8}*" : "Array<float>");
167  return (declare ? "{double*, i64, i8}*" : "Array<double>");
169  return (declare ? "{i1*, i64, i8}*" : "Array<i1>");
171  return "geo_point";
173  return "geo_linestring";
175  return "geo_polygon";
177  return "geo_multi_polygon";
179  return "cursor";
181  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "Column<i8>");
183  return (declare ? (byval ? "{i16*, i64}" : "i8*") : "Column<i16>");
185  return (declare ? (byval ? "{i32*, i64}" : "i8*") : "Column<i32>");
187  return (declare ? (byval ? "{i64*, i64}" : "i8*") : "Column<i64>");
189  return (declare ? (byval ? "{float*, i64}" : "i8*") : "Column<float>");
191  return (declare ? (byval ? "{double*, i64}" : "i8*") : "Column<double>");
193  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "Column<bool>");
195  return (declare ? (byval ? "{i32*, i64}" : "i8*") : "Column<TextEncodingDict>");
197  return (declare ? (byval ? "{i64*, i64}" : "i8*") : "Column<timestamp>");
199  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "TextEncodingNone");
201  return (declare ? "{i8*, i32}*" : "TextEncodingDict");
203  return (declare ? "{ i64 }" : "timestamp");
205  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i8>");
207  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i16>");
209  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i32>");
211  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i64>");
213  return (declare ? "{i8**, i64, i64}*" : "ColumnList<float>");
215  return (declare ? "{i8**, i64, i64}*" : "ColumnList<double>");
217  return (declare ? "{i8**, i64, i64}*" : "ColumnList<bool>");
219  return (declare ? "{i8**, i64, i64}*" : "ColumnList<TextEncodingDict>");
220  default:
221  CHECK(false);
222  }
223  CHECK(false);
224  return "";
225 }
226 
227 std::string drop_suffix(const std::string& str) {
228  const auto idx = str.find("__");
229  if (idx == std::string::npos) {
230  return str;
231  }
232  CHECK_GT(idx, std::string::size_type(0));
233  return str.substr(0, idx);
234 }
235 
236 } // namespace
237 
239  /* This function is mostly used for scalar types.
240  For non-scalar types, NULL is returned as a placeholder.
241  */
242 
243  switch (ext_arg_type) {
245  return SQLTypeInfo(kBOOLEAN, false);
247  return SQLTypeInfo(kTINYINT, false);
249  return SQLTypeInfo(kSMALLINT, false);
251  return SQLTypeInfo(kINT, false);
253  return SQLTypeInfo(kBIGINT, false);
255  return SQLTypeInfo(kFLOAT, false);
257  return SQLTypeInfo(kDOUBLE, false);
263  return generate_array_type(kINT);
267  return generate_array_type(kFLOAT);
277  return generate_column_type(kINT);
287  return generate_column_type(kTEXT, kENCODING_DICT, 0 /* comp_param */);
289  return SQLTypeInfo(kTEXT, false, kENCODING_NONE);
291  return SQLTypeInfo(kTEXT, false, kENCODING_DICT);
293  return SQLTypeInfo(kTIMESTAMP, 9, 0, false);
301  return generate_column_type(kINT);
311  return generate_column_type(kTEXT, kENCODING_DICT, 0 /* comp_param */);
312  default:
313  LOG(FATAL) << "ExtArgumentType `" << serialize_type(ext_arg_type)
314  << "` cannot be converted to SQLTypeInfo.";
315  }
316  return SQLTypeInfo(kNULLT, false);
317 }
318 
320  const std::vector<ExtensionFunction>& ext_funcs,
321  std::string tab) {
322  std::string r = "";
323  for (auto sig : ext_funcs) {
324  r += tab + sig.toString() + "\n";
325  }
326  return r;
327 }
328 
330  const std::vector<SQLTypeInfo>& arg_types) {
331  std::string r = "";
332  for (auto sig = arg_types.begin(); sig != arg_types.end();) {
333  r += sig->get_type_name();
334  sig++;
335  if (sig != arg_types.end()) {
336  r += ", ";
337  }
338  }
339  return r;
340 }
341 
343  const std::vector<ExtArgumentType>& sig_types) {
344  std::string r = "";
345  for (auto t = sig_types.begin(); t != sig_types.end();) {
346  r += serialize_type(*t, /* byval */ false, /* declare */ false);
347  t++;
348  if (t != sig_types.end()) {
349  r += ", ";
350  }
351  }
352  return r;
353 }
354 
356  const std::vector<ExtArgumentType>& sig_types) {
357  std::string r = "";
358  for (auto t = sig_types.begin(); t != sig_types.end();) {
360  t++;
361  if (t != sig_types.end()) {
362  r += ", ";
363  }
364  }
365  return r;
366 }
367 
369  return serialize_type(sig_type, /* byval */ false, /* declare */ false);
370 }
371 
373  switch (sig_type) {
375  return "TINYINT";
377  return "SMALLINT";
379  return "INTEGER";
381  return "BIGINT";
383  return "FLOAT";
385  return "DOUBLE";
387  return "BOOLEAN";
389  return "TINYINT[]";
391  return "SMALLINT[]";
393  return "INT[]";
395  return "BIGINT[]";
397  return "FLOAT[]";
399  return "DOUBLE[]";
401  return "BOOLEAN[]";
403  return "ARRAY<TINYINT>";
405  return "ARRAY<SMALLINT>";
407  return "ARRAY<INT>";
409  return "ARRAY<BIGINT>";
411  return "ARRAY<FLOAT>";
413  return "ARRAY<DOUBLE>";
415  return "ARRAY<BOOLEAN>";
417  return "COLUMN<TINYINT>";
419  return "COLUMN<SMALLINT>";
421  return "COLUMN<INT>";
423  return "COLUMN<BIGINT>";
425  return "COLUMN<FLOAT>";
427  return "COLUMN<DOUBLE>";
429  return "COLUMN<BOOLEAN>";
431  return "COLUMN<TEXT ENCODING DICT>";
433  return "COLUMN<TIMESTAMP(9)>";
435  return "CURSOR";
437  return "POINT";
439  return "LINESTRING";
441  return "POLYGON";
443  return "MULTIPOLYGON";
445  return "VOID";
447  return "TEXT ENCODING NONE";
449  return "TEXT ENCODING DICT";
451  return "TIMESTAMP(9)";
453  return "COLUMNLIST<TINYINT>";
455  return "COLUMNLIST<SMALLINT>";
457  return "COLUMNLIST<INT>";
459  return "COLUMNLIST<BIGINT>";
461  return "COLUMNLIST<FLOAT>";
463  return "COLUMNLIST<DOUBLE>";
465  return "COLUMNLIST<BOOLEAN>";
467  return "COLUMNLIST<TEXT ENCODING DICT>";
468  default:
469  UNREACHABLE();
470  }
471  return "";
472 }
473 
474 const std::string ExtensionFunction::getName(bool keep_suffix) const {
475  return (keep_suffix ? name_ : drop_suffix(name_));
476 }
477 
478 std::string ExtensionFunction::toString() const {
479  return getName() + "(" + ExtensionFunctionsWhitelist::toString(args_) + ") -> " +
480  serialize_type(ret_, /* byval */ false, /* declare */ false);
481 }
482 
483 std::string ExtensionFunction::toStringSQL() const {
484  return getName(/* keep_suffix = */ false) + "(" +
487 }
488 
489 // Converts the extension function signatures to their LLVM representation.
491  const std::unordered_set<std::string>& udf_decls,
492  const bool is_gpu) {
493  std::vector<std::string> declarations;
494  for (const auto& kv : functions_) {
495  const auto& signatures = kv.second;
496  CHECK(!signatures.empty());
497  for (const auto& signature : kv.second) {
498  // If there is a udf function declaration matching an extension function signature
499  // do not emit a duplicate declaration.
500  if (!udf_decls.empty() && udf_decls.find(signature.getName()) != udf_decls.end()) {
501  continue;
502  }
503 
504  std::string decl_prefix;
505  std::vector<std::string> arg_strs;
506 
507  if (is_ext_arg_type_array(signature.getRet())) {
508  decl_prefix = "declare void @" + signature.getName();
509  arg_strs.emplace_back(
510  serialize_type(signature.getRet(), /* byval */ true, /* declare */ true));
511  } else {
512  decl_prefix =
513  "declare " +
514  serialize_type(signature.getRet(), /* byval */ true, /* declare */ true) +
515  " @" + signature.getName();
516  }
517  for (const auto arg : signature.getArgs()) {
518  arg_strs.push_back(serialize_type(arg, /* byval */ false, /* declare */ true));
519  }
520  declarations.push_back(decl_prefix + "(" + boost::algorithm::join(arg_strs, ", ") +
521  ");");
522  }
523  }
524 
526  if (kv.second.isRuntime() || kv.second.useDefaultSizer()) {
527  // Runtime UDTFs are defined in LLVM/NVVM IR module
528  // UDTFs using default sizer share LLVM IR
529  continue;
530  }
531  if (!((is_gpu && kv.second.isGPU()) || (!is_gpu && kv.second.isCPU()))) {
532  continue;
533  }
534  std::string decl_prefix{
535  "declare " +
536  serialize_type(ExtArgumentType::Int32, /* byval */ true, /* declare */ true) +
537  " @" + kv.first};
538  std::vector<std::string> arg_strs;
539  for (const auto arg : kv.second.getArgs(/* ensure_column = */ true)) {
540  arg_strs.push_back(
541  serialize_type(arg, /* byval= */ kv.second.isRuntime(), /* declare= */ true));
542  }
543  declarations.push_back(decl_prefix + "(" + boost::algorithm::join(arg_strs, ", ") +
544  ");");
545  }
546  return declarations;
547 }
548 
549 namespace {
550 
552  if (type_name == "bool" || type_name == "i1") {
553  return ExtArgumentType::Bool;
554  }
555  if (type_name == "i8") {
556  return ExtArgumentType::Int8;
557  }
558  if (type_name == "i16") {
559  return ExtArgumentType::Int16;
560  }
561  if (type_name == "i32") {
562  return ExtArgumentType::Int32;
563  }
564  if (type_name == "i64") {
565  return ExtArgumentType::Int64;
566  }
567  if (type_name == "float") {
568  return ExtArgumentType::Float;
569  }
570  if (type_name == "double") {
572  }
573  if (type_name == "void") {
574  return ExtArgumentType::Void;
575  }
576  if (type_name == "i8*") {
577  return ExtArgumentType::PInt8;
578  }
579  if (type_name == "i16*") {
581  }
582  if (type_name == "i32*") {
584  }
585  if (type_name == "i64*") {
587  }
588  if (type_name == "float*") {
590  }
591  if (type_name == "double*") {
593  }
594  if (type_name == "i1*" || type_name == "bool*") {
595  return ExtArgumentType::PBool;
596  }
597  if (type_name == "Array<i8>") {
599  }
600  if (type_name == "Array<i16>") {
602  }
603  if (type_name == "Array<i32>") {
605  }
606  if (type_name == "Array<i64>") {
608  }
609  if (type_name == "Array<float>") {
611  }
612  if (type_name == "Array<double>") {
614  }
615  if (type_name == "Array<bool>" || type_name == "Array<i1>") {
617  }
618  if (type_name == "geo_point") {
620  }
621  if (type_name == "geo_linestring") {
623  }
624  if (type_name == "geo_polygon") {
626  }
627  if (type_name == "geo_multi_polygon") {
629  }
630  if (type_name == "cursor") {
632  }
633  if (type_name == "Column<i8>") {
635  }
636  if (type_name == "Column<i16>") {
638  }
639  if (type_name == "Column<i32>") {
641  }
642  if (type_name == "Column<i64>") {
644  }
645  if (type_name == "Column<float>") {
647  }
648  if (type_name == "Column<double>") {
650  }
651  if (type_name == "Column<bool>") {
653  }
654  if (type_name == "Column<TextEncodingDict>") {
656  }
657  if (type_name == "Column<timestamp>") {
659  }
660  if (type_name == "TextEncodingNone") {
662  }
663  if (type_name == "TextEncodingDict") {
665  }
666  if (type_name == "timestamp") {
668  }
669  if (type_name == "ColumnList<i8>") {
671  }
672  if (type_name == "ColumnList<i16>") {
674  }
675  if (type_name == "ColumnList<i32>") {
677  }
678  if (type_name == "ColumnList<i64>") {
680  }
681  if (type_name == "ColumnList<float>") {
683  }
684  if (type_name == "ColumnList<double>") {
686  }
687  if (type_name == "ColumnList<bool>") {
689  }
690  if (type_name == "ColumnList<TextEncodingDict>") {
692  }
693  CHECK(false);
694  return ExtArgumentType::Int16;
695 }
696 
697 } // namespace
698 
699 using SignatureMap = std::unordered_map<std::string, std::vector<ExtensionFunction>>;
700 
702  const std::string& json_func_sigs) {
703  rapidjson::Document func_sigs;
704  func_sigs.Parse(json_func_sigs.c_str());
705  CHECK(func_sigs.IsArray());
706  for (auto func_sigs_it = func_sigs.Begin(); func_sigs_it != func_sigs.End();
707  ++func_sigs_it) {
708  CHECK(func_sigs_it->IsObject());
709  const auto name = json_str(field(*func_sigs_it, "name"));
710  const auto ret = deserialize_type(json_str(field(*func_sigs_it, "ret")));
711  std::vector<ExtArgumentType> args;
712  const auto& args_serialized = field(*func_sigs_it, "args");
713  CHECK(args_serialized.IsArray());
714  for (auto args_serialized_it = args_serialized.Begin();
715  args_serialized_it != args_serialized.End();
716  ++args_serialized_it) {
717  args.push_back(deserialize_type(json_str(*args_serialized_it)));
718  }
719  signatures[to_upper(drop_suffix(name))].emplace_back(name, args, ret);
720  }
721 }
722 
723 // Calcite loads the available extensions from `ExtensionFunctions.ast`, adds
724 // them to its operator table and shares the list with the execution layer in
725 // JSON format. Build an in-memory representation of that list here so that it
726 // can be used by getLLVMDeclarations(), when the LLVM IR codegen asks for it.
727 void ExtensionFunctionsWhitelist::add(const std::string& json_func_sigs) {
728  // Valid json_func_sigs example:
729  // [
730  // {
731  // "name":"sum",
732  // "ret":"i32",
733  // "args":[
734  // "i32",
735  // "i32"
736  // ]
737  // }
738  // ]
739 
740  addCommon(functions_, json_func_sigs);
741 }
742 
743 void ExtensionFunctionsWhitelist::addUdfs(const std::string& json_func_sigs) {
744  if (!json_func_sigs.empty()) {
745  addCommon(udf_functions_, json_func_sigs);
746  }
747 }
748 
750  rt_udf_functions_.clear();
751 }
752 
753 void ExtensionFunctionsWhitelist::addRTUdfs(const std::string& json_func_sigs) {
754  if (!json_func_sigs.empty()) {
755  addCommon(rt_udf_functions_, json_func_sigs);
756  }
757 }
758 
759 std::unordered_map<std::string, std::vector<ExtensionFunction>>
761 
762 std::unordered_map<std::string, std::vector<ExtensionFunction>>
764 
765 std::unordered_map<std::string, std::vector<ExtensionFunction>>
767 
768 std::string toString(const ExtArgumentType& sig_type) {
769  return ExtensionFunctionsWhitelist::toString(sig_type);
770 }
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name, const bool is_gpu)
static void addUdfs(const std::string &json_func_sigs)
static std::unordered_map< std::string, std::vector< ExtensionFunction > > udf_functions_
static std::vector< ExtensionFunction > * get(const std::string &name)
#define LOG(tag)
Definition: Logger.h:217
auto generate_column_type(const SQLTypes subtype)
Definition: sqltypes.h:1124
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:44
static std::unordered_map< std::string, std::vector< ExtensionFunction > > rt_udf_functions_
std::string join(T const &container, std::string const &delim)
const std::vector< ExtArgumentType > args_
static void add(const std::string &json_func_sigs)
#define UNREACHABLE()
Definition: Logger.h:267
std::unordered_map< std::string, std::vector< ExtensionFunction >> SignatureMap
const std::string getName(bool keep_suffix=true) const
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:329
std::string toStringSQL() const
#define CHECK_GT(x, y)
Definition: Logger.h:235
static std::unordered_map< std::string, std::vector< ExtensionFunction > > functions_
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
static std::vector< ExtensionFunction > * get_udf(const std::string &name)
ExtArgumentType deserialize_type(const std::string &type_name)
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
std::string toString(const Executor::ExtModuleKinds &kind)
Definition: Execute.h:1453
static void addCommon(std::unordered_map< std::string, std::vector< ExtensionFunction >> &sigs, const std::string &json_func_sigs)
std::string to_upper(const std::string &str)
auto generate_array_type(const SQLTypes subtype)
Definition: sqltypes.h:1118
Definition: sqltypes.h:52
std::string serialize_type(const ExtArgumentType type, bool byval=true, bool declare=false)
const ExtArgumentType ret_
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:223
static std::unordered_map< std::string, TableFunction > functions_
Definition: sqltypes.h:45
string name
Definition: setup.in.py:72
constexpr auto type_name() noexcept
static std::vector< std::string > getLLVMDeclarations(const std::unordered_set< std::string > &udf_decls, const bool is_gpu=false)
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
static std::string toStringSQL(const std::vector< ExtArgumentType > &sig_types)
static void addRTUdfs(const std::string &json_func_sigs)