OmniSciDB  cde582ebc3
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TableFunctionsFactory.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <boost/algorithm/string.hpp>
20 #include <mutex>
21 #include <unordered_set>
22 
23 extern bool g_enable_table_functions;
25 
26 namespace table_functions {
27 
28 namespace {
29 
31  switch (ext_arg_type) {
33  return SQLTypeInfo(kTINYINT, false);
35  return SQLTypeInfo(kSMALLINT, false);
37  return SQLTypeInfo(kINT, false);
39  return SQLTypeInfo(kBIGINT, false);
41  return SQLTypeInfo(kFLOAT, false);
43  return SQLTypeInfo(kDOUBLE, false);
45  return SQLTypeInfo(kBOOLEAN, false);
51  return generate_column_type(kINT);
74  default:
75  LOG(WARNING) << "ext_arg_pointer_type_to_type_info: ExtArgumentType `"
77  << "` conversion to SQLTypeInfo not implemented.";
78  UNREACHABLE();
79  }
80  UNREACHABLE();
81  return SQLTypeInfo(kNULLT, false);
82 }
83 
85  switch (ext_arg_type) {
90  return SQLTypeInfo(kTINYINT, false);
95  return SQLTypeInfo(kSMALLINT, false);
100  return SQLTypeInfo(kINT, false);
105  return SQLTypeInfo(kBIGINT, false);
110  return SQLTypeInfo(kFLOAT, false);
115  return SQLTypeInfo(kDOUBLE, false);
120  return SQLTypeInfo(kBOOLEAN, false);
124  return SQLTypeInfo(kTEXT, false, kENCODING_DICT);
126  return SQLTypeInfo(kTIMESTAMP, 9, 0, false);
127  default:
128  LOG(WARNING) << "ext_arg_type_to_type_info_output: ExtArgumentType `"
130  << "` conversion to SQLTypeInfo not implemented.";
131  UNREACHABLE();
132  }
133  UNREACHABLE();
134  return SQLTypeInfo(kNULLT, false);
135 }
136 
137 } // namespace
138 
140  CHECK_LT(idx, input_args_.size());
142 }
143 
145  CHECK_LT(idx, output_args_.size());
146  // TODO(adb): conditionally handle nulls
148 }
149 
151  int32_t scalar_args = 0;
152  for (const auto& ext_arg : input_args_) {
153  if (is_ext_arg_type_scalar(ext_arg)) {
154  scalar_args += 1;
155  }
156  }
157  return scalar_args;
158 }
159 
161  if (hasPreFlightOutputSizer()) {
162  return true;
163  }
164  // workaround for default args
165  for (size_t idx = 0; idx < std::min(input_args_.size(), annotations_.size()); idx++) {
166  const auto& ann = getInputAnnotations(idx);
167  if (ann.find("require") != ann.end()) {
168  return true;
169  }
170  }
171  return false;
172 }
173 
174 const std::map<std::string, std::string> TableFunction::getAnnotations(
175  const size_t idx) const {
176  CHECK_LE(idx, sql_args_.size() + output_args_.size());
177  if (annotations_.empty() || idx >= annotations_.size()) {
178  static const std::map<std::string, std::string> empty = {};
179  return empty;
180  }
181  return annotations_[idx];
182 }
183 
184 const std::map<std::string, std::string> TableFunction::getInputAnnotations(
185  const size_t input_arg_idx) const {
186  CHECK_LT(input_arg_idx, input_args_.size());
187  return getAnnotations(input_arg_idx);
188 }
189 
190 const std::string TableFunction::getInputAnnotation(const size_t input_arg_idx,
191  const std::string& key,
192  const std::string& default_) const {
193  const std::map<std::string, std::string> ann = getInputAnnotations(input_arg_idx);
194  const auto& it = ann.find(key);
195  if (it != ann.end()) {
196  return it->second;
197  }
198  return default_;
199 }
200 
201 const std::map<std::string, std::string> TableFunction::getOutputAnnotations(
202  const size_t output_arg_idx) const {
203  CHECK_LT(output_arg_idx, output_args_.size());
204  return getAnnotations(output_arg_idx + sql_args_.size());
205 }
206 
207 const std::string TableFunction::getOutputAnnotation(const size_t output_arg_idx,
208  const std::string& key,
209  const std::string& default_) const {
210  const std::map<std::string, std::string> ann = getOutputAnnotations(output_arg_idx);
211  const auto& it = ann.find(key);
212  if (it != ann.end()) {
213  return it->second;
214  }
215  return default_;
216 }
217 
218 const std::map<std::string, std::string> TableFunction::getFunctionAnnotations() const {
219  return getAnnotations(sql_args_.size() + output_args_.size());
220 }
221 
223  const std::string& key,
224  const std::string& default_) const {
225  const std::map<std::string, std::string> ann = getFunctionAnnotations();
226  const auto& it = ann.find(key);
227  if (it != ann.end()) {
228  return it->second;
229  }
230  return default_;
231 }
232 
233 const std::vector<std::string> TableFunction::getCursorFields(
234  const size_t sql_idx) const {
235  std::vector<std::string> fields;
236  const std::string& line = getInputAnnotation(sql_idx, "fields", "");
237  if (line.empty()) {
238  static const std::vector<std::string> empty = {};
239  return empty;
240  }
241  std::string substr = line.substr(1, line.size() - 2);
242  boost::split(fields, substr, boost::is_any_of(", "), boost::token_compress_on);
243  return fields;
244 }
245 
246 const std::string TableFunction::getArgTypes(bool use_input_args) const {
247  if (use_input_args) {
248  std::vector<std::string> arg_types;
249  size_t arg_idx = 0;
250  for (size_t sql_idx = 0; sql_idx < sql_args_.size(); sql_idx++) {
251  const std::vector<std::string> cursor_fields = getCursorFields(sql_idx);
252  if (cursor_fields.empty()) {
253  // fields => {}
254  arg_types.emplace_back(
256  } else {
257  std::vector<std::string> vec;
258  for (size_t i = 0; i < cursor_fields.size(); i++) {
259  vec.emplace_back(ExtensionFunctionsWhitelist::toString(input_args_[arg_idx++]));
260  }
261  arg_types.emplace_back("Cursor<" + boost::algorithm::join(vec, ", ") + ">");
262  }
263  }
264  return "[" + boost::algorithm::join(arg_types, ", ") + "]";
265  } else {
267  }
268 }
269 
270 const std::string TableFunction::getArgNames(bool use_input_args) const {
271  std::vector<std::string> names;
272  if (use_input_args) {
273  for (size_t idx = 0; idx < sql_args_.size(); idx++) {
274  const std::vector<std::string> cursor_fields = getCursorFields(idx);
275  if (cursor_fields.empty()) {
276  const std::string& name = getInputAnnotation(idx, "name", "''");
277  names.emplace_back(name);
278  } else {
279  names.emplace_back("Cursor<" + boost::algorithm::join(cursor_fields, ", ") + ">");
280  }
281  }
282  } else {
283  for (size_t idx = 0; idx < output_args_.size(); idx++) {
284  const std::string& name = getOutputAnnotation(idx, "name", "''");
285  names.emplace_back(name);
286  }
287  }
288 
289  return "[" + boost::algorithm::join(names, ", ") + "]";
290 }
291 
292 std::pair<int32_t, int32_t> TableFunction::getInputID(const size_t idx) const {
293  // if the annotation is of the form args<INT,INT>, it is refering to a column list
294 #define PREFIX_LENGTH 5
295  const auto& annotation = getOutputAnnotations(idx);
296  auto annot = annotation.find("input_id");
297  if (annot == annotation.end()) {
298  size_t lo = 0;
299  for (const auto& ext_arg : input_args_) {
300  switch (ext_arg) {
304  return std::make_pair(lo, 0);
305  default:
306  lo++;
307  }
308  }
309  UNREACHABLE();
310  }
311 
312  const std::string& input_id = annot->second;
313 
314  if (input_id == "args<-1>") {
315  // empty input id! -1 seems to be the magic number used in RelAlgExecutor.cpp
316  return {-1, -1};
317  }
318 
319  size_t comma = input_id.find(",");
320  int32_t gt = input_id.size() - 1;
321  int32_t lo = std::stoi(input_id.substr(PREFIX_LENGTH, comma - 1));
322 
323  if (comma == std::string::npos) {
324  return std::make_pair(lo, 0);
325  }
326  int32_t hi = std::stoi(input_id.substr(comma + 1, gt - comma - 1));
327  return std::make_pair(lo, hi);
328 }
329 
331  /*
332  This function differs from getOutputRowSizeParameter() since it returns the correct
333  index for the sizer in the sql_args list. For instance, consider the example below:
334 
335  RowMultiplier=4
336  input_args=[{i32*, i64}, {i32*, i64}, {i32*, i64}, i32, {i32*, i64}, {i32*, i64},
337  i32] sql_args=[cursor, i32, cursor, i32]
338 
339  Non-scalar args are aggregated in a cursor inside the sql_args list and the new
340  sizer index is 2 rather than 4 originally specified.
341  */
342 
344  size_t sizer = getOutputRowSizeParameter(); // lookup until reach the sizer arg
345  int32_t ext_arg_index = 0, sql_arg_index = 0;
346 
347  auto same_kind = [&](const ExtArgumentType& ext_arg, const ExtArgumentType& sql_arg) {
348  return ((is_ext_arg_type_scalar(ext_arg) && is_ext_arg_type_scalar(sql_arg)) ||
350  };
351 
352  while ((size_t)ext_arg_index < sizer) {
353  if ((size_t)ext_arg_index == sizer - 1)
354  return sql_arg_index;
355 
356  const auto& ext_arg = input_args_[ext_arg_index];
357  const auto& sql_arg = sql_args_[sql_arg_index];
358 
359  if (same_kind(ext_arg, sql_arg)) {
360  ++ext_arg_index;
361  ++sql_arg_index;
362  } else {
363  CHECK(same_kind(ext_arg, sql_args_[sql_arg_index - 1]));
364  ext_arg_index += 1;
365  }
366  }
367 
368  CHECK(false);
369  }
370 
371  return getOutputRowSizeParameter();
372 }
373 
374 bool is_table_function_whitelisted(const std::string& function_name) {
375  // All table functions that will be on by default (and not just for testing)
376  // must be added to the whitelisted_table_functions set below.
377  static const std::unordered_set<std::string> whitelisted_table_functions = {
378  "generate_series",
379  "generate_random_strings",
380  "tf_mandelbrot",
381  "tf_mandelbrot_float",
382  "tf_mandelbrot_cuda",
383  "tf_mandelbrot_cuda_float",
384  "tf_geo_rasterize",
385  "tf_geo_rasterize_slope",
386  "tf_compute_dwell_times",
387  "tf_feature_similarity",
388  "tf_feature_self_similarity",
389  "supported_ml_frameworks",
390  "kmeans",
391  "dbscan",
392  "linear_reg_fit",
393  "linear_reg_predict",
394  "linear_reg_fit_predict",
395  "tf_point_cloud_metadata",
396  "tf_load_point_cloud"};
397 
398  return whitelisted_table_functions.find(function_name) !=
399  whitelisted_table_functions.end();
400 }
401 
403  const std::string& name,
404  const TableFunctionOutputRowSizer sizer,
405  const std::vector<ExtArgumentType>& input_args,
406  const std::vector<ExtArgumentType>& output_args,
407  const std::vector<ExtArgumentType>& sql_args,
408  const std::vector<std::map<std::string, std::string>>& annotations,
409  bool is_runtime) {
410  static const std::map<std::string, std::string> empty = {};
411 
412  auto func_annotations =
413  (annotations.size() == sql_args.size() + output_args.size() + 1 ? annotations.back()
414  : empty);
415  auto mgr_annotation = func_annotations.find("uses_manager");
416  bool uses_manager = mgr_annotation != func_annotations.end() &&
417  boost::algorithm::to_lower_copy(mgr_annotation->second) == "true";
418 
419  auto tf = TableFunction(name,
420  sizer,
421  input_args,
422  output_args,
423  sql_args,
424  annotations,
425  is_runtime,
426  uses_manager);
427  const auto tf_name = tf.getName(true /* drop_suffix */, true /* lower */);
428  if (!g_enable_dev_table_functions && !is_runtime &&
429  !is_table_function_whitelisted(tf_name)) {
430  return;
431  }
432  auto sig = tf.getSignature(/* include_name */ true, /* include_output */ false);
433  for (auto it = functions_.begin(); it != functions_.end();) {
434  if (it->second.getName() == name) {
435  if (it->second.isRuntime()) {
436  LOG(WARNING)
437  << "Overriding existing run-time table function (reset not called?): "
438  << name;
439  it = functions_.erase(it);
440  } else {
441  throw std::runtime_error("Will not override existing load-time table function: " +
442  name);
443  }
444  } else {
445  if (sig == it->second.getSignature(/* include_name */ true,
446  /* include_output */ false) &&
447  ((tf.isCPU() && it->second.isCPU()) || (tf.isGPU() && it->second.isGPU()))) {
448  LOG(WARNING)
449  << "The existing (1) and added (2) table functions have the same signature `"
450  << sig << "`:\n"
451  << " 1: " << it->second.toString() << "\n 2: " << tf.toString() << "\n";
452  }
453  ++it;
454  }
455  }
456 
457  functions_.emplace(name, tf);
459  auto input_args2 = input_args;
460  input_args2.erase(input_args2.begin() + sizer.val - 1);
461 
462  auto sql_args2 = sql_args;
463  auto sql_sizer_pos = tf.getSqlOutputRowSizeParameter();
464  sql_args2.erase(sql_args2.begin() + sql_sizer_pos);
465 
467  sizer,
468  input_args2,
469  output_args,
470  sql_args2,
471  annotations,
472  is_runtime,
473  uses_manager);
474  auto sig = tf2.getSignature(/* include_name */ true, /* include_output */ false);
475  for (auto it = functions_.begin(); it != functions_.end();) {
476  if (sig == it->second.getSignature(/* include_name */ true,
477  /* include_output */ false) &&
478  ((tf2.isCPU() && it->second.isCPU()) || (tf2.isGPU() && it->second.isGPU()))) {
479  LOG(WARNING)
480  << "The existing (1) and added (2) table functions have the same signature `"
481  << sig << "`:\n"
482  << " 1: " << it->second.toString() << "\n 2: " << tf2.toString() << "\n";
483  }
484  ++it;
485  }
486  functions_.emplace(name + DEFAULT_ROW_MULTIPLIER_SUFFIX, tf2);
487  }
488 }
489 
490 /*
491  The implementation for `void TableFunctionsFactory::init()` is
492  generated by QueryEngine/scripts/generate_TableFunctionsFactory_init.py
493 */
494 
495 // removes existing runtime table functions
498  return;
499  }
500  for (auto it = functions_.begin(); it != functions_.end();) {
501  if (it->second.isRuntime()) {
502  it = functions_.erase(it);
503  } else {
504  ++it;
505  }
506  }
507 }
508 
509 namespace {
510 
511 std::string drop_suffix_impl(const std::string& str) {
512  const auto idx = str.find("__");
513  if (idx == std::string::npos) {
514  return str;
515  }
516  CHECK_GT(idx, std::string::size_type(0));
517  return str.substr(0, idx);
518 }
519 
520 } // namespace
521 
522 std::string TableFunction::getName(const bool drop_suffix, const bool lower) const {
523  std::string result = name_;
524  if (drop_suffix) {
525  result = drop_suffix_impl(result);
526  }
527  if (lower) {
529  }
530  return result;
531 }
532 
533 std::string TableFunction::getSignature(const bool include_name,
534  const bool include_output) const {
535  std::string sig;
536  if (include_name) {
537  sig += getName(/*drop_suffix=*/true, /*lower=*/true);
538  }
539 
540  size_t arg_idx = 0;
541  std::vector<std::string> args;
542  for (size_t sql_idx = 0; sql_idx < sql_args_.size(); sql_idx++) {
543  const std::vector<std::string> cursor_fields = getCursorFields(sql_idx);
544  if (cursor_fields.empty()) {
545  const auto& type = ExtensionFunctionsWhitelist::toString(input_args_[arg_idx++]);
546  const auto& name = getInputAnnotation(sql_idx, "name", "");
547  args.emplace_back(name.empty() ? type : (type + " " + name));
548  } else {
549  std::vector<std::string> vec;
550  for (size_t i = 0; i < cursor_fields.size(); i++) {
551  const auto& type = ExtensionFunctionsWhitelist::toString(input_args_[arg_idx++]);
552  const auto& name = cursor_fields[i];
553  vec.emplace_back((name.empty() ? type : type + " " + name));
554  }
555  args.emplace_back("Cursor<" + boost::algorithm::join(vec, ", ") + ">");
556  }
557  }
558  sig += "(" + boost::algorithm::join(args, ", ") + ")";
559  if (include_output) {
561  }
562  return sig;
563 }
564 
566  // gets the name of the pre flight function associated with this table function
567  return getName(false, true) + PREFLIGHT_SUFFIX;
568 }
569 
570 std::vector<TableFunction> TableFunctionsFactory::get_table_funcs(const std::string& name,
571  const bool is_gpu) {
572  std::vector<TableFunction> table_funcs;
573  for (const auto& tf : get_table_funcs(name)) {
574  if (is_gpu ? tf.isGPU() : tf.isCPU()) {
575  table_funcs.emplace_back(tf);
576  }
577  }
578  return table_funcs;
579 }
580 
581 std::vector<TableFunction> TableFunctionsFactory::get_table_funcs(
582  const std::string& name) {
583  std::vector<TableFunction> table_funcs;
584  auto table_func_name = name;
585  boost::algorithm::to_lower(table_func_name);
586  for (const auto& pair : functions_) {
587  auto fname = drop_suffix_impl(pair.first);
588  if (fname == table_func_name) {
589  table_funcs.push_back(pair.second);
590  }
591  }
592  return table_funcs;
593 }
594 
595 std::vector<TableFunction> TableFunctionsFactory::get_table_funcs(const bool is_runtime) {
596  std::vector<TableFunction> table_funcs;
597  for (const auto& pair : functions_) {
598  if (pair.second.isRuntime() == is_runtime) {
599  table_funcs.push_back(pair.second);
600  }
601  }
602  return table_funcs;
603 }
604 
605 std::vector<TableFunction> TableFunctionsFactory::get_table_funcs() {
606  std::vector<TableFunction> table_funcs;
607  for (const auto& pair : functions_) {
608  table_funcs.push_back(pair.second);
609  }
610  return table_funcs;
611 }
612 
613 std::unordered_map<std::string, TableFunction> TableFunctionsFactory::functions_;
614 
615 } // namespace table_functions
SQLTypeInfo getOutputSQLType(const size_t idx) const
std::string to_lower(const std::string &str)
const std::string getOutputAnnotation(const size_t output_arg_idx, const std::string &key, const std::string &default_) const
static std::vector< TableFunction > get_table_funcs()
bool is_ext_arg_type_scalar(const ExtArgumentType ext_arg_type)
#define PREFIX_LENGTH
static void add(const std::string &name, const TableFunctionOutputRowSizer sizer, const std::vector< ExtArgumentType > &input_args, const std::vector< ExtArgumentType > &output_args, const std::vector< ExtArgumentType > &sql_args, const std::vector< std::map< std::string, std::string >> &annotations, bool is_runtime=false)
#define LOG(tag)
Definition: Logger.h:216
auto generate_column_type(const SQLTypes subtype)
Definition: sqltypes.h:1124
std::string join(T const &container, std::string const &delim)
SQLTypeInfo ext_arg_pointer_type_to_type_info(const ExtArgumentType ext_arg_type)
#define UNREACHABLE()
Definition: Logger.h:266
#define PREFLIGHT_SUFFIX
SQLTypeInfo ext_arg_type_to_type_info_output(const ExtArgumentType ext_arg_type)
const std::map< std::string, std::string > getFunctionAnnotations() const
const std::vector< std::map< std::string, std::string > > annotations_
const std::vector< ExtArgumentType > output_args_
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
std::pair< int32_t, int32_t > getInputID(const size_t idx) const
const std::string getFunctionAnnotation(const std::string &key, const std::string &default_) const
#define CHECK_GT(x, y)
Definition: Logger.h:234
std::vector< std::string > split(std::string_view str, std::string_view delim, std::optional< size_t > maxsplit)
split apart a string into a vector of substrings
std::string getSignature(const bool include_name, const bool include_output) const
bool is_ext_arg_type_nonscalar(const ExtArgumentType ext_arg_type)
const std::vector< ExtArgumentType > sql_args_
SQLTypeInfo getInputSQLType(const size_t idx) const
const std::string getArgNames(const bool use_input_args) const
auto generate_column_list_type(const SQLTypes subtype)
Definition: sqltypes.h:1138
const std::map< std::string, std::string > getOutputAnnotations(const size_t output_arg_idx) const
const std::string getInputAnnotation(const size_t input_arg_idx, const std::string &key, const std::string &default_) const
bool g_enable_dev_table_functions
Definition: Execute.cpp:113
std::string getName(const bool drop_suffix=false, const bool lower=false) const
#define CHECK_LT(x, y)
Definition: Logger.h:232
Definition: sqltypes.h:52
const std::string getArgTypes(const bool use_input_args) const
#define CHECK_LE(x, y)
Definition: Logger.h:233
tuple line
Definition: parse_ast.py:10
const std::vector< std::string > getCursorFields(const size_t sql_idx) const
const std::map< std::string, std::string > getInputAnnotations(const size_t input_arg_idx) const
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
bool is_table_function_whitelisted(const std::string &function_name)
#define CHECK(condition)
Definition: Logger.h:222
static std::unordered_map< std::string, TableFunction > functions_
Definition: sqltypes.h:45
const std::vector< ExtArgumentType > input_args_
string name
Definition: setup.in.py:72
bool g_enable_table_functions
Definition: Execute.cpp:112
const std::vector< std::map< std::string, std::string > > & getAnnotations() const