OmniSciDB  06b3bd477c
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ExtensionFunctionsBinding.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 #include "ExternalExecutor.h"
19 
20 #include <algorithm>
21 
22 // A rather crude function binding logic based on the types of the arguments.
23 // We want it to be possible to write specialized versions of functions to be
24 // exposed as SQL extensions. This is important especially for performance
25 // reasons, since double operations can be significantly slower than float. We
26 // compute a score for each candidate signature based on conversions required to
27 // from the function arguments as specified in the SQL query to the versions in
28 // ExtensionFunctions.hpp.
29 
30 /*
31  New implementation for binding a SQL function operator to the
32  optimal candidate within in all available extension functions.
33  */
34 
35 namespace {
36 
38  switch (ext_arg_array_type) {
40  return ExtArgumentType::Int8;
52  return ExtArgumentType::Bool;
53  default:
54  UNREACHABLE();
55  }
56  return ExtArgumentType{};
57 }
58 
59 static int match_arguments(const SQLTypeInfo& arg_type,
60  int sig_pos,
61  const std::vector<ExtArgumentType>& sig_types,
62  int& penalty_score) {
63  /*
64  Returns non-negative integer `offset` if `arg_type` and
65  `sig_types[sig_pos:sig_pos + offset]` match.
66 
67  The `offset` value can be interpreted as the number of extension
68  function arguments that is consumed by the given `arg_type`. For
69  instance, for scalar types the offset is always 1, for array
70  types the offset is 2: one argument for array pointer value and
71  one argument for the array size value, etc.
72 
73  Returns -1 when the types of an argument and the corresponding
74  extension function argument(s) mismatch, or when downcasting would
75  be effective.
76 
77  In case of non-negative `offset` result, the function updates
78  penalty_score argument as follows:
79 
80  add 1000 if arg_type is non-scalar, otherwise:
81  add 1000 * sizeof(sig_type) / sizeof(arg_type)
82  add 1000000 if type kinds differ (integer vs double, for instance)
83 
84  */
85  auto stype = sig_types[sig_pos];
86  int max_pos = sig_types.size() - 1;
87 
88  switch (arg_type.get_type()) {
89  case kBOOLEAN:
90  if (stype == ExtArgumentType::Bool) {
91  penalty_score += 1000;
92  return 1;
93  }
94  break;
95  case kTINYINT:
96  switch (stype) {
98  penalty_score += 1000;
99  break;
101  penalty_score += 2000;
102  break;
104  penalty_score += 4000;
105  break;
107  penalty_score += 8000;
108  break;
110  penalty_score += 1008000;
111  break; // temporary: allow integers as double arguments
112  default:
113  return -1;
114  }
115  return 1;
116  case kSMALLINT:
117  switch (stype) {
119  penalty_score += 1000;
120  break;
122  penalty_score += 2000;
123  break;
125  penalty_score += 4000;
126  break;
128  penalty_score += 1004000;
129  break; // temporary: allow integers as double arguments
130  default:
131  return -1;
132  }
133  return 1;
134  case kINT:
135  switch (stype) {
137  penalty_score += 1000;
138  break;
140  penalty_score += 2000;
141  break;
143  penalty_score += 1002000;
144  break; // temporary: allow integers as double arguments
145  default:
146  return -1;
147  }
148  return 1;
149  case kBIGINT:
150  switch (stype) {
152  penalty_score += 1000;
153  break;
155  penalty_score += 1001000;
156  break; // temporary: allow integers as double arguments
157  default:
158  return -1;
159  }
160  return 1;
161  case kFLOAT:
162  switch (stype) {
164  penalty_score += 1000;
165  break;
167  penalty_score += 2000;
168  break; // is it ok to use floats as double arguments?
169  default:
170  return -1;
171  }
172  return 1;
173  case kDOUBLE:
174  if (stype == ExtArgumentType::Double) {
175  penalty_score += 1000;
176  return 1;
177  }
178  break;
179 
180  case kPOINT:
181  case kLINESTRING:
182  if ((stype == ExtArgumentType::PInt8 || stype == ExtArgumentType::PInt16 ||
183  stype == ExtArgumentType::PInt32 || stype == ExtArgumentType::PInt64 ||
184  stype == ExtArgumentType::PFloat || stype == ExtArgumentType::PDouble) &&
185  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
186  penalty_score += 1000;
187  return 2;
188  } else if (stype == ExtArgumentType::GeoPoint ||
190  penalty_score += 1000;
191  return 1;
192  }
193  break;
194  case kARRAY:
195  if ((stype == ExtArgumentType::PInt8 || stype == ExtArgumentType::PInt16 ||
196  stype == ExtArgumentType::PInt32 || stype == ExtArgumentType::PInt64 ||
197  stype == ExtArgumentType::PFloat || stype == ExtArgumentType::PDouble ||
198  stype == ExtArgumentType::PBool) &&
199  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
200  penalty_score += 1000;
201  return 2;
202  } else if (is_ext_arg_type_array(stype)) {
203  // array arguments must match exactly
204  CHECK(arg_type.is_array());
205  const auto stype_ti = ext_arg_type_to_type_info(get_array_arg_elem_type(stype));
206  if (arg_type.get_elem_type() == kBOOLEAN && stype_ti.get_type() == kTINYINT) {
207  /* Boolean array has the same low-level structure as Int8 array. */
208  penalty_score += 1000;
209  return 1;
210  } else if (arg_type.get_elem_type().get_type() == stype_ti.get_type()) {
211  penalty_score += 1000;
212  return 1;
213  } else {
214  return -1;
215  }
216  }
217  break;
218  case kPOLYGON:
219  if (stype == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
220  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
221  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
222  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
223  penalty_score += 1000;
224  return 4;
225  } else if (stype == ExtArgumentType::GeoPolygon) {
226  penalty_score += 1000;
227  return 1;
228  }
229 
230  break;
231 
232  case kMULTIPOLYGON:
233  if (stype == ExtArgumentType::PInt8 && sig_pos + 5 < max_pos &&
234  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
235  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
236  sig_types[sig_pos + 3] == ExtArgumentType::Int64 &&
237  sig_types[sig_pos + 4] == ExtArgumentType::PInt32 &&
238  sig_types[sig_pos + 5] == ExtArgumentType::Int64) {
239  penalty_score += 1000;
240  return 6;
241  } else if (stype == ExtArgumentType::GeoMultiPolygon) {
242  penalty_score += 1000;
243  return 1;
244  }
245  break;
246  case kDECIMAL:
247  case kNUMERIC:
248  if (stype == ExtArgumentType::Double && arg_type.get_logical_size() == 8) {
249  penalty_score += 1000;
250  return 1;
251  }
252  if (stype == ExtArgumentType::Float && arg_type.get_logical_size() == 4) {
253  penalty_score += 1000;
254  return 1;
255  }
256  break;
257  case kNULLT: // NULL maps to a pointer and size argument
258  if ((stype == ExtArgumentType::PInt8 || stype == ExtArgumentType::PInt16 ||
259  stype == ExtArgumentType::PInt32 || stype == ExtArgumentType::PInt64 ||
260  stype == ExtArgumentType::PFloat || stype == ExtArgumentType::PDouble ||
261  stype == ExtArgumentType::PBool) &&
262  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
263  penalty_score += 1000;
264  return 2;
265  }
266  break;
267  /* Not implemented types:
268  kCHAR
269  kVARCHAR
270  kTIME
271  kTIMESTAMP
272  kTEXT
273  kDATE
274  kINTERVAL_DAY_TIME
275  kINTERVAL_YEAR_MONTH
276  kGEOMETRY
277  kGEOGRAPHY
278  kEVAL_CONTEXT_TYPE
279  kVOID
280  kCURSOR
281  */
282  default:
283  throw std::runtime_error(std::string(__FILE__) + "#" + std::to_string(__LINE__) +
284  ": support for " + arg_type.get_type_name() +
285  "(type=" + std::to_string(arg_type.get_type()) + ")" +
286  +" not implemented: \n pos=" + std::to_string(sig_pos) +
287  " max_pos=" + std::to_string(max_pos) + "\n sig_types=(" +
288  ExtensionFunctionsWhitelist::toString(sig_types) + ")");
289  }
290  return -1;
291 }
292 
293 } // namespace
294 
297  const std::vector<ExtensionFunction>& ext_funcs) {
298  // worker function
299  /*
300  Return extension function that has the following properties
301 
302  1. each argument type in `arg_types` matches with extension
303  function argument types.
304 
305  For scalar types, the matching means that the types are either
306  equal or the argument type is smaller than the corresponding
307  the extension function argument type. This ensures that no
308  information is lost when casting of argument values is
309  required.
310 
311  For array and geo types, the matching means that the argument
312  type matches exactly with a group of extension function
313  argument types. See `match_arguments`.
314 
315  2. has minimal penalty score among all implementations of the
316  extension function with given `name`, see `get_penalty_score`
317  for the definition of penalty score.
318 
319  It is assumed that function_oper and extension functions in
320  ext_funcs have the same name.
321  */
322  int minimal_score = std::numeric_limits<int>::max();
323  int index = -1;
324  int optimal = -1;
325  for (auto ext_func : ext_funcs) {
326  index++;
327  auto ext_func_args = ext_func.getArgs();
328  /* In general, `arg_types.size() <= ext_func_args.size()` because
329  non-scalar arguments (such as arrays and geo-objects) are
330  mapped to multiple `ext_func` arguments. */
331  if (func_args.size() <= ext_func_args.size()) {
332  /* argument type must fit into the corresponding signature
333  argument type, reject signature if not */
334  int penalty_score = 0;
335  int pos = 0;
336  for (auto atype : func_args) {
337  int offset =
338  match_arguments(atype->get_type_info(), pos, ext_func_args, penalty_score);
339  if (offset < 0) {
340  // atype does not match with ext_func argument
341  pos = -1;
342  break;
343  }
344  pos += offset;
345  }
346  if (pos >= 0) {
347  // prefer smaller return types
348  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
349  if (penalty_score < minimal_score) {
350  optimal = index;
351  minimal_score = penalty_score;
352  }
353  }
354  }
355  }
356 
357  if (optimal == -1) {
358  /* no extension function found that argument types would match
359  with types in `arg_types` */
360  std::vector<SQLTypeInfo> arg_types;
361  for (size_t i = 0; i < func_args.size(); ++i) {
362  arg_types.push_back(func_args[i]->get_type_info());
363  }
364  auto sarg_types = ExtensionFunctionsWhitelist::toString(arg_types);
365  if (!ext_funcs.size()) {
366  throw NativeExecutionError("Function " + name + "(" + sarg_types +
367  ") not supported.");
368  }
369  auto choices = ExtensionFunctionsWhitelist::toString(ext_funcs, " ");
370  throw std::runtime_error(
371  "Function " + name + "(" + sarg_types +
372  ") not supported.\n Existing extension function implementations:\n" + choices);
373  }
374  return ext_funcs[optimal];
375 }
376 
378  Analyzer::ExpressionPtrVector func_args) {
379  // used in RelAlgTranslator.cpp
380  std::vector<ExtensionFunction> ext_funcs =
382  return bind_function(name, func_args, ext_funcs);
383 }
384 
386  // used in ExtensionIR.cpp
387  auto name = function_oper->getName();
388  Analyzer::ExpressionPtrVector func_args = {};
389  for (size_t i = 0; i < function_oper->getArity(); ++i) {
390  func_args.push_back(function_oper->getOwnArg(i));
391  }
392  return bind_function(name, func_args);
393 }
394 
395 bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type) {
396  switch (ext_arg_type) {
404  return true;
405 
406  default:
407  return false;
408  }
409 }
410 
411 bool is_ext_arg_type_geo(const ExtArgumentType ext_arg_type) {
412  switch (ext_arg_type) {
417  return true;
418 
419  default:
420  return false;
421  }
422 }
ExtArgumentType get_array_arg_elem_type(const ExtArgumentType ext_arg_array_type)
size_t getArity() const
Definition: Analyzer.h:1361
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
bool is_ext_arg_type_geo(const ExtArgumentType ext_arg_type)
ExtensionFunction bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< ExtensionFunction > &ext_funcs)
#define UNREACHABLE()
Definition: Logger.h:241
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:258
std::string to_string(char const *&&v)
CHECK(cgen_state)
int get_logical_size() const
Definition: sqltypes.h:269
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:1368
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
std::string get_type_name() const
Definition: sqltypes.h:361
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:183
static int match_arguments(const SQLTypeInfo &arg_type, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
Definition: sqltypes.h:46
std::string getName() const
Definition: Analyzer.h:1359
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:622
bool is_array() const
Definition: sqltypes.h:423
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)