OmniSciDB  1dac507f6e
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ExtensionFunctionsWhitelist.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <boost/algorithm/string/join.hpp>
20 #include <iostream>
21 
24 #include "Shared/StringTransform.h"
25 
26 // Get the list of all type specializations for the given function name.
27 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get(
28  const std::string& name) {
29  const auto it = functions_.find(to_upper(name));
30  if (it == functions_.end()) {
31  return nullptr;
32  }
33  return &it->second;
34 }
35 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get_udf(
36  const std::string& name) {
37  const auto it = udf_functions_.find(to_upper(name));
38  if (it == udf_functions_.end()) {
39  return nullptr;
40  }
41  return &it->second;
42 }
43 
44 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
45  const std::string& name) {
46  std::vector<ExtensionFunction> ext_funcs = {};
47  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
48  const auto uname = to_upper(name);
49  for (auto funcs : collections) {
50  const auto it = funcs->find(uname);
51  if (it == funcs->end()) {
52  continue;
53  }
54  auto ext_func_sigs = it->second;
55  std::copy(ext_func_sigs.begin(), ext_func_sigs.end(), std::back_inserter(ext_funcs));
56  }
57  return ext_funcs;
58 }
59 
60 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
61  const std::string& name,
62  size_t arity) {
63  std::vector<ExtensionFunction> ext_funcs = {};
64  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
65  const auto uname = to_upper(name);
66  for (auto funcs : collections) {
67  const auto it = funcs->find(uname);
68  if (it == funcs->end()) {
69  continue;
70  }
71  auto ext_func_sigs = it->second;
72  std::copy_if(ext_func_sigs.begin(),
73  ext_func_sigs.end(),
74  std::back_inserter(ext_funcs),
75  [arity](auto sig) { return arity == sig.getArgs().size(); });
76  }
77  return ext_funcs;
78 }
79 
80 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
81  const std::string& name,
82  size_t arity,
83  const SQLTypeInfo& rtype) {
84  std::vector<ExtensionFunction> ext_funcs = {};
85  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
86  const auto uname = to_upper(name);
87  for (auto funcs : collections) {
88  const auto it = funcs->find(uname);
89  if (it == funcs->end()) {
90  continue;
91  }
92  auto ext_func_sigs = it->second;
93  std::copy_if(ext_func_sigs.begin(),
94  ext_func_sigs.end(),
95  std::back_inserter(ext_funcs),
96  [arity, rtype](auto sig) {
97  // Ideally, arity should be equal to the number of
98  // sig arguments but there seems to be many cases
99  // where some sig arguments will be represented
100  // with multiple arguments, for instance, array
101  // argument is translated to data pointer and array
102  // size arguments.
103  if (arity > sig.getArgs().size()) {
104  return false;
105  }
106  auto rt = rtype.get_type();
107  auto st = ext_arg_type_to_type_info(sig.getRet()).get_type();
108  return (st == rt || (st == kTINYINT && rt == kBOOLEAN));
109  });
110  }
111  return ext_funcs;
112 }
113 
114 namespace {
115 
116 // Returns the LLVM name for `type`.
117 std::string serialize_type(const ExtArgumentType type) {
118  switch (type) {
120  return "i1";
122  return "i8";
124  return "i16";
126  return "i32";
128  return "i64";
130  return "float";
132  return "double";
134  return "void";
136  return "i8*";
138  return "i16*";
140  return "i32*";
142  return "i64*";
144  return "float*";
146  return "double*";
148  return "array_i8";
150  return "array_i16";
152  return "array_i32";
154  return "array_i64";
156  return "array_float";
158  return "array_double";
160  return "geo_point";
162  return "geo_linestring";
164  return "geo_polygon";
166  return "cursor";
167  default:
168  CHECK(false);
169  }
170  CHECK(false);
171  return "";
172 }
173 
174 } // namespace
175 
177  /* This function is mostly used for scalar types.
178  For non-scalar types, NULL is returned as a placeholder.
179  */
180  switch (ext_arg_type) {
182  return SQLTypeInfo(kBOOLEAN, false);
184  return SQLTypeInfo(kTINYINT, false);
186  return SQLTypeInfo(kSMALLINT, false);
188  return SQLTypeInfo(kINT, false);
190  return SQLTypeInfo(kBIGINT, false);
192  return SQLTypeInfo(kFLOAT, false);
194  return SQLTypeInfo(kDOUBLE, false);
195  default:
196  LOG(WARNING) << "ExtArgumentType `" << serialize_type(ext_arg_type)
197  << "` cannot be converted to SQLTypeInfo. Returning nulltype.";
198  }
199  return SQLTypeInfo(kNULLT, false);
200 }
201 
203  const std::vector<ExtensionFunction>& ext_funcs,
204  std::string tab) {
205  std::string r = "";
206  for (auto sig : ext_funcs) {
207  r += tab + sig.toString() + "\n";
208  }
209  return r;
210 }
211 
213  const std::vector<SQLTypeInfo>& arg_types) {
214  std::string r = "";
215  for (auto sig = arg_types.begin(); sig != arg_types.end();) {
216  r += sig->get_type_name();
217  sig++;
218  if (sig != arg_types.end()) {
219  r += ", ";
220  }
221  }
222  return r;
223 }
224 
226  const std::vector<ExtArgumentType>& sig_types) {
227  std::string r = "";
228  for (auto t = sig_types.begin(); t != sig_types.end();) {
229  r += serialize_type(*t);
230  t++;
231  if (t != sig_types.end()) {
232  r += ", ";
233  }
234  }
235  return r;
236 }
237 
238 std::string ExtensionFunction::toString() const {
239  return getName() + "(" + ExtensionFunctionsWhitelist::toString(getArgs()) + ") -> " +
241 }
242 
243 // Converts the extension function signatures to their LLVM representation.
245  const std::unordered_set<std::string>& udf_decls) {
246  std::vector<std::string> declarations;
247  for (const auto& kv : functions_) {
248  const auto& signatures = kv.second;
249  CHECK(!signatures.empty());
250  for (const auto& signature : kv.second) {
251  // If there is a udf function declaration matching an extension function signature
252  // do not emit a duplicate declaration.
253  if (!udf_decls.empty() && udf_decls.find(signature.getName()) != udf_decls.end()) {
254  continue;
255  }
256 
257  std::string decl_prefix{"declare " + serialize_type(signature.getRet()) + " @" +
258  signature.getName()};
259  std::vector<std::string> arg_strs;
260  for (const auto arg : signature.getArgs()) {
261  arg_strs.push_back(serialize_type(arg));
262  }
263  declarations.push_back(decl_prefix + "(" + boost::algorithm::join(arg_strs, ", ") +
264  ");");
265  }
266  }
267 
269  if (kv.second.isRuntime()) {
270  // Runtime UDTFs are defined in LLVM/NVVM IR module
271  continue;
272  }
273  std::string decl_prefix{"declare " + serialize_type(ExtArgumentType::Int32) + " @" +
274  kv.first};
275  std::vector<std::string> arg_strs;
276  for (const auto arg : kv.second.getArgs()) {
277  arg_strs.push_back(serialize_type(arg));
278  }
279  declarations.push_back(decl_prefix + "(" + boost::algorithm::join(arg_strs, ", ") +
280  ");");
281  }
282  return declarations;
283 }
284 
285 namespace {
286 
287 ExtArgumentType deserialize_type(const std::string& type_name) {
288  if (type_name == "bool" || type_name == "i1") {
289  return ExtArgumentType::Int8; // need to handle the possibility of nulls
290  }
291  if (type_name == "i8") {
292  return ExtArgumentType::Int8;
293  }
294  if (type_name == "i16") {
295  return ExtArgumentType::Int16;
296  }
297  if (type_name == "i32") {
298  return ExtArgumentType::Int32;
299  }
300  if (type_name == "i64") {
301  return ExtArgumentType::Int64;
302  }
303  if (type_name == "float") {
304  return ExtArgumentType::Float;
305  }
306  if (type_name == "double") {
308  }
309  if (type_name == "void") {
310  return ExtArgumentType::Void;
311  }
312  if (type_name == "i8*") {
313  return ExtArgumentType::PInt8;
314  }
315  if (type_name == "i16*") {
317  }
318  if (type_name == "i32*") {
320  }
321  if (type_name == "i64*") {
323  }
324  if (type_name == "float*") {
326  }
327  if (type_name == "double*") {
329  }
330  if (type_name == "array_i8") {
332  }
333  if (type_name == "array_i16") {
335  }
336  if (type_name == "array_i32") {
338  }
339  if (type_name == "array_i64") {
341  }
342  if (type_name == "array_float") {
344  }
345  if (type_name == "array_double") {
347  }
348  if (type_name == "geo_point") {
350  }
351  if (type_name == "geo_linestring") {
353  }
354  if (type_name == "geo_polygon") {
356  }
357  if (type_name == "cursor") {
359  }
360 
361  CHECK(false);
362  return ExtArgumentType::Int16;
363 }
364 
365 std::string drop_suffix(const std::string& str) {
366  const auto idx = str.find("__");
367  if (idx == std::string::npos) {
368  return str;
369  }
370  CHECK_GT(idx, std::string::size_type(0));
371  return str.substr(0, idx);
372 }
373 
374 } // namespace
375 
376 using SignatureMap = std::unordered_map<std::string, std::vector<ExtensionFunction>>;
377 
379  const std::string& json_func_sigs) {
380  rapidjson::Document func_sigs;
381  func_sigs.Parse(json_func_sigs.c_str());
382  CHECK(func_sigs.IsArray());
383  for (auto func_sigs_it = func_sigs.Begin(); func_sigs_it != func_sigs.End();
384  ++func_sigs_it) {
385  CHECK(func_sigs_it->IsObject());
386  const auto name = json_str(field(*func_sigs_it, "name"));
387  const auto ret = deserialize_type(json_str(field(*func_sigs_it, "ret")));
388  std::vector<ExtArgumentType> args;
389  const auto& args_serialized = field(*func_sigs_it, "args");
390  CHECK(args_serialized.IsArray());
391  for (auto args_serialized_it = args_serialized.Begin();
392  args_serialized_it != args_serialized.End();
393  ++args_serialized_it) {
394  args.push_back(deserialize_type(json_str(*args_serialized_it)));
395  }
396  signatures[to_upper(drop_suffix(name))].emplace_back(name, args, ret);
397  }
398 }
399 
400 // Calcite loads the available extensions from `ExtensionFunctions.ast`, adds
401 // them to its operator table and shares the list with the execution layer in
402 // JSON format. Build an in-memory representation of that list here so that it
403 // can be used by getLLVMDeclarations(), when the LLVM IR codegen asks for it.
404 void ExtensionFunctionsWhitelist::add(const std::string& json_func_sigs) {
405  // Valid json_func_sigs example:
406  // [
407  // {
408  // "name":"sum",
409  // "ret":"i32",
410  // "args":[
411  // "i32",
412  // "i32"
413  // ]
414  // }
415  // ]
416 
417  addCommon(functions_, json_func_sigs);
418 }
419 
420 void ExtensionFunctionsWhitelist::addUdfs(const std::string& json_func_sigs) {
421  if (!json_func_sigs.empty()) {
422  addCommon(udf_functions_, json_func_sigs);
423  }
424 }
425 
427  rt_udf_functions_.clear();
428 }
429 
430 void ExtensionFunctionsWhitelist::addRTUdfs(const std::string& json_func_sigs) {
431  if (!json_func_sigs.empty()) {
432  addCommon(rt_udf_functions_, json_func_sigs);
433  }
434 }
435 
436 std::unordered_map<std::string, std::vector<ExtensionFunction>>
438 
439 std::unordered_map<std::string, std::vector<ExtensionFunction>>
441 
442 std::unordered_map<std::string, std::vector<ExtensionFunction>>
static void addUdfs(const std::string &json_func_sigs)
const std::vector< ExtArgumentType > & getArgs() const
const std::string & getName() const
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
static std::unordered_map< std::string, std::vector< ExtensionFunction > > udf_functions_
const ExtArgumentType getRet() const
static std::vector< ExtensionFunction > * get(const std::string &name)
#define LOG(tag)
Definition: Logger.h:185
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:44
static std::unordered_map< std::string, std::vector< ExtensionFunction > > rt_udf_functions_
std::string join(T const &container, std::string const &delim)
static void add(const std::string &json_func_sigs)
std::unordered_map< std::string, std::vector< ExtensionFunction >> SignatureMap
#define CHECK_GT(x, y)
Definition: Logger.h:202
static std::unordered_map< std::string, std::vector< ExtensionFunction > > functions_
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
CHECK(cgen_state)
static std::vector< ExtensionFunction > * get_udf(const std::string &name)
ExtArgumentType deserialize_type(const std::string &type_name)
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:326
static void addCommon(std::unordered_map< std::string, std::vector< ExtensionFunction >> &sigs, const std::string &json_func_sigs)
SQLTypeInfoCore< ArrayContextTypeSizer, ExecutorTypePackaging, DateTimeFacilities > SQLTypeInfo
Definition: sqltypes.h:852
std::string to_upper(const std::string &str)
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
static std::vector< std::string > getLLVMDeclarations(const std::unordered_set< std::string > &udf_decls)
static std::unordered_map< std::string, TableFunction > functions_
Definition: sqltypes.h:48
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
static void addRTUdfs(const std::string &json_func_sigs)