OmniSciDB  bf83d84833
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ExtensionFunctionsBinding.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 #include <algorithm>
19 #include "ExternalExecutor.h"
20 
21 // A rather crude function binding logic based on the types of the arguments.
22 // We want it to be possible to write specialized versions of functions to be
23 // exposed as SQL extensions. This is important especially for performance
24 // reasons, since double operations can be significantly slower than float. We
25 // compute a score for each candidate signature based on conversions required to
26 // from the function arguments as specified in the SQL query to the versions in
27 // ExtensionFunctions.hpp.
28 
29 /*
30  New implementation for binding a SQL function operator to the
31  optimal candidate within in all available extension functions.
32  */
33 
34 namespace {
35 
37  switch (ext_arg_column_type) {
39  return ExtArgumentType::Int8;
51  return ExtArgumentType::Bool;
52  default:
53  UNREACHABLE();
54  }
55  return ExtArgumentType{};
56 }
57 
59  switch (ext_arg_array_type) {
61  return ExtArgumentType::Int8;
73  return ExtArgumentType::Bool;
74  default:
75  UNREACHABLE();
76  }
77  return ExtArgumentType{};
78 }
79 
80 static int match_arguments(const SQLTypeInfo& arg_type,
81  int sig_pos,
82  const std::vector<ExtArgumentType>& sig_types,
83  int& penalty_score) {
84  /*
85  Returns non-negative integer `offset` if `arg_type` and
86  `sig_types[sig_pos:sig_pos + offset]` match.
87 
88  The `offset` value can be interpreted as the number of extension
89  function arguments that is consumed by the given `arg_type`. For
90  instance, for scalar types the offset is always 1, for array
91  types the offset is 2: one argument for array pointer value and
92  one argument for the array size value, etc.
93 
94  Returns -1 when the types of an argument and the corresponding
95  extension function argument(s) mismatch, or when downcasting would
96  be effective.
97 
98  In case of non-negative `offset` result, the function updates
99  penalty_score argument as follows:
100 
101  add 1000 if arg_type is non-scalar, otherwise:
102  add 1000 * sizeof(sig_type) / sizeof(arg_type)
103  add 1000000 if type kinds differ (integer vs double, for instance)
104 
105  */
106  auto stype = sig_types[sig_pos];
107  int max_pos = sig_types.size() - 1;
108  switch (arg_type.get_type()) {
109  case kBOOLEAN:
110  if (stype == ExtArgumentType::Bool) {
111  penalty_score += 1000;
112  return 1;
113  }
114  break;
115  case kTINYINT:
116  switch (stype) {
118  penalty_score += 1000;
119  break;
121  penalty_score += 2000;
122  break;
124  penalty_score += 4000;
125  break;
127  penalty_score += 8000;
128  break;
130  penalty_score += 1008000;
131  break; // temporary: allow integers as double arguments
132  default:
133  return -1;
134  }
135  return 1;
136  case kSMALLINT:
137  switch (stype) {
139  penalty_score += 1000;
140  break;
142  penalty_score += 2000;
143  break;
145  penalty_score += 4000;
146  break;
148  penalty_score += 1004000;
149  break; // temporary: allow integers as double arguments
150  default:
151  return -1;
152  }
153  return 1;
154  case kINT:
155  switch (stype) {
157  penalty_score += 1000;
158  break;
160  penalty_score += 2000;
161  break;
163  penalty_score += 1002000;
164  break; // temporary: allow integers as double arguments
165  default:
166  return -1;
167  }
168  return 1;
169  case kBIGINT:
170  switch (stype) {
172  penalty_score += 1000;
173  break;
175  penalty_score += 1001000;
176  break; // temporary: allow integers as double arguments
177  default:
178  return -1;
179  }
180  return 1;
181  case kFLOAT:
182  switch (stype) {
184  penalty_score += 1000;
185  break;
187  penalty_score += 2000;
188  break; // is it ok to use floats as double arguments?
189  default:
190  return -1;
191  }
192  return 1;
193  case kDOUBLE:
194  if (stype == ExtArgumentType::Double) {
195  penalty_score += 1000;
196  return 1;
197  }
198  break;
199 
200  case kPOINT:
201  case kLINESTRING:
202  if ((stype == ExtArgumentType::PInt8 || stype == ExtArgumentType::PInt16 ||
203  stype == ExtArgumentType::PInt32 || stype == ExtArgumentType::PInt64 ||
204  stype == ExtArgumentType::PFloat || stype == ExtArgumentType::PDouble) &&
205  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
206  penalty_score += 1000;
207  return 2;
208  } else if (stype == ExtArgumentType::GeoPoint ||
210  penalty_score += 1000;
211  return 1;
212  }
213  break;
214  case kARRAY:
215  if ((stype == ExtArgumentType::PInt8 || stype == ExtArgumentType::PInt16 ||
216  stype == ExtArgumentType::PInt32 || stype == ExtArgumentType::PInt64 ||
217  stype == ExtArgumentType::PFloat || stype == ExtArgumentType::PDouble ||
218  stype == ExtArgumentType::PBool) &&
219  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
220  penalty_score += 1000;
221  return 2;
222  } else if (is_ext_arg_type_array(stype)) {
223  // array arguments must match exactly
224  CHECK(arg_type.is_array());
225  const auto stype_ti = ext_arg_type_to_type_info(get_array_arg_elem_type(stype));
226  if (arg_type.get_elem_type() == kBOOLEAN && stype_ti.get_type() == kTINYINT) {
227  /* Boolean array has the same low-level structure as Int8 array. */
228  penalty_score += 1000;
229  return 1;
230  } else if (arg_type.get_elem_type().get_type() == stype_ti.get_type()) {
231  penalty_score += 1000;
232  return 1;
233  } else {
234  return -1;
235  }
236  }
237  break;
238  case kPOLYGON:
239  if (stype == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
240  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
241  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
242  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
243  penalty_score += 1000;
244  return 4;
245  } else if (stype == ExtArgumentType::GeoPolygon) {
246  penalty_score += 1000;
247  return 1;
248  }
249 
250  break;
251 
252  case kMULTIPOLYGON:
253  if (stype == ExtArgumentType::PInt8 && sig_pos + 5 < max_pos &&
254  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
255  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
256  sig_types[sig_pos + 3] == ExtArgumentType::Int64 &&
257  sig_types[sig_pos + 4] == ExtArgumentType::PInt32 &&
258  sig_types[sig_pos + 5] == ExtArgumentType::Int64) {
259  penalty_score += 1000;
260  return 6;
261  } else if (stype == ExtArgumentType::GeoMultiPolygon) {
262  penalty_score += 1000;
263  return 1;
264  }
265  break;
266  case kDECIMAL:
267  case kNUMERIC:
268  if (stype == ExtArgumentType::Double && arg_type.get_logical_size() == 8) {
269  penalty_score += 1000;
270  return 1;
271  }
272  if (stype == ExtArgumentType::Float && arg_type.get_logical_size() == 4) {
273  penalty_score += 1000;
274  return 1;
275  }
276  break;
277  case kNULLT: // NULL maps to a pointer and size argument
278  if ((stype == ExtArgumentType::PInt8 || stype == ExtArgumentType::PInt16 ||
279  stype == ExtArgumentType::PInt32 || stype == ExtArgumentType::PInt64 ||
280  stype == ExtArgumentType::PFloat || stype == ExtArgumentType::PDouble ||
281  stype == ExtArgumentType::PBool) &&
282  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
283  penalty_score += 1000;
284  return 2;
285  }
286  break;
287  case kCOLUMN:
288  if (is_ext_arg_type_column(stype)) {
289  // column arguments must match exactly
290  const auto stype_ti = ext_arg_type_to_type_info(get_column_arg_elem_type(stype));
291  if (arg_type.get_elem_type() == kBOOLEAN && stype_ti.get_type() == kTINYINT) {
292  /* Boolean column has the same low-level structure as Int8 column. */
293  penalty_score += 1000;
294  return 1;
295  } else if (arg_type.get_elem_type().get_type() == stype_ti.get_type()) {
296  penalty_score += 1000;
297  return 1;
298  } else {
299  return -1;
300  }
301  }
302  break;
303  case kTEXT:
304  switch (arg_type.get_compression()) {
305  case kENCODING_NONE:
306  if (stype == ExtArgumentType::TextEncodingNone) {
307  penalty_score += 1000;
308  return 1;
309  }
310  return -1;
311  case kENCODING_DICT:
313  penalty_score += 1000;
314  return 1;
315  }
316  default:;
317  // todo: dict(8) and dict(16) encodings
318  }
319  /* Not implemented types:
320  kCHAR
321  kVARCHAR
322  kTIME
323  kTIMESTAMP
324  kDATE
325  kINTERVAL_DAY_TIME
326  kINTERVAL_YEAR_MONTH
327  kGEOMETRY
328  kGEOGRAPHY
329  kEVAL_CONTEXT_TYPE
330  kVOID
331  kCURSOR
332  */
333  default:
334  throw std::runtime_error(std::string(__FILE__) + "#" + std::to_string(__LINE__) +
335  ": support for " + arg_type.get_type_name() +
336  "(type=" + std::to_string(arg_type.get_type()) + ")" +
337  +" not implemented: \n pos=" + std::to_string(sig_pos) +
338  " max_pos=" + std::to_string(max_pos) + "\n sig_types=(" +
339  ExtensionFunctionsWhitelist::toString(sig_types) + ")");
340  }
341  return -1;
342 }
343 
344 bool is_valid_identifier(std::string str) {
345  if (!str.size()) {
346  return false;
347  }
348 
349  if (!(std::isalpha(str[0]) || str[0] == '_')) {
350  return false;
351  }
352 
353  for (size_t i = 1; i < str.size(); i++) {
354  if (!(std::isalnum(str[i]) || str[i] == '_')) {
355  return false;
356  }
357  }
358 
359  return true;
360 }
361 
362 } // namespace
363 
364 template <typename T>
365 T bind_function(std::string name,
367  const std::vector<T>& ext_funcs,
368  const std::string processor) {
369  /* worker function
370 
371  Template type T must implement the following methods:
372 
373  std::vector<ExtArgumentType> getInputArgs()
374  */
375  /*
376  Return extension function/table function that has the following
377  properties
378 
379  1. each argument type in `arg_types` matches with extension
380  function argument types.
381 
382  For scalar types, the matching means that the types are either
383  equal or the argument type is smaller than the corresponding
384  the extension function argument type. This ensures that no
385  information is lost when casting of argument values is
386  required.
387 
388  For array and geo types, the matching means that the argument
389  type matches exactly with a group of extension function
390  argument types. See `match_arguments`.
391 
392  2. has minimal penalty score among all implementations of the
393  extension function with given `name`, see `get_penalty_score`
394  for the definition of penalty score.
395 
396  It is assumed that function_oper and extension functions in
397  ext_funcs have the same name.
398  */
399 
400  if (!is_valid_identifier(name)) {
401  throw NativeExecutionError(
402  "Cannot bind function with invalid UDF/UDTF function name: " + name);
403  }
404 
405  int minimal_score = std::numeric_limits<int>::max();
406  int index = -1;
407  int optimal = -1;
408 
409  std::vector<SQLTypeInfo> type_infos;
410  for (auto atype : func_args) {
411  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
412  if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
413  auto ti = SQLTypeInfo(
414  kCOLUMN, 0, 0, false, kENCODING_NONE, 0, atype->get_type_info().get_type());
415  type_infos.push_back(ti);
416  continue;
417  }
418  }
419  type_infos.push_back(atype->get_type_info());
420  }
421 
422  for (auto ext_func : ext_funcs) {
423  index++;
424  auto ext_func_args = ext_func.getInputArgs();
425  /* In general, `arg_types.size() <= ext_func_args.size()` because
426  non-scalar arguments (such as arrays and geo-objects) are
427  mapped to multiple `ext_func` arguments. */
428  if (func_args.size() <= ext_func_args.size()) {
429  /* argument type must fit into the corresponding signature
430  argument type, reject signature if not */
431  int penalty_score = 0;
432  int pos = 0;
433  for (auto ti : type_infos) {
434  int offset = match_arguments(ti, pos, ext_func_args, penalty_score);
435  if (offset < 0) {
436  // atype does not match with ext_func argument
437  pos = -1;
438  break;
439  }
440  pos += offset;
441  }
442  if (pos >= 0) {
443  // prefer smaller return types
444  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
445  if (penalty_score < minimal_score) {
446  optimal = index;
447  minimal_score = penalty_score;
448  }
449  }
450  }
451  }
452 
453  if (optimal == -1) {
454  /* no extension function found that argument types would match
455  with types in `arg_types` */
456  auto sarg_types = ExtensionFunctionsWhitelist::toString(type_infos);
457  std::string message;
458  if (!ext_funcs.size()) {
459  message = "Function " + name + "(" + sarg_types + ") not supported.";
460  throw ExtensionFunctionBindingError(message);
461  } else {
462  if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
463  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
464  " UDTF implementation.";
465  } else if constexpr (std::is_same_v<T, ExtensionFunction>) {
466  message = "Could not bind " + name + "(" + sarg_types + ") to any " + processor +
467  " UDF implementation.";
468  } else {
469  LOG(FATAL) << "bind_function: unknown extension function type "
470  << typeid(T).name();
471  }
472  message += "\n Existing extension function implementations:";
473  for (const auto& ext_func : ext_funcs) {
474  message += "\n " + ext_func.toStringSQL();
475  }
476  }
477  throw ExtensionFunctionBindingError(message);
478  }
479  return ext_funcs[optimal];
480 }
481 
483  std::string name,
485  const std::vector<table_functions::TableFunction>& table_funcs,
486  const bool is_gpu) {
487  std::string processor = (is_gpu ? "GPU" : "CPU");
488  return bind_function<table_functions::TableFunction>(
489  name, input_args, table_funcs, processor);
490 }
491 
493  Analyzer::ExpressionPtrVector func_args) {
494  // used in RelAlgTranslator.cpp, first try GPU UDFs, then fall back
495  // to CPU UDFs.
496  bool is_gpu = true;
497  std::string processor = "GPU";
498  auto ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
499  if (!ext_funcs.size()) {
500  is_gpu = false;
501  processor = "CPU";
502  ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
503  }
504  try {
505  return bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor);
506  } catch (ExtensionFunctionBindingError& e) {
507  if (is_gpu) {
508  is_gpu = false;
509  processor = "GPU|CPU";
510  ext_funcs = ExtensionFunctionsWhitelist::get_ext_funcs(name, is_gpu);
511  return bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor);
512  } else {
513  throw;
514  }
515  }
516 }
517 
520  const bool is_gpu) {
521  // used below
522  std::vector<ExtensionFunction> ext_funcs =
524  std::string processor = (is_gpu ? "GPU" : "CPU");
525  return bind_function<ExtensionFunction>(name, func_args, ext_funcs, processor);
526 }
527 
529  const bool is_gpu) {
530  // used in ExtensionsIR.cpp
531  auto name = function_oper->getName();
532  Analyzer::ExpressionPtrVector func_args = {};
533  for (size_t i = 0; i < function_oper->getArity(); ++i) {
534  func_args.push_back(function_oper->getOwnArg(i));
535  }
536  return bind_function(name, func_args, is_gpu);
537 }
538 
540  std::string name,
542  const bool is_gpu) {
543  // used in RelAlgExecutor.cpp
544  std::vector<table_functions::TableFunction> table_funcs =
546  return bind_table_function(name, input_args, table_funcs, is_gpu);
547 }
548 
549 bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type) {
550  switch (ext_arg_type) {
558  return true;
559 
560  default:
561  return false;
562  }
563 }
564 
565 bool is_ext_arg_type_column(const ExtArgumentType ext_arg_type) {
566  switch (ext_arg_type) {
574  return true;
575 
576  default:
577  return false;
578  }
579 }
580 
581 bool is_ext_arg_type_geo(const ExtArgumentType ext_arg_type) {
582  switch (ext_arg_type) {
587  return true;
588 
589  default:
590  return false;
591  }
592 }
593 
594 bool is_ext_arg_type_pointer(const ExtArgumentType ext_arg_type) {
595  switch (ext_arg_type) {
603  return true;
604 
605  default:
606  return false;
607  }
608 }
609 
610 bool is_ext_arg_type_scalar(const ExtArgumentType ext_arg_type) {
611  switch (ext_arg_type) {
619  return true;
620 
621  default:
622  return false;
623  }
624 }
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name, const bool is_gpu)
ExtArgumentType get_array_arg_elem_type(const ExtArgumentType ext_arg_array_type)
size_t getArity() const
Definition: Analyzer.h:1360
bool is_ext_arg_type_geo(const ExtArgumentType ext_arg_type)
#define LOG(tag)
Definition: Logger.h:188
#define UNREACHABLE()
Definition: Logger.h:241
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:311
std::string to_string(char const *&&v)
int get_logical_size() const
Definition: sqltypes.h:322
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:1367
ExtArgumentType get_column_arg_elem_type(const ExtArgumentType ext_arg_column_type)
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
Definition: sqltypes.h:51
const table_functions::TableFunction bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:319
bool is_ext_arg_type_scalar(const ExtArgumentType ext_arg_type)
std::string get_type_name() const
Definition: sqltypes.h:414
T bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:197
static std::vector< TableFunction > get_table_funcs(const std::string &name, const bool is_gpu)
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:182
static int match_arguments(const SQLTypeInfo &arg_type, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
Definition: sqltypes.h:44
std::string getName() const
Definition: Analyzer.h:1358
string name
Definition: setup.py:35
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:697
bool is_array() const
Definition: sqltypes.h:486
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
bool is_ext_arg_type_column(const ExtArgumentType ext_arg_type)
bool is_ext_arg_type_pointer(const ExtArgumentType ext_arg_type)