OmniSciDB  c07336695a
ExtensionFunctionsBinding.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 #include <algorithm>
19 
20 // A rather crude function binding logic based on the types of the arguments.
21 // We want it to be possible to write specialized versions of functions to be
22 // exposed as SQL extensions. This is important especially for performance
23 // reasons, since double operations can be significantly slower than float. We
24 // compute a score for each candidate signature based on conversions required to
25 // from the function arguments as specified in the SQL query to the versions in
26 // ExtensionFunctions.hpp.
27 
28 /*
29  New implementation for binding a SQL function operator to the
30  optimal candidate within in all available extension functions.
31  */
32 
33 namespace {
34 
35 static int match_arguments(const SQLTypeInfo& arg_type,
36  int sig_pos,
37  const std::vector<ExtArgumentType>& sig_types,
38  int& penalty_score) {
39  /*
40  Returns non-negative integer `offset` if `arg_type` and
41  `sig_types[sig_pos:sig_pos + offset]` match.
42 
43  The `offset` value can be interpreted as the number of extension
44  function arguments that is consumed by the given `arg_type`. For
45  instance, for scalar types the offset is always 1, for array
46  types the offset is 2: one argument for array pointer value and
47  one argument for the array size value, etc.
48 
49  Returns -1 when the types of an argument and the corresponding
50  extension function argument(s) mismatch, or when downcasting would
51  be effective.
52 
53  In case of non-negative `offset` result, the function updates
54  penalty_score argument as follows:
55 
56  add 1000 if arg_type is non-scalar, otherwise:
57  add 1000 * sizeof(sig_type) / sizeof(arg_type)
58  add 1000000 if type kinds differ (integer vs double, for instance)
59 
60  */
61  auto stype = sig_types[sig_pos];
62  int max_pos = sig_types.size() - 1;
63  switch (arg_type.get_type()) {
64  case kBOOLEAN:
65  if (stype == ExtArgumentType::Bool) {
66  penalty_score += 1000;
67  return 1;
68  }
69  break;
70  case kTINYINT:
71  switch (stype) {
73  penalty_score += 1000;
74  break;
76  penalty_score += 2000;
77  break;
79  penalty_score += 4000;
80  break;
82  penalty_score += 8000;
83  break;
85  penalty_score += 1008000;
86  break; // temporary: allow integers as double arguments
87  default:
88  return -1;
89  }
90  return 1;
91  case kSMALLINT:
92  switch (stype) {
94  penalty_score += 1000;
95  break;
97  penalty_score += 2000;
98  break;
100  penalty_score += 4000;
101  break;
103  penalty_score += 1004000;
104  break; // temporary: allow integers as double arguments
105  default:
106  return -1;
107  }
108  return 1;
109  case kINT:
110  switch (stype) {
112  penalty_score += 1000;
113  break;
115  penalty_score += 2000;
116  break;
118  penalty_score += 1002000;
119  break; // temporary: allow integers as double arguments
120  default:
121  return -1;
122  }
123  return 1;
124  case kBIGINT:
125  switch (stype) {
127  penalty_score += 1000;
128  break;
130  penalty_score += 1001000;
131  break; // temporary: allow integers as double arguments
132  default:
133  return -1;
134  }
135  return 1;
136  case kFLOAT:
137  switch (stype) {
139  penalty_score += 1000;
140  break;
142  penalty_score += 2000;
143  break; // is it ok to use floats as double arguments?
144  default:
145  return -1;
146  }
147  return 1;
148  case kDOUBLE:
149  if (stype == ExtArgumentType::Double) {
150  penalty_score += 1000;
151  return 1;
152  }
153  break;
154  case kLINESTRING:
155  case kPOINT:
156  case kARRAY:
157  if ((stype == ExtArgumentType::PInt8 || stype == ExtArgumentType::PInt16 ||
158  stype == ExtArgumentType::PInt32 || stype == ExtArgumentType::PInt64 ||
159  stype == ExtArgumentType::PFloat || stype == ExtArgumentType::PDouble) &&
160  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
161  penalty_score += 1000;
162  return 2;
163  }
164  break;
165  case kPOLYGON:
166  if (stype == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
167  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
168  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
169  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
170  penalty_score += 1000;
171  return 4;
172  }
173  break;
174  case kMULTIPOLYGON:
175  if (stype == ExtArgumentType::PInt8 && sig_pos + 5 < max_pos &&
176  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
177  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
178  sig_types[sig_pos + 3] == ExtArgumentType::Int64 &&
179  sig_types[sig_pos + 4] == ExtArgumentType::PInt32 &&
180  sig_types[sig_pos + 5] == ExtArgumentType::Int64) {
181  penalty_score += 1000;
182  return 6;
183  }
184  break;
185  case kDECIMAL:
186  case kNUMERIC:
187  if (stype == ExtArgumentType::Double && arg_type.get_logical_size() == 8) {
188  penalty_score += 1000;
189  return 1;
190  }
191  if (stype == ExtArgumentType::Float && arg_type.get_logical_size() == 4) {
192  penalty_score += 1000;
193  return 1;
194  }
195  break;
196  case kNULLT: // NULL maps to a pointer and size argument
197  if ((stype == ExtArgumentType::PInt8 || stype == ExtArgumentType::PInt16 ||
198  stype == ExtArgumentType::PInt32 || stype == ExtArgumentType::PInt64 ||
199  stype == ExtArgumentType::PFloat || stype == ExtArgumentType::PDouble) &&
200  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
201  penalty_score += 1000;
202  return 2;
203  }
204  break;
205  /* Not implemented types:
206  kCHAR
207  kVARCHAR
208  kTIME
209  kTIMESTAMP
210  kTEXT
211  kDATE
212  kINTERVAL_DAY_TIME
213  kINTERVAL_YEAR_MONTH
214  kGEOMETRY
215  kGEOGRAPHY
216  kEVAL_CONTEXT_TYPE
217  */
218  default:
219  throw std::runtime_error(std::string(__FILE__) + "#" + std::to_string(__LINE__) +
220  ": support for " + arg_type.get_type_name() +
221  "(type=" + std::to_string(arg_type.get_type()) + ")" +
222  +" not implemented: \n pos=" + std::to_string(sig_pos) +
223  " max_pos=" + std::to_string(max_pos) + "\n sig_types=(" +
224  ExtensionFunctionsWhitelist::toString(sig_types) + ")");
225  }
226  return -1;
227 }
228 
229 } // namespace
230 
233  const std::vector<ExtensionFunction>& ext_funcs) {
234  // worker function
235  /*
236  Return extension function that has the following properties
237 
238  1. each argument type in `arg_types` matches with extension
239  function argument types.
240 
241  For scalar types, the matching means that the types are either
242  equal or the argument type is smaller than the corresponding
243  the extension function argument type. This ensures that no
244  information is lost when casting of argument values is
245  required.
246 
247  For array and geo types, the matching means that the argument
248  type matches exactly with a group of extension function
249  argument types. See `match_arguments`.
250 
251  2. has minimal penalty score among all implementations of the
252  extension function with given `name`, see `get_penalty_score`
253  for the definition of penalty score.
254 
255  It is assumed that function_oper and extension functions in
256  ext_funcs have the same name.
257  */
258  int minimal_score = std::numeric_limits<int>::max();
259  int index = -1;
260  int optimal = -1;
261  for (auto ext_func : ext_funcs) {
262  index++;
263  auto ext_func_args = ext_func.getArgs();
264  /* In general, `arg_types.size() <= ext_func_args.size()` because
265  non-scalar arguments (such as arrays and geo-objects) are
266  mapped to multiple `ext_func` arguments. */
267  if (func_args.size() <= ext_func_args.size()) {
268  /* argument type must fit into the corresponding signature
269  argument type, reject signature if not */
270  int penalty_score = 0;
271  int pos = 0;
272  for (auto atype : func_args) {
273  int offset =
274  match_arguments(atype->get_type_info(), pos, ext_func_args, penalty_score);
275  if (offset < 0) {
276  // atype does not match with ext_func argument
277  pos = -1;
278  break;
279  }
280  pos += offset;
281  }
282  if (pos >= 0) {
283  // prefer smaller return types
284  penalty_score += ext_arg_type_to_type_info(ext_func.getRet()).get_logical_size();
285  if (penalty_score < minimal_score) {
286  optimal = index;
287  minimal_score = penalty_score;
288  }
289  }
290  }
291  }
292 
293  if (optimal == -1) {
294  /* no extension function found that argument types would match
295  with types in `arg_types` */
296  std::vector<SQLTypeInfo> arg_types;
297  for (size_t i = 0; i < func_args.size(); ++i) {
298  arg_types.push_back(func_args[i]->get_type_info());
299  }
300  auto sarg_types = ExtensionFunctionsWhitelist::toString(arg_types);
301  if (!ext_funcs.size()) {
302  throw std::runtime_error("Function " + name + "(" + sarg_types +
303  ") not supported.");
304  }
305  auto choices = ExtensionFunctionsWhitelist::toString(ext_funcs, " ");
306  throw std::runtime_error(
307  "Function " + name + "(" + sarg_types +
308  ") not supported.\n Existing extension function implementations:\n" + choices);
309  }
310  return ext_funcs[optimal];
311 }
312 
314  Analyzer::ExpressionPtrVector func_args) {
315  // used in RelAlgTranslator.cpp
316  std::vector<ExtensionFunction> ext_funcs =
318  return bind_function(name, func_args, ext_funcs);
319 }
320 
322  // used in ExtensionIR.cpp
323  auto name = function_oper->getName();
324  Analyzer::ExpressionPtrVector func_args = {};
325  for (size_t i = 0; i < function_oper->getArity(); ++i) {
326  func_args.push_back(function_oper->getOwnArg(i));
327  }
328  return bind_function(name, func_args);
329 }
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
Definition: Analyzer.h:1259
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
std::string getName() const
Definition: Analyzer.h:1250
ExtensionFunction bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< ExtensionFunction > &ext_funcs)
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:319
size_t getArity() const
Definition: Analyzer.h:1252
std::string to_string(char const *&&v)
std::string get_type_name() const
Definition: sqltypes.h:422
int get_logical_size() const
Definition: sqltypes.h:330
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
std::vector< ExpressionPtr > ExpressionPtrVector
Definition: Analyzer.h:181
static int match_arguments(const SQLTypeInfo &arg_type, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
Definition: sqltypes.h:47
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)