OmniSciDB  6686921089
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
generate_TableFunctionsFactory_init.py
Go to the documentation of this file.
1 """Given a list of input files, scan for lines containing UDTF
2 specification statements in the following form:
3 
4  UDTF: function_name(<arguments>) -> <output column types> (, <template type specifications>)?
5 
6 where <arguments> is a comma-separated list of argument types. The
7 argument types specifications are:
8 
9 - scalar types:
10  Int8, Int16, Int32, Int64, Float, Double, Bool, TextEncodingDict, etc
11 - column types:
12  ColumnInt8, ColumnInt16, ColumnInt32, ColumnInt64, ColumnFloat, ColumnDouble, ColumnBool, etc
13 - column list types:
14  ColumnListInt8, ColumnListInt16, ColumnListInt32, ColumnListInt64, ColumnListFloat, ColumnListDouble, ColumnListBool, etc
15 - cursor type:
16  Cursor<t0, t1, ...>
17  where t0, t1 are column or column list types
18 - output buffer size parameter type:
19  RowMultiplier<i>, ConstantParameter<i>, Constant<i>, TableFunctionSpecifiedParameter<i>
20  where i is a literal integer.
21 
22 The output column types is a comma-separated list of column types, see above.
23 
24 In addition, the following equivalents are suppored:
25 
26  Column<T> == ColumnT
27  ColumnList<T> == ColumnListT
28  Cursor<T, V, ...> == Cursor<ColumnT, ColumnV, ...>
29  int8 == int8_t == Int8, etc
30  float == Float, double == Double, bool == Bool
31  T == ColumnT for output column types
32  RowMultiplier == RowMultiplier<i> where i is the one-based position of the sizer argument
33  when no sizer argument is provided, Constant<1> is assumed
34 
35 Argument types can be annotated using `|' (bar) symbol after an
36 argument type specification. An annotation is specified by a label and
37 a value separated by `=' (equal) symbol. Multiple annotations can be
38 specified by using `|` (bar) symbol as the annotations separator.
39 Supported annotation labels are:
40 
41 - name: to specify argument name
42 - input_id: to specify the dict id mapping for output TextEncodingDict columns.
43 
44 If argument type follows an identifier, it will be mapped to name
45 annotations. For example, the following argument type specifications
46 are equivalent:
47 
48  Int8 a
49  Int8 | name=a
50 
51 Template type specifications is a comma separated list of template
52 type assignments where values are lists of argument type names. For
53 instance:
54 
55  T = [Int8, Int16, Int32, Float], V = [Float, Double]
56 
57 """
58 # Author: Pearu Peterson
59 # Created: January 2021
60 
61 
62 import os
63 import sys
64 import itertools
65 import copy
66 from abc import abstractmethod
67 
68 from collections import deque, namedtuple
69 
70 if sys.version_info > (3, 0):
71  from abc import ABC
72  from collections.abc import Iterable
73 else:
74  from abc import ABCMeta as ABC
75  from collections import Iterable
76 
77 # fmt: off
78 Signature = namedtuple('Signature', ['name', 'inputs', 'outputs', 'input_annotations', 'output_annotations', 'function_annotations'])
79 
80 ExtArgumentTypes = ''' Int8, Int16, Int32, Int64, Float, Double, Void, PInt8, PInt16,
81 PInt32, PInt64, PFloat, PDouble, PBool, Bool, ArrayInt8, ArrayInt16,
82 ArrayInt32, ArrayInt64, ArrayFloat, ArrayDouble, ArrayBool, GeoPoint,
83 GeoLineString, Cursor, GeoPolygon, GeoMultiPolygon, ColumnInt8,
84 ColumnInt16, ColumnInt32, ColumnInt64, ColumnFloat, ColumnDouble,
85 ColumnBool, ColumnTextEncodingDict, TextEncodingNone, TextEncodingDict,
86 ColumnListInt8, ColumnListInt16, ColumnListInt32, ColumnListInt64,
87 ColumnListFloat, ColumnListDouble, ColumnListBool, ColumnListTextEncodingDict'''.strip().replace(' ', '').replace('\n', '').split(',')
88 
89 OutputBufferSizeTypes = '''
90 kConstant, kUserSpecifiedConstantParameter, kUserSpecifiedRowMultiplier, kTableFunctionSpecifiedParameter
91 '''.strip().replace(' ', '').split(',')
92 
93 SupportedAnnotations = '''
94 input_id, name, fields, require
95 '''.strip().replace(' ', '').split(',')
96 
97 # TODO: support `gpu`, `cpu`, `template` as function annotations
98 SupportedFunctionAnnotations = '''
99 filter_table_function_transpose
100 '''.strip().replace(' ', '').split(',')
101 
102 translate_map = dict(
103  Constant='kConstant',
104  ConstantParameter='kUserSpecifiedConstantParameter',
105  RowMultiplier='kUserSpecifiedRowMultiplier',
106  UserSpecifiedConstantParameter='kUserSpecifiedConstantParameter',
107  UserSpecifiedRowMultiplier='kUserSpecifiedRowMultiplier',
108  TableFunctionSpecifiedParameter='kTableFunctionSpecifiedParameter',
109  short='Int16',
110  int='Int32',
111  long='Int64',
112 )
113 for t in ['Int8', 'Int16', 'Int32', 'Int64', 'Float', 'Double', 'Bool',
114  'TextEncodingDict', 'TextEncodingNone']:
115  translate_map[t.lower()] = t
116  if t.startswith('Int'):
117  translate_map[t.lower() + '_t'] = t
118 
119 
121  """Holds a `TYPE | ANNOTATIONS`-like structure.
122  """
123  def __init__(self, type, annotations=[]):
124  self.type = type
125  self.annotations = annotations
126 
127  def __repr__(self):
128  return 'Declaration(%r, %r)' % (self.type, self.annotations)
129 
130  def __str__(self):
131  if not self.annotations:
132  return str(self.type)
133  return '%s | %s' % (self.type, ' | '.join(map(str, self.annotations)))
134 
135  def tostring(self):
136  return self.type.tostring()
137 
138  def apply_column(self):
139  return self.__class__(self.type.apply_column(), self.annotations)
140 
141  def apply_namespace(self, ns='ExtArgumentType'):
142  return self.__class__(self.type.apply_namespace(ns), self.annotations)
143 
144  def get_cpp_type(self):
145  return self.type.get_cpp_type()
146 
147  def format_cpp_type(self, idx):
148  return self.type.format_cpp_type(idx)
149 
150  def __getattr__(self, name):
151  if name.startswith('is_'):
152  return getattr(self.type, name)()
153  raise AttributeError(name)
154 
155 
156 def tostring(obj):
157  return obj.tostring()
158 
159 
160 class Bracket:
161  """Holds a `NAME<ARGS>`-like structure.
162  """
163 
164  def __init__(self, name, args=None):
165  assert isinstance(name, str)
166  assert isinstance(args, tuple) or args is None, args
167  self.name = name
168  self.args = args
169 
170  def __repr__(self):
171  return 'Bracket(%r, %r)' % (self.name, self.args)
172 
173  def __str__(self):
174  if not self.args:
175  return self.name
176  return '%s<%s>' % (self.name, ', '.join(map(str, self.args)))
177 
178  def tostring(self):
179  if not self.args:
180  return self.name
181  return '%s<%s>' % (self.name, ', '.join(map(tostring, self.args)))
182 
183  def normalize(self, kind='input'):
184  """Normalize bracket for given kind
185  """
186  assert kind in ['input', 'output'], kind
187  if self.is_column_any() and self.args:
188  return Bracket(self.name + ''.join(map(str, self.args)))
189  if kind == 'input':
190  if self.name == 'Cursor':
191  args = [(a if a.is_column_any() else Bracket('Column', args=(a,))).normalize(kind=kind) for a in self.args]
192  return Bracket(self.name, tuple(args))
193  if kind == 'output':
194  if not self.is_column_any():
195  return Bracket('Column', args=(self,)).normalize(kind=kind)
196  return self
197 
198  def apply_cursor(self):
199  """Apply cursor to a non-cursor column argument type.
200  TODO: this method is currently unused but we should apply
201  cursor to all input column arguments in order to distingush
202  signatures like:
203  foo(Cursor(Column<int32>, Column<float>)) -> Column<int32>
204  foo(Cursor(Column<int32>), Cursor(Column<float>)) -> Column<int32>
205  that at the moment are treated as the same :(
206  """
207  if self.is_column():
208  return Bracket('Cursor', args=(self,))
209  return self
210 
211  def apply_column(self):
212  if not self.is_column() and not self.is_column_list():
213  return Bracket('Column' + self.name)
214  return self
215 
216  def apply_namespace(self, ns='ExtArgumentType'):
217  if self.name == 'Cursor':
218  return Bracket(ns + '::' + self.name, args=tuple(a.apply_namespace(ns=ns) for a in self.args))
219  if not self.name.startswith(ns + '::'):
220  return Bracket(ns + '::' + self.name)
221  return self
222 
223  def is_cursor(self):
224  return self.name.rsplit("::", 1)[-1] == 'Cursor'
225 
226  def is_column_any(self):
227  return self.name.rsplit("::", 1)[-1].startswith('Column')
228 
229  def is_column_list(self):
230  return self.name.rsplit("::", 1)[-1].startswith('ColumnList')
231 
232  def is_column(self):
233  return self.name.rsplit("::", 1)[-1].startswith('Column') and not self.is_column_list()
234 
236  return self.name.rsplit("::", 1)[-1].endswith('TextEncodedDict')
237 
239  return self.name.rsplit("::", 1)[-1] == 'ColumnTextEncodedDict'
240 
242  return self.name.rsplit("::", 1)[-1] == 'ColumnListTextEncodedDict'
243 
245  return self.name.rsplit("::", 1)[-1] in OutputBufferSizeTypes
246 
247  def is_row_multiplier(self):
248  return self.name.rsplit("::", 1)[-1] == 'kUserSpecifiedRowMultiplier'
249 
250  def is_user_specified(self):
251  # Return True if given argument cannot specified by user
252  if self.is_output_buffer_sizer():
253  return self.name.rsplit("::", 1)[-1] not in ('kConstant', 'kTableFunctionSpecifiedParameter')
254  return True
255 
256  def get_cpp_type(self):
257  name = self.name.rsplit("::", 1)[-1]
258  clsname = None
259  if name.startswith('ColumnList'):
260  name = name.lstrip('ColumnList')
261  clsname = 'ColumnList'
262  elif name.startswith('Column'):
263  name = name.lstrip('Column')
264  clsname = 'Column'
265  if name.startswith('Bool'):
266  ctype = name.lower()
267  elif name.startswith('Int'):
268  ctype = name.lower() + '_t'
269  elif name in ['Double', 'Float']:
270  ctype = name.lower()
271  elif name == 'TextEncodingDict':
272  ctype = name
273  elif name == 'TextEncodingNone':
274  ctype = name
275  else:
276  raise NotImplementedError(self)
277  if clsname is None:
278  return ctype
279  return '%s<%s>' % (clsname, ctype)
280 
281  def format_cpp_type(self, idx, arg_name=None, is_input=True):
282  # Perhaps integrate this to Bracket?
283  col_typs = ('Column', 'ColumnList')
284  literal_ref_typs = ('TextEncodingNone',)
285  # TODO: use name in annotations when present?
286  if arg_name is None:
287  arg_name = 'input' + str(idx) if is_input else 'output' + str(idx)
288  const = 'const ' if is_input else ''
289  cpp_type = self.get_cpp_type()
290 
291  if any(cpp_type.startswith(t) for t in col_typs + literal_ref_typs):
292  return '%s%s& %s' % (const, cpp_type, arg_name), arg_name
293  else:
294  return '%s %s' % (cpp_type, arg_name), arg_name
295 
296  @classmethod
297  def parse(cls, typ):
298  """typ is a string in format NAME<ARGS> or NAME
299 
300  Returns Bracket instance.
301  """
302  i = typ.find('<')
303  if i == -1:
304  name = typ.strip()
305  args = None
306  else:
307  assert typ.endswith('>'), typ
308  name = typ[:i].strip()
309  args = []
310  rest = typ[i + 1:-1].strip()
311  while rest:
312  i = find_comma(rest)
313  if i == -1:
314  a, rest = rest, ''
315  else:
316  a, rest = rest[:i].rstrip(), rest[i + 1:].lstrip()
317  args.append(cls.parse(a))
318  args = tuple(args)
319 
320  name = translate_map.get(name, name)
321  return cls(name, args)
322 
323 
324 def find_comma(line):
325  d = 0
326  for i, c in enumerate(line):
327  if c in '<([{':
328  d += 1
329  elif c in '>)]{':
330  d -= 1
331  elif d == 0 and c == ',':
332  return i
333  return -1
334 
335 
337  # TODO: try to parse the line to be certain about completeness.
338  # `!' is used to separate the UDTF signature and the expected result
339  return line.endswith(',') or line.endswith('->') or line.endswith('!')
340 
341 
342 def is_identifier_cursor(identifier):
343  return identifier.lower() == 'cursor'
344 
345 
346 # fmt: on
347 
348 
349 class TokenizeException(Exception):
350  pass
351 
352 
353 class ParserException(Exception):
354  pass
355 
356 
357 class TransformerException(Exception):
358  pass
359 
360 
361 class Token:
362  LESS = 1 # <
363  GREATER = 2 # >
364  COMMA = 3 # ,
365  EQUAL = 4 # =
366  RARROW = 5 # ->
367  STRING = 6 # reserved for string constants
368  NUMBER = 7 #
369  VBAR = 8 # |
370  BANG = 9 # !
371  LPAR = 10 # (
372  RPAR = 11 # )
373  LSQB = 12 # [
374  RSQB = 13 # ]
375  IDENTIFIER = 14 #
376  COLON = 15 # :
377 
378  def __init__(self, type, lexeme):
379  """
380  Parameters
381  ----------
382  type : int
383  One of the tokens in the list above
384  lexeme : str
385  Corresponding string in the text
386  """
387  self.type = type
388  self.lexeme = lexeme
389 
390  @classmethod
391  def tok_name(cls, token):
392  names = {
393  Token.LESS: "LESS",
394  Token.GREATER: "GREATER",
395  Token.COMMA: "COMMA",
396  Token.EQUAL: "EQUAL",
397  Token.RARROW: "RARROW",
398  Token.STRING: "STRING",
399  Token.NUMBER: "NUMBER",
400  Token.VBAR: "VBAR",
401  Token.BANG: "BANG",
402  Token.LPAR: "LPAR",
403  Token.RPAR: "RPAR",
404  Token.LSQB: "LSQB",
405  Token.RSQB: "RSQB",
406  Token.IDENTIFIER: "IDENTIFIER",
407  Token.COLON: "COLON",
408  }
409  return names.get(token)
410 
411  def __str__(self):
412  return 'Token(%s, "%s")' % (Token.tok_name(self.type), self.lexeme)
413 
414  __repr__ = __str__
415 
416 
417 class Tokenize:
418  def __init__(self, line):
419  self._line = line
420  self._tokens = []
421  self.start = 0
422  self.curr = 0
423  self.tokenize()
424 
425  @property
426  def line(self):
427  return self._line
428 
429  @property
430  def tokens(self):
431  return self._tokens
432 
433  def tokenize(self):
434  while not self.is_at_end():
435  self.start = self.curr
436 
437  if self.is_token_whitespace():
438  self.consume_whitespace()
439  elif self.is_digit():
440  self.consume_number()
441  elif self.is_token_string():
442  self.consume_string()
443  elif self.is_token_identifier():
444  self.consume_identifier()
445  elif self.can_token_be_double_char():
446  self.consume_double_char()
447  else:
448  self.consume_single_char()
449 
450  def is_at_end(self):
451  return len(self.line) == self.curr
452 
453  def current_token(self):
454  return self.line[self.start:self.curr + 1]
455 
456  def add_token(self, type):
457  lexeme = self.line[self.start:self.curr + 1]
458  self._tokens.append(Token(type, lexeme))
459 
460  def lookahead(self):
461  if self.curr + 1 >= len(self.line):
462  return None
463  return self.line[self.curr + 1]
464 
465  def advance(self):
466  self.curr += 1
467 
468  def peek(self):
469  return self.line[self.curr]
470 
472  char = self.peek()
473  return char in ("-",)
474 
476  ahead = self.lookahead()
477  if ahead == ">":
478  self.advance()
479  self.add_token(Token.RARROW) # ->
480  self.advance()
481  else:
482  self.raise_tokenize_error()
483 
485  char = self.peek()
486  if char == "(":
487  self.add_token(Token.LPAR)
488  elif char == ")":
489  self.add_token(Token.RPAR)
490  elif char == "<":
491  self.add_token(Token.LESS)
492  elif char == ">":
493  self.add_token(Token.GREATER)
494  elif char == ",":
495  self.add_token(Token.COMMA)
496  elif char == "=":
497  self.add_token(Token.EQUAL)
498  elif char == "|":
499  self.add_token(Token.VBAR)
500  elif char == "!":
501  self.add_token(Token.BANG)
502  elif char == "[":
503  self.add_token(Token.LSQB)
504  elif char == "]":
505  self.add_token(Token.RSQB)
506  elif char == ":":
507  self.add_token(Token.COLON)
508  else:
509  self.raise_tokenize_error()
510  self.advance()
511 
513  self.advance()
514 
515  def consume_string(self):
516  """
517  STRING: \".*?\"
518  """
519  while True:
520  char = self.lookahead()
521  curr = self.peek()
522  if char == '"' and curr != '\\':
523  self.advance()
524  break
525  self.advance()
526  self.add_token(Token.STRING)
527  self.advance()
528 
529  def consume_number(self):
530  """
531  NUMBER: [0-9]+
532  """
533  while True:
534  char = self.lookahead()
535  if char and char.isdigit():
536  self.advance()
537  else:
538  break
539  self.add_token(Token.NUMBER)
540  self.advance()
541 
543  """
544  IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
545  """
546  while True:
547  char = self.lookahead()
548  if char and char.isalnum() or char == "_":
549  self.advance()
550  else:
551  break
552  self.add_token(Token.IDENTIFIER)
553  self.advance()
554 
556  return self.peek().isalpha() or self.peek() == "_"
557 
558  def is_token_string(self):
559  return self.peek() == '"'
560 
561  def is_digit(self):
562  return self.peek().isdigit()
563 
564  def is_alpha(self):
565  return self.peek().isalpha()
566 
568  return self.peek().isspace()
569 
571  curr = self.curr
572  char = self.peek()
573  raise TokenizeException(
574  'Could not match char "%s" at pos %d on line\n %s' % (char, curr, self.line)
575  )
576 
577 
578 class AstVisitor(object):
579  __metaclass__ = ABC
580 
581  @abstractmethod
582  def visit_udtf_node(self, node):
583  pass
584 
585  @abstractmethod
586  def visit_composed_node(self, node):
587  pass
588 
589  @abstractmethod
590  def visit_arg_node(self, node):
591  pass
592 
593  @abstractmethod
594  def visit_primitive_node(self, node):
595  pass
596 
597  @abstractmethod
598  def visit_annotation_node(self, node):
599  pass
600 
601  @abstractmethod
602  def visit_template_node(self, node):
603  pass
604 
605 
606 class AstTransformer(AstVisitor):
607  """Only overload the methods you need"""
608 
609  def visit_udtf_node(self, udtf_node):
610  udtf = copy.copy(udtf_node)
611  udtf.inputs = [arg.accept(self) for arg in udtf.inputs]
612  udtf.outputs = [arg.accept(self) for arg in udtf.outputs]
613  if udtf.templates:
614  udtf.templates = [t.accept(self) for t in udtf.templates]
615  udtf.annotations = [annot.accept(self) for annot in udtf.annotations]
616  return udtf
617 
618  def visit_composed_node(self, composed_node):
619  c = copy.copy(composed_node)
620  c.inner = [i.accept(self) for i in c.inner]
621  return c
622 
623  def visit_arg_node(self, arg_node):
624  arg_node = copy.copy(arg_node)
625  arg_node.type = arg_node.type.accept(self)
626  if arg_node.annotations:
627  arg_node.annotations = [a.accept(self) for a in arg_node.annotations]
628  return arg_node
629 
630  def visit_primitive_node(self, primitive_node):
631  return copy.copy(primitive_node)
632 
633  def visit_template_node(self, template_node):
634  return copy.copy(template_node)
635 
636  def visit_annotation_node(self, annotation_node):
637  return copy.copy(annotation_node)
638 
639 
641  """Returns a line formatted. Useful for testing"""
642 
643  def visit_udtf_node(self, udtf_node):
644  name = udtf_node.name
645  inputs = ", ".join([arg.accept(self) for arg in udtf_node.inputs])
646  outputs = ", ".join([arg.accept(self) for arg in udtf_node.outputs])
647  annotations = "| ".join([annot.accept(self) for annot in udtf_node.annotations])
648  if annotations:
649  annotations = ' | ' + annotations
650  if udtf_node.templates:
651  templates = ", ".join([t.accept(self) for t in udtf_node.templates])
652  return "%s(%s)%s -> %s, %s" % (name, inputs, annotations, outputs, templates)
653  else:
654  return "%s(%s)%s -> %s" % (name, inputs, annotations, outputs)
655 
656  def visit_template_node(self, template_node):
657  # T=[T1, T2, ..., TN]
658  key = template_node.key
659  types = ['"%s"' % typ for typ in template_node.types]
660  return "%s=[%s]" % (key, ", ".join(types))
661 
662  def visit_annotation_node(self, annotation_node):
663  # key=value
664  key = annotation_node.key
665  value = annotation_node.value
666  if isinstance(value, list):
667  return "%s=[%s]" % (key, ','.join([v.accept(self) for v in value]))
668  return "%s=%s" % (key, value)
669 
670  def visit_arg_node(self, arg_node):
671  # type | annotation
672  typ = arg_node.type.accept(self)
673  if arg_node.annotations:
674  ann = " | ".join([a.accept(self) for a in arg_node.annotations])
675  s = "%s | %s" % (typ, ann)
676  else:
677  s = "%s" % (typ,)
678  # insert input_id=args<0> if input_id is not specified
679  if s == "ColumnTextEncodingDict" and arg_node.kind == "output":
680  return s + " | input_id=args<0>"
681  return s
682 
683  def visit_composed_node(self, composed_node):
684  T = composed_node.inner[0].accept(self)
685  if composed_node.is_column():
686  # Column<T>
687  assert len(composed_node.inner) == 1
688  return "Column" + T
689  if composed_node.is_column_list():
690  # ColumnList<T>
691  assert len(composed_node.inner) == 1
692  return "ColumnList" + T
693  if composed_node.is_output_buffer_sizer():
694  # kConstant<N>
695  N = T
696  assert len(composed_node.inner) == 1
697  return translate_map.get(composed_node.type) + "<%s>" % (N,)
698  if composed_node.is_cursor():
699  # Cursor<T1, T2, ..., TN>
700  Ts = ", ".join([i.accept(self) for i in composed_node.inner])
701  return "Cursor<%s>" % (Ts)
702  raise ValueError(composed_node)
703 
704  def visit_primitive_node(self, primitive_node):
705  t = primitive_node.type
706  if primitive_node.is_output_buffer_sizer():
707  # arg_pos is zero-based
708  return translate_map.get(t, t) + "<%d>" % (
709  primitive_node.get_parent(ArgNode).arg_pos + 1,
710  )
711  return translate_map.get(t, t)
712 
713 
714 def product_dict(**kwargs):
715  keys = kwargs.keys()
716  vals = kwargs.values()
717  for instance in itertools.product(*vals):
718  yield dict(zip(keys, instance))
719 
720 
722  """Expand template definition into multiple inputs"""
723 
724  def visit_udtf_node(self, udtf_node):
725  if not udtf_node.templates:
726  return udtf_node
727 
728  udtfs = []
729 
730  d = dict([(node.key, node.types) for node in udtf_node.templates])
731  name = udtf_node.name
732 
733  for product in product_dict(**d):
734  self.mapping_dict = product
735  inputs = [input_arg.accept(self) for input_arg in udtf_node.inputs]
736  outputs = [output_arg.accept(self) for output_arg in udtf_node.outputs]
737  udtfs.append(UdtfNode(name, inputs, outputs, udtf_node.annotations, None, udtf_node.line))
738  self.mapping_dict = {}
739 
740  if len(udtfs) == 1:
741  return udtfs[0]
742 
743  return udtfs
744 
745  def visit_composed_node(self, composed_node):
746  typ = composed_node.type
747  typ = self.mapping_dict.get(typ, typ)
748 
749  inner = [i.accept(self) for i in composed_node.inner]
750  return composed_node.copy(typ, inner)
751 
752  def visit_primitive_node(self, primitive_node):
753  typ = primitive_node.type
754  typ = self.mapping_dict.get(typ, typ)
755  return primitive_node.copy(typ)
756 
757 
759 
760  def visit_udtf_node(self, udtf_node):
761  """
762  * Add default_input_id to Column(List)<TextEncodingDict> without one
763  * Generate fields annotation to Cursor if non-existing
764  """
765  udtf_node = super(type(self), self).visit_udtf_node(udtf_node)
766 
767  # add default input_id
768  default_input_id = None
769  for idx, t in enumerate(udtf_node.inputs):
770 
771  if not isinstance(t.type, ComposedNode):
772  continue
773  if default_input_id is not None:
774  pass
775  elif t.type.is_column_text_encoding_dict():
776  default_input_id = AnnotationNode('input_id', 'args<%s>' % (idx,))
777  elif t.type.is_column_list_text_encoding_dict():
778  default_input_id = AnnotationNode('input_id', 'args<%s, 0>' % (idx,))
779 
780  if t.type.is_cursor() and t.get_annotation('fields') is None:
781  fields = list(PrimitiveNode(a.get_annotation('name', 'field%s' % i)) for i, a in enumerate(t.type.inner))
782  t.annotations.append(AnnotationNode('fields', fields))
783 
784  for t in udtf_node.outputs:
785  if isinstance(t.type, ComposedNode) and t.type.is_any_text_encoding_dict():
786  for a in t.annotations:
787  if a.key == 'input_id':
788  break
789  else:
790  if default_input_id is None:
791  raise TypeError('Cannot parse line "%s".\n'
792  'Missing TextEncodingDict input?' %
793  (udtf_node.line))
794  t.annotations.append(default_input_id)
795 
796  return udtf_node
797 
798  def visit_primitive_node(self, primitive_node):
799  """
800  * Rename nodes using translate_map as dictionary
801  int -> Int32
802  float -> Float
803 
804  * Fix kUserSpecifiedRowMultiplier without a pos arg
805  """
806  t = primitive_node.type
807 
808  if primitive_node.is_output_buffer_sizer():
809  pos = PrimitiveNode(str(primitive_node.get_parent(ArgNode).arg_pos + 1))
810  node = ComposedNode(t, inner=[pos])
811  return node
812 
813  return primitive_node.copy(translate_map.get(t, t))
814 
815 
817  """
818  * Checks for supported annotations in a UDTF
819  """
820  def visit_udtf_node(self, udtf_node):
821  for idx, t in enumerate(udtf_node.inputs):
822  for a in t.annotations:
823  if a.key not in SupportedAnnotations:
824  raise TransformerException('unknown input annotation: `%s`' % (a.key))
825  for t in udtf_node.outputs:
826  for a in t.annotations:
827  if a.key not in SupportedAnnotations:
828  raise TransformerException('unknown output annotation: `%s`' % (a.key))
829  for annot in udtf_node.annotations:
830  if annot.key not in SupportedFunctionAnnotations:
831  raise TransformerException('unknown function annotation: `%s`' % (annot.key))
832  if annot.value.lower() in ['enable', 'on', '1', 'true']:
833  annot.value = '1'
834  elif annot.value.lower() in ['disable', 'off', '0', 'false']:
835  annot.value = '0'
836  return udtf_node
837 
838 
840 
841  def visit_udtf_node(self, udtf_node):
842  name = udtf_node.name
843  inputs = []
844  input_annotations = []
845  outputs = []
846  output_annotations = []
847  function_annotations = []
848 
849  for i in udtf_node.inputs:
850  decl = i.accept(self)
851  inputs.append(decl.type)
852  input_annotations.append(decl.annotations)
853 
854  for o in udtf_node.outputs:
855  decl = o.accept(self)
856  outputs.append(decl.type)
857  output_annotations.append(decl.annotations)
858 
859  for annot in udtf_node.annotations:
860  annot = annot.accept(self)
861  function_annotations.append(annot)
862 
863  return Signature(name, inputs, outputs, input_annotations, output_annotations, function_annotations)
864 
865  def visit_arg_node(self, arg_node):
866  t = arg_node.type.accept(self)
867  anns = [a.accept(self) for a in arg_node.annotations]
868  return Declaration(t, anns)
869 
870  def visit_composed_node(self, composed_node):
871  typ = translate_map.get(composed_node.type, composed_node.type)
872  inner = [i.accept(self) for i in composed_node.inner]
873  if composed_node.is_cursor():
874  inner = list(map(lambda x: x.apply_column(), inner))
875  return Bracket(typ, args=tuple(inner))
876  elif composed_node.is_output_buffer_sizer():
877  return Bracket(typ, args=tuple(inner))
878  else:
879  return Bracket(typ + str(inner[0]))
880 
881  def visit_primitive_node(self, primitive_node):
882  t = primitive_node.type
883  return Bracket(translate_map.get(t, t))
884 
885  def visit_annotation_node(self, annotation_node):
886  key = annotation_node.key
887  value = annotation_node.value
888  return (key, value)
889 
890 
891 class Node(object):
892 
893  __metaclass__ = ABC
894 
895  @abstractmethod
896  def accept(self, visitor):
897  pass
898 
899  @abstractmethod
900  def __str__(self):
901  pass
902 
903  def get_parent(self, cls):
904  if isinstance(self, cls):
905  return self
906 
907  if self.parent is not None:
908  return self.parent.get_parent(cls)
909 
910  raise ValueError("could not find parent with given class %s" % (cls))
911 
912  def copy(self, *args):
913  other = self.__class__(*args)
914 
915  # copy parent and arg_pos
916  for attr in ['parent', 'arg_pos']:
917  if attr in self.__dict__:
918  setattr(other, attr, getattr(self, attr))
919 
920  return other
921 
922 
923 class IterableNode(Iterable):
924  pass
925 
926 
927 class UdtfNode(Node, IterableNode):
928 
929  def __init__(self, name, inputs, outputs, annotations, templates, line):
930  """
931  Parameters
932  ----------
933  name : str
934  inputs : list[ArgNode]
935  outputs : list[ArgNode]
936  annotations : Optional[List[AnnotationNode]]
937  templates : Optional[list[TemplateNode]]
938  line: str
939  """
940  self.name = name
941  self.inputs = inputs
942  self.outputs = outputs
943  self.annotations = annotations
944  self.templates = templates
945  self.line = line
946 
947  def accept(self, visitor):
948  return visitor.visit_udtf_node(self)
949 
950  def __str__(self):
951  name = self.name
952  inputs = [str(i) for i in self.inputs]
953  outputs = [str(o) for o in self.outputs]
954  annotations = [str(a) for a in self.annotations]
955  if self.templates:
956  templates = [str(t) for t in self.templates]
957  if annotations:
958  return "UDTF: %s (%s) | %s -> %s, %s" % (name, inputs, annotations, outputs, templates)
959  else:
960  return "UDTF: %s (%s) -> %s, %s" % (name, inputs, outputs, templates)
961  else:
962  if annotations:
963  return "UDTF: %s (%s) | %s -> %s" % (name, inputs, annotations, outputs)
964  else:
965  return "UDTF: %s (%s) -> %s" % (name, inputs, outputs)
966 
967  def __iter__(self):
968  for i in self.inputs:
969  yield i
970  for o in self.outputs:
971  yield o
972  for a in self.annotations:
973  yield a
974  if self.templates:
975  for t in self.templates:
976  yield t
977 
978  __repr__ = __str__
979 
980 
982 
983  def __init__(self, type, annotations):
984  """
985  Parameters
986  ----------
987  type : TypeNode
988  annotations : List[AnnotationNode]
989  """
990  self.type = type
991  self.annotations = annotations
992  self.arg_pos = None
993 
994  def accept(self, visitor):
995  return visitor.visit_arg_node(self)
996 
997  def __str__(self):
998  t = str(self.type)
999  anns = ""
1000  if self.annotations:
1001  anns = "| ".join([str(a) for a in self.annotations])
1002  return "ArgNode(%s %s)" % (t, anns)
1003  return "ArgNode(%s)" % (t)
1004 
1005  def __iter__(self):
1006  yield self.type
1007  for a in self.annotations:
1008  yield a
1009 
1010  __repr__ = __str__
1011 
1012  def get_annotation(self, key, default=None):
1013  for a in self.annotations:
1014  if a.key == key:
1015  return a.value
1016  return default
1017 
1018 
1020  def is_column(self):
1021  return self.type == "Column"
1022 
1023  def is_column_list(self):
1024  return self.type == "ColumnList"
1025 
1026  def is_cursor(self):
1027  return self.type == "Cursor"
1028 
1030  t = self.type
1031  return translate_map.get(t, t) in OutputBufferSizeTypes
1032 
1033 
1035 
1036  def __init__(self, type):
1037  """
1038  Parameters
1039  ----------
1040  type : str
1041  """
1042  self.type = type
1043 
1044  def accept(self, visitor):
1045  return visitor.visit_primitive_node(self)
1046 
1047  def __str__(self):
1048  return self.accept(AstPrinter())
1049 
1051  return self.type == 'TextEncodingDict'
1052 
1053  __repr__ = __str__
1054 
1055 
1057 
1058  def __init__(self, type, inner):
1059  """
1060  Parameters
1061  ----------
1062  type : str
1063  inner : list[TypeNode]
1064  """
1065  self.type = type
1066  self.inner = inner
1067 
1068  def accept(self, visitor):
1069  return visitor.visit_composed_node(self)
1070 
1071  def cursor_length(self):
1072  assert self.is_cursor()
1073  return len(self.inner)
1074 
1075  def __str__(self):
1076  i = ", ".join([str(i) for i in self.inner])
1077  return "Composed(%s<%s>)" % (self.type, i)
1078 
1079  def __iter__(self):
1080  for i in self.inner:
1081  yield i
1082 
1084  return self.is_column() and self.inner[0].is_text_encoding_dict()
1085 
1087  return self.is_column_list() and self.inner[0].is_text_encoding_dict()
1088 
1090  return self.inner[0].is_text_encoding_dict()
1091 
1092  __repr__ = __str__
1093 
1094 
1096 
1097  def __init__(self, key, value):
1098  """
1099  Parameters
1100  ----------
1101  key : str
1102  value : {str, list}
1103  """
1104  self.key = key
1105  self.value = value
1106 
1107  def accept(self, visitor):
1108  return visitor.visit_annotation_node(self)
1109 
1110  def __str__(self):
1111  printer = AstPrinter()
1112  return self.accept(printer)
1113 
1114  __repr__ = __str__
1115 
1116 
1118 
1119  def __init__(self, key, types):
1120  """
1121  Parameters
1122  ----------
1123  key : str
1124  types : tuple[str]
1125  """
1126  self.key = key
1127  self.types = types
1128 
1129  def accept(self, visitor):
1130  return visitor.visit_template_node(self)
1131 
1132  def __str__(self):
1133  printer = AstPrinter()
1134  return self.accept(printer)
1135 
1136  __repr__ = __str__
1137 
1138 
1139 class Pipeline(object):
1140  def __init__(self, *passes):
1141  self.passes = passes
1142 
1143  def __call__(self, ast_list):
1144  if not isinstance(ast_list, list):
1145  ast_list = [ast_list]
1146 
1147  for c in self.passes:
1148  ast_list = [ast.accept(c()) for ast in ast_list]
1149  ast_list = itertools.chain.from_iterable( # flatten the list
1150  map(lambda x: x if isinstance(x, list) else [x], ast_list))
1151 
1152  return list(ast_list)
1153 
1154 
1155 class Parser:
1156  def __init__(self, line):
1157  self._tokens = Tokenize(line).tokens
1158  self._curr = 0
1159  self.line = line
1160 
1161  @property
1162  def tokens(self):
1163  return self._tokens
1164 
1165  def is_at_end(self):
1166  return self._curr >= len(self._tokens)
1167 
1168  def current_token(self):
1169  return self._tokens[self._curr]
1170 
1171  def advance(self):
1172  self._curr += 1
1173 
1174  def expect(self, expected_type):
1175  curr_token = self.current_token()
1176  msg = "Expected token %s but got %s at pos %d.\n Tokens: %s" % (
1177  curr_token,
1178  Token.tok_name(expected_type),
1179  self._curr,
1180  self._tokens,
1181  )
1182  assert curr_token.type == expected_type, msg
1183  self.advance()
1184 
1185  def consume(self, expected_type):
1186  """consumes the current token iff its type matches the
1187  expected_type. Otherwise, an error is raised
1188  """
1189  curr_token = self.current_token()
1190  if curr_token.type == expected_type:
1191  self.advance()
1192  return curr_token
1193  else:
1194  expected_token = Token.tok_name(expected_type)
1195  self.raise_parser_error(
1196  'Token mismatch at function consume. '
1197  'Expected type "%s" but got token "%s"\n\n'
1198  'Tokens: %s\n' % (expected_token, curr_token, self._tokens)
1199  )
1200 
1201  def current_pos(self):
1202  return self._curr
1203 
1204  def raise_parser_error(self, msg=None):
1205  if not msg:
1206  token = self.current_token()
1207  pos = self.current_pos()
1208  tokens = self.tokens
1209  msg = "\n\nError while trying to parse token %s at pos %d.\n" "Tokens: %s" % (
1210  token,
1211  pos,
1212  tokens,
1213  )
1214  raise ParserException(msg)
1215 
1216  def match(self, expected_type):
1217  curr_token = self.current_token()
1218  return curr_token.type == expected_type
1219 
1220  def parse_udtf(self):
1221  """fmt: off
1222 
1223  udtf: IDENTIFIER "(" (args)? ")" ("|" annotation)* "->" args ("," templates)?
1224 
1225  fmt: on
1226  """
1227  name = self.parse_identifier()
1228  self.expect(Token.LPAR) # (
1229  input_args = []
1230  if not self.match(Token.RPAR):
1231  input_args = self.parse_args()
1232  self.expect(Token.RPAR) # )
1233  annotations = []
1234  while not self.is_at_end() and self.match(Token.VBAR): # |
1235  self.consume(Token.VBAR)
1236  annotations.append(self.parse_annotation())
1237  self.expect(Token.RARROW) # ->
1238  output_args = self.parse_args()
1239 
1240  templates = None
1241  if not self.is_at_end() and self.match(Token.COMMA):
1242  self.consume(Token.COMMA)
1243  templates = self.parse_templates()
1244 
1245  # set arg_pos
1246  i = 0
1247  for arg in input_args:
1248  arg.arg_pos = i
1249  arg.kind = "input"
1250  i += arg.type.cursor_length() if arg.type.is_cursor() else 1
1251 
1252  for i, arg in enumerate(output_args):
1253  arg.arg_pos = i
1254  arg.kind = "output"
1255 
1256  return UdtfNode(name, input_args, output_args, annotations, templates, self.line)
1257 
1258  def parse_args(self):
1259  """fmt: off
1260 
1261  args: arg IDENTIFIER ("," arg)*
1262 
1263  fmt: on
1264  """
1265  args = []
1266  args.append(self.parse_arg())
1267  while not self.is_at_end() and self.match(Token.COMMA):
1268  curr = self._curr
1269  self.consume(Token.COMMA)
1270  self.parse_type() # assuming that we are not ending with COMMA
1271  if not self.is_at_end() and self.match(Token.EQUAL):
1272  # arg type cannot be assigned, so this must be a template specification
1273  self._curr = curr # step back and let the code below parse the templates
1274  break
1275  else:
1276  self._curr = curr + 1 # step back from self.parse_type(), parse_arg will parse it again
1277  args.append(self.parse_arg())
1278  return args
1279 
1280  def parse_arg(self):
1281  """fmt: off
1282 
1283  arg: type IDENTIFIER? ("|" annotation)*
1284 
1285  fmt: on
1286  """
1287  typ = self.parse_type()
1288 
1289  annotations = []
1290 
1291  if not self.is_at_end() and self.match(Token.IDENTIFIER):
1292  name = self.parse_identifier()
1293  annotations.append(AnnotationNode('name', name))
1294 
1295  while not self.is_at_end() and self.match(Token.VBAR):
1296  self.consume(Token.VBAR)
1297  annotations.append(self.parse_annotation())
1298 
1299  return ArgNode(typ, annotations)
1300 
1301  def parse_type(self):
1302  """fmt: off
1303 
1304  type: composed
1305  | primitive
1306 
1307  fmt: on
1308  """
1309  curr = self._curr # save state
1310  primitive = self.parse_primitive()
1311  if self.is_at_end():
1312  return primitive
1313 
1314  if not self.match(Token.LESS):
1315  return primitive
1316 
1317  self._curr = curr # return state
1318 
1319  return self.parse_composed()
1320 
1321  def parse_composed(self):
1322  """fmt: off
1323 
1324  composed: "Cursor" "<" arg ("," arg)* ">"
1325  | IDENTIFIER "<" type ("," type)* ">"
1326 
1327  fmt: on
1328  """
1329  idtn = self.parse_identifier()
1330  self.consume(Token.LESS)
1331  if is_identifier_cursor(idtn):
1332  inner = [self.parse_arg()]
1333  while self.match(Token.COMMA):
1334  self.consume(Token.COMMA)
1335  inner.append(self.parse_arg())
1336  else:
1337  inner = [self.parse_type()]
1338  while self.match(Token.COMMA):
1339  self.consume(Token.COMMA)
1340  inner.append(self.parse_type())
1341  self.consume(Token.GREATER)
1342  return ComposedNode(idtn, inner)
1343 
1344  def parse_primitive(self):
1345  """fmt: off
1346 
1347  primitive: IDENTIFIER
1348  | NUMBER
1349 
1350  fmt: on
1351  """
1352  if self.match(Token.IDENTIFIER):
1353  lexeme = self.parse_identifier()
1354  elif self.match(Token.NUMBER):
1355  lexeme = self.parse_number()
1356  else:
1357  raise self.raise_parser_error()
1358  return PrimitiveNode(lexeme)
1359 
1360  def parse_templates(self):
1361  """fmt: off
1362 
1363  templates: template ("," template)*
1364 
1365  fmt: on
1366  """
1367  T = []
1368  T.append(self.parse_template())
1369  while not self.is_at_end() and self.match(Token.COMMA):
1370  self.consume(Token.COMMA)
1371  T.append(self.parse_template())
1372  return T
1373 
1374  def parse_template(self):
1375  """fmt: off
1376 
1377  template: IDENTIFIER "=" "[" IDENTIFIER ("," IDENTIFIER)* "]"
1378 
1379  fmt: on
1380  """
1381  key = self.parse_identifier()
1382  types = []
1383  self.consume(Token.EQUAL)
1384  self.consume(Token.LSQB)
1385  types.append(self.parse_identifier())
1386  while self.match(Token.COMMA):
1387  self.consume(Token.COMMA)
1388  types.append(self.parse_identifier())
1389  self.consume(Token.RSQB)
1390  return TemplateNode(key, tuple(types))
1391 
1392  def parse_annotation(self):
1393  """fmt: off
1394 
1395  annotation: IDENTIFIER "=" IDENTIFIER ("<" NUMBER ("," NUMBER) ">")?
1396  | IDENTIFIER "=" "[" PRIMITIVE? ("," PRIMITIVE)* "]"
1397  | "require" "=" STRING
1398 
1399  fmt: on
1400  """
1401  key = self.parse_identifier()
1402  self.consume(Token.EQUAL)
1403 
1404  if key == "require":
1405  value = self.parse_string()
1406  elif not self.is_at_end() and self.match(Token.LSQB):
1407  value = []
1408  self.consume(Token.LSQB)
1409  if not self.match(Token.RSQB):
1410  value.append(self.parse_primitive())
1411  while self.match(Token.COMMA):
1412  self.consume(Token.COMMA)
1413  value.append(self.parse_primitive())
1414  self.consume(Token.RSQB)
1415  else:
1416  value = self.parse_identifier()
1417  if not self.is_at_end() and self.match(Token.LESS):
1418  self.consume(Token.LESS)
1419  num1 = self.parse_number()
1420  if self.match(Token.COMMA):
1421  self.consume(Token.COMMA)
1422  num2 = self.parse_number()
1423  value += "<%s,%s>" % (num1, num2)
1424  else:
1425  value += "<%s>" % (num1)
1426  self.consume(Token.GREATER)
1427  return AnnotationNode(key, value)
1428 
1429  def parse_identifier(self):
1430  """ fmt: off
1431 
1432  IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
1433 
1434  fmt: on
1435  """
1436  token = self.consume(Token.IDENTIFIER)
1437  return token.lexeme
1438 
1439  def parse_string(self):
1440  """ fmt: off
1441 
1442  STRING: \".*?\"
1443 
1444  fmt: on
1445  """
1446  token = self.consume(Token.STRING)
1447  return token.lexeme
1448 
1449  def parse_number(self):
1450  """ fmt: off
1451 
1452  NUMBER: [0-9]+
1453 
1454  fmt: on
1455  """
1456  token = self.consume(Token.NUMBER)
1457  return token.lexeme
1458 
1459  def parse(self):
1460  """fmt: off
1461 
1462  udtf: IDENTIFIER "(" (args)? ")" ("|" annotation)* "->" args ("," templates)?
1463 
1464  args: arg ("," arg)*
1465 
1466  arg: type IDENTIFIER? ("|" annotation)*
1467 
1468  type: composed
1469  | primitive
1470 
1471  composed: "Cursor" "<" arg ("," arg)* ">"
1472  | IDENTIFIER "<" type ("," type)* ">"
1473 
1474  primitive: IDENTIFIER
1475  | NUMBER
1476 
1477  annotation: IDENTIFIER "=" IDENTIFIER ("<" NUMBER ("," NUMBER) ">")?
1478  | IDENTIFIER "=" "[" PRIMITIVE? ("," PRIMITIVE)* "]"
1479  | "require" "=" STRING
1480 
1481  templates: template ("," template)
1482  template: IDENTIFIER "=" "[" IDENTIFIER ("," IDENTIFIER)* "]"
1483 
1484  IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
1485  NUMBER: [0-9]+
1486  STRING: \".*?\"
1487 
1488  fmt: on
1489  """
1490  self._curr = 0
1491  udtf = self.parse_udtf()
1492 
1493  # set parent
1494  udtf.parent = None
1495  d = deque()
1496  d.append(udtf)
1497  while d:
1498  node = d.pop()
1499  if isinstance(node, Iterable):
1500  for child in node:
1501  child.parent = node
1502  d.append(child)
1503  return udtf
1504 
1505 
1506 # fmt: off
1507 def find_signatures(input_file):
1508  """Returns a list of parsed UDTF signatures."""
1509  signatures = []
1510 
1511  last_line = None
1512  for line in open(input_file).readlines():
1513  line = line.strip()
1514  if last_line is not None:
1515  line = last_line + line
1516  last_line = None
1517  if not line.startswith('UDTF:'):
1518  continue
1519  if line_is_incomplete(line):
1520  last_line = line
1521  continue
1522  last_line = None
1523  line = line[5:].lstrip()
1524  i = line.find('(')
1525  j = line.find(')')
1526  if i == -1 or j == -1:
1527  sys.stderr.write('Invalid UDTF specification: `%s`. Skipping.\n' % (line))
1528  continue
1529 
1530  expected_result = None
1531  if '!' in line:
1532  line, expected_result = line.split('!', 1)
1533  expected_result = expected_result.strip().split('!')
1534  expected_result = list(map(lambda s: s.strip(), expected_result))
1535 
1536  ast = Parser(line).parse()
1537 
1538  if expected_result is not None:
1539  # Template transformer expands templates into multiple lines
1540  skip_signature = False
1541  try:
1542  result = Pipeline(TemplateTransformer, NormalizeTransformer, SupportedAnnotationsTransformer, AstPrinter)(ast)
1543  except TransformerException as msg:
1544  result = ['%s: %s' % (type(msg).__name__, msg)]
1545  skip_signature = True
1546  assert set(result) == set(expected_result), "\n\tresult: %s != \n\texpected: %s" % (
1547  result,
1548  expected_result,
1549  )
1550  if skip_signature:
1551  continue
1552 
1553  signature = Pipeline(TemplateTransformer,
1554  NormalizeTransformer,
1555  SupportedAnnotationsTransformer,
1556  SignatureTransformer)(ast)
1557 
1558  signatures.extend(signature)
1559 
1560  return signatures
1561 
1562 
1563 def build_template_function_call(caller, callee, input_types, output_types, uses_manager):
1564  # caller calls callee
1565 
1566  input_cpp_args = []
1567  output_cpp_args = []
1568  arg_names = []
1569 
1570  if uses_manager:
1571  input_cpp_args.append("TableFunctionManager& mgr")
1572  arg_names.append("mgr")
1573 
1574  for idx, input_type in enumerate(input_types):
1575  cpp_arg, arg_name = input_type.format_cpp_type(idx)
1576  input_cpp_args.append(cpp_arg)
1577  arg_names.append(arg_name)
1578 
1579  for idx, output_type in enumerate(output_types):
1580  cpp_arg, arg_name = output_type.format_cpp_type(idx, is_input=False)
1581  output_cpp_args.append(cpp_arg)
1582  arg_names.append(arg_name)
1583 
1584  args = ', '.join(input_cpp_args + output_cpp_args)
1585  arg_names = ', '.join(arg_names)
1586 
1587  template = ("EXTENSION_NOINLINE int32_t\n"
1588  "%s(%s) {\n"
1589  " return %s(%s);\n"
1590  "}\n") % (caller, args, callee, arg_names)
1591  return template
1592 
1593 
1594 def build_require_check_function(fn_name, input_types, input_annotations, uses_manager):
1595 
1596  def format_error_msg(err_msg, uses_manager):
1597  if uses_manager:
1598  return " return mgr.error_message(%s);\n" % (err_msg,)
1599  else:
1600  return " return table_function_error(%s);\n" % (err_msg,)
1601 
1602  input_cpp_args = []
1603  arg_names = []
1604 
1605  for idx, (typ, ann) in enumerate(zip(input_types, input_annotations)):
1606  arg_names.append(dict(ann).get('name', 'input%s' % (idx)))
1607  cpp_arg, _ = typ.format_cpp_type(idx, arg_name=arg_names[idx])
1608  input_cpp_args.append(cpp_arg)
1609 
1610  if uses_manager:
1611  input_cpp_args = ["TableFunctionManager& mgr"] + input_cpp_args
1612  arg_names = ["mgr"] + arg_names
1613  input_annotations = [[]] + input_annotations
1614 
1615  fn = "EXTENSION_NOINLINE int32_t\n"
1616  fn += "%s(%s) {\n" % (fn_name.lower() + "__require_check", ', '.join(input_cpp_args))
1617 
1618  for idx, arg_name in enumerate(arg_names):
1619  ann = input_annotations[idx]
1620  for key, value in ann:
1621  if key == 'require':
1622  err_msg = '"Constraint `%s` is not satisfied."' % (value[1:-1])
1623 
1624  fn += " if (!(%s)) {\n" % (value[1:-1].replace('\\', ''),)
1625  fn += format_error_msg(err_msg, uses_manager)
1626  fn += " }\n"
1627 
1628  fn += " return 0;\n"
1629  fn += "}\n\n"
1630 
1631  return fn
1632 
1633 
1635  for arg_annotations in sig.input_annotations:
1636  d = dict(arg_annotations)
1637  if 'require' in d.keys():
1638  return True
1639  return False
1640 
1641 
1642 def format_annotations(annotations_):
1643  def fmt(k, v):
1644  # type(v) is not always 'str'
1645  if k == 'require':
1646  return v[1:-1]
1647  return v
1648 
1649  s = "std::vector<std::map<std::string, std::string>>{"
1650  s += ', '.join(('{' + ', '.join('{"%s", "%s"}' % (k, fmt(k, v)) for k, v in a) + '}') for a in annotations_)
1651  s += "}"
1652  return s
1653 
1654 
1656  i = sig.name.rfind('_template')
1657  return i >= 0 and '__' in sig.name[:i + 1]
1658 
1659 
1660 def uses_manager(sig):
1661  return sig.inputs and sig.inputs[0].name == 'TableFunctionManager'
1662 
1663 
1665  # Any function that does not have _gpu_ suffix is a cpu function.
1666  i = sig.name.rfind('_gpu_')
1667  if i >= 0 and '__' in sig.name[:i + 1]:
1668  if uses_manager(sig):
1669  raise ValueError('Table function {} with gpu execution target cannot have TableFunctionManager argument'.format(sig.name))
1670  return False
1671  return True
1672 
1673 
1675  # A function with TableFunctionManager argument is a cpu-only function
1676  if uses_manager(sig):
1677  return False
1678  # Any function that does not have _cpu_ suffix is a gpu function.
1679  i = sig.name.rfind('_cpu_')
1680  return not (i >= 0 and '__' in sig.name[:i + 1])
1681 
1682 
1683 def parse_annotations(input_files):
1684 
1685  counter = 0
1686 
1687  add_stmts = []
1688  cpu_template_functions = []
1689  gpu_template_functions = []
1690  cpu_function_address_expressions = []
1691  gpu_function_address_expressions = []
1692  cond_fns = []
1693 
1694  for input_file in input_files:
1695  for sig in find_signatures(input_file):
1696 
1697  # Compute sql_types, input_types, and sizer
1698  sql_types_ = []
1699  input_types_ = []
1700  input_annotations = []
1701  sizer = None
1702  uses_manager = False
1703  for i, (t, annot) in enumerate(zip(sig.inputs, sig.input_annotations)):
1704  if t.is_output_buffer_sizer():
1705  if t.is_user_specified():
1706  sql_types_.append(Bracket.parse('int32').normalize(kind='input'))
1707  input_types_.append(sql_types_[-1])
1708  input_annotations.append(annot)
1709  assert sizer is None # exactly one sizer argument is allowed
1710  assert len(t.args) == 1, t
1711  sizer = 'TableFunctionOutputRowSizer{OutputBufferSizeType::%s, %s}' % (t.name, t.args[0])
1712  elif t.name == 'Cursor':
1713  for t_ in t.args:
1714  input_types_.append(t_)
1715  input_annotations.append(annot)
1716  sql_types_.append(Bracket('Cursor', args=()))
1717  elif t.name == 'TableFunctionManager':
1718  if i != 0:
1719  raise ValueError('{} must appear as a first argument of {}, but found it at position {}.'.format(t, sig.name, i))
1720  uses_manager = True
1721  else:
1722  input_types_.append(t)
1723  input_annotations.append(annot)
1724  if t.is_column_any():
1725  # XXX: let Bracket handle mapping of column to cursor(column)
1726  sql_types_.append(Bracket('Cursor', args=()))
1727  else:
1728  sql_types_.append(t)
1729 
1730  if sizer is None:
1731  name = 'kTableFunctionSpecifiedParameter'
1732  idx = 1 # this sizer is not actually materialized in the UDTF
1733  sizer = 'TableFunctionOutputRowSizer{OutputBufferSizeType::%s, %s}' % (name, idx)
1734 
1735  assert sizer is not None
1736  ns_output_types = tuple([a.apply_namespace(ns='ExtArgumentType') for a in sig.outputs])
1737  ns_input_types = tuple([t.apply_namespace(ns='ExtArgumentType') for t in input_types_])
1738  ns_sql_types = tuple([t.apply_namespace(ns='ExtArgumentType') for t in sql_types_])
1739 
1740  input_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(tostring, ns_input_types)))
1741  output_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(tostring, ns_output_types)))
1742  sql_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(tostring, ns_sql_types)))
1743  annotations = format_annotations(input_annotations + sig.output_annotations + [sig.function_annotations])
1744 
1745  # Notice that input_types and sig.input_types, (and
1746  # similarly, input_annotations and sig.input_annotations)
1747  # have different lengths when the sizer argument is
1748  # Constant or TableFunctionSpecifiedParameter. That is,
1749  # input_types contains all the user-specified arguments
1750  # while sig.input_types contains all arguments of the
1751  # implementation of an UDTF.
1752 
1753  if contains_require_check(sig):
1754  check_fn = build_require_check_function(sig.name, input_types_, input_annotations, uses_manager)
1755  cond_fns.append(check_fn)
1756 
1757  if is_template_function(sig):
1758  name = sig.name + '_' + str(counter)
1759  counter += 1
1760  t = build_template_function_call(name, sig.name, input_types_, sig.outputs, uses_manager)
1761  address_expression = ('avoid_opt_address(reinterpret_cast<void*>(%s))' % name)
1762  if is_cpu_function(sig):
1763  cpu_template_functions.append(t)
1764  cpu_function_address_expressions.append(address_expression)
1765  if is_gpu_function(sig):
1766  gpu_template_functions.append(t)
1767  gpu_function_address_expressions.append(address_expression)
1768  add = ('TableFunctionsFactory::add("%s", %s, %s, %s, %s, %s, /*is_runtime:*/false, /*uses_manager:*/%s);'
1769  % (name, sizer, input_types, output_types, sql_types, annotations, str(uses_manager).lower()))
1770  add_stmts.append(add)
1771 
1772  else:
1773  add = ('TableFunctionsFactory::add("%s", %s, %s, %s, %s, %s, /*is_runtime:*/false, /*uses_manager:*/%s);'
1774  % (sig.name, sizer, input_types, output_types, sql_types, annotations, str(uses_manager).lower()))
1775  add_stmts.append(add)
1776  address_expression = ('avoid_opt_address(reinterpret_cast<void*>(%s))' % sig.name)
1777 
1778  if is_cpu_function(sig):
1779  cpu_function_address_expressions.append(address_expression)
1780  if is_gpu_function(sig):
1781  gpu_function_address_expressions.append(address_expression)
1782 
1783  return add_stmts, cpu_template_functions, gpu_template_functions, cpu_function_address_expressions, gpu_function_address_expressions, cond_fns
1784 
1785 
1786 if len(sys.argv) < 3:
1787 
1788  input_files = [os.path.join(os.path.dirname(__file__), 'test_udtf_signatures.hpp')]
1789  print('Running tests from %s' % (', '.join(input_files)))
1790  add_stmts, _, _, _, _, _ = parse_annotations(input_files)
1791 
1792  print('Usage:\n %s %s input1.hpp input2.hpp ... output.hpp' % (sys.executable, sys.argv[0], ))
1793 
1794  sys.exit(1)
1795 
1796 input_files, output_filename = sys.argv[1:-1], sys.argv[-1]
1797 cpu_output_header = os.path.splitext(output_filename)[0] + '_cpu.hpp'
1798 gpu_output_header = os.path.splitext(output_filename)[0] + '_gpu.hpp'
1799 assert input_files, sys.argv
1800 
1801 add_stmts, cpu_template_functions, gpu_template_functions, cpu_address_expressions, gpu_address_expressions, cond_fns = parse_annotations(sys.argv[1:-1])
1802 
1803 canonical_input_files = [input_file[input_file.find("/QueryEngine/") + 1:] for input_file in input_files]
1804 header_includes = ['#include "' + canonical_input_file + '"' for canonical_input_file in canonical_input_files]
1805 
1806 content = '''
1807 /*
1808  This file is generated by %s. Do no edit!
1809 */
1810 
1811 #include "QueryEngine/TableFunctions/TableFunctionsFactory.h"
1812 %s
1813 
1814 /*
1815  Include the UDTF template initiations:
1816 */
1817 #ifdef __CUDACC__
1818 #include "TableFunctionsFactory_init_gpu.hpp"
1819 #else
1820 #include "TableFunctionsFactory_init_cpu.hpp"
1821 #endif
1822 
1823 extern bool g_enable_table_functions;
1824 
1825 namespace table_functions {
1826 
1827 std::once_flag init_flag;
1828 
1829 // volatile+noinline prevents compiler optimization
1830 __attribute__((noinline))
1831 volatile bool avoid_opt_address(void *address) {
1832  return address != nullptr;
1833 }
1834 
1835 bool functions_exist() {
1836  bool ret = true;
1837 
1838  ret &= (%s);
1839 
1840 #ifdef __CUDACC__
1841  ret &= (%s);
1842 #endif
1843 
1844  return ret;
1845 }
1846 
1847 void TableFunctionsFactory::init() {
1848  if (!g_enable_table_functions) {
1849  return;
1850  }
1851 
1852  if (!functions_exist()) {
1853  UNREACHABLE();
1854  return;
1855  }
1856 
1857  std::call_once(init_flag, []() {
1858  %s
1859  });
1860 }
1861 
1862 // conditional check functions
1863 %s
1864 
1865 } // namespace table_functions
1866 
1867 ''' % (sys.argv[0],
1868  '\n'.join(header_includes),
1869  ' &&\n'.join(cpu_address_expressions),
1870  ' &&\n'.join(gpu_address_expressions),
1871  '\n '.join(add_stmts),
1872  ''.join(cond_fns))
1873 
1874 header_content = '''
1875 /*
1876  This file is generated by %s. Do no edit!
1877 */
1878 %s
1879 
1880 %s
1881 '''
1882 
1883 dirname = os.path.dirname(output_filename)
1884 
1885 if dirname and not os.path.exists(dirname):
1886  try:
1887  os.makedirs(dirname)
1888  except OSError as e:
1889  import errno
1890  if e.errno != errno.EEXIST:
1891  raise
1892 
1893 f = open(output_filename, 'w')
1894 f.write(content)
1895 f.close()
1896 
1897 f = open(cpu_output_header, 'w')
1898 f.write(header_content % (sys.argv[0], '\n'.join(header_includes), '\n'.join(cpu_template_functions)))
1899 f.close()
1900 
1901 f = open(gpu_output_header, 'w')
1902 f.write(header_content % (sys.argv[0], '\n'.join(header_includes), '\n'.join(gpu_template_functions)))
1903 f.close()
int open(const char *path, int flags, int mode)
Definition: omnisci_fs.cpp:64
std::string strip(std::string_view str)
trim any whitespace from the left and right ends of a string
std::string join(T const &container, std::string const &delim)
std::vector< std::string > split(std::string_view str, std::string_view delim, std::optional< size_t > maxsplit)
split apart a string into a vector of substrings