OmniSciDB  1dac507f6e
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ExtensionFunctionsBinding.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 #include <algorithm>
19 
20 // A rather crude function binding logic based on the types of the arguments.
21 // We want it to be possible to write specialized versions of functions to be
22 // exposed as SQL extensions. This is important especially for performance
23 // reasons, since double operations can be significantly slower than float. We
24 // compute a score for each candidate signature based on conversions required to
25 // from the function arguments as specified in the SQL query to the versions in
26 // ExtensionFunctions.hpp.
27 
28 /*
29  New implementation for binding a SQL function operator to the
30  optimal candidate within in all available extension functions.
31  */
32 
33 namespace {
34 
35 static int match_arguments(const SQLTypeInfo& arg_type,
36  int sig_pos,
37  const std::vector<ExtArgumentType>& sig_types,
38  int& penalty_score) {
39  /*
40  Returns non-negative integer `offset` if `arg_type` and
41  `sig_types[sig_pos:sig_pos + offset]` match.
42 
43  The `offset` value can be interpreted as the number of extension
44  function arguments that is consumed by the given `arg_type`. For
45  instance, for scalar types the offset is always 1, for array
46  types the offset is 2: one argument for array pointer value and
47  one argument for the array size value, etc.
48 
49  Returns -1 when the types of an argument and the corresponding
50  extension function argument(s) mismatch, or when downcasting would
51  be effective.
52 
53  In case of non-negative `offset` result, the function updates
54  penalty_score argument as follows:
55 
56  add 1000 if arg_type is non-scalar, otherwise:
57  add 1000 * sizeof(sig_type) / sizeof(arg_type)
58  add 1000000 if type kinds differ (integer vs double, for instance)
59 
60  */
61  auto stype = sig_types[sig_pos];
62  int max_pos = sig_types.size() - 1;
63  switch (arg_type.get_type()) {
64  case kBOOLEAN:
65  if (stype == ExtArgumentType::Bool) {
66  penalty_score += 1000;
67  return 1;
68  }
69  break;
70  case kTINYINT:
71  switch (stype) {
73  penalty_score += 1000;
74  break;
76  penalty_score += 2000;
77  break;
79  penalty_score += 4000;
80  break;
82  penalty_score += 8000;
83  break;
85  penalty_score += 1008000;
86  break; // temporary: allow integers as double arguments
87  default:
88  return -1;
89  }
90  return 1;
91  case kSMALLINT:
92  switch (stype) {
94  penalty_score += 1000;
95  break;
97  penalty_score += 2000;
98  break;
100  penalty_score += 4000;
101  break;
103  penalty_score += 1004000;
104  break; // temporary: allow integers as double arguments
105  default:
106  return -1;
107  }
108  return 1;
109  case kINT:
110  switch (stype) {
112  penalty_score += 1000;
113  break;
115  penalty_score += 2000;
116  break;
118  penalty_score += 1002000;
119  break; // temporary: allow integers as double arguments
120  default:
121  return -1;
122  }
123  return 1;
124  case kBIGINT:
125  switch (stype) {
127  penalty_score += 1000;
128  break;
130  penalty_score += 1001000;
131  break; // temporary: allow integers as double arguments
132  default:
133  return -1;
134  }
135  return 1;
136  case kFLOAT:
137  switch (stype) {
139  penalty_score += 1000;
140  break;
142  penalty_score += 2000;
143  break; // is it ok to use floats as double arguments?
144  default:
145  return -1;
146  }
147  return 1;
148  case kDOUBLE:
149  if (stype == ExtArgumentType::Double) {
150  penalty_score += 1000;
151  return 1;
152  }
153  break;
154 
155  case kPOINT:
156  case kLINESTRING:
157  if ((stype == ExtArgumentType::PInt8 || stype == ExtArgumentType::PInt16 ||
158  stype == ExtArgumentType::PInt32 || stype == ExtArgumentType::PInt64 ||
159  stype == ExtArgumentType::PFloat || stype == ExtArgumentType::PDouble) &&
160  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
161  penalty_score += 1000;
162  return 2;
163  } else if (stype == ExtArgumentType::GeoPoint ||
165  penalty_score += 1000;
166  return 1;
167  }
168  break;
169  case kARRAY:
170  if ((stype == ExtArgumentType::PInt8 || stype == ExtArgumentType::PInt16 ||
171  stype == ExtArgumentType::PInt32 || stype == ExtArgumentType::PInt64 ||
172  stype == ExtArgumentType::PFloat || stype == ExtArgumentType::PDouble) &&
173  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
174  penalty_score += 1000;
175  return 2;
176  } else if (is_ext_arg_type_array(stype)) {
177  penalty_score += 1000;
178  return 1;
179  }
180  break;
181  case kPOLYGON:
182  if (stype == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
183  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
184  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
185  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
186  penalty_score += 1000;
187  return 4;
188  } else if (stype == ExtArgumentType::GeoPolygon) {
189  penalty_score += 1000;
190  return 1;
191  }
192 
193  break;
194 
195  case kMULTIPOLYGON:
196  if (stype == ExtArgumentType::PInt8 && sig_pos + 5 < max_pos &&
197  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
198  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
199  sig_types[sig_pos + 3] == ExtArgumentType::Int64 &&
200  sig_types[sig_pos + 4] == ExtArgumentType::PInt32 &&
201  sig_types[sig_pos + 5] == ExtArgumentType::Int64) {
202  penalty_score += 1000;
203  return 6;
204  }
205  break;
206  case kDECIMAL:
207  case kNUMERIC:
208  if (stype == ExtArgumentType::Double && arg_type.get_logical_size() == 8) {
209  penalty_score += 1000;
210  return 1;
211  }
212  if (stype == ExtArgumentType::Float && arg_type.get_logical_size() == 4) {
213  penalty_score += 1000;
214  return 1;
215  }
216  break;
217  case kNULLT: // NULL maps to a pointer and size argument
218  if ((stype == ExtArgumentType::PInt8 || stype == ExtArgumentType::PInt16 ||
219  stype == ExtArgumentType::PInt32 || stype == ExtArgumentType::PInt64 ||
220  stype == ExtArgumentType::PFloat || stype == ExtArgumentType::PDouble) &&
221  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
222  penalty_score += 1000;
223  return 2;
224  }
225  break;
226  /* Not implemented types:
227  kCHAR
228  kVARCHAR
229  kTIME
230  kTIMESTAMP
231  kTEXT
232  kDATE
233  kINTERVAL_DAY_TIME
234  kINTERVAL_YEAR_MONTH
235  kGEOMETRY
236  kGEOGRAPHY
237  kEVAL_CONTEXT_TYPE
238  kVOID
239  kCURSOR
240  */
241  default:
242  throw std::runtime_error(std::string(__FILE__) + "#" + std::to_string(__LINE__) +
243  ": support for " + arg_type.get_type_name() +
244  "(type=" + std::to_string(arg_type.get_type()) + ")" +
245  +" not implemented: \n pos=" + std::to_string(sig_pos) +
246  " max_pos=" + std::to_string(max_pos) + "\n sig_types=(" +
247  ExtensionFunctionsWhitelist::toString(sig_types) + ")");
248  }
249  return -1;
250 }
251 
252 } // namespace
253 
256  const std::vector<ExtensionFunction>& ext_funcs) {
257  // worker function
258  /*
259  Return extension function that has the following properties
260 
261  1. each argument type in `arg_types` matches with extension
262  function argument types.
263 
264  For scalar types, the matching means that the types are either
265  equal or the argument type is smaller than the corresponding
266  the extension function argument type. This ensures that no
267  information is lost when casting of argument values is
268  required.
269 
270  For array and geo types, the matching means that the argument
271  type matches exactly with a group of extension function
272  argument types. See `match_arguments`.
273 
274  2. has minimal penalty score among all implementations of the
275  extension function with given `name`, see `get_penalty_score`
276  for the definition of penalty score.
277 
278  It is assumed that function_oper and extension functions in
279  ext_funcs have the same name.
280  */
281  int minimal_score = std::numeric_limits<int>::max();
282  int index = -1;
283  int optimal = -1;
284  for (auto ext_func : ext_funcs) {
285  index++;
286  auto ext_func_args = ext_func.getArgs();
287  /* In general, `arg_types.size() <= ext_func_args.size()` because
288  non-scalar arguments (such as arrays and geo-objects) are
289  mapped to multiple `ext_func` arguments. */
290  if (func_args.size() <= ext_func_args.size()) {
291  /* argument type must fit into the corresponding signature
292  argument type, reject signature if not */
293  int penalty_score = 0;
294  int pos = 0;
295  for (auto atype : func_args) {
296  int offset =
297  match_arguments(atype->get_type_info(), pos, ext_func_args, penalty_score);
298  if (offset < 0) {
299  // atype does not match with ext_func argument
300  pos = -1;
301  break;
302  }
303  pos += offset;
304  }
305  if (pos >= 0) {
306  // prefer smaller return types
307  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
308  if (penalty_score < minimal_score) {
309  optimal = index;
310  minimal_score = penalty_score;
311  }
312  }
313  }
314  }
315 
316  if (optimal == -1) {
317  /* no extension function found that argument types would match
318  with types in `arg_types` */
319  std::vector<SQLTypeInfo> arg_types;
320  for (size_t i = 0; i < func_args.size(); ++i) {
321  arg_types.push_back(func_args[i]->get_type_info());
322  }
323  auto sarg_types = ExtensionFunctionsWhitelist::toString(arg_types);
324  if (!ext_funcs.size()) {
325  throw std::runtime_error("Function " + name + "(" + sarg_types +
326  ") not supported.");
327  }
328  auto choices = ExtensionFunctionsWhitelist::toString(ext_funcs, " ");
329  throw std::runtime_error(
330  "Function " + name + "(" + sarg_types +
331  ") not supported.\n Existing extension function implementations:\n" + choices);
332  }
333  return ext_funcs[optimal];
334 }
335 
337  Analyzer::ExpressionPtrVector func_args) {
338  // used in RelAlgTranslator.cpp
339  std::vector<ExtensionFunction> ext_funcs =
341  return bind_function(name, func_args, ext_funcs);
342 }
343 
345  // used in ExtensionIR.cpp
346  auto name = function_oper->getName();
347  Analyzer::ExpressionPtrVector func_args = {};
348  for (size_t i = 0; i < function_oper->getArity(); ++i) {
349  func_args.push_back(function_oper->getOwnArg(i));
350  }
351  return bind_function(name, func_args);
352 }
353 
354 bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type) {
355  switch (ext_arg_type) {
362  return true;
363 
364  default:
365  return false;
366  }
367 }
368 
369 bool is_ext_arg_type_geo(const ExtArgumentType ext_arg_type) {
370  switch (ext_arg_type) {
374  return true;
375 
376  default:
377  return false;
378  }
379 }
size_t getArity() const
Definition: Analyzer.h:1309
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
bool is_ext_arg_type_geo(const ExtArgumentType ext_arg_type)
ExtensionFunction bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< ExtensionFunction > &ext_funcs)
std::string to_string(char const *&&v)
std::string get_type_name() const
Definition: sqltypes.h:429
int get_logical_size() const
Definition: sqltypes.h:337
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:1316
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:326
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:182
static int match_arguments(const SQLTypeInfo &arg_type, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
Definition: sqltypes.h:48
std::string getName() const
Definition: Analyzer.h:1307
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)