OmniSciDB  c1a53651b2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TableFunctionsFactory.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <boost/algorithm/string.hpp>
20 #include <mutex>
21 #include <unordered_set>
22 
23 extern bool g_enable_table_functions;
25 
26 namespace table_functions {
27 
28 namespace {
29 
31  switch (ext_arg_type) {
33  return SQLTypeInfo(kTINYINT, false);
35  return SQLTypeInfo(kSMALLINT, false);
37  return SQLTypeInfo(kINT, false);
39  return SQLTypeInfo(kBIGINT, false);
41  return SQLTypeInfo(kFLOAT, false);
43  return SQLTypeInfo(kDOUBLE, false);
45  return SQLTypeInfo(kBOOLEAN, false);
46  default:
47  return ext_arg_type_to_type_info(ext_arg_type);
48  }
49  UNREACHABLE();
50  return SQLTypeInfo(kNULLT, false);
51 }
52 
54  switch (ext_arg_type) {
56  return SQLTypeInfo(kTINYINT, false);
58  return SQLTypeInfo(kSMALLINT, false);
60  return SQLTypeInfo(kINT, false);
62  return SQLTypeInfo(kBIGINT, false);
64  return SQLTypeInfo(kFLOAT, false);
66  return SQLTypeInfo(kDOUBLE, false);
68  return SQLTypeInfo(kBOOLEAN, false);
69  default:
70  return ext_arg_type_to_type_info(ext_arg_type).get_elem_type();
71  }
72  UNREACHABLE();
73  return SQLTypeInfo(kNULLT, false);
74 }
75 
76 } // namespace
77 
79  CHECK_LT(idx, input_args_.size());
81 }
82 
84  CHECK_LT(idx, output_args_.size());
85  // TODO(adb): conditionally handle nulls
87 }
88 
90  int32_t scalar_args = 0;
91  for (const auto& ext_arg : input_args_) {
92  if (is_ext_arg_type_scalar(ext_arg)) {
93  scalar_args += 1;
94  }
95  }
96  return scalar_args;
97 }
98 
100  if (hasPreFlightOutputSizer()) {
101  return true;
102  }
103  // workaround for default args
104  for (size_t idx = 0; idx < std::min(input_args_.size(), annotations_.size()); idx++) {
105  const auto& ann = getInputAnnotations(idx);
106  if (ann.find("require") != ann.end()) {
107  return true;
108  }
109  }
110  return false;
111 }
112 
113 const std::map<std::string, std::string> TableFunction::getAnnotations(
114  const size_t idx) const {
115  CHECK_LE(idx, sql_args_.size() + output_args_.size());
116  if (annotations_.empty() || idx >= annotations_.size()) {
117  static const std::map<std::string, std::string> empty = {};
118  return empty;
119  }
120  return annotations_[idx];
121 }
122 
123 const std::map<std::string, std::string> TableFunction::getInputAnnotations(
124  const size_t input_arg_idx) const {
125  CHECK_LT(input_arg_idx, input_args_.size());
126  return getAnnotations(input_arg_idx);
127 }
128 
129 const std::string TableFunction::getInputAnnotation(const size_t input_arg_idx,
130  const std::string& key,
131  const std::string& default_) const {
132  const std::map<std::string, std::string> ann = getInputAnnotations(input_arg_idx);
133  const auto& it = ann.find(key);
134  if (it != ann.end()) {
135  return it->second;
136  }
137  return default_;
138 }
139 
140 const std::map<std::string, std::string> TableFunction::getOutputAnnotations(
141  const size_t output_arg_idx) const {
142  CHECK_LT(output_arg_idx, output_args_.size());
143  return getAnnotations(output_arg_idx + sql_args_.size());
144 }
145 
146 const std::string TableFunction::getOutputAnnotation(const size_t output_arg_idx,
147  const std::string& key,
148  const std::string& default_) const {
149  const std::map<std::string, std::string> ann = getOutputAnnotations(output_arg_idx);
150  const auto& it = ann.find(key);
151  if (it != ann.end()) {
152  return it->second;
153  }
154  return default_;
155 }
156 
157 const std::map<std::string, std::string> TableFunction::getFunctionAnnotations() const {
158  return getAnnotations(sql_args_.size() + output_args_.size());
159 }
160 
162  const std::string& key,
163  const std::string& default_) const {
164  const std::map<std::string, std::string> ann = getFunctionAnnotations();
165  const auto& it = ann.find(key);
166  if (it != ann.end()) {
167  return it->second;
168  }
169  return default_;
170 }
171 
172 const std::vector<std::string> TableFunction::getCursorFields(
173  const size_t sql_idx) const {
174  std::vector<std::string> fields;
175  const std::string& line = getInputAnnotation(sql_idx, "fields", "");
176  if (line.empty()) {
177  static const std::vector<std::string> empty = {};
178  return empty;
179  }
180  std::string substr = line.substr(1, line.size() - 2);
181  boost::split(fields, substr, boost::is_any_of(", "), boost::token_compress_on);
182  return fields;
183 }
184 
185 const std::string TableFunction::getArgTypes(bool use_input_args) const {
186  if (use_input_args) {
187  std::vector<std::string> arg_types;
188  size_t arg_idx = 0;
189  for (size_t sql_idx = 0; sql_idx < sql_args_.size(); sql_idx++) {
190  const std::vector<std::string> cursor_fields = getCursorFields(sql_idx);
191  if (cursor_fields.empty()) {
192  // fields => {}
193  arg_types.emplace_back(
195  } else {
196  std::vector<std::string> vec;
197  for (size_t i = 0; i < cursor_fields.size(); i++) {
198  vec.emplace_back(ExtensionFunctionsWhitelist::toString(input_args_[arg_idx++]));
199  }
200  arg_types.emplace_back("Cursor<" + boost::algorithm::join(vec, ", ") + ">");
201  }
202  }
203  return "[" + boost::algorithm::join(arg_types, ", ") + "]";
204  } else {
206  }
207 }
208 
209 const std::string TableFunction::getArgNames(bool use_input_args) const {
210  std::vector<std::string> names;
211  if (use_input_args) {
212  for (size_t idx = 0; idx < sql_args_.size(); idx++) {
213  const std::vector<std::string> cursor_fields = getCursorFields(idx);
214  if (cursor_fields.empty()) {
215  const std::string& name = getInputAnnotation(idx, "name", "''");
216  names.emplace_back(name);
217  } else {
218  names.emplace_back("Cursor<" + boost::algorithm::join(cursor_fields, ", ") + ">");
219  }
220  }
221  } else {
222  for (size_t idx = 0; idx < output_args_.size(); idx++) {
223  const std::string& name = getOutputAnnotation(idx, "name", "''");
224  names.emplace_back(name);
225  }
226  }
227 
228  return "[" + boost::algorithm::join(names, ", ") + "]";
229 }
230 
231 std::pair<int32_t, int32_t> TableFunction::getInputID(const size_t idx) const {
232  // if the annotation is of the form args<INT,INT>, it is refering to a column list
233 #define PREFIX_LENGTH 5
234  const auto& annotation = getOutputAnnotations(idx);
235  auto annot = annotation.find("input_id");
236  if (annot == annotation.end()) {
237  size_t lo = 0;
238  for (const auto& ext_arg : input_args_) {
239  switch (ext_arg) {
243  return std::make_pair(lo, 0);
244  default:
245  lo++;
246  }
247  }
248  UNREACHABLE();
249  }
250 
251  const std::string& input_id = annot->second;
252 
253  if (input_id == "args<-1>") {
254  // empty input id! -1 seems to be the magic number used in RelAlgExecutor.cpp
255  return {-1, -1};
256  }
257 
258  size_t comma = input_id.find(",");
259  int32_t gt = input_id.size() - 1;
260  int32_t lo = std::stoi(input_id.substr(PREFIX_LENGTH, comma - 1));
261 
262  if (comma == std::string::npos) {
263  return std::make_pair(lo, 0);
264  }
265  int32_t hi = std::stoi(input_id.substr(comma + 1, gt - comma - 1));
266  return std::make_pair(lo, hi);
267 }
268 
270  /*
271  This function differs from getOutputRowSizeParameter() since it returns the correct
272  index for the sizer in the sql_args list. For instance, consider the example below:
273 
274  RowMultiplier=4
275  input_args=[{i32*, i64}, {i32*, i64}, {i32*, i64}, i32, {i32*, i64}, {i32*, i64},
276  i32] sql_args=[cursor, i32, cursor, i32]
277 
278  Non-scalar args are aggregated in a cursor inside the sql_args list and the new
279  sizer index is 2 rather than 4 originally specified.
280  */
281 
283  size_t sizer = getOutputRowSizeParameter(); // lookup until reach the sizer arg
284  int32_t ext_arg_index = 0, sql_arg_index = 0;
285 
286  auto same_kind = [&](const ExtArgumentType& ext_arg, const ExtArgumentType& sql_arg) {
287  return ((is_ext_arg_type_scalar(ext_arg) && is_ext_arg_type_scalar(sql_arg)) ||
289  };
290 
291  while ((size_t)ext_arg_index < sizer) {
292  if ((size_t)ext_arg_index == sizer - 1)
293  return sql_arg_index;
294 
295  const auto& ext_arg = input_args_[ext_arg_index];
296  const auto& sql_arg = sql_args_[sql_arg_index];
297 
298  if (same_kind(ext_arg, sql_arg)) {
299  ++ext_arg_index;
300  ++sql_arg_index;
301  } else {
302  CHECK(same_kind(ext_arg, sql_args_[sql_arg_index - 1]));
303  ext_arg_index += 1;
304  }
305  }
306 
307  CHECK(false);
308  }
309 
310  return getOutputRowSizeParameter();
311 }
312 
313 bool is_table_function_whitelisted(const std::string& function_name) {
314  // All table functions that will be on by default (and not just for testing)
315  // must be added to the whitelisted_table_functions set below.
316  static const std::unordered_set<std::string> whitelisted_table_functions = {
317  "generate_series",
318  "generate_random_strings",
319  "tf_mandelbrot",
320  "tf_mandelbrot_float",
321  "tf_mandelbrot_cuda",
322  "tf_mandelbrot_cuda_float",
323  "tf_geo_rasterize",
324  "tf_geo_rasterize_slope",
325  "tf_compute_dwell_times",
326  "tf_feature_similarity",
327  "tf_feature_self_similarity",
328  "tf_graph_shortest_path",
329  "tf_graph_shortest_paths_distances",
330  "tf_raster_graph_shortest_slope_weighted_path",
331  "supported_ml_frameworks",
332  "kmeans",
333  "dbscan",
334  "linear_reg_fit",
335  "linear_reg_predict",
336  "linear_reg_fit_predict",
337  "tf_point_cloud_metadata",
338  "tf_load_point_cloud"
339 #ifdef HAVE_OMNIVERSE_CONNECTOR
340  ,
341  "tf_export_ov_terrain_texture",
342  "tf_export_ov_buildings_texture",
343  "tf_export_ov_polygons_2d",
344  "tf_export_ov_polygons_3d",
345  "tf_export_ov_grid_mesh"
346 #endif
347  };
348 
349  return whitelisted_table_functions.find(function_name) !=
350  whitelisted_table_functions.end();
351 }
352 
354  const std::string& name,
355  const TableFunctionOutputRowSizer sizer,
356  const std::vector<ExtArgumentType>& input_args,
357  const std::vector<ExtArgumentType>& output_args,
358  const std::vector<ExtArgumentType>& sql_args,
359  const std::vector<std::map<std::string, std::string>>& annotations,
360  bool is_runtime) {
361  static const std::map<std::string, std::string> empty = {};
362 
363  auto func_annotations =
364  (annotations.size() == sql_args.size() + output_args.size() + 1 ? annotations.back()
365  : empty);
366  auto mgr_annotation = func_annotations.find("uses_manager");
367  bool uses_manager = mgr_annotation != func_annotations.end() &&
368  boost::algorithm::to_lower_copy(mgr_annotation->second) == "true";
369 
370  auto tf = TableFunction(name,
371  sizer,
372  input_args,
373  output_args,
374  sql_args,
375  annotations,
376  is_runtime,
377  uses_manager);
378  const auto tf_name = tf.getName(true /* drop_suffix */, true /* lower */);
379  if (!g_enable_dev_table_functions && !is_runtime &&
380  !is_table_function_whitelisted(tf_name)) {
381  return;
382  }
383  auto sig = tf.getSignature(/* include_name */ true, /* include_output */ false);
384  for (auto it = functions_.begin(); it != functions_.end();) {
385  if (it->second.getName() == name) {
386  if (it->second.isRuntime()) {
387  LOG(WARNING)
388  << "Overriding existing run-time table function (reset not called?): "
389  << name;
390  it = functions_.erase(it);
391  } else {
392  throw std::runtime_error("Will not override existing load-time table function: " +
393  name);
394  }
395  } else {
396  if (sig == it->second.getSignature(/* include_name */ true,
397  /* include_output */ false) &&
398  ((tf.isCPU() && it->second.isCPU()) || (tf.isGPU() && it->second.isGPU()))) {
399  LOG(WARNING)
400  << "The existing (1) and added (2) table functions have the same signature `"
401  << sig << "`:\n"
402  << " 1: " << it->second.toString() << "\n 2: " << tf.toString() << "\n";
403  }
404  ++it;
405  }
406  }
407 
408  functions_.emplace(name, tf);
410  auto input_args2 = input_args;
411  input_args2.erase(input_args2.begin() + sizer.val - 1);
412 
413  auto sql_args2 = sql_args;
414  auto sql_sizer_pos = tf.getSqlOutputRowSizeParameter();
415  sql_args2.erase(sql_args2.begin() + sql_sizer_pos);
416 
417  auto annotations2 = annotations;
418  annotations2.erase(annotations2.begin() + sql_sizer_pos);
419 
421  sizer,
422  input_args2,
423  output_args,
424  sql_args2,
425  annotations2,
426  is_runtime,
427  uses_manager);
428  auto sig = tf2.getSignature(/* include_name */ true, /* include_output */ false);
429  for (auto it = functions_.begin(); it != functions_.end();) {
430  if (sig == it->second.getSignature(/* include_name */ true,
431  /* include_output */ false) &&
432  ((tf2.isCPU() && it->second.isCPU()) || (tf2.isGPU() && it->second.isGPU()))) {
433  LOG(WARNING)
434  << "The existing (1) and added (2) table functions have the same signature `"
435  << sig << "`:\n"
436  << " 1: " << it->second.toString() << "\n 2: " << tf2.toString() << "\n";
437  }
438  ++it;
439  }
440  functions_.emplace(name + DEFAULT_ROW_MULTIPLIER_SUFFIX, tf2);
441  }
442 }
443 
444 /*
445  The implementation for `void TableFunctionsFactory::init()` is
446  generated by QueryEngine/scripts/generate_TableFunctionsFactory_init.py
447 */
448 
449 // removes existing runtime table functions
452  return;
453  }
454  for (auto it = functions_.begin(); it != functions_.end();) {
455  if (it->second.isRuntime()) {
456  it = functions_.erase(it);
457  } else {
458  ++it;
459  }
460  }
461 }
462 
463 namespace {
464 
465 std::string drop_suffix_impl(const std::string& str) {
466  const auto idx = str.find("__");
467  if (idx == std::string::npos) {
468  return str;
469  }
470  CHECK_GT(idx, std::string::size_type(0));
471  return str.substr(0, idx);
472 }
473 
474 } // namespace
475 
476 std::string TableFunction::getName(const bool drop_suffix, const bool lower) const {
477  std::string result = name_;
478  if (drop_suffix) {
479  result = drop_suffix_impl(result);
480  }
481  if (lower) {
483  }
484  return result;
485 }
486 
487 std::string TableFunction::getSignature(const bool include_name,
488  const bool include_output) const {
489  std::string sig;
490  if (include_name) {
491  sig += getName(/*drop_suffix=*/true, /*lower=*/true);
492  }
493 
494  size_t arg_idx = 0;
495  std::vector<std::string> args;
496  for (size_t sql_idx = 0; sql_idx < sql_args_.size(); sql_idx++) {
497  const std::vector<std::string> cursor_fields = getCursorFields(sql_idx);
498  if (cursor_fields.empty()) {
499  const auto& type = ExtensionFunctionsWhitelist::toString(input_args_[arg_idx++]);
500  const auto& name = getInputAnnotation(sql_idx, "name", "");
501  args.emplace_back(name.empty() ? type : (type + " " + name));
502  } else {
503  std::vector<std::string> vec;
504  for (size_t i = 0; i < cursor_fields.size(); i++) {
505  const auto& type = ExtensionFunctionsWhitelist::toString(input_args_[arg_idx++]);
506  const auto& name = cursor_fields[i];
507  vec.emplace_back((name.empty() ? type : type + " " + name));
508  }
509  args.emplace_back("Cursor<" + boost::algorithm::join(vec, ", ") + ">");
510  }
511  }
512  sig += "(" + boost::algorithm::join(args, ", ") + ")";
513  if (include_output) {
515  }
516  return sig;
517 }
518 
520  // gets the name of the pre flight function associated with this table function
521  return getName(false, true) + PREFLIGHT_SUFFIX;
522 }
523 
524 std::vector<TableFunction> TableFunctionsFactory::get_table_funcs(const std::string& name,
525  const bool is_gpu) {
526  std::vector<TableFunction> table_funcs;
527  for (const auto& tf : get_table_funcs(name)) {
528  if (is_gpu ? tf.isGPU() : tf.isCPU()) {
529  table_funcs.emplace_back(tf);
530  }
531  }
532  return table_funcs;
533 }
534 
535 std::vector<TableFunction> TableFunctionsFactory::get_table_funcs(
536  const std::string& name) {
537  std::vector<TableFunction> table_funcs;
538  auto table_func_name = name;
539  boost::algorithm::to_lower(table_func_name);
540  for (const auto& pair : functions_) {
541  auto fname = drop_suffix_impl(pair.first);
542  if (fname == table_func_name) {
543  table_funcs.push_back(pair.second);
544  }
545  }
546  return table_funcs;
547 }
548 
549 std::vector<TableFunction> TableFunctionsFactory::get_table_funcs(const bool is_runtime) {
550  std::vector<TableFunction> table_funcs;
551  for (const auto& pair : functions_) {
552  if (pair.second.isRuntime() == is_runtime) {
553  table_funcs.push_back(pair.second);
554  }
555  }
556  return table_funcs;
557 }
558 
559 std::vector<TableFunction> TableFunctionsFactory::get_table_funcs() {
560  std::vector<TableFunction> table_funcs;
561  for (const auto& pair : functions_) {
562  table_funcs.push_back(pair.second);
563  }
564  return table_funcs;
565 }
566 
567 std::unordered_map<std::string, TableFunction> TableFunctionsFactory::functions_;
568 
569 } // namespace table_functions
SQLTypeInfo getOutputSQLType(const size_t idx) const
std::string to_lower(const std::string &str)
const std::string getOutputAnnotation(const size_t output_arg_idx, const std::string &key, const std::string &default_) const
static std::vector< TableFunction > get_table_funcs()
bool is_ext_arg_type_scalar(const ExtArgumentType ext_arg_type)
#define PREFIX_LENGTH
static void add(const std::string &name, const TableFunctionOutputRowSizer sizer, const std::vector< ExtArgumentType > &input_args, const std::vector< ExtArgumentType > &output_args, const std::vector< ExtArgumentType > &sql_args, const std::vector< std::map< std::string, std::string >> &annotations, bool is_runtime=false)
#define LOG(tag)
Definition: Logger.h:285
std::string join(T const &container, std::string const &delim)
SQLTypeInfo ext_arg_pointer_type_to_type_info(const ExtArgumentType ext_arg_type)
#define UNREACHABLE()
Definition: Logger.h:337
#define PREFLIGHT_SUFFIX
SQLTypeInfo ext_arg_type_to_type_info_output(const ExtArgumentType ext_arg_type)
const std::map< std::string, std::string > getFunctionAnnotations() const
const std::vector< std::map< std::string, std::string > > annotations_
const std::vector< ExtArgumentType > output_args_
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
std::pair< int32_t, int32_t > getInputID(const size_t idx) const
const std::string getFunctionAnnotation(const std::string &key, const std::string &default_) const
#define CHECK_GT(x, y)
Definition: Logger.h:305
std::vector< std::string > split(std::string_view str, std::string_view delim, std::optional< size_t > maxsplit)
split apart a string into a vector of substrings
std::string getSignature(const bool include_name, const bool include_output) const
bool is_ext_arg_type_nonscalar(const ExtArgumentType ext_arg_type)
const std::vector< ExtArgumentType > sql_args_
SQLTypeInfo getInputSQLType(const size_t idx) const
const std::string getArgNames(const bool use_input_args) const
std::string drop_suffix_impl(const std::string &str)
const std::map< std::string, std::string > getOutputAnnotations(const size_t output_arg_idx) const
const std::string getInputAnnotation(const size_t input_arg_idx, const std::string &key, const std::string &default_) const
bool g_enable_dev_table_functions
Definition: Execute.cpp:113
std::string getName(const bool drop_suffix=false, const bool lower=false) const
#define CHECK_LT(x, y)
Definition: Logger.h:303
const std::string getArgTypes(const bool use_input_args) const
#define CHECK_LE(x, y)
Definition: Logger.h:304
tuple line
Definition: parse_ast.py:10
const std::vector< std::string > getCursorFields(const size_t sql_idx) const
const std::map< std::string, std::string > getInputAnnotations(const size_t input_arg_idx) const
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
bool is_table_function_whitelisted(const std::string &function_name)
#define CHECK(condition)
Definition: Logger.h:291
static std::unordered_map< std::string, TableFunction > functions_
Definition: sqltypes.h:62
const std::vector< ExtArgumentType > input_args_
string name
Definition: setup.in.py:72
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:963
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
bool g_enable_table_functions
Definition: Execute.cpp:112
const std::vector< std::map< std::string, std::string > > & getAnnotations() const