OmniSciDB  a987f07e93
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TableFunctionsFactory.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <boost/algorithm/string.hpp>
20 #include <mutex>
21 #include <unordered_set>
22 
23 extern bool g_enable_table_functions;
25 
26 namespace table_functions {
27 
28 namespace {
29 
31  switch (ext_arg_type) {
33  return SQLTypeInfo(kTINYINT, false);
35  return SQLTypeInfo(kSMALLINT, false);
37  return SQLTypeInfo(kINT, false);
39  return SQLTypeInfo(kBIGINT, false);
41  return SQLTypeInfo(kFLOAT, false);
43  return SQLTypeInfo(kDOUBLE, false);
45  return SQLTypeInfo(kBOOLEAN, false);
46  default:
47  return ext_arg_type_to_type_info(ext_arg_type);
48  }
49  UNREACHABLE();
50  return SQLTypeInfo(kNULLT, false);
51 }
52 
54  switch (ext_arg_type) {
56  return SQLTypeInfo(kTINYINT, false);
58  return SQLTypeInfo(kSMALLINT, false);
60  return SQLTypeInfo(kINT, false);
62  return SQLTypeInfo(kBIGINT, false);
64  return SQLTypeInfo(kFLOAT, false);
66  return SQLTypeInfo(kDOUBLE, false);
68  return SQLTypeInfo(kBOOLEAN, false);
69  default:
70  return ext_arg_type_to_type_info(ext_arg_type).get_elem_type();
71  }
72  UNREACHABLE();
73  return SQLTypeInfo(kNULLT, false);
74 }
75 
76 } // namespace
77 
79  CHECK_LT(idx, input_args_.size());
81 }
82 
84  CHECK_LT(idx, output_args_.size());
85  // TODO(adb): conditionally handle nulls
87 }
88 
90  int32_t scalar_args = 0;
91  for (const auto& ext_arg : input_args_) {
92  if (is_ext_arg_type_scalar(ext_arg)) {
93  scalar_args += 1;
94  }
95  }
96  return scalar_args;
97 }
98 
100  if (hasPreFlightOutputSizer()) {
101  return true;
102  }
103  // workaround for default args
104  for (size_t idx = 0; idx < std::min(input_args_.size(), annotations_.size()); idx++) {
105  const auto& ann = getInputAnnotations(idx);
106  if (ann.find("require") != ann.end()) {
107  return true;
108  }
109  }
110  return false;
111 }
112 
113 const std::map<std::string, std::string> TableFunction::getAnnotations(
114  const size_t idx) const {
115  CHECK_LE(idx, sql_args_.size() + output_args_.size());
116  if (annotations_.empty() || idx >= annotations_.size()) {
117  static const std::map<std::string, std::string> empty = {};
118  return empty;
119  }
120  return annotations_[idx];
121 }
122 
123 const std::map<std::string, std::string> TableFunction::getInputAnnotations(
124  const size_t input_arg_idx) const {
125  CHECK_LT(input_arg_idx, input_args_.size());
126  return getAnnotations(input_arg_idx);
127 }
128 
129 const std::string TableFunction::getInputAnnotation(const size_t input_arg_idx,
130  const std::string& key,
131  const std::string& default_) const {
132  const std::map<std::string, std::string> ann = getInputAnnotations(input_arg_idx);
133  const auto& it = ann.find(key);
134  if (it != ann.end()) {
135  return it->second;
136  }
137  return default_;
138 }
139 
140 const std::map<std::string, std::string> TableFunction::getOutputAnnotations(
141  const size_t output_arg_idx) const {
142  CHECK_LT(output_arg_idx, output_args_.size());
143  return getAnnotations(output_arg_idx + sql_args_.size());
144 }
145 
146 const std::string TableFunction::getOutputAnnotation(const size_t output_arg_idx,
147  const std::string& key,
148  const std::string& default_) const {
149  const std::map<std::string, std::string> ann = getOutputAnnotations(output_arg_idx);
150  const auto& it = ann.find(key);
151  if (it != ann.end()) {
152  return it->second;
153  }
154  return default_;
155 }
156 
157 const std::map<std::string, std::string> TableFunction::getFunctionAnnotations() const {
158  return getAnnotations(sql_args_.size() + output_args_.size());
159 }
160 
162  const std::string& key,
163  const std::string& default_) const {
164  const std::map<std::string, std::string> ann = getFunctionAnnotations();
165  const auto& it = ann.find(key);
166  if (it != ann.end()) {
167  return it->second;
168  }
169  return default_;
170 }
171 
172 const std::vector<std::string> TableFunction::getCursorFields(
173  const size_t sql_idx) const {
174  std::vector<std::string> fields;
175  const std::string& line = getInputAnnotation(sql_idx, "fields", "");
176  if (line.empty()) {
177  static const std::vector<std::string> empty = {};
178  return empty;
179  }
180  std::string substr = line.substr(1, line.size() - 2);
181  boost::split(fields, substr, boost::is_any_of(", "), boost::token_compress_on);
182  return fields;
183 }
184 
185 const std::string TableFunction::getArgTypes(bool use_input_args) const {
186  if (use_input_args) {
187  std::vector<std::string> arg_types;
188  size_t arg_idx = 0;
189  for (size_t sql_idx = 0; sql_idx < sql_args_.size(); sql_idx++) {
190  const std::vector<std::string> cursor_fields = getCursorFields(sql_idx);
191  if (cursor_fields.empty()) {
192  // fields => {}
193  arg_types.emplace_back(
195  } else {
196  std::vector<std::string> vec;
197  for (size_t i = 0; i < cursor_fields.size(); i++) {
198  vec.emplace_back(ExtensionFunctionsWhitelist::toString(input_args_[arg_idx++]));
199  }
200  arg_types.emplace_back("Cursor<" + boost::algorithm::join(vec, ", ") + ">");
201  }
202  }
203  return "[" + boost::algorithm::join(arg_types, ", ") + "]";
204  } else {
206  }
207 }
208 
209 const std::string TableFunction::getArgNames(bool use_input_args) const {
210  std::vector<std::string> names;
211  if (use_input_args) {
212  for (size_t idx = 0; idx < sql_args_.size(); idx++) {
213  const std::vector<std::string> cursor_fields = getCursorFields(idx);
214  if (cursor_fields.empty()) {
215  const std::string& name = getInputAnnotation(idx, "name", "''");
216  names.emplace_back(name);
217  } else {
218  names.emplace_back("Cursor<" + boost::algorithm::join(cursor_fields, ", ") + ">");
219  }
220  }
221  } else {
222  for (size_t idx = 0; idx < output_args_.size(); idx++) {
223  const std::string& name = getOutputAnnotation(idx, "name", "''");
224  names.emplace_back(name);
225  }
226  }
227 
228  return "[" + boost::algorithm::join(names, ", ") + "]";
229 }
230 
231 std::pair<int32_t, int32_t> TableFunction::getInputID(const size_t idx) const {
232  // if the annotation is of the form args<INT,INT>, it is refering to a column list
233 #define PREFIX_LENGTH 5
234  const auto& annotation = getOutputAnnotations(idx);
235  auto annot = annotation.find("input_id");
236  if (annot == annotation.end()) {
237  size_t lo = 0;
238  for (const auto& ext_arg : input_args_) {
239  switch (ext_arg) {
243  return std::make_pair(lo, 0);
244  default:
245  lo++;
246  }
247  }
248  UNREACHABLE();
249  }
250 
251  const std::string& input_id = annot->second;
252 
253  if (input_id == "args<-1>") {
254  // empty input id! -1 seems to be the magic number used in RelAlgExecutor.cpp
255  return {-1, -1};
256  }
257 
258  size_t comma = input_id.find(",");
259  int32_t gt = input_id.size() - 1;
260  int32_t lo = std::stoi(input_id.substr(PREFIX_LENGTH, comma - 1));
261 
262  if (comma == std::string::npos) {
263  return std::make_pair(lo, 0);
264  }
265  int32_t hi = std::stoi(input_id.substr(comma + 1, gt - comma - 1));
266  return std::make_pair(lo, hi);
267 }
268 
270  /*
271  This function differs from getOutputRowSizeParameter() since it returns the correct
272  index for the sizer in the sql_args list. For instance, consider the example below:
273 
274  RowMultiplier=4
275  input_args=[{i32*, i64}, {i32*, i64}, {i32*, i64}, i32, {i32*, i64}, {i32*, i64},
276  i32] sql_args=[cursor, i32, cursor, i32]
277 
278  Non-scalar args are aggregated in a cursor inside the sql_args list and the new
279  sizer index is 2 rather than 4 originally specified.
280  */
281 
283  size_t sizer = getOutputRowSizeParameter(); // lookup until reach the sizer arg
284  int32_t ext_arg_index = 0, sql_arg_index = 0;
285 
286  auto same_kind = [&](const ExtArgumentType& ext_arg, const ExtArgumentType& sql_arg) {
287  return ((is_ext_arg_type_scalar(ext_arg) && is_ext_arg_type_scalar(sql_arg)) ||
289  };
290 
291  while ((size_t)ext_arg_index < sizer) {
292  if ((size_t)ext_arg_index == sizer - 1)
293  return sql_arg_index;
294 
295  const auto& ext_arg = input_args_[ext_arg_index];
296  const auto& sql_arg = sql_args_[sql_arg_index];
297 
298  if (same_kind(ext_arg, sql_arg)) {
299  ++ext_arg_index;
300  ++sql_arg_index;
301  } else {
302  CHECK(same_kind(ext_arg, sql_args_[sql_arg_index - 1]));
303  ext_arg_index += 1;
304  }
305  }
306 
307  CHECK(false);
308  }
309 
310  return getOutputRowSizeParameter();
311 }
312 
313 bool is_table_function_whitelisted(const std::string& function_name) {
314  // All table functions that will be on by default (and not just for testing)
315  // must be added to the whitelisted_table_functions set below.
316  static const std::unordered_set<std::string> whitelisted_table_functions = {
317  "generate_series",
318  "generate_random_strings",
319  "tf_mandelbrot",
320  "tf_mandelbrot_float",
321  "tf_mandelbrot_cuda",
322  "tf_mandelbrot_cuda_float",
323  "tf_geo_rasterize",
324  "tf_geo_rasterize_slope",
325  "tf_compute_dwell_times",
326  "tf_feature_similarity",
327  "tf_feature_self_similarity",
328  "tf_graph_shortest_path",
329  "tf_graph_shortest_paths_distances",
330  "tf_raster_graph_shortest_slope_weighted_path",
331  "supported_ml_frameworks",
332  "kmeans",
333  "dbscan",
334  "linear_reg_fit",
335  "linear_reg_predict",
336  "linear_reg_fit_predict",
337  "tf_point_cloud_metadata",
338  "tf_load_point_cloud"};
339 
340  return whitelisted_table_functions.find(function_name) !=
341  whitelisted_table_functions.end();
342 }
343 
345  const std::string& name,
346  const TableFunctionOutputRowSizer sizer,
347  const std::vector<ExtArgumentType>& input_args,
348  const std::vector<ExtArgumentType>& output_args,
349  const std::vector<ExtArgumentType>& sql_args,
350  const std::vector<std::map<std::string, std::string>>& annotations,
351  bool is_runtime) {
352  static const std::map<std::string, std::string> empty = {};
353 
354  auto func_annotations =
355  (annotations.size() == sql_args.size() + output_args.size() + 1 ? annotations.back()
356  : empty);
357  auto mgr_annotation = func_annotations.find("uses_manager");
358  bool uses_manager = mgr_annotation != func_annotations.end() &&
359  boost::algorithm::to_lower_copy(mgr_annotation->second) == "true";
360 
361  auto tf = TableFunction(name,
362  sizer,
363  input_args,
364  output_args,
365  sql_args,
366  annotations,
367  is_runtime,
368  uses_manager);
369  const auto tf_name = tf.getName(true /* drop_suffix */, true /* lower */);
370  if (!g_enable_dev_table_functions && !is_runtime &&
371  !is_table_function_whitelisted(tf_name)) {
372  return;
373  }
374  auto sig = tf.getSignature(/* include_name */ true, /* include_output */ false);
375  for (auto it = functions_.begin(); it != functions_.end();) {
376  if (it->second.getName() == name) {
377  if (it->second.isRuntime()) {
378  LOG(WARNING)
379  << "Overriding existing run-time table function (reset not called?): "
380  << name;
381  it = functions_.erase(it);
382  } else {
383  throw std::runtime_error("Will not override existing load-time table function: " +
384  name);
385  }
386  } else {
387  if (sig == it->second.getSignature(/* include_name */ true,
388  /* include_output */ false) &&
389  ((tf.isCPU() && it->second.isCPU()) || (tf.isGPU() && it->second.isGPU()))) {
390  LOG(WARNING)
391  << "The existing (1) and added (2) table functions have the same signature `"
392  << sig << "`:\n"
393  << " 1: " << it->second.toString() << "\n 2: " << tf.toString() << "\n";
394  }
395  ++it;
396  }
397  }
398 
399  functions_.emplace(name, tf);
401  auto input_args2 = input_args;
402  input_args2.erase(input_args2.begin() + sizer.val - 1);
403 
404  auto sql_args2 = sql_args;
405  auto sql_sizer_pos = tf.getSqlOutputRowSizeParameter();
406  sql_args2.erase(sql_args2.begin() + sql_sizer_pos);
407 
408  auto annotations2 = annotations;
409  annotations2.erase(annotations2.begin() + sql_sizer_pos);
410 
412  sizer,
413  input_args2,
414  output_args,
415  sql_args2,
416  annotations2,
417  is_runtime,
418  uses_manager);
419  auto sig = tf2.getSignature(/* include_name */ true, /* include_output */ false);
420  for (auto it = functions_.begin(); it != functions_.end();) {
421  if (sig == it->second.getSignature(/* include_name */ true,
422  /* include_output */ false) &&
423  ((tf2.isCPU() && it->second.isCPU()) || (tf2.isGPU() && it->second.isGPU()))) {
424  LOG(WARNING)
425  << "The existing (1) and added (2) table functions have the same signature `"
426  << sig << "`:\n"
427  << " 1: " << it->second.toString() << "\n 2: " << tf2.toString() << "\n";
428  }
429  ++it;
430  }
431  functions_.emplace(name + DEFAULT_ROW_MULTIPLIER_SUFFIX, tf2);
432  }
433 }
434 
435 /*
436  The implementation for `void TableFunctionsFactory::init()` is
437  generated by QueryEngine/scripts/generate_TableFunctionsFactory_init.py
438 */
439 
440 // removes existing runtime table functions
443  return;
444  }
445  for (auto it = functions_.begin(); it != functions_.end();) {
446  if (it->second.isRuntime()) {
447  it = functions_.erase(it);
448  } else {
449  ++it;
450  }
451  }
452 }
453 
454 namespace {
455 
456 std::string drop_suffix_impl(const std::string& str) {
457  const auto idx = str.find("__");
458  if (idx == std::string::npos) {
459  return str;
460  }
461  CHECK_GT(idx, std::string::size_type(0));
462  return str.substr(0, idx);
463 }
464 
465 } // namespace
466 
467 std::string TableFunction::getName(const bool drop_suffix, const bool lower) const {
468  std::string result = name_;
469  if (drop_suffix) {
470  result = drop_suffix_impl(result);
471  }
472  if (lower) {
474  }
475  return result;
476 }
477 
478 std::string TableFunction::getSignature(const bool include_name,
479  const bool include_output) const {
480  std::string sig;
481  if (include_name) {
482  sig += getName(/*drop_suffix=*/true, /*lower=*/true);
483  }
484 
485  size_t arg_idx = 0;
486  std::vector<std::string> args;
487  for (size_t sql_idx = 0; sql_idx < sql_args_.size(); sql_idx++) {
488  const std::vector<std::string> cursor_fields = getCursorFields(sql_idx);
489  if (cursor_fields.empty()) {
490  const auto& type = ExtensionFunctionsWhitelist::toString(input_args_[arg_idx++]);
491  const auto& name = getInputAnnotation(sql_idx, "name", "");
492  args.emplace_back(name.empty() ? type : (type + " " + name));
493  } else {
494  std::vector<std::string> vec;
495  for (size_t i = 0; i < cursor_fields.size(); i++) {
496  const auto& type = ExtensionFunctionsWhitelist::toString(input_args_[arg_idx++]);
497  const auto& name = cursor_fields[i];
498  vec.emplace_back((name.empty() ? type : type + " " + name));
499  }
500  args.emplace_back("Cursor<" + boost::algorithm::join(vec, ", ") + ">");
501  }
502  }
503  sig += "(" + boost::algorithm::join(args, ", ") + ")";
504  if (include_output) {
506  }
507  return sig;
508 }
509 
511  // gets the name of the pre flight function associated with this table function
512  return getName(false, true) + PREFLIGHT_SUFFIX;
513 }
514 
515 std::vector<TableFunction> TableFunctionsFactory::get_table_funcs(const std::string& name,
516  const bool is_gpu) {
517  std::vector<TableFunction> table_funcs;
518  for (const auto& tf : get_table_funcs(name)) {
519  if (is_gpu ? tf.isGPU() : tf.isCPU()) {
520  table_funcs.emplace_back(tf);
521  }
522  }
523  return table_funcs;
524 }
525 
526 std::vector<TableFunction> TableFunctionsFactory::get_table_funcs(
527  const std::string& name) {
528  std::vector<TableFunction> table_funcs;
529  auto table_func_name = name;
530  boost::algorithm::to_lower(table_func_name);
531  for (const auto& pair : functions_) {
532  auto fname = drop_suffix_impl(pair.first);
533  if (fname == table_func_name) {
534  table_funcs.push_back(pair.second);
535  }
536  }
537  return table_funcs;
538 }
539 
540 std::vector<TableFunction> TableFunctionsFactory::get_table_funcs(const bool is_runtime) {
541  std::vector<TableFunction> table_funcs;
542  for (const auto& pair : functions_) {
543  if (pair.second.isRuntime() == is_runtime) {
544  table_funcs.push_back(pair.second);
545  }
546  }
547  return table_funcs;
548 }
549 
550 std::vector<TableFunction> TableFunctionsFactory::get_table_funcs() {
551  std::vector<TableFunction> table_funcs;
552  for (const auto& pair : functions_) {
553  table_funcs.push_back(pair.second);
554  }
555  return table_funcs;
556 }
557 
558 std::unordered_map<std::string, TableFunction> TableFunctionsFactory::functions_;
559 
560 } // namespace table_functions
SQLTypeInfo getOutputSQLType(const size_t idx) const
std::string to_lower(const std::string &str)
const std::string getOutputAnnotation(const size_t output_arg_idx, const std::string &key, const std::string &default_) const
static std::vector< TableFunction > get_table_funcs()
bool is_ext_arg_type_scalar(const ExtArgumentType ext_arg_type)
#define PREFIX_LENGTH
static void add(const std::string &name, const TableFunctionOutputRowSizer sizer, const std::vector< ExtArgumentType > &input_args, const std::vector< ExtArgumentType > &output_args, const std::vector< ExtArgumentType > &sql_args, const std::vector< std::map< std::string, std::string >> &annotations, bool is_runtime=false)
#define LOG(tag)
Definition: Logger.h:283
std::string join(T const &container, std::string const &delim)
SQLTypeInfo ext_arg_pointer_type_to_type_info(const ExtArgumentType ext_arg_type)
#define UNREACHABLE()
Definition: Logger.h:333
#define PREFLIGHT_SUFFIX
SQLTypeInfo ext_arg_type_to_type_info_output(const ExtArgumentType ext_arg_type)
const std::map< std::string, std::string > getFunctionAnnotations() const
const std::vector< std::map< std::string, std::string > > annotations_
const std::vector< ExtArgumentType > output_args_
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
std::pair< int32_t, int32_t > getInputID(const size_t idx) const
const std::string getFunctionAnnotation(const std::string &key, const std::string &default_) const
#define CHECK_GT(x, y)
Definition: Logger.h:301
std::vector< std::string > split(std::string_view str, std::string_view delim, std::optional< size_t > maxsplit)
split apart a string into a vector of substrings
std::string getSignature(const bool include_name, const bool include_output) const
bool is_ext_arg_type_nonscalar(const ExtArgumentType ext_arg_type)
const std::vector< ExtArgumentType > sql_args_
SQLTypeInfo getInputSQLType(const size_t idx) const
const std::string getArgNames(const bool use_input_args) const
std::string drop_suffix_impl(const std::string &str)
const std::map< std::string, std::string > getOutputAnnotations(const size_t output_arg_idx) const
const std::string getInputAnnotation(const size_t input_arg_idx, const std::string &key, const std::string &default_) const
bool g_enable_dev_table_functions
Definition: Execute.cpp:113
std::string getName(const bool drop_suffix=false, const bool lower=false) const
#define CHECK_LT(x, y)
Definition: Logger.h:299
const std::string getArgTypes(const bool use_input_args) const
#define CHECK_LE(x, y)
Definition: Logger.h:300
tuple line
Definition: parse_ast.py:10
const std::vector< std::string > getCursorFields(const size_t sql_idx) const
const std::map< std::string, std::string > getInputAnnotations(const size_t input_arg_idx) const
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
bool is_table_function_whitelisted(const std::string &function_name)
#define CHECK(condition)
Definition: Logger.h:289
static std::unordered_map< std::string, TableFunction > functions_
Definition: sqltypes.h:60
const std::vector< ExtArgumentType > input_args_
string name
Definition: setup.in.py:72
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:957
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
bool g_enable_table_functions
Definition: Execute.cpp:112
const std::vector< std::map< std::string, std::string > > & getAnnotations() const