OmniSciDB  dfae7c3b14
ExtensionFunctionsWhitelist.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <boost/algorithm/string/join.hpp>
20 #include <iostream>
21 
25 #include "Shared/StringTransform.h"
26 
27 // Get the list of all type specializations for the given function name.
28 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get(
29  const std::string& name) {
30  const auto it = functions_.find(to_upper(name));
31  if (it == functions_.end()) {
32  return nullptr;
33  }
34  return &it->second;
35 }
36 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get_udf(
37  const std::string& name) {
38  const auto it = udf_functions_.find(to_upper(name));
39  if (it == udf_functions_.end()) {
40  return nullptr;
41  }
42  return &it->second;
43 }
44 
45 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
46  const std::string& name) {
47  std::vector<ExtensionFunction> ext_funcs = {};
48  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
49  const auto uname = to_upper(name);
50  for (auto funcs : collections) {
51  const auto it = funcs->find(uname);
52  if (it == funcs->end()) {
53  continue;
54  }
55  auto ext_func_sigs = it->second;
56  std::copy(ext_func_sigs.begin(), ext_func_sigs.end(), std::back_inserter(ext_funcs));
57  }
58  return ext_funcs;
59 }
60 
61 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
62  const std::string& name,
63  size_t arity) {
64  std::vector<ExtensionFunction> ext_funcs = {};
65  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
66  const auto uname = to_upper(name);
67  for (auto funcs : collections) {
68  const auto it = funcs->find(uname);
69  if (it == funcs->end()) {
70  continue;
71  }
72  auto ext_func_sigs = it->second;
73  std::copy_if(ext_func_sigs.begin(),
74  ext_func_sigs.end(),
75  std::back_inserter(ext_funcs),
76  [arity](auto sig) { return arity == sig.getArgs().size(); });
77  }
78  return ext_funcs;
79 }
80 
81 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
82  const std::string& name,
83  size_t arity,
84  const SQLTypeInfo& rtype) {
85  std::vector<ExtensionFunction> ext_funcs = {};
86  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
87  const auto uname = to_upper(name);
88  for (auto funcs : collections) {
89  const auto it = funcs->find(uname);
90  if (it == funcs->end()) {
91  continue;
92  }
93  auto ext_func_sigs = it->second;
94  std::copy_if(ext_func_sigs.begin(),
95  ext_func_sigs.end(),
96  std::back_inserter(ext_funcs),
97  [arity, rtype](auto sig) {
98  // Ideally, arity should be equal to the number of
99  // sig arguments but there seems to be many cases
100  // where some sig arguments will be represented
101  // with multiple arguments, for instance, array
102  // argument is translated to data pointer and array
103  // size arguments.
104  if (arity > sig.getArgs().size()) {
105  return false;
106  }
107  auto rt = rtype.get_type();
108  auto st = ext_arg_type_to_type_info(sig.getRet()).get_type();
109  return (st == rt || (st == kTINYINT && rt == kBOOLEAN));
110  });
111  }
112  return ext_funcs;
113 }
114 
115 namespace {
116 
117 // Returns the LLVM name for `type`.
118 std::string serialize_type(const ExtArgumentType type, bool byval = true) {
119  switch (type) {
121  return "i8"; // clang converts bool to i8
123  return "i8";
125  return "i16";
127  return "i32";
129  return "i64";
131  return "float";
133  return "double";
135  return "void";
137  return "i8*";
139  return "i16*";
141  return "i32*";
143  return "i64*";
145  return "float*";
147  return "double*";
149  return "i1*";
151  return "{i8*, i64, i8}*";
153  return "{i16*, i64, i8}*";
155  return "{i32*, i64, i8}*";
157  return "{i64*, i64, i8}*";
159  return "{float*, i64, i8}*";
161  return "{double*, i64, i8}*";
163  return "{i1*, i64, i8}*";
165  return "geo_point";
167  return "geo_linestring";
169  return "geo_polygon";
171  return "geo_multi_polygon";
173  return "cursor";
175  return (byval ? "{i8*, i64}" : "i8*");
177  return (byval ? "{i16*, i64}" : "i8*");
179  return (byval ? "{i32*, i64}" : "i8*");
181  return (byval ? "{i64*, i64}" : "i8*");
183  return (byval ? "{float*, i64}" : "i8*");
185  return (byval ? "{double*, i64}" : "i8*");
187  return (byval ? "{i1*, i64}" : "i8*");
188  default:
189  CHECK(false);
190  }
191  CHECK(false);
192  return "";
193 }
194 
195 std::string drop_suffix(const std::string& str) {
196  const auto idx = str.find("__");
197  if (idx == std::string::npos) {
198  return str;
199  }
200  CHECK_GT(idx, std::string::size_type(0));
201  return str.substr(0, idx);
202 }
203 
204 } // namespace
205 
207  /* This function is mostly used for scalar types.
208  For non-scalar types, NULL is returned as a placeholder.
209  */
210 
211  auto generate_array_type = [](const auto subtype) {
212  auto ti = SQLTypeInfo(kARRAY, false);
213  ti.set_subtype(subtype);
214  return ti;
215  };
216 
217  auto generate_column_type = [](const auto subtype) {
218  auto ti = SQLTypeInfo(kCOLUMN, false);
219  ti.set_subtype(subtype);
220  return ti;
221  };
222 
223  switch (ext_arg_type) {
225  return SQLTypeInfo(kBOOLEAN, false);
227  return SQLTypeInfo(kTINYINT, false);
229  return SQLTypeInfo(kSMALLINT, false);
231  return SQLTypeInfo(kINT, false);
233  return SQLTypeInfo(kBIGINT, false);
235  return SQLTypeInfo(kFLOAT, false);
237  return SQLTypeInfo(kDOUBLE, false);
239  return generate_array_type(kTINYINT);
241  return generate_array_type(kSMALLINT);
243  return generate_array_type(kINT);
245  return generate_array_type(kBIGINT);
247  return generate_array_type(kFLOAT);
249  return generate_array_type(kDOUBLE);
251  return generate_array_type(kBOOLEAN);
253  return generate_column_type(kTINYINT);
255  return generate_column_type(kSMALLINT);
257  return generate_column_type(kINT);
259  return generate_column_type(kBIGINT);
261  return generate_column_type(kFLOAT);
263  return generate_column_type(kDOUBLE);
265  return generate_column_type(kBOOLEAN);
266  default:
267  LOG(WARNING) << "ExtArgumentType `" << serialize_type(ext_arg_type)
268  << "` cannot be converted to SQLTypeInfo. Returning nulltype.";
269  }
270  return SQLTypeInfo(kNULLT, false);
271 }
272 
274  const std::vector<ExtensionFunction>& ext_funcs,
275  std::string tab) {
276  std::string r = "";
277  for (auto sig : ext_funcs) {
278  r += tab + sig.toString() + "\n";
279  }
280  return r;
281 }
282 
284  const std::vector<SQLTypeInfo>& arg_types) {
285  std::string r = "";
286  for (auto sig = arg_types.begin(); sig != arg_types.end();) {
287  r += sig->get_type_name();
288  sig++;
289  if (sig != arg_types.end()) {
290  r += ", ";
291  }
292  }
293  return r;
294 }
295 
297  const std::vector<ExtArgumentType>& sig_types) {
298  std::string r = "";
299  for (auto t = sig_types.begin(); t != sig_types.end();) {
300  r += serialize_type(*t);
301  t++;
302  if (t != sig_types.end()) {
303  r += ", ";
304  }
305  }
306  return r;
307 }
308 
310  const std::vector<ExtArgumentType>& sig_types) {
311  std::string r = "";
312  for (auto t = sig_types.begin(); t != sig_types.end();) {
314  t++;
315  if (t != sig_types.end()) {
316  r += ", ";
317  }
318  }
319  return r;
320 }
321 
323  return serialize_type(sig_type);
324 }
325 
327  switch (sig_type) {
329  return "TINYINT";
331  return "SMALLINT";
333  return "INTEGER";
335  return "BIGINT";
337  return "FLOAT";
339  return "DOUBLE";
341  return "BOOLEAN";
343  return "TINYINT[]";
345  return "SMALLINT[]";
347  return "INT[]";
349  return "BIGINT[]";
351  return "FLOAT[]";
353  return "DOUBLE[]";
355  return "BOOLEAN[]";
357  return "ARRAY<TINYINT>";
359  return "ARRAY<SMALLINT>";
361  return "ARRAY<INT>";
363  return "ARRAY<BIGINT>";
365  return "ARRAY<FLOAT>";
367  return "ARRAY<DOUBLE>";
369  return "ARRAY<BOOLEAN>";
371  return "COLUMN<TINYINT>";
373  return "COLUMN<SMALLINT>";
375  return "COLUMN<INT>";
377  return "COLUMN<BIGINT>";
379  return "COLUMN<FLOAT>";
381  return "COLUMN<DOUBLE>";
383  return "COLUMN<BOOLEAN>";
385  return "CURSOR";
387  return "POINT";
389  return "LINESTRING";
391  return "POLYGON";
393  return "MULTIPOLYGON";
395  return "VOID";
396  default:
397  UNREACHABLE();
398  }
399  return "";
400 }
401 
402 const std::string ExtensionFunction::getName(bool keep_suffix) const {
403  return (keep_suffix ? name_ : drop_suffix(name_));
404 }
405 
406 std::string ExtensionFunction::toString() const {
407  return getName() + "(" + ExtensionFunctionsWhitelist::toString(args_) + ") -> " +
408  serialize_type(ret_);
409 }
410 
411 std::string ExtensionFunction::toStringSQL() const {
412  return getName(/* keep_suffix = */ false) + "(" +
415 }
416 
417 // Converts the extension function signatures to their LLVM representation.
419  const std::unordered_set<std::string>& udf_decls) {
420  std::vector<std::string> declarations;
421  for (const auto& kv : functions_) {
422  const auto& signatures = kv.second;
423  CHECK(!signatures.empty());
424  for (const auto& signature : kv.second) {
425  // If there is a udf function declaration matching an extension function signature
426  // do not emit a duplicate declaration.
427  if (!udf_decls.empty() && udf_decls.find(signature.getName()) != udf_decls.end()) {
428  continue;
429  }
430 
431  std::string decl_prefix;
432  std::vector<std::string> arg_strs;
433 
434  if (is_ext_arg_type_array(signature.getRet())) {
435  decl_prefix = "declare void @" + signature.getName();
436  arg_strs.emplace_back(serialize_type(signature.getRet()));
437  } else {
438  decl_prefix =
439  "declare " + serialize_type(signature.getRet()) + " @" + signature.getName();
440  }
441  for (const auto arg : signature.getArgs()) {
442  arg_strs.push_back(serialize_type(arg));
443  }
444  declarations.push_back(decl_prefix + "(" + boost::algorithm::join(arg_strs, ", ") +
445  ");");
446  }
447  }
448 
450  if (kv.second.isRuntime()) {
451  // Runtime UDTFs are defined in LLVM/NVVM IR module
452  continue;
453  }
454  std::string decl_prefix{"declare " + serialize_type(ExtArgumentType::Int32) + " @" +
455  kv.first};
456  std::vector<std::string> arg_strs;
457  for (const auto arg : kv.second.getArgs()) {
458  arg_strs.push_back(serialize_type(arg, /* byval= */ kv.second.isRuntime()));
459  }
460  declarations.push_back(decl_prefix + "(" + boost::algorithm::join(arg_strs, ", ") +
461  ");");
462  }
463  return declarations;
464 }
465 
466 namespace {
467 
468 ExtArgumentType deserialize_type(const std::string& type_name) {
469  if (type_name == "bool" || type_name == "i1") {
470  return ExtArgumentType::Bool;
471  }
472  if (type_name == "i8") {
473  return ExtArgumentType::Int8;
474  }
475  if (type_name == "i16") {
476  return ExtArgumentType::Int16;
477  }
478  if (type_name == "i32") {
479  return ExtArgumentType::Int32;
480  }
481  if (type_name == "i64") {
482  return ExtArgumentType::Int64;
483  }
484  if (type_name == "float") {
485  return ExtArgumentType::Float;
486  }
487  if (type_name == "double") {
489  }
490  if (type_name == "void") {
491  return ExtArgumentType::Void;
492  }
493  if (type_name == "i8*") {
494  return ExtArgumentType::PInt8;
495  }
496  if (type_name == "i16*") {
498  }
499  if (type_name == "i32*") {
501  }
502  if (type_name == "i64*") {
504  }
505  if (type_name == "float*") {
507  }
508  if (type_name == "double*") {
510  }
511  if (type_name == "i1*" || type_name == "bool*") {
512  return ExtArgumentType::PBool;
513  }
514  if (type_name == "{i8*, i64, i8}*") {
516  }
517  if (type_name == "{i16*, i64, i8}*") {
519  }
520  if (type_name == "{i32*, i64, i8}*") {
522  }
523  if (type_name == "{i64*, i64, i8}*") {
525  }
526  if (type_name == "{float*, i64, i8}*") {
528  }
529  if (type_name == "{double*, i64, i8}*") {
531  }
532  if (type_name == "{i1*, i64, i8}*" || type_name == "{bool*, i64, i8}*") {
534  }
535  if (type_name == "geo_point") {
537  }
538  if (type_name == "geo_linestring") {
540  }
541  if (type_name == "geo_polygon") {
543  }
544  if (type_name == "geo_multi_polygon") {
546  }
547  if (type_name == "cursor") {
549  }
550  if (type_name == "{i8*, i64}") {
552  }
553  if (type_name == "{i16*, i64}") {
555  }
556  if (type_name == "{i32*, i64}") {
558  }
559  if (type_name == "{i64*, i64}") {
561  }
562  if (type_name == "{float*, i64}") {
564  }
565  if (type_name == "{double*, i64}") {
567  }
568  if (type_name == "{i1*, i64}" || type_name == "{bool*, i64}") {
570  }
571  CHECK(false);
572  return ExtArgumentType::Int16;
573 }
574 
575 } // namespace
576 
577 using SignatureMap = std::unordered_map<std::string, std::vector<ExtensionFunction>>;
578 
580  const std::string& json_func_sigs) {
581  rapidjson::Document func_sigs;
582  func_sigs.Parse(json_func_sigs.c_str());
583  CHECK(func_sigs.IsArray());
584  for (auto func_sigs_it = func_sigs.Begin(); func_sigs_it != func_sigs.End();
585  ++func_sigs_it) {
586  CHECK(func_sigs_it->IsObject());
587  const auto name = json_str(field(*func_sigs_it, "name"));
588  const auto ret = deserialize_type(json_str(field(*func_sigs_it, "ret")));
589  std::vector<ExtArgumentType> args;
590  const auto& args_serialized = field(*func_sigs_it, "args");
591  CHECK(args_serialized.IsArray());
592  for (auto args_serialized_it = args_serialized.Begin();
593  args_serialized_it != args_serialized.End();
594  ++args_serialized_it) {
595  args.push_back(deserialize_type(json_str(*args_serialized_it)));
596  }
597  signatures[to_upper(drop_suffix(name))].emplace_back(name, args, ret);
598  }
599 }
600 
601 // Calcite loads the available extensions from `ExtensionFunctions.ast`, adds
602 // them to its operator table and shares the list with the execution layer in
603 // JSON format. Build an in-memory representation of that list here so that it
604 // can be used by getLLVMDeclarations(), when the LLVM IR codegen asks for it.
605 void ExtensionFunctionsWhitelist::add(const std::string& json_func_sigs) {
606  // Valid json_func_sigs example:
607  // [
608  // {
609  // "name":"sum",
610  // "ret":"i32",
611  // "args":[
612  // "i32",
613  // "i32"
614  // ]
615  // }
616  // ]
617 
618  addCommon(functions_, json_func_sigs);
619 }
620 
621 void ExtensionFunctionsWhitelist::addUdfs(const std::string& json_func_sigs) {
622  if (!json_func_sigs.empty()) {
623  addCommon(udf_functions_, json_func_sigs);
624  }
625 }
626 
628  rt_udf_functions_.clear();
629 }
630 
631 void ExtensionFunctionsWhitelist::addRTUdfs(const std::string& json_func_sigs) {
632  if (!json_func_sigs.empty()) {
633  addCommon(rt_udf_functions_, json_func_sigs);
634  }
635 }
636 
637 std::unordered_map<std::string, std::vector<ExtensionFunction>>
639 
640 std::unordered_map<std::string, std::vector<ExtensionFunction>>
642 
643 std::unordered_map<std::string, std::vector<ExtensionFunction>>
645 
646 std::string toString(const ExtArgumentType& sig_type) {
647  return ExtensionFunctionsWhitelist::toString(sig_type);
648 }
static void addUdfs(const std::string &json_func_sigs)
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
static std::unordered_map< std::string, std::vector< ExtensionFunction > > udf_functions_
static std::vector< ExtensionFunction > * get(const std::string &name)
#define LOG(tag)
Definition: Logger.h:188
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:44
static std::unordered_map< std::string, std::vector< ExtensionFunction > > rt_udf_functions_
std::string join(T const &container, std::string const &delim)
static void add(const std::string &json_func_sigs)
#define UNREACHABLE()
Definition: Logger.h:241
std::unordered_map< std::string, std::vector< ExtensionFunction > > SignatureMap
#define CHECK_GT(x, y)
Definition: Logger.h:209
name
Definition: setup.py:35
const std::string getName(bool keep_suffix=true) const
static std::unordered_map< std::string, std::vector< ExtensionFunction > > functions_
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
static std::vector< ExtensionFunction > * get_udf(const std::string &name)
ExtArgumentType deserialize_type(const std::string &type_name)
std::string serialize_type(const ExtArgumentType type, bool byval=true)
static void addCommon(std::unordered_map< std::string, std::vector< ExtensionFunction >> &sigs, const std::string &json_func_sigs)
std::string to_upper(const std::string &str)
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
static std::vector< std::string > getLLVMDeclarations(const std::unordered_set< std::string > &udf_decls)
#define CHECK(condition)
Definition: Logger.h:197
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:259
static std::unordered_map< std::string, TableFunction > functions_
Definition: sqltypes.h:47
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
static std::string toStringSQL(const std::vector< ExtArgumentType > &sig_types)
static void addRTUdfs(const std::string &json_func_sigs)