OmniSciDB  04ee39c94c
ExtensionFunctionsWhitelist.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 #include <iostream>
19 #include "JsonAccessors.h"
20 
21 #include "../Shared/StringTransform.h"
22 
23 #include <boost/algorithm/string/join.hpp>
24 
25 // Get the list of all type specializations for the given function name.
26 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get(
27  const std::string& name) {
28  const auto it = functions_.find(to_upper(name));
29  if (it == functions_.end()) {
30  return nullptr;
31  }
32  return &it->second;
33 }
34 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get_udf(
35  const std::string& name) {
36  const auto it = udf_functions_.find(to_upper(name));
37  if (it == udf_functions_.end()) {
38  return nullptr;
39  }
40  return &it->second;
41 }
42 
43 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
44  const std::string& name) {
45  std::vector<ExtensionFunction> ext_funcs = {};
46  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
47  const auto uname = to_upper(name);
48  for (auto funcs : collections) {
49  const auto it = funcs->find(uname);
50  if (it == funcs->end()) {
51  continue;
52  }
53  auto ext_func_sigs = it->second;
54  std::copy(ext_func_sigs.begin(), ext_func_sigs.end(), std::back_inserter(ext_funcs));
55  }
56  return ext_funcs;
57 }
58 
59 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
60  const std::string& name,
61  size_t arity) {
62  std::vector<ExtensionFunction> ext_funcs = {};
63  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
64  const auto uname = to_upper(name);
65  for (auto funcs : collections) {
66  const auto it = funcs->find(uname);
67  if (it == funcs->end()) {
68  continue;
69  }
70  auto ext_func_sigs = it->second;
71  std::copy_if(ext_func_sigs.begin(),
72  ext_func_sigs.end(),
73  std::back_inserter(ext_funcs),
74  [arity](auto sig) { return arity == sig.getArgs().size(); });
75  }
76  return ext_funcs;
77 }
78 
79 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
80  const std::string& name,
81  size_t arity,
82  const SQLTypeInfo& rtype) {
83  std::vector<ExtensionFunction> ext_funcs = {};
84  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
85  const auto uname = to_upper(name);
86  for (auto funcs : collections) {
87  const auto it = funcs->find(uname);
88  if (it == funcs->end()) {
89  continue;
90  }
91  auto ext_func_sigs = it->second;
92  std::copy_if(ext_func_sigs.begin(),
93  ext_func_sigs.end(),
94  std::back_inserter(ext_funcs),
95  [arity, rtype](auto sig) {
96  // Ideally, arity should be equal to the number of
97  // sig arguments but there seems to be many cases
98  // where some sig arguments will be represented
99  // with multiple arguments, for instance, array
100  // argument is translated to data pointer and array
101  // size arguments.
102  if (arity > sig.getArgs().size()) {
103  return false;
104  }
105  auto rt = rtype.get_type();
106  auto st = ext_arg_type_to_type_info(sig.getRet()).get_type();
107  return (st == rt || (st == kTINYINT && rt == kBOOLEAN));
108  });
109  }
110  return ext_funcs;
111 }
112 
113 namespace {
114 
115 // Returns the LLVM name for `type`.
116 std::string serialize_type(const ExtArgumentType type) {
117  switch (type) {
119  return "i1";
121  return "i8";
123  return "i16";
125  return "i32";
127  return "i64";
129  return "float";
131  return "double";
133  return "i8*";
135  return "i16*";
137  return "i32*";
139  return "i64*";
141  return "float*";
143  return "double*";
144  default:
145  CHECK(false);
146  }
147  CHECK(false);
148  return "";
149 }
150 
151 } // namespace
152 
154  /* This function is mostly used for scalar types.
155  For non-scalar types, NULL is returned as a placeholder.
156  */
157  switch (ext_arg_type) {
159  return SQLTypeInfo(kBOOLEAN, false);
161  return SQLTypeInfo(kTINYINT, false);
163  return SQLTypeInfo(kSMALLINT, false);
165  return SQLTypeInfo(kINT, false);
167  return SQLTypeInfo(kBIGINT, false);
169  return SQLTypeInfo(kFLOAT, false);
171  return SQLTypeInfo(kDOUBLE, false);
172  default:
173  LOG(WARNING) << "ExtArgumentType `" << serialize_type(ext_arg_type)
174  << "` cannot be converted to SQLTypeInfo. Returning nulltype.";
175  }
176  return SQLTypeInfo(kNULLT, false);
177 }
178 
180  const std::vector<ExtensionFunction>& ext_funcs,
181  std::string tab) {
182  std::string r = "";
183  for (auto sig : ext_funcs) {
184  r += tab + sig.toString() + "\n";
185  }
186  return r;
187 }
188 
190  const std::vector<SQLTypeInfo>& arg_types) {
191  std::string r = "";
192  for (auto sig = arg_types.begin(); sig != arg_types.end();) {
193  r += sig->get_type_name();
194  sig++;
195  if (sig != arg_types.end()) {
196  r += ", ";
197  }
198  }
199  return r;
200 }
201 
203  const std::vector<ExtArgumentType>& sig_types) {
204  std::string r = "";
205  for (auto t = sig_types.begin(); t != sig_types.end();) {
206  r += serialize_type(*t);
207  t++;
208  if (t != sig_types.end()) {
209  r += ", ";
210  }
211  }
212  return r;
213 }
214 
215 std::string ExtensionFunction::toString() const {
216  return getName() + "(" + ExtensionFunctionsWhitelist::toString(getArgs()) + ") -> " +
217  serialize_type(getRet());
218 }
219 
220 // Converts the extension function signatures to their LLVM representation.
222  std::vector<std::string> declarations;
223  for (const auto& kv : functions_) {
224  const auto& signatures = kv.second;
225  CHECK(!signatures.empty());
226  for (const auto& signature : kv.second) {
227  std::string decl_prefix{"declare " + serialize_type(signature.getRet()) + " @" +
228  signature.getName()};
229  std::vector<std::string> arg_strs;
230  for (const auto arg : signature.getArgs()) {
231  arg_strs.push_back(serialize_type(arg));
232  }
233  declarations.push_back(decl_prefix + "(" + boost::algorithm::join(arg_strs, ", ") +
234  ");");
235  }
236  }
237  return declarations;
238 }
239 
240 namespace {
241 
242 ExtArgumentType deserialize_type(const std::string& type_name) {
243  if (type_name == "bool" || type_name == "i1") {
244  return ExtArgumentType::Int8; // need to handle the possibility of nulls
245  }
246  if (type_name == "i8") {
247  return ExtArgumentType::Int8;
248  }
249  if (type_name == "i16") {
250  return ExtArgumentType::Int16;
251  }
252  if (type_name == "i32") {
253  return ExtArgumentType::Int32;
254  }
255  if (type_name == "i64") {
256  return ExtArgumentType::Int64;
257  }
258  if (type_name == "float") {
259  return ExtArgumentType::Float;
260  }
261  if (type_name == "double") {
263  }
264  if (type_name == "i8*") {
265  return ExtArgumentType::PInt8;
266  }
267  if (type_name == "i16*") {
269  }
270  if (type_name == "i32*") {
272  }
273  if (type_name == "i64*") {
275  }
276  if (type_name == "float*") {
278  }
279  if (type_name == "double*") {
281  }
282  CHECK(false);
283  return ExtArgumentType::Int16;
284 }
285 
286 std::string drop_suffix(const std::string& str) {
287  const auto idx = str.find("__");
288  if (idx == std::string::npos) {
289  return str;
290  }
291  CHECK_GT(idx, std::string::size_type(0));
292  return str.substr(0, idx);
293 }
294 
295 } // namespace
296 
297 using SignatureMap = std::unordered_map<std::string, std::vector<ExtensionFunction>>;
298 
300  const std::string& json_func_sigs) {
301  rapidjson::Document func_sigs;
302  func_sigs.Parse(json_func_sigs.c_str());
303  CHECK(func_sigs.IsArray());
304  for (auto func_sigs_it = func_sigs.Begin(); func_sigs_it != func_sigs.End();
305  ++func_sigs_it) {
306  CHECK(func_sigs_it->IsObject());
307  const auto name = json_str(field(*func_sigs_it, "name"));
308  const auto ret = deserialize_type(json_str(field(*func_sigs_it, "ret")));
309  std::vector<ExtArgumentType> args;
310  const auto& args_serialized = field(*func_sigs_it, "args");
311  CHECK(args_serialized.IsArray());
312  for (auto args_serialized_it = args_serialized.Begin();
313  args_serialized_it != args_serialized.End();
314  ++args_serialized_it) {
315  args.push_back(deserialize_type(json_str(*args_serialized_it)));
316  }
317  signatures[to_upper(drop_suffix(name))].emplace_back(name, args, ret);
318  }
319 }
320 
321 // Calcite loads the available extensions from `ExtensionFunctions.ast`, adds
322 // them to its operator table and shares the list with the execution layer in
323 // JSON format. Build an in-memory representation of that list here so that it
324 // can be used by getLLVMDeclarations(), when the LLVM IR codegen asks for it.
325 void ExtensionFunctionsWhitelist::add(const std::string& json_func_sigs) {
326  // Valid json_func_sigs example:
327  // [
328  // {
329  // "name":"sum",
330  // "ret":"i32",
331  // "args":[
332  // "i32",
333  // "i32"
334  // ]
335  // }
336  // ]
337 
338  addCommon(functions_, json_func_sigs);
339 }
340 
341 void ExtensionFunctionsWhitelist::addUdfs(const std::string& json_func_sigs) {
342  if (!json_func_sigs.empty()) {
343  addCommon(udf_functions_, json_func_sigs);
344  }
345 }
346 
348  rt_udf_functions_.clear();
349 }
350 
351 void ExtensionFunctionsWhitelist::addRTUdfs(const std::string& json_func_sigs) {
352  if (!json_func_sigs.empty()) {
353  addCommon(rt_udf_functions_, json_func_sigs);
354  }
355 }
356 
357 std::unordered_map<std::string, std::vector<ExtensionFunction>>
359 
360 std::unordered_map<std::string, std::vector<ExtensionFunction>>
362 
363 std::unordered_map<std::string, std::vector<ExtensionFunction>>
static void addUdfs(const std::string &json_func_sigs)
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
static std::unordered_map< std::string, std::vector< ExtensionFunction > > udf_functions_
static std::vector< ExtensionFunction > * get(const std::string &name)
#define LOG(tag)
Definition: Logger.h:182
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:44
static std::unordered_map< std::string, std::vector< ExtensionFunction > > rt_udf_functions_
std::string join(T const &container, std::string const &delim)
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:323
static void add(const std::string &json_func_sigs)
static std::vector< std::string > getLLVMDeclarations()
std::unordered_map< std::string, std::vector< ExtensionFunction > > SignatureMap
#define CHECK_GT(x, y)
Definition: Logger.h:199
static std::unordered_map< std::string, std::vector< ExtensionFunction > > functions_
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
static std::vector< ExtensionFunction > * get_udf(const std::string &name)
ExtArgumentType deserialize_type(const std::string &type_name)
static void addCommon(std::unordered_map< std::string, std::vector< ExtensionFunction >> &sigs, const std::string &json_func_sigs)
SQLTypeInfoCore< ArrayContextTypeSizer, ExecutorTypePackaging, DateTimeFacilities > SQLTypeInfo
Definition: sqltypes.h:823
std::string to_upper(const std::string &str)
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:187
Definition: sqltypes.h:47
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
static void addRTUdfs(const std::string &json_func_sigs)