OmniSciDB  91042dcc5b
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TableFunctionsFactory.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2019 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <boost/algorithm/string.hpp>
20 #include <mutex>
21 #include <unordered_set>
22 
23 extern bool g_enable_table_functions;
25 
26 namespace table_functions {
27 
28 namespace {
29 
31  switch (ext_arg_type) {
33  return SQLTypeInfo(kTINYINT, false);
35  return SQLTypeInfo(kSMALLINT, false);
37  return SQLTypeInfo(kINT, false);
39  return SQLTypeInfo(kBIGINT, false);
41  return SQLTypeInfo(kFLOAT, false);
43  return SQLTypeInfo(kDOUBLE, false);
45  return SQLTypeInfo(kBOOLEAN, false);
51  return generate_column_type(kINT);
74  default:
75  LOG(WARNING) << "ext_arg_pointer_type_to_type_info: ExtArgumentType `"
77  << "` conversion to SQLTypeInfo not implemented.";
78  UNREACHABLE();
79  }
80  UNREACHABLE();
81  return SQLTypeInfo(kNULLT, false);
82 }
83 
85  switch (ext_arg_type) {
90  return SQLTypeInfo(kTINYINT, false);
95  return SQLTypeInfo(kSMALLINT, false);
100  return SQLTypeInfo(kINT, false);
105  return SQLTypeInfo(kBIGINT, false);
110  return SQLTypeInfo(kFLOAT, false);
115  return SQLTypeInfo(kDOUBLE, false);
120  return SQLTypeInfo(kBOOLEAN, false);
124  return SQLTypeInfo(kTEXT, false, kENCODING_DICT);
125  default:
126  LOG(WARNING) << "ext_arg_pointer_type_to_type_info: ExtArgumentType `"
128  << "` conversion to SQLTypeInfo not implemented.";
129  UNREACHABLE();
130  }
131  UNREACHABLE();
132  return SQLTypeInfo(kNULLT, false);
133 }
134 
135 } // namespace
136 
138  CHECK_LT(idx, input_args_.size());
140 }
141 
143  CHECK_LT(idx, output_args_.size());
144  // TODO(adb): conditionally handle nulls
146 }
147 
149  int32_t scalar_args = 0;
150  for (const auto& ext_arg : input_args_) {
151  if (is_ext_arg_type_scalar(ext_arg)) {
152  scalar_args += 1;
153  }
154  }
155  return scalar_args;
156 }
157 
159  if (hasPreFlightOutputSizer()) {
160  return true;
161  }
162  // workaround for default args
163  for (size_t idx = 0; idx < std::min(input_args_.size(), annotations_.size()); idx++) {
164  const auto& ann = getInputAnnotation(idx);
165  if (ann.find("require") != ann.end()) {
166  return true;
167  }
168  }
169  return false;
170 }
171 
172 const std::map<std::string, std::string>& TableFunction::getAnnotation(
173  const size_t idx) const {
174  CHECK_LE(idx, sql_args_.size() + output_args_.size());
175  if (annotations_.empty() || idx >= annotations_.size()) {
176  static const std::map<std::string, std::string> empty = {};
177  return empty;
178  }
179  return annotations_[idx];
180 }
181 
182 const std::map<std::string, std::string>& TableFunction::getInputAnnotation(
183  const size_t input_arg_idx) const {
184  CHECK_LT(input_arg_idx, input_args_.size());
185  return getAnnotation(input_arg_idx);
186 }
187 
188 const std::map<std::string, std::string>& TableFunction::getOutputAnnotation(
189  const size_t output_arg_idx) const {
190  CHECK_LT(output_arg_idx, output_args_.size());
191  return getAnnotation(output_arg_idx + sql_args_.size());
192 }
193 
194 const std::map<std::string, std::string>& TableFunction::getFunctionAnnotation() const {
195  return getAnnotation(sql_args_.size() + output_args_.size());
196 }
197 
198 std::pair<int32_t, int32_t> TableFunction::getInputID(const size_t idx) const {
199  // if the annotation is of the form args<INT,INT>, it is refering to a column list
200 #define PREFIX_LENGTH 5
201  const auto& annotation = getOutputAnnotation(idx);
202  auto annot = annotation.find("input_id");
203  if (annot == annotation.end()) {
204  size_t lo = 0;
205  for (const auto& ext_arg : input_args_) {
206  switch (ext_arg) {
210  return std::make_pair(lo, 0);
211  default:
212  lo++;
213  }
214  }
215  UNREACHABLE();
216  }
217 
218  const std::string& input_id = annot->second;
219 
220  size_t comma = input_id.find(",");
221  int32_t gt = input_id.size() - 1;
222  int32_t lo = std::stoi(input_id.substr(PREFIX_LENGTH, comma - 1));
223 
224  if (comma == std::string::npos) {
225  return std::make_pair(lo, 0);
226  }
227  int32_t hi = std::stoi(input_id.substr(comma + 1, gt - comma - 1));
228  return std::make_pair(lo, hi);
229 }
230 
232  /*
233  This function differs from getOutputRowSizeParameter() since it returns the correct
234  index for the sizer in the sql_args list. For instance, consider the example below:
235 
236  RowMultiplier=4
237  input_args=[{i32*, i64}, {i32*, i64}, {i32*, i64}, i32, {i32*, i64}, {i32*, i64},
238  i32] sql_args=[cursor, i32, cursor, i32]
239 
240  Non-scalar args are aggregated in a cursor inside the sql_args list and the new
241  sizer index is 2 rather than 4 originally specified.
242  */
243 
245  size_t sizer = getOutputRowSizeParameter(); // lookup until reach the sizer arg
246  int32_t ext_arg_index = 0, sql_arg_index = 0;
247 
248  auto same_kind = [&](const ExtArgumentType& ext_arg, const ExtArgumentType& sql_arg) {
249  return ((is_ext_arg_type_scalar(ext_arg) && is_ext_arg_type_scalar(sql_arg)) ||
251  };
252 
253  while ((size_t)ext_arg_index < sizer) {
254  if ((size_t)ext_arg_index == sizer - 1)
255  return sql_arg_index;
256 
257  const auto& ext_arg = input_args_[ext_arg_index];
258  const auto& sql_arg = sql_args_[sql_arg_index];
259 
260  if (same_kind(ext_arg, sql_arg)) {
261  ++ext_arg_index;
262  ++sql_arg_index;
263  } else {
264  CHECK(same_kind(ext_arg, sql_args_[sql_arg_index - 1]));
265  ext_arg_index += 1;
266  }
267  }
268 
269  CHECK(false);
270  }
271 
272  return getOutputRowSizeParameter();
273 }
274 
275 bool is_table_function_whitelisted(const std::string& function_name) {
276  // All table functions that will be on by default (and not just for testing)
277  // must be added to the whitelisted_table_functions set below.
278  static const std::unordered_set<std::string> whitelisted_table_functions = {
279  "generate_series",
280  "tf_mandelbrot",
281  "tf_mandelbrot_float",
282  "tf_mandelbrot_cuda",
283  "tf_mandelbrot_cuda_float",
284  "tf_geo_rasterize",
285  "tf_geo_rasterize_slope",
286  "tf_rf_prop",
287  "tf_rf_prop_max_signal"};
288 
289  return (whitelisted_table_functions.find(function_name) !=
290  whitelisted_table_functions.end());
291 }
292 
294  const std::string& name,
295  const TableFunctionOutputRowSizer sizer,
296  const std::vector<ExtArgumentType>& input_args,
297  const std::vector<ExtArgumentType>& output_args,
298  const std::vector<ExtArgumentType>& sql_args,
299  const std::vector<std::map<std::string, std::string>>& annotations,
300  bool is_runtime) {
301  static const std::map<std::string, std::string> empty = {};
302 
303  auto func_annotations =
304  (annotations.size() == sql_args.size() + output_args.size() + 1 ? annotations.back()
305  : empty);
306  auto mgr_annotation = func_annotations.find("uses_manager");
307  bool uses_manager = mgr_annotation != func_annotations.end() &&
308  boost::algorithm::to_lower_copy(mgr_annotation->second) == "true";
309 
310  auto tf = TableFunction(name,
311  sizer,
312  input_args,
313  output_args,
314  sql_args,
315  annotations,
316  is_runtime,
317  uses_manager);
318  if (!g_enable_dev_table_functions && !is_runtime &&
320  tf.getName(true /* drop_suffix */, true /* lower */))) {
321  return;
322  }
323  auto sig = tf.getSignature();
324  for (auto it = functions_.begin(); it != functions_.end();) {
325  if (it->second.getName() == name) {
326  if (it->second.isRuntime()) {
327  LOG(WARNING)
328  << "Overriding existing run-time table function (reset not called?): "
329  << name;
330  it = functions_.erase(it);
331  } else {
332  throw std::runtime_error("Will not override existing load-time table function: " +
333  name);
334  }
335  } else {
336  if (sig == it->second.getSignature() &&
337  ((tf.isCPU() && it->second.isCPU()) || (tf.isGPU() && it->second.isGPU()))) {
338  LOG(WARNING)
339  << "The existing (1) and added (2) table functions have the same signature `"
340  << sig << "`:\n"
341  << " 1: " << it->second.toString() << "\n 2: " << tf.toString() << "\n";
342  }
343  ++it;
344  }
345  }
346 
347  functions_.emplace(name, tf);
349  auto input_args2 = input_args;
350  input_args2.erase(input_args2.begin() + sizer.val - 1);
351 
352  auto sql_args2 = sql_args;
353  auto sql_sizer_pos = tf.getSqlOutputRowSizeParameter();
354  sql_args2.erase(sql_args2.begin() + sql_sizer_pos);
355 
357  sizer,
358  input_args2,
359  output_args,
360  sql_args2,
361  annotations,
362  is_runtime,
363  uses_manager);
364  auto sig = tf2.getSignature();
365  for (auto it = functions_.begin(); it != functions_.end();) {
366  if (sig == it->second.getSignature() &&
367  ((tf2.isCPU() && it->second.isCPU()) || (tf2.isGPU() && it->second.isGPU()))) {
368  LOG(WARNING)
369  << "The existing (1) and added (2) table functions have the same signature `"
370  << sig << "`:\n"
371  << " 1: " << it->second.toString() << "\n 2: " << tf2.toString() << "\n";
372  }
373  ++it;
374  }
375  functions_.emplace(name + DEFAULT_ROW_MULTIPLIER_SUFFIX, tf2);
376  }
377 }
378 
379 /*
380  The implementation for `void TableFunctionsFactory::init()` is
381  generated by QueryEngine/scripts/generate_TableFunctionsFactory_init.py
382 */
383 
384 // removes existing runtime table functions
387  return;
388  }
389  for (auto it = functions_.begin(); it != functions_.end();) {
390  if (it->second.isRuntime()) {
391  it = functions_.erase(it);
392  } else {
393  ++it;
394  }
395  }
396 }
397 
398 namespace {
399 
400 std::string drop_suffix_impl(const std::string& str) {
401  const auto idx = str.find("__");
402  if (idx == std::string::npos) {
403  return str;
404  }
405  CHECK_GT(idx, std::string::size_type(0));
406  return str.substr(0, idx);
407 }
408 
409 } // namespace
410 
411 std::string TableFunction::getName(const bool drop_suffix, const bool lower) const {
412  std::string result = name_;
413  if (drop_suffix) {
414  result = drop_suffix_impl(result);
415  }
416  if (lower) {
418  }
419  return result;
420 }
421 
423  // gets the name of the pre flight function associated with this table function
424  return getName(false, true) + PREFLIGHT_SUFFIX;
425 }
426 
427 std::vector<TableFunction> TableFunctionsFactory::get_table_funcs(const std::string& name,
428  const bool is_gpu) {
429  std::vector<TableFunction> table_funcs;
430  auto table_func_name = name;
431  boost::algorithm::to_lower(table_func_name);
432  for (const auto& pair : functions_) {
433  auto fname = drop_suffix_impl(pair.first);
434  if (fname == table_func_name &&
435  (is_gpu ? pair.second.isGPU() : pair.second.isCPU())) {
436  table_funcs.push_back(pair.second);
437  }
438  }
439  return table_funcs;
440 }
441 
442 std::vector<TableFunction> TableFunctionsFactory::get_table_funcs(const bool is_runtime) {
443  std::vector<TableFunction> table_funcs;
444  for (const auto& pair : functions_) {
445  if (pair.second.isRuntime() == is_runtime) {
446  table_funcs.push_back(pair.second);
447  }
448  }
449  return table_funcs;
450 }
451 
452 std::unordered_map<std::string, TableFunction> TableFunctionsFactory::functions_;
453 
454 } // namespace table_functions
SQLTypeInfo getOutputSQLType(const size_t idx) const
std::string to_lower(const std::string &str)
bool is_ext_arg_type_scalar(const ExtArgumentType ext_arg_type)
#define PREFIX_LENGTH
static void add(const std::string &name, const TableFunctionOutputRowSizer sizer, const std::vector< ExtArgumentType > &input_args, const std::vector< ExtArgumentType > &output_args, const std::vector< ExtArgumentType > &sql_args, const std::vector< std::map< std::string, std::string >> &annotations, bool is_runtime=false)
#define LOG(tag)
Definition: Logger.h:205
auto generate_column_type(const SQLTypes subtype)
Definition: sqltypes.h:1101
string name
Definition: setup.in.py:72
SQLTypeInfo ext_arg_pointer_type_to_type_info(const ExtArgumentType ext_arg_type)
#define UNREACHABLE()
Definition: Logger.h:255
#define PREFLIGHT_SUFFIX
SQLTypeInfo ext_arg_type_to_type_info_output(const ExtArgumentType ext_arg_type)
const std::map< std::string, std::string > & getInputAnnotation(const size_t input_arg_idx) const
const std::vector< std::map< std::string, std::string > > annotations_
const std::vector< ExtArgumentType > output_args_
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
std::pair< int32_t, int32_t > getInputID(const size_t idx) const
#define CHECK_GT(x, y)
Definition: Logger.h:223
const std::map< std::string, std::string > & getFunctionAnnotation() const
bool is_ext_arg_type_nonscalar(const ExtArgumentType ext_arg_type)
const std::vector< ExtArgumentType > sql_args_
SQLTypeInfo getInputSQLType(const size_t idx) const
auto generate_column_list_type(const SQLTypes subtype)
Definition: sqltypes.h:1115
bool g_enable_dev_table_functions
Definition: Execute.cpp:110
const std::map< std::string, std::string > & getOutputAnnotation(const size_t output_arg_idx) const
std::string getName(const bool drop_suffix=false, const bool lower=false) const
#define CHECK_LT(x, y)
Definition: Logger.h:221
Definition: sqltypes.h:52
const std::map< std::string, std::string > & getAnnotation(const size_t idx) const
#define CHECK_LE(x, y)
Definition: Logger.h:222
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
bool is_table_function_whitelisted(const std::string &function_name)
#define CHECK(condition)
Definition: Logger.h:211
static std::vector< TableFunction > get_table_funcs(const std::string &name, const bool is_gpu)
static std::unordered_map< std::string, TableFunction > functions_
Definition: sqltypes.h:45
const std::vector< ExtArgumentType > input_args_
bool g_enable_table_functions
Definition: Execute.cpp:109