OmniSciDB  21ac014ffc
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
generate_TableFunctionsFactory_init.py
Go to the documentation of this file.
1 """Given a list of input files, scan for lines containing UDTF
2 specification statements in the following form:
3 
4  UDTF: function_name(<arguments>) -> <output column types>
5 
6 where <arguments> is a comma-separated list of argument types. The
7 argument types specifications are:
8 
9 - scalar types:
10  Int8, Int16, Int32, Int64, Float, Double, Bool, TextEncodingDict, etc
11 - column types:
12  ColumnInt8, ColumnInt16, ColumnInt32, ColumnInt64, ColumnFloat, ColumnDouble, ColumnBool, etc
13 - column list types:
14  ColumnListInt8, ColumnListInt16, ColumnListInt32, ColumnListInt64, ColumnListFloat, ColumnListDouble, ColumnListBool, etc
15 - cursor type:
16  Cursor<t0, t1, ...>
17  where t0, t1 are column or column list types
18 - output buffer size parameter type:
19  RowMultiplier<i>, ConstantParameter<i>, Constant<i>
20  where i is literal integer
21 
22 The output column types is a comma-separated list of column types, see above.
23 
24 In addition, the following equivalents are suppored:
25  Column<T> == ColumnT
26  ColumnList<T> == ColumnListT
27  Cursor<T, V, ...> == Cursor<ColumnT, ColumnV, ...>
28  int8 == int8_t == Int8, etc
29  float == Float, double == Double, bool == Bool
30  T == ColumnT for output column types
31  RowMultiplier == RowMultiplier<i> where i is the one-based position of the sizer argument
32  when no sizer argument is provided, Constant<1> is assumed
33 
34 Argument types can be annotated using `|' (bar) symbol after an
35 argument type specification. An annotation is specified by a label and
36 a value separated by `=' (equal) symbol. Multiple annotations can be
37 specified by using `|` (bar) symbol as the annotations separator.
38 Supported annotation labels are:
39 
40 - name: to specify argument name
41 - input_id: to specify the dict id mapping for output TextEncodingDict columns.
42 """
43 # Author: Pearu Peterson
44 # Created: January 2021
45 
46 import os
47 import re
48 import sys
49 from collections import namedtuple
50 
51 Signature = namedtuple('Signature', ['name', 'inputs', 'outputs', 'line'])
52 
53 Signature = namedtuple('Signature', ['name', 'inputs', 'outputs', 'input_annotations', 'output_annotations'])
54 
55 ExtArgumentTypes = ''' Int8, Int16, Int32, Int64, Float, Double, Void, PInt8, PInt16,
56 PInt32, PInt64, PFloat, PDouble, PBool, Bool, ArrayInt8, ArrayInt16,
57 ArrayInt32, ArrayInt64, ArrayFloat, ArrayDouble, ArrayBool, GeoPoint,
58 GeoLineString, Cursor, GeoPolygon, GeoMultiPolygon, ColumnInt8,
59 ColumnInt16, ColumnInt32, ColumnInt64, ColumnFloat, ColumnDouble,
60 ColumnBool, ColumnTextEncodingDict, TextEncodingNone, TextEncodingDict,
61 ColumnListInt8, ColumnListInt16, ColumnListInt32, ColumnListInt64,
62 ColumnListFloat, ColumnListDouble, ColumnListBool, ColumnListTextEncodingDict'''.strip().replace(' ', '').replace('\n', '').split(',')
63 
64 OutputBufferSizeTypes = '''
65 kConstant, kUserSpecifiedConstantParameter, kUserSpecifiedRowMultiplier, kTableFunctionSpecifiedParameter
66 '''.strip().replace(' ', '').split(',')
67 
68 SupportedAnnotations = '''
69 input_id, name
70 '''.strip().replace(' ', '').split(',')
71 
72 translate_map = dict(
73  Constant='kConstant',
74  ConstantParameter='kUserSpecifiedConstantParameter',
75  RowMultiplier='kUserSpecifiedRowMultiplier',
76  UserSpecifiedConstantParameter='kUserSpecifiedConstantParameter',
77  UserSpecifiedRowMultiplier='kUserSpecifiedRowMultiplier',
78  TableFunctionSpecifiedParameter='kTableFunctionSpecifiedParameter',
79  short='Int16',
80  int='Int32',
81  long='Int64',
82 )
83 for t in ['Int8', 'Int16', 'Int32', 'Int64', 'Float', 'Double', 'Bool',
84  'TextEncodingDict']:
85  translate_map[t.lower()] = t
86  if t.startswith('Int'):
87  translate_map[t.lower() + '_t'] = t
88 
89 
90 _is_int = re.compile(r'\d+').match
91 
92 
93 class Bracket:
94  """Holds a `NAME<ARGS>`-like structure.
95  """
96 
97  def __init__(self, name, args=None):
98  assert isinstance(name, str)
99  assert isinstance(args, tuple) or args is None, args
100  self.name = name
101  self.args = args
102 
103  def __repr__(self):
104  return 'Bracket(%r, %r)' % (self.name, self.args)
105 
106  def __str__(self):
107  if not self.args:
108  return self.name
109  return '%s<%s>' % (self.name, ', '.join(map(str, self.args)))
110 
111  def normalize(self, kind='input'):
112  """Normalize bracket for given kind
113  """
114  assert kind in ['input', 'output'], kind
115  if self.is_column_any() and self.args:
116  return Bracket(self.name + ''.join(map(str, self.args)))
117  if kind == 'input':
118  if self.name == 'Cursor':
119  args = [(a if a.is_column_any() else Bracket('Column', args=(a,))).normalize(kind=kind) for a in self.args]
120  return Bracket(self.name, tuple(args))
121  if kind == 'output':
122  if not self.is_column_any():
123  return Bracket('Column', args=(self,)).normalize(kind=kind)
124  return self
125 
126  def apply_cursor(self):
127  """Apply cursor to a non-cursor column argument type.
128 
129  TODO: this method is currently unused but we should apply
130  cursor to all input column arguments in order to distingush
131  signatures like:
132 
133  foo(Cursor(Column<int32>, Column<float>)) -> Column<int32>
134  foo(Cursor(Column<int32>), Cursor(Column<float>)) -> Column<int32>
135 
136  that at the moment are treated as the same :(
137  """
138  if self.is_column():
139  return Bracket('Cursor', args=(self,))
140  return self
141 
142  def apply_namespace(self, ns='ExtArgumentType'):
143  if self.name == 'Cursor':
144  return Bracket(ns + '::' + self.name, args=tuple(a.apply_namespace(ns=ns) for a in self.args))
145  if not self.name.startswith(ns + '::'):
146  return Bracket(ns + '::' + self.name)
147  return self
148 
149  def is_cursor(self):
150  return self.name.rsplit("::", 1)[-1] == 'Cursor'
151 
152  def is_column_any(self):
153  return self.name.rsplit("::", 1)[-1].startswith('Column')
154 
155  def is_column_list(self):
156  return self.name.rsplit("::", 1)[-1].startswith('ColumnList')
157 
158  def is_column(self):
159  return self.name.rsplit("::", 1)[-1].startswith('Column') and not self.is_column_list()
160 
162  return self.name.rsplit("::", 1)[-1].endswith('TextEncodedDict')
163 
165  return self.name.rsplit("::", 1)[-1] == 'ColumnTextEncodedDict'
166 
168  return self.name.rsplit("::", 1)[-1] == 'ColumnListTextEncodedDict'
169 
171  return self.name.rsplit("::", 1)[-1] in OutputBufferSizeTypes
172 
173  def is_row_multiplier(self):
174  return self.name.rsplit("::", 1)[-1] == 'kUserSpecifiedRowMultiplier'
175 
176  def is_user_specified(self):
177  # Return True if given argument cannot specified by user
178  if self.is_output_buffer_sizer():
179  return self.name.rsplit("::", 1)[-1] not in ('kConstant', 'kTableFunctionSpecifiedParameter')
180  return True
181 
182  def get_cpp_type(self):
183  name = self.name.rsplit("::", 1)[-1]
184  clsname = None
185  if name.startswith('ColumnList'):
186  name = name.lstrip('ColumnList')
187  clsname = 'ColumnList'
188  elif name.startswith('Column'):
189  name = name.lstrip('Column')
190  clsname = 'Column'
191  if name.startswith('Int'):
192  ctype = name.lower() + '_t'
193  elif name in ['Double', 'Float']:
194  ctype = name.lower()
195  elif name == 'TextEncodingDict':
196  ctype = name
197  else:
198  raise NotImplementedError(self)
199  if clsname is None:
200  return ctype
201  return '%s<%s>' % (clsname, ctype)
202 
203  @classmethod
204  def parse(cls, typ):
205  """typ is a string in format NAME<ARGS> or NAME
206 
207  Returns Bracket instance.
208  """
209  i = typ.find('<')
210  if i == -1:
211  name = typ.strip()
212  args = None
213  else:
214  assert typ.endswith('>'), typ
215  name = typ[:i].strip()
216  args = []
217  rest = typ[i+1:-1].strip()
218  while rest:
219  i = find_comma(rest)
220  if i == -1:
221  a, rest = rest, ''
222  else:
223  a, rest = rest[:i].rstrip(), rest[i+1:].lstrip()
224  args.append(cls.parse(a))
225  args = tuple(args)
226 
227  name = translate_map.get(name, name)
228  return cls(name, args)
229 
230 
231 def find_comma(line):
232  d = 0
233  for i, c in enumerate(line):
234  if c in '<([{':
235  d += 1
236  elif c in '>)]{':
237  d -= 1
238  elif d == 0 and c == ',':
239  return i
240  return -1
241 
242 
244  # TODO: try to parse the line to be certain about completeness.
245  # `!' is used to separate the UDTF signature and the expected result
246  return line.endswith(',') or line.endswith('->') or line.endswith('!')
247 
248 
249 def find_signatures(input_file):
250  """Returns a list of parsed UDTF signatures.
251  """
252 
253  def get_function_name(line):
254  return line.split('(')[0]
255 
256  def get_types_and_annotations(line):
257  """Line is a comma separated string of types.
258  """
259  rest = line.strip()
260  types, annotations = [], []
261  while rest:
262  i = find_comma(rest)
263  if i == -1:
264  type_annot, rest = rest, ''
265  else:
266  type_annot, rest = rest[:i].rstrip(), rest[i+1:].lstrip()
267  if '|' in type_annot:
268  typ, annots = type_annot.split('|', 1)
269  typ, annots = typ.rstrip(), annots.lstrip().split('|')
270  else:
271  typ, annots = type_annot, []
272  types.append(typ)
273  pairs = []
274  for annot in annots:
275  label, value = annot.strip().split('=', 1)
276  label, value = label.rstrip(), value.lstrip()
277  pairs.append((label, value))
278  annotations.append(pairs)
279  return types, annotations
280 
281  def get_input_types_and_annotations(line):
282  start = line.rfind('(') + 1
283  end = line.find(')')
284  assert -1 not in [start, end], line
285  return get_types_and_annotations(line[start:end])
286 
287  def get_output_types_and_annotations(line):
288  start = line.rfind('->') + 2
289  end = len(line)
290  assert -1 not in [start, end], line
291  return get_types_and_annotations(line[start:end])
292 
293  signatures = []
294 
295  last_line = None
296  for line in open(input_file).readlines():
297  line = line.strip()
298  if last_line is not None:
299  line = last_line + line
300  last_line = None
301  if not line.startswith('UDTF:'):
302  continue
303  if line_is_incomplete(line):
304  last_line = line
305  continue
306  last_line = None
307  line = line[5:].lstrip()
308  i = line.find('(')
309  j = line.find(')')
310  if i == -1 or j == -1:
311  sys.stderr.write('Invalid UDTF specification: `%s`. Skipping.\n' % (line))
312  continue
313 
314  expected_result = None
315  if '!' in line:
316  line, expected_result = line.split('!', 1)
317  expected_result = expected_result.strip()
318 
319  name = get_function_name(line)
320  input_types, input_annotations = get_input_types_and_annotations(line)
321  output_types, output_annotations = get_output_types_and_annotations(line)
322 
323  input_types = tuple([Bracket.parse(typ).normalize(kind='input') for typ in input_types])
324  output_types = tuple([Bracket.parse(typ).normalize(kind='output') for typ in output_types])
325 
326  # Apply default sizer
327  has_sizer = False
328  consumed_nargs = 0
329  for i, t in enumerate(input_types):
330  if t.is_output_buffer_sizer():
331  has_sizer = True
332  if t.is_row_multiplier():
333  if not t.args:
334  t.args = Bracket.parse('RowMultiplier<%s>' % (consumed_nargs + 1)).args
335  elif t.is_cursor():
336  consumed_nargs += len(t.args)
337  else:
338  consumed_nargs += 1
339  if not has_sizer:
340  t = Bracket.parse('kTableFunctionSpecifiedParameter<1>')
341  input_types += (t,)
342 
343  # Apply default input_id to output TextEncodedDict columns
344  default_input_id = None
345  for i, t in enumerate(input_types):
346  if t.is_column_text_encoded_dict():
347  default_input_id = 'args<%s>' % (i,)
348  break
349  elif t.is_column_list_text_encoded_dict():
350  default_input_id = 'args<%s, 0>' % (i,)
351  break
352  for t, annots in zip(output_types, output_annotations):
353  if t.is_any_text_encoded_dict():
354  has_input_id = False
355  for a in annots:
356  if a[0] == 'input_id':
357  has_input_id = True
358  break
359  if not has_input_id:
360  assert default_input_id is not None
361  annots.append(('input_id', default_input_id))
362 
363  result = name + '('
364  result += ', '.join([' | '.join([str(t)] + [k + '=' + v for k, v in a]) for t, a in zip(input_types, input_annotations)])
365  result += ') -> '
366  result += ', '.join([' | '.join([str(t)] + [k + '=' + v for k, v in a]) for t, a in zip(output_types, output_annotations)])
367 
368  if expected_result is not None:
369  assert result == expected_result, (result, expected_result)
370  if 1:
371  # Make sure that we have stable parsing result
372  line = result
373  name = get_function_name(line)
374  input_types, input_annotations = get_input_types_and_annotations(line)
375  output_types, output_annotations = get_output_types_and_annotations(line)
376  input_types = tuple([Bracket.parse(typ).normalize(kind='input') for typ in input_types])
377  output_types = tuple([Bracket.parse(typ).normalize(kind='output') for typ in output_types])
378  result2 = name + '('
379  result2 += ', '.join([' | '.join([str(t)] + [k + '=' + v for k, v in a]) for t, a in zip(input_types, input_annotations)])
380  result2 += ') -> '
381  result2 += ', '.join([' | '.join([str(t)] + [k + '=' + v for k, v in a]) for t, a in zip(output_types, output_annotations)])
382  assert result == result2, (result, result2)
383  signatures.append(Signature(name, input_types, output_types, input_annotations, output_annotations))
384 
385  return signatures
386 
387 
389  return '_template' in sig.name
390 
391 
392 def build_template_function_call(name, input_types, output_types):
393 
394  def format_cpp_type(cpp_type, idx, is_input=True):
395  # Perhaps integrate this to Bracket?
396  col_typs = ('Column', 'ColumnList')
397  idx = str(idx)
398  # TODO: use name in annotations when present?
399  arg_name = 'input' + idx if is_input else 'out' + idx
400  const = 'const ' if is_input else ''
401 
402  if any(cpp_type.startswith(ct) for ct in col_typs):
403  return '%s%s& %s' % (const, cpp_type, arg_name), arg_name
404  else:
405  return '%s %s' % (cpp_type, arg_name), arg_name
406 
407  input_cpp_args = []
408  output_cpp_args = []
409  arg_names = []
410 
411  for idx, input_type in enumerate(input_types):
412  cpp_type = input_type.get_cpp_type()
413  cpp_arg, arg_name = format_cpp_type(cpp_type, idx)
414  input_cpp_args.append(cpp_arg)
415  arg_names.append(arg_name)
416 
417  for idx, output_type in enumerate(output_types):
418  cpp_type = output_type.get_cpp_type()
419  cpp_arg, arg_name = format_cpp_type(cpp_type, idx, is_input=False)
420  output_cpp_args.append(cpp_arg)
421  arg_names.append(arg_name)
422 
423  callee = name
424  called = name.split('__')[0]
425  args = ', '.join(input_cpp_args + output_cpp_args)
426  arg_names = ', '.join(arg_names)
427 
428  template = ("EXTENSION_NOINLINE int32_t\n"
429  "%s(%s) {\n"
430  " return %s(%s);\n"
431  "}\n") % (callee, args, called, arg_names)
432  return template
433 
434 
435 def format_annotations(annotations_):
436  s = "std::vector<std::map<std::string, std::string>>{"
437  s += ', '.join(('{' + ', '.join('{"%s", "%s"}' % (k, v) for k, v in a) + '}') for a in annotations_)
438  s += "}"
439  return s
440 
441 
442 def parse_annotations(input_files):
443 
444  add_stmts = []
445  template_functions = []
446 
447  for input_file in input_files:
448  for sig in find_signatures(input_file):
449 
450  # Compute sql_types, input_types, and sizer
451  sql_types_ = []
452  input_types_ = []
453  sizer = None
454  for t in sig.inputs:
455  if t.is_output_buffer_sizer():
456  if t.is_user_specified():
457  sql_types_.append(Bracket.parse('int32').normalize(kind='input'))
458  input_types_.append(sql_types_[-1])
459  assert sizer is None # exactly one sizer argument is allowed
460  assert len(t.args) == 1, t
461  sizer = 'TableFunctionOutputRowSizer{OutputBufferSizeType::%s, %s}' % (t.name, t.args[0])
462  elif t.name == 'Cursor':
463  for t_ in t.args:
464  input_types_.append(t_)
465  sql_types_.append(Bracket('Cursor', args=()))
466  else:
467  input_types_.append(t)
468  if t.is_column_any():
469  # XXX: let Bracket handle mapping of column to cursor(column)
470  sql_types_.append(Bracket('Cursor', args=()))
471  else:
472  sql_types_.append(t)
473  assert sizer is not None
474 
475  ns_output_types = tuple([a.apply_namespace(ns='ExtArgumentType') for a in sig.outputs])
476  ns_input_types = tuple([t.apply_namespace(ns='ExtArgumentType') for t in input_types_])
477  ns_sql_types = tuple([t.apply_namespace(ns='ExtArgumentType') for t in sql_types_])
478 
479  input_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(str, ns_input_types)))
480  output_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(str, ns_output_types)))
481  sql_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(str, ns_sql_types)))
482  annotations = format_annotations(sig.input_annotations + sig.output_annotations)
483 
484  add = 'TableFunctionsFactory::add("%s", %s, %s, %s, %s, %s);' % (sig.name, sizer, input_types, output_types, sql_types, annotations)
485  add_stmts.append(add)
486 
487  if is_template_function(sig):
488  t = build_template_function_call(sig.name, input_types_, sig.outputs)
489  template_functions.append(t)
490 
491  return add_stmts, template_functions
492 
493 
494 if len(sys.argv) < 3:
495 
496  input_files = [os.path.join(os.path.dirname(__file__), 'test_udtf_signatures.hpp')]
497  print('Running tests from %s' % (', '.join(input_files)))
498  add_stmts, template_functions = parse_annotations(input_files)
499  print('Usage:\n %s %s input1.hpp input2.hpp ... output.hpp' % (sys.executable, sys.argv[0], ))
500 
501  sys.exit(1)
502 
503 input_files, output_filename = sys.argv[1:-1], sys.argv[-1]
504 assert input_files, sys.argv
505 
506 add_stmts, template_functions = parse_annotations(sys.argv[1:-1])
507 
508 content = '''
509 /*
510  This file is generated by %s. Do no edit!
511 */
512 
513 #include "QueryEngine/TableFunctions/TableFunctionsFactory.h"
514 #include "QueryEngine/TableFunctions/TableFunctions.hpp"
515 #include "QueryEngine/OmniSciTypes.h"
516 
517 extern bool g_enable_table_functions;
518 
519 namespace table_functions {
520 
521 std::once_flag init_flag;
522 
523 void TableFunctionsFactory::init() {
524  if (!g_enable_table_functions) {
525  return;
526  }
527  std::call_once(init_flag, []() {
528  %s
529  });
530 }
531 
532 %s
533 
534 } // namespace table_functions
535 ''' % (sys.argv[0], '\n '.join(add_stmts), '\n'.join(template_functions))
536 
537 
538 dirname = os.path.dirname(output_filename)
539 if dirname and not os.path.exists(dirname):
540  os.makedirs(dirname)
541 
542 f = open(output_filename, 'w')
543 f.write(content)
544 f.close()
int open(const char *path, int flags, int mode)
Definition: omnisci_fs.cpp:64
std::string strip(std::string_view str)
trim any whitespace from the left and right ends of a string
std::string join(T const &container, std::string const &delim)
std::vector< std::string > split(std::string_view str, std::string_view delim, std::optional< size_t > maxsplit)
split apart a string into a vector of substrings