OmniSciDB  c1a53651b2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
anonymous_namespace{ExtensionFunctionsBinding.cpp} Namespace Reference

Functions

ExtArgumentType get_column_arg_elem_type (const ExtArgumentType ext_arg_column_type)
 
ExtArgumentType get_column_list_arg_elem_type (const ExtArgumentType ext_arg_column_list_type)
 
ExtArgumentType get_array_arg_elem_type (const ExtArgumentType ext_arg_array_type)
 
static int match_numeric_argument (const SQLTypeInfo &arg_type_info, const bool is_arg_literal, const ExtArgumentType &sig_ext_arg_type, int32_t &penalty_score)
 
static int match_arguments (const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
 
bool is_valid_identifier (std::string str)
 

Function Documentation

ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_array_arg_elem_type ( const ExtArgumentType  ext_arg_array_type)

Definition at line 119 of file ExtensionFunctionsBinding.cpp.

References ArrayBool, ArrayDouble, ArrayFloat, ArrayInt16, ArrayInt32, ArrayInt64, ArrayInt8, ArrayTextEncodingDict, Bool, Double, Float, Int16, Int32, Int64, Int8, TextEncodingDict, and UNREACHABLE.

Referenced by match_arguments().

+ Here is the caller graph for this function:

ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_column_arg_elem_type ( const ExtArgumentType  ext_arg_column_type)

Definition at line 36 of file ExtensionFunctionsBinding.cpp.

References ArrayBool, ArrayDouble, ArrayFloat, ArrayInt16, ArrayInt32, ArrayInt64, ArrayInt8, ArrayTextEncodingDict, Bool, ColumnArrayBool, ColumnArrayDouble, ColumnArrayFloat, ColumnArrayInt16, ColumnArrayInt32, ColumnArrayInt64, ColumnArrayInt8, ColumnArrayTextEncodingDict, ColumnBool, ColumnDouble, ColumnFloat, ColumnInt16, ColumnInt32, ColumnInt64, ColumnInt8, ColumnTextEncodingDict, ColumnTimestamp, Double, Float, Int16, Int32, Int64, Int8, TextEncodingDict, Timestamp, and UNREACHABLE.

Referenced by match_arguments().

36  {
37  switch (ext_arg_column_type) {
39  return ExtArgumentType::Int8;
51  return ExtArgumentType::Bool;
72  default:
73  UNREACHABLE();
74  }
75  return ExtArgumentType{};
76 }
#define UNREACHABLE()
Definition: Logger.h:337

+ Here is the caller graph for this function:

ExtArgumentType anonymous_namespace{ExtensionFunctionsBinding.cpp}::get_column_list_arg_elem_type ( const ExtArgumentType  ext_arg_column_list_type)

Definition at line 78 of file ExtensionFunctionsBinding.cpp.

References ArrayBool, ArrayDouble, ArrayFloat, ArrayInt16, ArrayInt32, ArrayInt64, ArrayInt8, ArrayTextEncodingDict, Bool, ColumnListArrayBool, ColumnListArrayDouble, ColumnListArrayFloat, ColumnListArrayInt16, ColumnListArrayInt32, ColumnListArrayInt64, ColumnListArrayInt8, ColumnListArrayTextEncodingDict, ColumnListBool, ColumnListDouble, ColumnListFloat, ColumnListInt16, ColumnListInt32, ColumnListInt64, ColumnListInt8, ColumnListTextEncodingDict, Double, Float, Int16, Int32, Int64, Int8, TextEncodingDict, and UNREACHABLE.

Referenced by match_arguments().

79  {
80  switch (ext_arg_column_list_type) {
82  return ExtArgumentType::Int8;
94  return ExtArgumentType::Bool;
113  default:
114  UNREACHABLE();
115  }
116  return ExtArgumentType{};
117 }
#define UNREACHABLE()
Definition: Logger.h:337

+ Here is the caller graph for this function:

bool anonymous_namespace{ExtensionFunctionsBinding.cpp}::is_valid_identifier ( std::string  str)

Definition at line 494 of file ExtensionFunctionsBinding.cpp.

Referenced by bind_function().

494  {
495  if (!str.size()) {
496  return false;
497  }
498 
499  if (!(std::isalpha(str[0]) || str[0] == '_')) {
500  return false;
501  }
502 
503  for (size_t i = 1; i < str.size(); i++) {
504  if (!(std::isalnum(str[i]) || str[i] == '_')) {
505  return false;
506  }
507  }
508 
509  return true;
510 }

+ Here is the caller graph for this function:

static int anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_arguments ( const SQLTypeInfo arg_type,
const bool  is_arg_literal,
int  sig_pos,
const std::vector< ExtArgumentType > &  sig_types,
int &  penalty_score 
)
static

Definition at line 235 of file ExtensionFunctionsBinding.cpp.

References CHECK, DayTimeInterval, ext_arg_type_to_type_info(), GeoLineString, GeoMultiLineString, GeoMultiPoint, GeoMultiPolygon, GeoPoint, GeoPolygon, get_array_arg_elem_type(), get_column_arg_elem_type(), get_column_list_arg_elem_type(), SQLTypeInfo::get_compression(), SQLTypeInfo::get_elem_type(), SQLTypeInfo::get_type(), SQLTypeInfo::get_type_name(), Int64, SQLTypeInfo::is_array(), is_ext_arg_type_array(), is_ext_arg_type_column(), is_ext_arg_type_column_list(), kARRAY, kBIGINT, kBOOLEAN, kCOLUMN, kCOLUMN_LIST, kDECIMAL, kDOUBLE, kENCODING_DICT, kENCODING_NONE, kFLOAT, kINT, kINTERVAL_DAY_TIME, kINTERVAL_YEAR_MONTH, kLINESTRING, kMULTILINESTRING, kMULTIPOINT, kMULTIPOLYGON, kNULLT, kNUMERIC, kPOINT, kPOLYGON, kSMALLINT, kTEXT, kTIMESTAMP, kTINYINT, kVARCHAR, match_numeric_argument(), PBool, PDouble, PFloat, PInt16, PInt32, PInt64, PInt8, TextEncodingDict, TextEncodingNone, Timestamp, to_string(), ExtensionFunctionsWhitelist::toString(), UNREACHABLE, and YearMonthTimeInterval.

Referenced by bind_function().

239  {
240  /*
241  Returns non-negative integer `offset` if `arg_type` and
242  `sig_types[sig_pos:sig_pos + offset]` match.
243 
244  The `offset` value can be interpreted as the number of extension
245  function arguments that is consumed by the given `arg_type`. For
246  instance, for scalar types the offset is always 1, for array
247  types the offset is 2: one argument for array pointer value and
248  one argument for the array size value, etc.
249 
250  Returns -1 when the types of an argument and the corresponding
251  extension function argument(s) mismatch, or when downcasting would
252  be effective.
253 
254  In case of non-negative `offset` result, the function updates
255  penalty_score argument as follows:
256 
257  add 1000 if arg_type is non-scalar, otherwise:
258  add 1000 * sizeof(sig_type) / sizeof(arg_type)
259  add 1000000 if type kinds differ (integer vs double, for instance)
260 
261  */
262  int max_pos = sig_types.size() - 1;
263  if (sig_pos > max_pos) {
264  return -1;
265  }
266  auto sig_type = sig_types[sig_pos];
267  switch (arg_type.get_type()) {
268  case kBOOLEAN:
269  case kTINYINT:
270  case kSMALLINT:
271  case kINT:
272  case kBIGINT:
273  case kFLOAT:
274  case kDOUBLE:
275  case kDECIMAL:
276  case kNUMERIC:
277  return match_numeric_argument(arg_type, is_arg_literal, sig_type, penalty_score);
278  case kPOINT:
279  case kMULTIPOINT:
280  case kLINESTRING:
281  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
282  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
283  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble) &&
284  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
285  penalty_score += 1000;
286  return 2;
287  } else if (sig_type == ExtArgumentType::GeoPoint ||
288  sig_type == ExtArgumentType::GeoMultiPoint ||
289  sig_type == ExtArgumentType::GeoLineString) {
290  penalty_score += 1000;
291  return 1;
292  }
293  return -1;
294  case kMULTILINESTRING:
295  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
296  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
297  sig_types[sig_pos + 2] == ExtArgumentType::PInt8 &&
298  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
299  penalty_score += 1000;
300  return 4;
301  } else if (sig_type == ExtArgumentType::GeoMultiLineString) {
302  penalty_score += 1000;
303  return 1;
304  }
305  break;
306  case kARRAY:
307  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
308  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
309  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble ||
310  sig_type == ExtArgumentType::PBool) &&
311  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
312  penalty_score += 1000;
313  return 2;
314  } else if (is_ext_arg_type_array(sig_type)) {
315  // array arguments must match exactly
316  CHECK(arg_type.is_array());
317  const auto sig_type_ti =
319  if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
320  sig_type_ti.get_type() == kTINYINT) {
321  /* Boolean array has the same low-level structure as Int8 array. */
322  penalty_score += 1000;
323  return 1;
324  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
325  penalty_score += 1000;
326  return 1;
327  } else {
328  return -1;
329  }
330  }
331  break;
332  case kPOLYGON:
333  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 3 < max_pos &&
334  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
335  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
336  sig_types[sig_pos + 3] == ExtArgumentType::Int64) {
337  penalty_score += 1000;
338  return 4;
339  } else if (sig_type == ExtArgumentType::GeoPolygon) {
340  penalty_score += 1000;
341  return 1;
342  }
343  break;
344  case kMULTIPOLYGON:
345  if (sig_type == ExtArgumentType::PInt8 && sig_pos + 5 < max_pos &&
346  sig_types[sig_pos + 1] == ExtArgumentType::Int64 &&
347  sig_types[sig_pos + 2] == ExtArgumentType::PInt32 &&
348  sig_types[sig_pos + 3] == ExtArgumentType::Int64 &&
349  sig_types[sig_pos + 4] == ExtArgumentType::PInt32 &&
350  sig_types[sig_pos + 5] == ExtArgumentType::Int64) {
351  penalty_score += 1000;
352  return 6;
353  } else if (sig_type == ExtArgumentType::GeoMultiPolygon) {
354  penalty_score += 1000;
355  return 1;
356  }
357  break;
358  case kNULLT: // NULL maps to a pointer and size argument
359  if ((sig_type == ExtArgumentType::PInt8 || sig_type == ExtArgumentType::PInt16 ||
360  sig_type == ExtArgumentType::PInt32 || sig_type == ExtArgumentType::PInt64 ||
361  sig_type == ExtArgumentType::PFloat || sig_type == ExtArgumentType::PDouble ||
362  sig_type == ExtArgumentType::PBool) &&
363  sig_pos < max_pos && sig_types[sig_pos + 1] == ExtArgumentType::Int64) {
364  penalty_score += 1000;
365  return 2;
366  }
367  break;
368  case kCOLUMN:
369  if (is_ext_arg_type_column(sig_type)) {
370  // column arguments must match exactly
371  const auto sig_type_ti =
373  if (arg_type.get_elem_type().get_type() == kARRAY &&
374  sig_type_ti.get_type() == kARRAY) {
375  if (arg_type.get_elem_type().get_elem_type().get_type() ==
376  sig_type_ti.get_elem_type().get_type()) {
377  penalty_score += 1000;
378  return 1;
379  } else {
380  return -1;
381  }
382  } else if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
383  sig_type_ti.get_type() == kTINYINT) {
384  /* Boolean column has the same low-level structure as Int8 column. */
385  penalty_score += 1000;
386  return 1;
387  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
388  penalty_score += 1000;
389  return 1;
390  } else {
391  return -1;
392  }
393  }
394  break;
395  case kCOLUMN_LIST:
396  if (is_ext_arg_type_column_list(sig_type)) {
397  // column_list arguments must match exactly
398  const auto sig_type_ti =
400  if (arg_type.get_elem_type().get_type() == kARRAY &&
401  sig_type_ti.get_type() == kARRAY) {
402  if (arg_type.get_elem_type().get_elem_type().get_type() ==
403  sig_type_ti.get_elem_type().get_type()) {
404  penalty_score += 1000;
405  return 1;
406  } else {
407  return -1;
408  }
409  } else if (arg_type.get_elem_type().get_type() == kBOOLEAN &&
410  sig_type_ti.get_type() == kTINYINT) {
411  /* Boolean column_list has the same low-level structure as Int8 column_list. */
412  penalty_score += 10000;
413  return 1;
414  } else if (arg_type.get_elem_type().get_type() == sig_type_ti.get_type()) {
415  penalty_score += 10000;
416  return 1;
417  } else {
418  return -1;
419  }
420  }
421  break;
422  case kVARCHAR:
423  if (sig_type != ExtArgumentType::TextEncodingNone) {
424  return -1;
425  }
426  switch (arg_type.get_compression()) {
427  case kENCODING_NONE:
428  penalty_score += 1000;
429  return 1;
430  case kENCODING_DICT:
431  return -1;
432  // Todo (todd): Evaluate when and where we can tranlate to dictionary-encoded
433  default:
434  UNREACHABLE();
435  }
436  case kTEXT:
437  switch (arg_type.get_compression()) {
438  case kENCODING_NONE:
439  if (sig_type == ExtArgumentType::TextEncodingNone) {
440  penalty_score += 1000;
441  return 1;
442  }
443  return -1;
444  case kENCODING_DICT:
445  if (sig_type == ExtArgumentType::TextEncodingDict) {
446  penalty_score += 1000;
447  return 1;
448  }
449  return -1;
450  default:
451  UNREACHABLE();
452  }
453  case kTIMESTAMP:
454  if (sig_type == ExtArgumentType::Timestamp) {
455  penalty_score += 1000;
456  return 1;
457  }
458  break;
459  case kINTERVAL_DAY_TIME:
460  if (sig_type == ExtArgumentType::DayTimeInterval) {
461  penalty_score += 1000;
462  return 1;
463  }
464  break;
465 
467  if (sig_type == ExtArgumentType::YearMonthTimeInterval) {
468  penalty_score += 1000;
469  return 1;
470  }
471  break;
472 
473  /* Not implemented types:
474  kCHAR
475  kTIME
476  kDATE
477  kGEOMETRY
478  kGEOGRAPHY
479  kEVAL_CONTEXT_TYPE
480  kVOID
481  kCURSOR
482  */
483  default:
484  throw std::runtime_error(std::string(__FILE__) + "#" + std::to_string(__LINE__) +
485  ": support for " + arg_type.get_type_name() +
486  "(type=" + std::to_string(arg_type.get_type()) + ")" +
487  +" not implemented: \n pos=" + std::to_string(sig_pos) +
488  " max_pos=" + std::to_string(max_pos) + "\n sig_types=(" +
489  ExtensionFunctionsWhitelist::toString(sig_types) + ")");
490  }
491  return -1;
492 }
ExtArgumentType get_array_arg_elem_type(const ExtArgumentType ext_arg_array_type)
bool is_ext_arg_type_column(const ExtArgumentType ext_arg_type)
#define UNREACHABLE()
Definition: Logger.h:337
ExtArgumentType get_column_list_arg_elem_type(const ExtArgumentType ext_arg_column_list_type)
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:381
std::string to_string(char const *&&v)
bool is_ext_arg_type_column_list(const ExtArgumentType ext_arg_type)
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
ExtArgumentType get_column_arg_elem_type(const ExtArgumentType ext_arg_column_type)
static int match_numeric_argument(const SQLTypeInfo &arg_type_info, const bool is_arg_literal, const ExtArgumentType &sig_ext_arg_type, int32_t &penalty_score)
Definition: sqltypes.h:69
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:389
std::string get_type_name() const
Definition: sqltypes.h:507
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:291
Definition: sqltypes.h:62
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:963
bool is_array() const
Definition: sqltypes.h:588
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

static int anonymous_namespace{ExtensionFunctionsBinding.cpp}::match_numeric_argument ( const SQLTypeInfo arg_type_info,
const bool  is_arg_literal,
const ExtArgumentType sig_ext_arg_type,
int32_t &  penalty_score 
)
static

Definition at line 143 of file ExtensionFunctionsBinding.cpp.

References CHECK, CHECK_GE, CHECK_GT, CHECK_LE, ext_arg_type_to_type_info(), SQLTypeInfo::get_numeric_scalar_scale(), SQLTypeInfo::get_type(), SQLTypeInfo::is_numeric_scalar_auto_castable(), kBIGINT, kBOOLEAN, kDECIMAL, kDOUBLE, kFLOAT, kINT, kNUMERIC, kSMALLINT, and kTINYINT.

Referenced by match_arguments().

146  {
147  const auto arg_type = arg_type_info.get_type();
148  CHECK(arg_type == kBOOLEAN || arg_type == kTINYINT || arg_type == kSMALLINT ||
149  arg_type == kINT || arg_type == kBIGINT || arg_type == kFLOAT ||
150  arg_type == kDOUBLE || arg_type == kDECIMAL || arg_type == kNUMERIC);
151  // Todo (todd): Add support for timestamp, date, and time types
152  const auto sig_type_info = ext_arg_type_to_type_info(sig_ext_arg_type);
153  const auto sig_type = sig_type_info.get_type();
154 
155  // If we can't legally auto-cast to sig_type, abort
156  if (!arg_type_info.is_numeric_scalar_auto_castable(sig_type_info)) {
157  return -1;
158  }
159 
160  // We now compare a measure of the scale of the sig_type with the
161  // arg_type, which provides a basis for scoring the match between
162  // the two. Note that get_numeric_scalar_scale for the most part
163  // returns the logical byte width of the type, with a few caveats
164  // for decimals and timestamps described in more depth in comments
165  // in the function itself. Also even though for example float and
166  // int types return 4 (as in 4 bytes), and double and bigint types
167  // return 8, a fp32 type cannot express every 32-bit integer (even
168  // if it can cover a larger absolute range), and an fp64 type
169  // likewise cannot express every 64-bit integer. With the aim to
170  // minimize the precision loss from casting (always precise) integer
171  // value to (imprecise) floating point value, in the case of integer
172  // inputs, we'll penalize wider floating point argument types least
173  // by a specific scale transformation (see the implementation
174  // below). For instance, casting tinyint to fp64 is prefered over
175  // casting it to fp32 to minimize precision loss.
176  const bool is_integer_to_fp_cast = (arg_type == kTINYINT || arg_type == kSMALLINT ||
177  arg_type == kINT || arg_type == kBIGINT) &&
178  (sig_type == kFLOAT || sig_type == kDOUBLE);
179 
180  const auto arg_type_relative_scale = arg_type_info.get_numeric_scalar_scale();
181  CHECK_GE(arg_type_relative_scale, 1);
182  CHECK_LE(arg_type_relative_scale, 8);
183  auto sig_type_relative_scale = sig_type_info.get_numeric_scalar_scale();
184  CHECK_GE(sig_type_relative_scale, 1);
185  CHECK_LE(sig_type_relative_scale, 8);
186 
187  if (is_integer_to_fp_cast) {
188  // transform fp scale: 4 becomes 16, 8 remains 8
189  sig_type_relative_scale = (3 - (sig_type_relative_scale >> 2)) << 3;
190  }
191 
192  // We do not allow auto-casting to types with less scale/precision
193  // within the same type family.
194  CHECK_GE(sig_type_relative_scale, arg_type_relative_scale);
195 
196  // Calculate the ratio of the sig_type by the arg_type, per the above check will be >= 1
197  const auto sig_type_scale_gain_ratio =
198  sig_type_relative_scale / arg_type_relative_scale;
199  CHECK_GE(sig_type_scale_gain_ratio, 1);
200 
201  // Following the old bespoke scoring logic this function replaces, we heavily penalize
202  // any casts that move ints to floats/doubles for the precision-loss reasons above
203  // Arguably all integers in the tinyint and smallint can be fully specified with both
204  // float and double types, but we treat them the same as int and bigint types here.
205  const int32_t type_family_cast_penalty_score = is_integer_to_fp_cast ? 1001000 : 1000;
206 
207  int32_t scale_cast_penalty_score;
208 
209  // The following logic is new. Basically there are strong reasons to
210  // prefer the promotion of constant literals to the most precise type possible, as
211  // rather than the type being inherent in the data - that is a column or columns where
212  // a user specified a type (and with any expressions on those columns following our
213  // standard sql casting logic), literal types are given to us by Calcite and do not
214  // necessarily convey any semantic intent (i.e. 10 will be an int, but 10.0 a decimal)
215  // Hence it is better to promote these types to the most precise sig_type available,
216  // while at the same time keeping column expressions as close as possible to the input
217  // types (mainly for performance, we have many float versions of various functions
218  // to allow for greater performance when the underlying data is not of double precision,
219  // and hence there is little benefit of the extra cost of computing double precision
220  // operators on this data)
221  if (is_arg_literal) {
222  scale_cast_penalty_score =
223  (8000 / arg_type_relative_scale) - (1000 * sig_type_scale_gain_ratio);
224  } else {
225  scale_cast_penalty_score = (1000 * sig_type_scale_gain_ratio);
226  }
227 
228  const auto cast_penalty_score =
229  type_family_cast_penalty_score + scale_cast_penalty_score;
230  CHECK_GT(cast_penalty_score, 0);
231  penalty_score += cast_penalty_score;
232  return 1;
233 }
#define CHECK_GE(x, y)
Definition: Logger.h:306
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:381
#define CHECK_GT(x, y)
Definition: Logger.h:305
bool is_numeric_scalar_auto_castable(const SQLTypeInfo &new_type_info) const
returns true if the sql_type can be cast to the type specified by new_type_info with no loss of preci...
Definition: sqltypes.h:749
#define CHECK_LE(x, y)
Definition: Logger.h:304
int32_t get_numeric_scalar_scale() const
returns integer between 1 and 8 indicating what is roughly equivalent to the logical byte size of a s...
Definition: sqltypes.h:810
#define CHECK(condition)
Definition: Logger.h:291
Definition: sqltypes.h:62
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)

+ Here is the call graph for this function:

+ Here is the caller graph for this function: