OmniSciDB  dfae7c3b14
ExtensionFunctionsBinding.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 #include <algorithm>
19 #include "ExternalExecutor.h"
20 
21 // A rather crude function binding logic based on the types of the arguments.
22 // We want it to be possible to write specialized versions of functions to be
23 // exposed as SQL extensions. This is important especially for performance
24 // reasons, since double operations can be significantly slower than float. We
25 // compute a score for each candidate signature based on conversions required to
26 // from the function arguments as specified in the SQL query to the versions in
27 // ExtensionFunctions.hpp.
28 
29 /*
30  New implementation for binding a SQL function operator to the
31  optimal candidate within in all available extension functions.
32  */
33 
34 namespace {
35 
37  switch (ext_arg_column_type) {
39  return ExtArgumentType::Int8;
51  return ExtArgumentType::Bool;
52  default:
53  UNREACHABLE();
54  }
55  return ExtArgumentType{};
56 }
57 
59  switch (ext_arg_array_type) {
61  return ExtArgumentType::Int8;
73  return ExtArgumentType::Bool;
74  default:
75  UNREACHABLE();
76  }
77  return ExtArgumentType{};
78 }
79 
80 static int match_arguments(const SQLTypeInfo& arg_type,
81  int sig_pos,
82  const std::vector<ExtArgumentType>& sig_types,
83  int& penalty_score) {
84  /*
85  Returns non-negative integer `offset` if `arg_type` and
86  `sig_types[sig_pos:sig_pos + offset]` match.
87 
88  The `offset` value can be interpreted as the number of extension
89  function arguments that is consumed by the given `arg_type`. For
90  instance, for scalar types the offset is always 1, for array
91  types the offset is 2: one argument for array pointer value and
92  one argument for the array size value, etc.
93 
94  Returns -1 when the types of an argument and the corresponding
95  extension function argument(s) mismatch, or when downcasting would
96  be effective.
97 
98  In case of non-negative `offset` result, the function updates
99  penalty_score argument as follows:
100 
101  add 1000 if arg_type is non-scalar, otherwise:
102  add 1000 * sizeof(sig_type) / sizeof(arg_type)
103  add 1000000 if type kinds differ (integer vs double, for instance)
104 
105  */
106  auto stype = sig_types[sig_pos];
107  int max_pos = sig_types.size() - 1;
108  switch (arg_type.get_type()) {
109  case kBOOLEAN:
110  if (stype == ExtArgumentType::Bool) {
111  penalty_score += 1000;
112  return 1;
113  }
114  break;
115  case kTINYINT:
116  switch (stype) {
118  penalty_score += 1000;
119  break;
121  penalty_score += 2000;
122  break;
124  penalty_score += 4000;
125  break;
127  penalty_score += 8000;
128  break;
130  penalty_score += 1008000;
131  break; // temporary: allow integers as double arguments
132  default:
133  return -1;
134  }
135  return 1;
136  case kSMALLINT:
137  switch (stype) {
139  penalty_score += 1000;
140  break;
142  penalty_score += 2000;
143  break;
145  penalty_score += 4000;
146  break;
148  penalty_score += 1004000;
149  break; // temporary: allow integers as double arguments
150  default:
151  return -1;
152  }
153  return 1;
154  case kINT:
155  switch (stype) {
157  penalty_score += 1000;
158  break;
160  penalty_score += 2000;
161  break;
163  penalty_score += 1002000;
164  break; // temporary: allow integers as double arguments
165  default:
166  return -1;
167  }
168  return 1;
169  case kBIGINT:
170  switch (stype) {
172  penalty_score += 1000;
173  break;
175  penalty_score += 1001000;
176  break; // temporary: allow integers as double arguments
177  default:
178  return -1;
179  }
180  return 1;
181  case kFLOAT:
182  switch (stype) {
184  penalty_score += 1000;
185  break;
187  penalty_score += 2000;
188  break; // is it ok to use floats as double arguments?
189  default:
190  return -1;
191  }
192  return 1;
193  case kDOUBLE:
194  if (stype == ExtArgumentType::Double) {
195  penalty_score += 1000;
196  return 1;
197  }
198  break;
199 
200  case kPOINT:
201  case kLINESTRING:
202  if ((stype == ExtArgumentType::PInt8 || stype == ExtArgumentType::PInt16 ||
203  stype == ExtArgumentType::PInt32 || stype == ExtArgumentType::PInt64 ||
204  stype == ExtArgumentType::PFloat || stype == ExtArgumentType::PDouble) &&
205  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
206  penalty_score += 1000;
207  return 2;
208  } else if (stype == ExtArgumentType::GeoPoint ||
210  penalty_score += 1000;
211  return 1;
212  }
213  break;
214  case kARRAY:
215  if ((stype == ExtArgumentType::PInt8 || stype == ExtArgumentType::PInt16 ||
216  stype == ExtArgumentType::PInt32 || stype == ExtArgumentType::PInt64 ||
217  stype == ExtArgumentType::PFloat || stype == ExtArgumentType::PDouble ||
218  stype == ExtArgumentType::PBool) &&
219  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
220  penalty_score += 1000;
221  return 2;
222  } else if (is_ext_arg_type_array(stype)) {
223  // array arguments must match exactly
224  CHECK(arg_type.is_array());
225  const auto stype_ti = ext_arg_type_to_type_info(get_array_arg_elem_type(stype));
226  if (arg_type.get_elem_type() == kBOOLEAN && stype_ti.get_type() == kTINYINT) {
227  /* Boolean array has the same low-level structure as Int8 array. */
228  penalty_score += 1000;
229  return 1;
230  } else if (arg_type.get_elem_type().get_type() == stype_ti.get_type()) {
231  penalty_score += 1000;
232  return 1;
233  } else {
234  return -1;
235  }
236  }
237  break;
238  case kPOLYGON:
239  if (stype == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
240  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
241  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
242  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
243  penalty_score += 1000;
244  return 4;
245  } else if (stype == ExtArgumentType::GeoPolygon) {
246  penalty_score += 1000;
247  return 1;
248  }
249 
250  break;
251 
252  case kMULTIPOLYGON:
253  if (stype == ExtArgumentType::PInt8 && sig_pos + 5 < max_pos &&
254  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
255  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
256  sig_types[sig_pos + 3] == ExtArgumentType::Int64 &&
257  sig_types[sig_pos + 4] == ExtArgumentType::PInt32 &&
258  sig_types[sig_pos + 5] == ExtArgumentType::Int64) {
259  penalty_score += 1000;
260  return 6;
261  } else if (stype == ExtArgumentType::GeoMultiPolygon) {
262  penalty_score += 1000;
263  return 1;
264  }
265  break;
266  case kDECIMAL:
267  case kNUMERIC:
268  if (stype == ExtArgumentType::Double && arg_type.get_logical_size() == 8) {
269  penalty_score += 1000;
270  return 1;
271  }
272  if (stype == ExtArgumentType::Float && arg_type.get_logical_size() == 4) {
273  penalty_score += 1000;
274  return 1;
275  }
276  break;
277  case kNULLT: // NULL maps to a pointer and size argument
278  if ((stype == ExtArgumentType::PInt8 || stype == ExtArgumentType::PInt16 ||
279  stype == ExtArgumentType::PInt32 || stype == ExtArgumentType::PInt64 ||
280  stype == ExtArgumentType::PFloat || stype == ExtArgumentType::PDouble ||
281  stype == ExtArgumentType::PBool) &&
282  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
283  penalty_score += 1000;
284  return 2;
285  }
286  break;
287  case kCOLUMN:
288  if (is_ext_arg_type_column(stype)) {
289  // column arguments must match exactly
290  CHECK(arg_type.is_column());
291  const auto stype_ti = ext_arg_type_to_type_info(get_column_arg_elem_type(stype));
292  if (arg_type.get_elem_type() == kBOOLEAN && stype_ti.get_type() == kTINYINT) {
293  /* Boolean column has the same low-level structure as Int8 column. */
294  penalty_score += 1000;
295  return 1;
296  } else if (arg_type.get_elem_type().get_type() == stype_ti.get_type()) {
297  penalty_score += 1000;
298  return 1;
299  } else {
300  return -1;
301  }
302  }
303  break;
304  /* Not implemented types:
305  kCHAR
306  kVARCHAR
307  kTIME
308  kTIMESTAMP
309  kTEXT
310  kDATE
311  kINTERVAL_DAY_TIME
312  kINTERVAL_YEAR_MONTH
313  kGEOMETRY
314  kGEOGRAPHY
315  kEVAL_CONTEXT_TYPE
316  kVOID
317  kCURSOR
318  */
319  default:
320  throw std::runtime_error(std::string(__FILE__) + "#" + std::to_string(__LINE__) +
321  ": support for " + arg_type.get_type_name() +
322  "(type=" + std::to_string(arg_type.get_type()) + ")" +
323  +" not implemented: \n pos=" + std::to_string(sig_pos) +
324  " max_pos=" + std::to_string(max_pos) + "\n sig_types=(" +
325  ExtensionFunctionsWhitelist::toString(sig_types) + ")");
326  }
327  return -1;
328 }
329 
330 } // namespace
331 
332 template <typename T>
333 T bind_function(std::string name,
335  const std::vector<T>& ext_funcs) {
336  /* worker function
337 
338  Template type T must implement the following methods:
339 
340  std::vector<ExtArgumentType> getInputArgs()
341  */
342  /*
343  Return extension function/table function that has the following
344  properties
345 
346  1. each argument type in `arg_types` matches with extension
347  function argument types.
348 
349  For scalar types, the matching means that the types are either
350  equal or the argument type is smaller than the corresponding
351  the extension function argument type. This ensures that no
352  information is lost when casting of argument values is
353  required.
354 
355  For array and geo types, the matching means that the argument
356  type matches exactly with a group of extension function
357  argument types. See `match_arguments`.
358 
359  2. has minimal penalty score among all implementations of the
360  extension function with given `name`, see `get_penalty_score`
361  for the definition of penalty score.
362 
363  It is assumed that function_oper and extension functions in
364  ext_funcs have the same name.
365  */
366  int minimal_score = std::numeric_limits<int>::max();
367  int index = -1;
368  int optimal = -1;
369 
370  std::vector<SQLTypeInfo> type_infos;
371  for (auto atype : func_args) {
372  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
373  if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
374  auto ti = SQLTypeInfo(kCOLUMN, false);
375  ti.set_subtype(atype->get_type_info().get_type());
376  type_infos.push_back(ti);
377  continue;
378  }
379  }
380  type_infos.push_back(atype->get_type_info());
381  }
382  for (auto ext_func : ext_funcs) {
383  index++;
384  auto ext_func_args = ext_func.getInputArgs();
385  /* In general, `arg_types.size() <= ext_func_args.size()` because
386  non-scalar arguments (such as arrays and geo-objects) are
387  mapped to multiple `ext_func` arguments. */
388  if (func_args.size() <= ext_func_args.size()) {
389  /* argument type must fit into the corresponding signature
390  argument type, reject signature if not */
391  int penalty_score = 0;
392  int pos = 0;
393  for (auto ti : type_infos) {
394  int offset = match_arguments(ti, pos, ext_func_args, penalty_score);
395  if (offset < 0) {
396  // atype does not match with ext_func argument
397  pos = -1;
398  break;
399  }
400  pos += offset;
401  }
402  if (pos >= 0) {
403  // prefer smaller return types
404  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
405  if (penalty_score < minimal_score) {
406  optimal = index;
407  minimal_score = penalty_score;
408  }
409  }
410  }
411  }
412 
413  if (optimal == -1) {
414  /* no extension function found that argument types would match
415  with types in `arg_types` */
416  auto sarg_types = ExtensionFunctionsWhitelist::toString(type_infos);
417  if (!ext_funcs.size()) {
418  throw NativeExecutionError("Function " + name + "(" + sarg_types +
419  ") not supported.");
420  }
421  std::string choices;
422  for (const auto& ext_func : ext_funcs) {
423  choices += "\n " + ext_func.toStringSQL();
424  }
425  throw std::runtime_error(
426  "Function " + name + "(" + sarg_types +
427  ") not supported.\n Existing extension function implementations:" + choices);
428  }
429  return ext_funcs[optimal];
430 }
431 
433  std::string name,
435  const std::vector<table_functions::TableFunction>& table_funcs) {
436  return bind_function<table_functions::TableFunction>(name, input_args, table_funcs);
437 }
438 
440  Analyzer::ExpressionPtrVector func_args) {
441  // used in RelAlgTranslator.cpp
442  std::vector<ExtensionFunction> ext_funcs =
444  return bind_function<ExtensionFunction>(name, func_args, ext_funcs);
445 }
446 
448  // used in ExtensionsIR.cpp
449  auto name = function_oper->getName();
450  Analyzer::ExpressionPtrVector func_args = {};
451  for (size_t i = 0; i < function_oper->getArity(); ++i) {
452  func_args.push_back(function_oper->getOwnArg(i));
453  }
454  return bind_function(name, func_args);
455 }
456 
458  std::string name,
460  const bool is_gpu) {
461  // used in RelAlgExecutor.cpp
462  std::vector<table_functions::TableFunction> table_funcs =
464  return bind_table_function(name, input_args, table_funcs);
465 }
466 
467 bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type) {
468  switch (ext_arg_type) {
476  return true;
477 
478  default:
479  return false;
480  }
481 }
482 
483 bool is_ext_arg_type_column(const ExtArgumentType ext_arg_type) {
484  switch (ext_arg_type) {
492  return true;
493 
494  default:
495  return false;
496  }
497 }
498 
499 bool is_ext_arg_type_geo(const ExtArgumentType ext_arg_type) {
500  switch (ext_arg_type) {
505  return true;
506 
507  default:
508  return false;
509  }
510 }
ExtArgumentType get_array_arg_elem_type(const ExtArgumentType ext_arg_array_type)
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:1367
bool is_array() const
Definition: sqltypes.h:425
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
std::string getName() const
Definition: Analyzer.h:1358
bool is_ext_arg_type_geo(const ExtArgumentType ext_arg_type)
T bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs)
#define UNREACHABLE()
Definition: Logger.h:241
size_t getArity() const
Definition: Analyzer.h:1360
name
Definition: setup.py:35
std::string to_string(char const *&&v)
int get_logical_size() const
Definition: sqltypes.h:270
bool is_column() const
Definition: sqltypes.h:430
const table_functions::TableFunction bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs)
ExtArgumentType get_column_arg_elem_type(const ExtArgumentType ext_arg_column_type)
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
std::string get_type_name() const
Definition: sqltypes.h:362
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:624
#define CHECK(condition)
Definition: Logger.h:197
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:259
static std::vector< TableFunction > get_table_funcs(const std::string &name, const bool is_gpu)
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:182
static int match_arguments(const SQLTypeInfo &arg_type, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
Definition: sqltypes.h:47
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
bool is_ext_arg_type_column(const ExtArgumentType ext_arg_type)