OmniSciDB  085a039ca4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
generate_TableFunctionsFactory_init.py
Go to the documentation of this file.
1 """Given a list of input files, scan for lines containing UDTF
2 specification statements in the following form:
3 
4  UDTF: function_name(<arguments>) -> <output column types> (, <template type specifications>)?
5 
6 where <arguments> is a comma-separated list of argument types. The
7 argument types specifications are:
8 
9 - scalar types:
10  Int8, Int16, Int32, Int64, Float, Double, Bool, TextEncodingDict, etc
11 - column types:
12  ColumnInt8, ColumnInt16, ColumnInt32, ColumnInt64, ColumnFloat, ColumnDouble, ColumnBool, etc
13 - column list types:
14  ColumnListInt8, ColumnListInt16, ColumnListInt32, ColumnListInt64, ColumnListFloat, ColumnListDouble, ColumnListBool, etc
15 - cursor type:
16  Cursor<t0, t1, ...>
17  where t0, t1 are column or column list types
18 - output buffer size parameter type:
19  RowMultiplier<i>, ConstantParameter<i>, Constant<i>, TableFunctionSpecifiedParameter<i>
20  where i is a literal integer.
21 
22 The output column types is a comma-separated list of column types, see above.
23 
24 In addition, the following equivalents are suppored:
25 
26  Column<T> == ColumnT
27  ColumnList<T> == ColumnListT
28  Cursor<T, V, ...> == Cursor<ColumnT, ColumnV, ...>
29  int8 == int8_t == Int8, etc
30  float == Float, double == Double, bool == Bool
31  T == ColumnT for output column types
32  RowMultiplier == RowMultiplier<i> where i is the one-based position of the sizer argument
33  when no sizer argument is provided, Constant<1> is assumed
34 
35 Argument types can be annotated using `|' (bar) symbol after an
36 argument type specification. An annotation is specified by a label and
37 a value separated by `=' (equal) symbol. Multiple annotations can be
38 specified by using `|` (bar) symbol as the annotations separator.
39 Supported annotation labels are:
40 
41 - name: to specify argument name
42 - input_id: to specify the dict id mapping for output TextEncodingDict columns.
43 
44 If argument type follows an identifier, it will be mapped to name
45 annotations. For example, the following argument type specifications
46 are equivalent:
47 
48  Int8 a
49  Int8 | name=a
50 
51 Template type specifications is a comma separated list of template
52 type assignments where values are lists of argument type names. For
53 instance:
54 
55  T = [Int8, Int16, Int32, Float], V = [Float, Double]
56 
57 """
58 # Author: Pearu Peterson
59 # Created: January 2021
60 
61 
62 import os
63 import sys
64 import itertools
65 import copy
66 from abc import abstractmethod
67 
68 from collections import deque, namedtuple
69 
70 if sys.version_info > (3, 0):
71  from abc import ABC
72  from collections.abc import Iterable
73 else:
74  from abc import ABCMeta as ABC
75  from collections import Iterable
76 
77 # fmt: off
78 separator = '$=>$'
79 
80 Signature = namedtuple('Signature', ['name', 'inputs', 'outputs', 'input_annotations', 'output_annotations', 'function_annotations', 'sizer'])
81 
82 ExtArgumentTypes = ''' Int8, Int16, Int32, Int64, Float, Double, Void, PInt8, PInt16,
83 PInt32, PInt64, PFloat, PDouble, PBool, Bool, ArrayInt8, ArrayInt16,
84 ArrayInt32, ArrayInt64, ArrayFloat, ArrayDouble, ArrayBool, GeoPoint,
85 GeoLineString, Cursor, GeoPolygon, GeoMultiPolygon, ColumnInt8,
86 ColumnInt16, ColumnInt32, ColumnInt64, ColumnFloat, ColumnDouble,
87 ColumnBool, ColumnTextEncodingDict, ColumnTimestamp, TextEncodingNone,
88 TextEncodingDict, Timestamp, ColumnListInt8, ColumnListInt16, ColumnListInt32,
89 ColumnListInt64, ColumnListFloat, ColumnListDouble, ColumnListBool,
90 ColumnListTextEncodingDict'''.strip().replace(' ', '').replace('\n', '').split(',')
91 
92 OutputBufferSizeTypes = '''
93 kConstant, kUserSpecifiedConstantParameter, kUserSpecifiedRowMultiplier, kTableFunctionSpecifiedParameter, kPreFlightParameter
94 '''.strip().replace(' ', '').split(',')
95 
96 SupportedAnnotations = '''
97 input_id, name, fields, require
98 '''.strip().replace(' ', '').split(',')
99 
100 # TODO: support `gpu`, `cpu`, `template` as function annotations
101 SupportedFunctionAnnotations = '''
102 filter_table_function_transpose, uses_manager
103 '''.strip().replace(' ', '').split(',')
104 
105 translate_map = dict(
106  Constant='kConstant',
107  PreFlight='kPreFlightParameter',
108  ConstantParameter='kUserSpecifiedConstantParameter',
109  RowMultiplier='kUserSpecifiedRowMultiplier',
110  UserSpecifiedConstantParameter='kUserSpecifiedConstantParameter',
111  UserSpecifiedRowMultiplier='kUserSpecifiedRowMultiplier',
112  TableFunctionSpecifiedParameter='kTableFunctionSpecifiedParameter',
113  short='Int16',
114  int='Int32',
115  long='Int64',
116 )
117 for t in ['Int8', 'Int16', 'Int32', 'Int64', 'Float', 'Double', 'Bool',
118  'TextEncodingDict', 'TextEncodingNone']:
119  translate_map[t.lower()] = t
120  if t.startswith('Int'):
121  translate_map[t.lower() + '_t'] = t
122 
123 
125  """Holds a `TYPE | ANNOTATIONS`-like structure.
126  """
127  def __init__(self, type, annotations=[]):
128  self.type = type
129  self.annotations = annotations
130 
131  @property
132  def name(self):
133  return self.type.name
134 
135  @property
136  def args(self):
137  return self.type.args
138 
139  def format_sizer(self):
140  return self.type.format_sizer()
141 
142  def __repr__(self):
143  return 'Declaration(%r, ann=%r)' % (self.type, self.annotations)
144 
145  def __str__(self):
146  if not self.annotations:
147  return str(self.type)
148  return '%s | %s' % (self.type, ' | '.join(map(str, self.annotations)))
149 
150  def tostring(self):
151  return self.type.tostring()
152 
153  def apply_column(self):
154  return self.__class__(self.type.apply_column(), self.annotations)
155 
156  def apply_namespace(self, ns='ExtArgumentType'):
157  return self.__class__(self.type.apply_namespace(ns), self.annotations)
158 
159  def get_cpp_type(self):
160  return self.type.get_cpp_type()
161 
162  def format_cpp_type(self, idx, use_generic_arg_name=False, is_input=True):
163  real_arg_name = dict(self.annotations).get('name', None)
164  return self.type.format_cpp_type(idx,
165  use_generic_arg_name=use_generic_arg_name,
166  real_arg_name=real_arg_name,
167  is_input=is_input)
168 
169  def __getattr__(self, name):
170  if name.startswith('is_'):
171  return getattr(self.type, name)
172  raise AttributeError(name)
173 
174 
175 def tostring(obj):
176  return obj.tostring()
177 
178 
179 class Bracket:
180  """Holds a `NAME<ARGS>`-like structure.
181  """
182 
183  def __init__(self, name, args=None):
184  assert isinstance(name, str)
185  assert isinstance(args, tuple) or args is None, args
186  self.name = name
187  self.args = args
188 
189  def __repr__(self):
190  return 'Bracket(%r, args=%r)' % (self.name, self.args)
191 
192  def __str__(self):
193  if not self.args:
194  return self.name
195  return '%s<%s>' % (self.name, ', '.join(map(str, self.args)))
196 
197  def tostring(self):
198  if not self.args:
199  return self.name
200  return '%s<%s>' % (self.name, ', '.join(map(tostring, self.args)))
201 
202  def normalize(self, kind='input'):
203  """Normalize bracket for given kind
204  """
205  assert kind in ['input', 'output'], kind
206  if self.is_column_any() and self.args:
207  return Bracket(self.name + ''.join(map(str, self.args)))
208  if kind == 'input':
209  if self.name == 'Cursor':
210  args = [(a if a.is_column_any() else Bracket('Column', args=(a,))).normalize(kind=kind) for a in self.args]
211  return Bracket(self.name, tuple(args))
212  if kind == 'output':
213  if not self.is_column_any():
214  return Bracket('Column', args=(self,)).normalize(kind=kind)
215  return self
216 
217  def apply_cursor(self):
218  """Apply cursor to a non-cursor column argument type.
219  TODO: this method is currently unused but we should apply
220  cursor to all input column arguments in order to distingush
221  signatures like:
222  foo(Cursor(Column<int32>, Column<float>)) -> Column<int32>
223  foo(Cursor(Column<int32>), Cursor(Column<float>)) -> Column<int32>
224  that at the moment are treated as the same :(
225  """
226  if self.is_column():
227  return Bracket('Cursor', args=(self,))
228  return self
229 
230  def apply_column(self):
231  if not self.is_column() and not self.is_column_list():
232  return Bracket('Column' + self.name)
233  return self
234 
235  def apply_namespace(self, ns='ExtArgumentType'):
236  if self.name == 'Cursor':
237  return Bracket(ns + '::' + self.name, args=tuple(a.apply_namespace(ns=ns) for a in self.args))
238  if not self.name.startswith(ns + '::'):
239  return Bracket(ns + '::' + self.name)
240  return self
241 
242  def is_cursor(self):
243  return self.name.rsplit("::", 1)[-1] == 'Cursor'
244 
245  def is_column_any(self):
246  return self.name.rsplit("::", 1)[-1].startswith('Column')
247 
248  def is_column_list(self):
249  return self.name.rsplit("::", 1)[-1].startswith('ColumnList')
250 
251  def is_column(self):
252  return self.name.rsplit("::", 1)[-1].startswith('Column') and not self.is_column_list()
253 
255  return self.name.rsplit("::", 1)[-1].endswith('TextEncodingDict')
256 
258  return self.name.rsplit("::", 1)[-1] == 'ColumnTextEncodingDict'
259 
261  return self.name.rsplit("::", 1)[-1] == 'ColumnListTextEncodingDict'
262 
264  return self.name.rsplit("::", 1)[-1] in OutputBufferSizeTypes
265 
266  def is_row_multiplier(self):
267  return self.name.rsplit("::", 1)[-1] == 'kUserSpecifiedRowMultiplier'
268 
269  def is_arg_sizer(self):
270  return self.name.rsplit("::", 1)[-1] == 'kPreFlightParameter'
271 
272  def is_user_specified(self):
273  # Return True if given argument cannot specified by user
274  if self.is_output_buffer_sizer():
275  return self.name.rsplit("::", 1)[-1] not in ('kConstant', 'kTableFunctionSpecifiedParameter', 'kPreFlightParameter')
276  return True
277 
278  def format_sizer(self):
279  val = 0 if self.is_arg_sizer() else self.args[0]
280  return 'TableFunctionOutputRowSizer{OutputBufferSizeType::%s, %s}' % (self.name, val)
281 
282  def get_cpp_type(self):
283  name = self.name.rsplit("::", 1)[-1]
284  clsname = None
285  if name.startswith('ColumnList'):
286  name = name.lstrip('ColumnList')
287  clsname = 'ColumnList'
288  elif name.startswith('Column'):
289  name = name.lstrip('Column')
290  clsname = 'Column'
291  if name.startswith('Bool'):
292  ctype = name.lower()
293  elif name.startswith('Int'):
294  ctype = name.lower() + '_t'
295  elif name in ['Double', 'Float']:
296  ctype = name.lower()
297  elif name == 'TextEncodingDict':
298  ctype = name
299  elif name == 'TextEncodingNone':
300  ctype = name
301  elif name == 'Timestamp':
302  ctype = name
303  else:
304  raise NotImplementedError(self)
305  if clsname is None:
306  return ctype
307  return '%s<%s>' % (clsname, ctype)
308 
309  def format_cpp_type(self, idx, use_generic_arg_name=False, real_arg_name=None, is_input=True):
310  col_typs = ('Column', 'ColumnList')
311  literal_ref_typs = ('TextEncodingNone',)
312  if use_generic_arg_name:
313  arg_name = 'input' + str(idx) if is_input else 'output' + str(idx)
314  elif real_arg_name is not None:
315  arg_name = real_arg_name
316  else:
317  # in some cases, the real arg name is not specified
318  arg_name = 'input' + str(idx) if is_input else 'output' + str(idx)
319  const = 'const ' if is_input else ''
320  cpp_type = self.get_cpp_type()
321 
322  if any(cpp_type.startswith(t) for t in col_typs + literal_ref_typs):
323  return '%s%s& %s' % (const, cpp_type, arg_name), arg_name
324  else:
325  return '%s %s' % (cpp_type, arg_name), arg_name
326 
327  @classmethod
328  def parse(cls, typ):
329  """typ is a string in format NAME<ARGS> or NAME
330 
331  Returns Bracket instance.
332  """
333  i = typ.find('<')
334  if i == -1:
335  name = typ.strip()
336  args = None
337  else:
338  assert typ.endswith('>'), typ
339  name = typ[:i].strip()
340  args = []
341  rest = typ[i + 1:-1].strip()
342  while rest:
343  i = find_comma(rest)
344  if i == -1:
345  a, rest = rest, ''
346  else:
347  a, rest = rest[:i].rstrip(), rest[i + 1:].lstrip()
348  args.append(cls.parse(a))
349  args = tuple(args)
350 
351  name = translate_map.get(name, name)
352  return cls(name, args)
353 
354 
355 def find_comma(line):
356  d = 0
357  for i, c in enumerate(line):
358  if c in '<([{':
359  d += 1
360  elif c in '>)]{':
361  d -= 1
362  elif d == 0 and c == ',':
363  return i
364  return -1
365 
366 
368  # TODO: try to parse the line to be certain about completeness.
369  # `$=>$' is used to separate the UDTF signature and the expected result
370  return line.endswith(',') or line.endswith('->') or line.endswith(separator) or line.endswith('|')
371 
372 
373 def is_identifier_cursor(identifier):
374  return identifier.lower() == 'cursor'
375 
376 
377 # fmt: on
378 
379 
380 class TokenizeException(Exception):
381  pass
382 
383 
384 class ParserException(Exception):
385  pass
386 
387 
388 class TransformerException(Exception):
389  pass
390 
391 
392 class Token:
393  LESS = 1 # <
394  GREATER = 2 # >
395  COMMA = 3 # ,
396  EQUAL = 4 # =
397  RARROW = 5 # ->
398  STRING = 6 # reserved for string constants
399  NUMBER = 7 #
400  VBAR = 8 # |
401  BANG = 9 # !
402  LPAR = 10 # (
403  RPAR = 11 # )
404  LSQB = 12 # [
405  RSQB = 13 # ]
406  IDENTIFIER = 14 #
407  COLON = 15 # :
408 
409  def __init__(self, type, lexeme):
410  """
411  Parameters
412  ----------
413  type : int
414  One of the tokens in the list above
415  lexeme : str
416  Corresponding string in the text
417  """
418  self.type = type
419  self.lexeme = lexeme
420 
421  @classmethod
422  def tok_name(cls, token):
423  names = {
424  Token.LESS: "LESS",
425  Token.GREATER: "GREATER",
426  Token.COMMA: "COMMA",
427  Token.EQUAL: "EQUAL",
428  Token.RARROW: "RARROW",
429  Token.STRING: "STRING",
430  Token.NUMBER: "NUMBER",
431  Token.VBAR: "VBAR",
432  Token.BANG: "BANG",
433  Token.LPAR: "LPAR",
434  Token.RPAR: "RPAR",
435  Token.LSQB: "LSQB",
436  Token.RSQB: "RSQB",
437  Token.IDENTIFIER: "IDENTIFIER",
438  Token.COLON: "COLON",
439  }
440  return names.get(token)
441 
442  def __str__(self):
443  return 'Token(%s, "%s")' % (Token.tok_name(self.type), self.lexeme)
444 
445  __repr__ = __str__
446 
447 
448 class Tokenize:
449  def __init__(self, line):
450  self._line = line
451  self._tokens = []
452  self.start = 0
453  self.curr = 0
454  self.tokenize()
455 
456  @property
457  def line(self):
458  return self._line
459 
460  @property
461  def tokens(self):
462  return self._tokens
463 
464  def tokenize(self):
465  while not self.is_at_end():
466  self.start = self.curr
467 
468  if self.is_token_whitespace():
469  self.consume_whitespace()
470  elif self.is_digit():
471  self.consume_number()
472  elif self.is_token_string():
473  self.consume_string()
474  elif self.is_token_identifier():
475  self.consume_identifier()
476  elif self.can_token_be_double_char():
477  self.consume_double_char()
478  else:
479  self.consume_single_char()
480 
481  def is_at_end(self):
482  return len(self.line) == self.curr
483 
484  def current_token(self):
485  return self.line[self.start:self.curr + 1]
486 
487  def add_token(self, type):
488  lexeme = self.line[self.start:self.curr + 1]
489  self._tokens.append(Token(type, lexeme))
490 
491  def lookahead(self):
492  if self.curr + 1 >= len(self.line):
493  return None
494  return self.line[self.curr + 1]
495 
496  def advance(self):
497  self.curr += 1
498 
499  def peek(self):
500  return self.line[self.curr]
501 
503  char = self.peek()
504  return char in ("-",)
505 
507  ahead = self.lookahead()
508  if ahead == ">":
509  self.advance()
510  self.add_token(Token.RARROW) # ->
511  self.advance()
512  else:
513  self.raise_tokenize_error()
514 
516  char = self.peek()
517  if char == "(":
518  self.add_token(Token.LPAR)
519  elif char == ")":
520  self.add_token(Token.RPAR)
521  elif char == "<":
522  self.add_token(Token.LESS)
523  elif char == ">":
524  self.add_token(Token.GREATER)
525  elif char == ",":
526  self.add_token(Token.COMMA)
527  elif char == "=":
528  self.add_token(Token.EQUAL)
529  elif char == "|":
530  self.add_token(Token.VBAR)
531  elif char == "!":
532  self.add_token(Token.BANG)
533  elif char == "[":
534  self.add_token(Token.LSQB)
535  elif char == "]":
536  self.add_token(Token.RSQB)
537  elif char == ":":
538  self.add_token(Token.COLON)
539  else:
540  self.raise_tokenize_error()
541  self.advance()
542 
544  self.advance()
545 
546  def consume_string(self):
547  """
548  STRING: \".*?\"
549  """
550  while True:
551  char = self.lookahead()
552  curr = self.peek()
553  if char == '"' and curr != '\\':
554  self.advance()
555  break
556  self.advance()
557  self.add_token(Token.STRING)
558  self.advance()
559 
560  def consume_number(self):
561  """
562  NUMBER: [0-9]+
563  """
564  while True:
565  char = self.lookahead()
566  if char and char.isdigit():
567  self.advance()
568  else:
569  break
570  self.add_token(Token.NUMBER)
571  self.advance()
572 
574  """
575  IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
576  """
577  while True:
578  char = self.lookahead()
579  if char and char.isalnum() or char == "_":
580  self.advance()
581  else:
582  break
583  self.add_token(Token.IDENTIFIER)
584  self.advance()
585 
587  return self.peek().isalpha() or self.peek() == "_"
588 
589  def is_token_string(self):
590  return self.peek() == '"'
591 
592  def is_digit(self):
593  return self.peek().isdigit()
594 
595  def is_alpha(self):
596  return self.peek().isalpha()
597 
599  return self.peek().isspace()
600 
602  curr = self.curr
603  char = self.peek()
604  raise TokenizeException(
605  'Could not match char "%s" at pos %d on line\n %s' % (char, curr, self.line)
606  )
607 
608 
609 class AstVisitor(object):
610  __metaclass__ = ABC
611 
612  @abstractmethod
613  def visit_udtf_node(self, node):
614  pass
615 
616  @abstractmethod
617  def visit_composed_node(self, node):
618  pass
619 
620  @abstractmethod
621  def visit_arg_node(self, node):
622  pass
623 
624  @abstractmethod
625  def visit_primitive_node(self, node):
626  pass
627 
628  @abstractmethod
629  def visit_annotation_node(self, node):
630  pass
631 
632  @abstractmethod
633  def visit_template_node(self, node):
634  pass
635 
636 
637 class AstTransformer(AstVisitor):
638  """Only overload the methods you need"""
639 
640  def visit_udtf_node(self, udtf_node):
641  udtf = copy.copy(udtf_node)
642  udtf.inputs = [arg.accept(self) for arg in udtf.inputs]
643  udtf.outputs = [arg.accept(self) for arg in udtf.outputs]
644  if udtf.templates:
645  udtf.templates = [t.accept(self) for t in udtf.templates]
646  udtf.annotations = [annot.accept(self) for annot in udtf.annotations]
647  return udtf
648 
649  def visit_composed_node(self, composed_node):
650  c = copy.copy(composed_node)
651  c.inner = [i.accept(self) for i in c.inner]
652  return c
653 
654  def visit_arg_node(self, arg_node):
655  arg_node = copy.copy(arg_node)
656  arg_node.type = arg_node.type.accept(self)
657  if arg_node.annotations:
658  arg_node.annotations = [a.accept(self) for a in arg_node.annotations]
659  return arg_node
660 
661  def visit_primitive_node(self, primitive_node):
662  return copy.copy(primitive_node)
663 
664  def visit_template_node(self, template_node):
665  return copy.copy(template_node)
666 
667  def visit_annotation_node(self, annotation_node):
668  return copy.copy(annotation_node)
669 
670 
672  """Returns a line formatted. Useful for testing"""
673 
674  def visit_udtf_node(self, udtf_node):
675  name = udtf_node.name
676  inputs = ", ".join([arg.accept(self) for arg in udtf_node.inputs])
677  outputs = ", ".join([arg.accept(self) for arg in udtf_node.outputs])
678  annotations = "| ".join([annot.accept(self) for annot in udtf_node.annotations])
679  sizer = " | " + udtf_node.sizer.accept(self) if udtf_node.sizer else ""
680  if annotations:
681  annotations = ' | ' + annotations
682  if udtf_node.templates:
683  templates = ", ".join([t.accept(self) for t in udtf_node.templates])
684  return "%s(%s)%s -> %s, %s%s" % (name, inputs, annotations, outputs, templates, sizer)
685  else:
686  return "%s(%s)%s -> %s%s" % (name, inputs, annotations, outputs, sizer)
687 
688  def visit_template_node(self, template_node):
689  # T=[T1, T2, ..., TN]
690  key = template_node.key
691  types = ['"%s"' % typ for typ in template_node.types]
692  return "%s=[%s]" % (key, ", ".join(types))
693 
694  def visit_annotation_node(self, annotation_node):
695  # key=value
696  key = annotation_node.key
697  value = annotation_node.value
698  if isinstance(value, list):
699  return "%s=[%s]" % (key, ','.join([v.accept(self) for v in value]))
700  return "%s=%s" % (key, value)
701 
702  def visit_arg_node(self, arg_node):
703  # type | annotation
704  typ = arg_node.type.accept(self)
705  if arg_node.annotations:
706  ann = " | ".join([a.accept(self) for a in arg_node.annotations])
707  s = "%s | %s" % (typ, ann)
708  else:
709  s = "%s" % (typ,)
710  # insert input_id=args<0> if input_id is not specified
711  if s == "ColumnTextEncodingDict" and arg_node.kind == "output":
712  return s + " | input_id=args<0>"
713  return s
714 
715  def visit_composed_node(self, composed_node):
716  T = composed_node.inner[0].accept(self)
717  if composed_node.is_column():
718  # Column<T>
719  assert len(composed_node.inner) == 1
720  return "Column" + T
721  if composed_node.is_column_list():
722  # ColumnList<T>
723  assert len(composed_node.inner) == 1
724  return "ColumnList" + T
725  if composed_node.is_output_buffer_sizer():
726  # kConstant<N>
727  N = T
728  assert len(composed_node.inner) == 1
729  return translate_map.get(composed_node.type) + "<%s>" % (N,)
730  if composed_node.is_cursor():
731  # Cursor<T1, T2, ..., TN>
732  Ts = ", ".join([i.accept(self) for i in composed_node.inner])
733  return "Cursor<%s>" % (Ts)
734  raise ValueError(composed_node)
735 
736  def visit_primitive_node(self, primitive_node):
737  t = primitive_node.type
738  if primitive_node.is_output_buffer_sizer():
739  # arg_pos is zero-based
740  return translate_map.get(t, t) + "<%d>" % (
741  primitive_node.get_parent(ArgNode).arg_pos + 1,
742  )
743  return translate_map.get(t, t)
744 
745 
747  """Like AstPrinter but returns a node instead of a string
748  """
749 
750 
751 def product_dict(**kwargs):
752  keys = kwargs.keys()
753  vals = kwargs.values()
754  for instance in itertools.product(*vals):
755  yield dict(zip(keys, instance))
756 
757 
759  """Expand template definition into multiple inputs"""
760 
761  def visit_udtf_node(self, udtf_node):
762  if not udtf_node.templates:
763  return udtf_node
764 
765  udtfs = []
766 
767  d = dict([(node.key, node.types) for node in udtf_node.templates])
768  name = udtf_node.name
769 
770  for product in product_dict(**d):
771  self.mapping_dict = product
772  inputs = [input_arg.accept(self) for input_arg in udtf_node.inputs]
773  outputs = [output_arg.accept(self) for output_arg in udtf_node.outputs]
774  udtfs.append(UdtfNode(name, inputs, outputs, udtf_node.annotations, None, udtf_node.sizer, udtf_node.line))
775  self.mapping_dict = {}
776 
777  if len(udtfs) == 1:
778  return udtfs[0]
779 
780  return udtfs
781 
782  def visit_composed_node(self, composed_node):
783  typ = composed_node.type
784  typ = self.mapping_dict.get(typ, typ)
785 
786  inner = [i.accept(self) for i in composed_node.inner]
787  return composed_node.copy(typ, inner)
788 
789  def visit_primitive_node(self, primitive_node):
790  typ = primitive_node.type
791  typ = self.mapping_dict.get(typ, typ)
792  return primitive_node.copy(typ)
793 
794 
796  def visit_primitive_node(self, primitive_node):
797  """
798  * Fix kUserSpecifiedRowMultiplier without a pos arg
799  """
800  t = primitive_node.type
801 
802  if primitive_node.is_output_buffer_sizer():
803  pos = PrimitiveNode(str(primitive_node.get_parent(ArgNode).arg_pos + 1))
804  node = ComposedNode(t, inner=[pos])
805  return node
806 
807  return primitive_node
808 
809 
811  def visit_primitive_node(self, primitive_node):
812  """
813  * Rename nodes using translate_map as dictionary
814  int -> Int32
815  float -> Float
816  """
817  t = primitive_node.type
818  return primitive_node.copy(translate_map.get(t, t))
819 
820 
822  def visit_udtf_node(self, udtf_node):
823  """
824  * Add default_input_id to Column(List)<TextEncodingDict> without one
825  """
826  udtf_node = super(type(self), self).visit_udtf_node(udtf_node)
827 
828  # add default input_id
829  default_input_id = None
830  for idx, t in enumerate(udtf_node.inputs):
831 
832  if not isinstance(t.type, ComposedNode):
833  continue
834  if default_input_id is not None:
835  pass
836  elif t.type.is_column_text_encoding_dict():
837  default_input_id = AnnotationNode('input_id', 'args<%s>' % (idx,))
838  elif t.type.is_column_list_text_encoding_dict():
839  default_input_id = AnnotationNode('input_id', 'args<%s, 0>' % (idx,))
840 
841  for t in udtf_node.outputs:
842  if isinstance(t.type, ComposedNode) and t.type.is_any_text_encoding_dict():
843  for a in t.annotations:
844  if a.key == 'input_id':
845  break
846  else:
847  if default_input_id is None:
848  raise TypeError('Cannot parse line "%s".\n'
849  'Missing TextEncodingDict input?' %
850  (udtf_node.line))
851  t.annotations.append(default_input_id)
852 
853  return udtf_node
854 
855 
857 
858  def visit_udtf_node(self, udtf_node):
859  """
860  * Generate fields annotation to Cursor if non-existing
861  """
862  udtf_node = super(type(self), self).visit_udtf_node(udtf_node)
863 
864  for _, t in enumerate(udtf_node.inputs):
865 
866  if not isinstance(t.type, ComposedNode):
867  continue
868 
869  if t.type.is_cursor() and t.get_annotation('fields') is None:
870  fields = list(PrimitiveNode(a.get_annotation('name', 'field%s' % i)) for i, a in enumerate(t.type.inner))
871  t.annotations.append(AnnotationNode('fields', fields))
872 
873  return udtf_node
874 
875 
877  """
878  * Checks for supported annotations in a UDTF
879  """
880  def visit_udtf_node(self, udtf_node):
881  for idx, t in enumerate(udtf_node.inputs):
882  for a in t.annotations:
883  if a.key not in SupportedAnnotations:
884  raise TransformerException('unknown input annotation: `%s`' % (a.key))
885  for t in udtf_node.outputs:
886  for a in t.annotations:
887  if a.key not in SupportedAnnotations:
888  raise TransformerException('unknown output annotation: `%s`' % (a.key))
889  for annot in udtf_node.annotations:
890  if annot.key not in SupportedFunctionAnnotations:
891  raise TransformerException('unknown function annotation: `%s`' % (annot.key))
892  if annot.value.lower() in ['enable', 'on', '1', 'true']:
893  annot.value = '1'
894  elif annot.value.lower() in ['disable', 'off', '0', 'false']:
895  annot.value = '0'
896  return udtf_node
897 
898 
900 
901  def visit_udtf_node(self, udtf_node):
902  name = udtf_node.name
903  inputs = []
904  input_annotations = []
905  outputs = []
906  output_annotations = []
907  function_annotations = []
908  sizer = udtf_node.sizer
909 
910  for i in udtf_node.inputs:
911  decl = i.accept(self)
912  inputs.append(decl)
913  input_annotations.append(decl.annotations)
914 
915  for o in udtf_node.outputs:
916  decl = o.accept(self)
917  outputs.append(decl.type)
918  output_annotations.append(decl.annotations)
919 
920  for annot in udtf_node.annotations:
921  annot = annot.accept(self)
922  function_annotations.append(annot)
923 
924  return Signature(name, inputs, outputs, input_annotations, output_annotations, function_annotations, sizer)
925 
926  def visit_arg_node(self, arg_node):
927  t = arg_node.type.accept(self)
928  anns = [a.accept(self) for a in arg_node.annotations]
929  return Declaration(t, anns)
930 
931  def visit_composed_node(self, composed_node):
932  typ = translate_map.get(composed_node.type, composed_node.type)
933  inner = [i.accept(self) for i in composed_node.inner]
934  if composed_node.is_cursor():
935  inner = list(map(lambda x: x.apply_column(), inner))
936  return Bracket(typ, args=tuple(inner))
937  elif composed_node.is_output_buffer_sizer():
938  return Bracket(typ, args=tuple(inner))
939  else:
940  return Bracket(typ + str(inner[0]))
941 
942  def visit_primitive_node(self, primitive_node):
943  t = primitive_node.type
944  return Bracket(t)
945 
946  def visit_annotation_node(self, annotation_node):
947  key = annotation_node.key
948  value = annotation_node.value
949  return (key, value)
950 
951 
952 class Node(object):
953 
954  __metaclass__ = ABC
955 
956  @abstractmethod
957  def accept(self, visitor):
958  pass
959 
960  @abstractmethod
961  def __str__(self):
962  pass
963 
964  def get_parent(self, cls):
965  if isinstance(self, cls):
966  return self
967 
968  if self.parent is not None:
969  return self.parent.get_parent(cls)
970 
971  raise ValueError("could not find parent with given class %s" % (cls))
972 
973  def copy(self, *args):
974  other = self.__class__(*args)
975 
976  # copy parent and arg_pos
977  for attr in ['parent', 'arg_pos']:
978  if attr in self.__dict__:
979  setattr(other, attr, getattr(self, attr))
980 
981  return other
982 
983 
984 class IterableNode(Iterable):
985  pass
986 
987 
988 class UdtfNode(Node, IterableNode):
989 
990  def __init__(self, name, inputs, outputs, annotations, templates, sizer, line):
991  """
992  Parameters
993  ----------
994  name : str
995  inputs : list[ArgNode]
996  outputs : list[ArgNode]
997  annotations : Optional[List[AnnotationNode]]
998  templates : Optional[list[TemplateNode]]
999  sizer : Optional[str]
1000  line: str
1001  """
1002  self.name = name
1003  self.inputs = inputs
1004  self.outputs = outputs
1005  self.annotations = annotations
1006  self.templates = templates
1007  self.sizer = sizer
1008  self.line = line
1009 
1010  def accept(self, visitor):
1011  return visitor.visit_udtf_node(self)
1012 
1013  def __str__(self):
1014  name = self.name
1015  inputs = [str(i) for i in self.inputs]
1016  outputs = [str(o) for o in self.outputs]
1017  annotations = [str(a) for a in self.annotations]
1018  sizer = "| %s" % str(self.sizer) if self.sizer else ""
1019  if self.templates:
1020  templates = [str(t) for t in self.templates]
1021  if annotations:
1022  return "UDTF: %s (%s) | %s -> %s, %s %s" % (name, inputs, annotations, outputs, templates, sizer)
1023  else:
1024  return "UDTF: %s (%s) -> %s, %s %s" % (name, inputs, outputs, templates, sizer)
1025  else:
1026  if annotations:
1027  return "UDTF: %s (%s) | %s -> %s %s" % (name, inputs, annotations, outputs, sizer)
1028  else:
1029  return "UDTF: %s (%s) -> %s %s" % (name, inputs, outputs, sizer)
1030 
1031  def __iter__(self):
1032  for i in self.inputs:
1033  yield i
1034  for o in self.outputs:
1035  yield o
1036  for a in self.annotations:
1037  yield a
1038  if self.templates:
1039  for t in self.templates:
1040  yield t
1041 
1042  __repr__ = __str__
1043 
1044 
1046 
1047  def __init__(self, type, annotations):
1048  """
1049  Parameters
1050  ----------
1051  type : TypeNode
1052  annotations : List[AnnotationNode]
1053  """
1054  self.type = type
1055  self.annotations = annotations
1056  self.arg_pos = None
1057 
1058  def accept(self, visitor):
1059  return visitor.visit_arg_node(self)
1060 
1061  def __str__(self):
1062  t = str(self.type)
1063  anns = ""
1064  if self.annotations:
1065  anns = "| ".join([str(a) for a in self.annotations])
1066  return "ArgNode(%s %s)" % (t, anns)
1067  return "ArgNode(%s)" % (t)
1068 
1069  def __iter__(self):
1070  yield self.type
1071  for a in self.annotations:
1072  yield a
1073 
1074  __repr__ = __str__
1075 
1076  def get_annotation(self, key, default=None):
1077  for a in self.annotations:
1078  if a.key == key:
1079  return a.value
1080  return default
1081 
1082 
1084  def is_column(self):
1085  return self.type == "Column"
1086 
1087  def is_column_list(self):
1088  return self.type == "ColumnList"
1089 
1090  def is_cursor(self):
1091  return self.type == "Cursor"
1092 
1094  t = self.type
1095  return translate_map.get(t, t) in OutputBufferSizeTypes
1096 
1097 
1099 
1100  def __init__(self, type):
1101  """
1102  Parameters
1103  ----------
1104  type : str
1105  """
1106  self.type = type
1107 
1108  def accept(self, visitor):
1109  return visitor.visit_primitive_node(self)
1110 
1111  def __str__(self):
1112  return self.accept(AstPrinter())
1113 
1115  return self.type == 'TextEncodingDict'
1116 
1117  __repr__ = __str__
1118 
1119 
1121 
1122  def __init__(self, type, inner):
1123  """
1124  Parameters
1125  ----------
1126  type : str
1127  inner : list[TypeNode]
1128  """
1129  self.type = type
1130  self.inner = inner
1131 
1132  def accept(self, visitor):
1133  return visitor.visit_composed_node(self)
1134 
1135  def cursor_length(self):
1136  assert self.is_cursor()
1137  return len(self.inner)
1138 
1139  def __str__(self):
1140  i = ", ".join([str(i) for i in self.inner])
1141  return "Composed(%s<%s>)" % (self.type, i)
1142 
1143  def __iter__(self):
1144  for i in self.inner:
1145  yield i
1146 
1148  return self.is_column() and self.inner[0].is_text_encoding_dict()
1149 
1151  return self.is_column_list() and self.inner[0].is_text_encoding_dict()
1152 
1154  return self.inner[0].is_text_encoding_dict()
1155 
1156  __repr__ = __str__
1157 
1158 
1160 
1161  def __init__(self, key, value):
1162  """
1163  Parameters
1164  ----------
1165  key : str
1166  value : {str, list}
1167  """
1168  self.key = key
1169  self.value = value
1170 
1171  def accept(self, visitor):
1172  return visitor.visit_annotation_node(self)
1173 
1174  def __str__(self):
1175  printer = AstPrinter()
1176  return self.accept(printer)
1177 
1178  __repr__ = __str__
1179 
1180 
1182 
1183  def __init__(self, key, types):
1184  """
1185  Parameters
1186  ----------
1187  key : str
1188  types : tuple[str]
1189  """
1190  self.key = key
1191  self.types = types
1192 
1193  def accept(self, visitor):
1194  return visitor.visit_template_node(self)
1195 
1196  def __str__(self):
1197  printer = AstPrinter()
1198  return self.accept(printer)
1199 
1200  __repr__ = __str__
1201 
1202 
1203 class Pipeline(object):
1204  def __init__(self, *passes):
1205  self.passes = passes
1206 
1207  def __call__(self, ast_list):
1208  if not isinstance(ast_list, list):
1209  ast_list = [ast_list]
1210 
1211  for c in self.passes:
1212  ast_list = [ast.accept(c()) for ast in ast_list]
1213  ast_list = itertools.chain.from_iterable( # flatten the list
1214  map(lambda x: x if isinstance(x, list) else [x], ast_list))
1215 
1216  return list(ast_list)
1217 
1218 
1219 class Parser:
1220  def __init__(self, line):
1221  self._tokens = Tokenize(line).tokens
1222  self._curr = 0
1223  self.line = line
1224 
1225  @property
1226  def tokens(self):
1227  return self._tokens
1228 
1229  def is_at_end(self):
1230  return self._curr >= len(self._tokens)
1231 
1232  def current_token(self):
1233  return self._tokens[self._curr]
1234 
1235  def advance(self):
1236  self._curr += 1
1237 
1238  def expect(self, expected_type):
1239  curr_token = self.current_token()
1240  msg = "Expected token %s but got %s at pos %d.\n Tokens: %s" % (
1241  curr_token,
1242  Token.tok_name(expected_type),
1243  self._curr,
1244  self._tokens,
1245  )
1246  assert curr_token.type == expected_type, msg
1247  self.advance()
1248 
1249  def consume(self, expected_type):
1250  """consumes the current token iff its type matches the
1251  expected_type. Otherwise, an error is raised
1252  """
1253  curr_token = self.current_token()
1254  if curr_token.type == expected_type:
1255  self.advance()
1256  return curr_token
1257  else:
1258  expected_token = Token.tok_name(expected_type)
1259  self.raise_parser_error(
1260  'Token mismatch at function consume. '
1261  'Expected type "%s" but got token "%s"\n\n'
1262  'Tokens: %s\n' % (expected_token, curr_token, self._tokens)
1263  )
1264 
1265  def current_pos(self):
1266  return self._curr
1267 
1268  def raise_parser_error(self, msg=None):
1269  if not msg:
1270  token = self.current_token()
1271  pos = self.current_pos()
1272  tokens = self.tokens
1273  msg = "\n\nError while trying to parse token %s at pos %d.\n" "Tokens: %s" % (
1274  token,
1275  pos,
1276  tokens,
1277  )
1278  raise ParserException(msg)
1279 
1280  def match(self, expected_type):
1281  curr_token = self.current_token()
1282  return curr_token.type == expected_type
1283 
1284  def lookahead(self):
1285  return self._tokens[self._curr + 1]
1286 
1287  def parse_udtf(self):
1288  """fmt: off
1289 
1290  udtf: IDENTIFIER "(" (args)? ")" ("|" annotation)* "->" args ("," templates)? ("|" "output_row_size" "=" primitive)?
1291 
1292  fmt: on
1293  """
1294  name = self.parse_identifier()
1295  self.expect(Token.LPAR) # (
1296  input_args = []
1297  if not self.match(Token.RPAR):
1298  input_args = self.parse_args()
1299  self.expect(Token.RPAR) # )
1300  annotations = []
1301  while not self.is_at_end() and self.match(Token.VBAR): # |
1302  self.consume(Token.VBAR)
1303  annotations.append(self.parse_annotation())
1304  self.expect(Token.RARROW) # ->
1305  output_args = self.parse_args()
1306 
1307  templates = None
1308  if not self.is_at_end() and self.match(Token.COMMA):
1309  self.consume(Token.COMMA)
1310  templates = self.parse_templates()
1311 
1312  sizer = None
1313  if not self.is_at_end() and self.match(Token.VBAR):
1314  self.consume(Token.VBAR)
1315  idtn = self.parse_identifier()
1316  assert idtn == "output_row_size"
1317  self.consume(Token.EQUAL)
1318  node = self.parse_primitive()
1319  key = "kPreFlightParameter"
1320  sizer = AnnotationNode(key, value=node.type)
1321 
1322  # set arg_pos
1323  i = 0
1324  for arg in input_args:
1325  arg.arg_pos = i
1326  arg.kind = "input"
1327  i += arg.type.cursor_length() if arg.type.is_cursor() else 1
1328 
1329  for i, arg in enumerate(output_args):
1330  arg.arg_pos = i
1331  arg.kind = "output"
1332 
1333  return UdtfNode(name, input_args, output_args, annotations, templates, sizer, self.line)
1334 
1335  def parse_args(self):
1336  """fmt: off
1337 
1338  args: arg IDENTIFIER ("," arg)*
1339 
1340  fmt: on
1341  """
1342  args = []
1343  args.append(self.parse_arg())
1344  while not self.is_at_end() and self.match(Token.COMMA):
1345  curr = self._curr
1346  self.consume(Token.COMMA)
1347  self.parse_type() # assuming that we are not ending with COMMA
1348  if not self.is_at_end() and self.match(Token.EQUAL):
1349  # arg type cannot be assigned, so this must be a template specification
1350  self._curr = curr # step back and let the code below parse the templates
1351  break
1352  else:
1353  self._curr = curr + 1 # step back from self.parse_type(), parse_arg will parse it again
1354  args.append(self.parse_arg())
1355  return args
1356 
1357  def parse_arg(self):
1358  """fmt: off
1359 
1360  arg: type IDENTIFIER? ("|" annotation)*
1361 
1362  fmt: on
1363  """
1364  typ = self.parse_type()
1365 
1366  annotations = []
1367 
1368  if not self.is_at_end() and self.match(Token.IDENTIFIER):
1369  name = self.parse_identifier()
1370  annotations.append(AnnotationNode('name', name))
1371 
1372  while not self.is_at_end() and self.match(Token.VBAR):
1373  ahead = self.lookahead()
1374  if ahead.type == Token.IDENTIFIER and ahead.lexeme == 'output_row_size':
1375  break
1376  self.consume(Token.VBAR)
1377  annotations.append(self.parse_annotation())
1378 
1379  return ArgNode(typ, annotations)
1380 
1381  def parse_type(self):
1382  """fmt: off
1383 
1384  type: composed
1385  | primitive
1386 
1387  fmt: on
1388  """
1389  curr = self._curr # save state
1390  primitive = self.parse_primitive()
1391  if self.is_at_end():
1392  return primitive
1393 
1394  if not self.match(Token.LESS):
1395  return primitive
1396 
1397  self._curr = curr # return state
1398 
1399  return self.parse_composed()
1400 
1401  def parse_composed(self):
1402  """fmt: off
1403 
1404  composed: "Cursor" "<" arg ("," arg)* ">"
1405  | IDENTIFIER "<" type ("," type)* ">"
1406 
1407  fmt: on
1408  """
1409  idtn = self.parse_identifier()
1410  self.consume(Token.LESS)
1411  if is_identifier_cursor(idtn):
1412  inner = [self.parse_arg()]
1413  while self.match(Token.COMMA):
1414  self.consume(Token.COMMA)
1415  inner.append(self.parse_arg())
1416  else:
1417  inner = [self.parse_type()]
1418  while self.match(Token.COMMA):
1419  self.consume(Token.COMMA)
1420  inner.append(self.parse_type())
1421  self.consume(Token.GREATER)
1422  return ComposedNode(idtn, inner)
1423 
1424  def parse_primitive(self):
1425  """fmt: off
1426 
1427  primitive: IDENTIFIER
1428  | NUMBER
1429  | STRING
1430 
1431  fmt: on
1432  """
1433  if self.match(Token.IDENTIFIER):
1434  lexeme = self.parse_identifier()
1435  elif self.match(Token.NUMBER):
1436  lexeme = self.parse_number()
1437  elif self.match(Token.STRING):
1438  lexeme = self.parse_string()
1439  else:
1440  raise self.raise_parser_error()
1441  return PrimitiveNode(lexeme)
1442 
1443  def parse_templates(self):
1444  """fmt: off
1445 
1446  templates: template ("," template)*
1447 
1448  fmt: on
1449  """
1450  T = []
1451  T.append(self.parse_template())
1452  while not self.is_at_end() and self.match(Token.COMMA):
1453  self.consume(Token.COMMA)
1454  T.append(self.parse_template())
1455  return T
1456 
1457  def parse_template(self):
1458  """fmt: off
1459 
1460  template: IDENTIFIER "=" "[" IDENTIFIER ("," IDENTIFIER)* "]"
1461 
1462  fmt: on
1463  """
1464  key = self.parse_identifier()
1465  types = []
1466  self.consume(Token.EQUAL)
1467  self.consume(Token.LSQB)
1468  types.append(self.parse_identifier())
1469  while self.match(Token.COMMA):
1470  self.consume(Token.COMMA)
1471  types.append(self.parse_identifier())
1472  self.consume(Token.RSQB)
1473  return TemplateNode(key, tuple(types))
1474 
1475  def parse_annotation(self):
1476  """fmt: off
1477 
1478  annotation: IDENTIFIER "=" IDENTIFIER ("<" NUMBER ("," NUMBER) ">")?
1479  | IDENTIFIER "=" "[" PRIMITIVE? ("," PRIMITIVE)* "]"
1480  | "require" "=" STRING
1481 
1482  fmt: on
1483  """
1484  key = self.parse_identifier()
1485  self.consume(Token.EQUAL)
1486 
1487  if key == "require":
1488  value = self.parse_string()
1489  elif not self.is_at_end() and self.match(Token.LSQB):
1490  value = []
1491  self.consume(Token.LSQB)
1492  if not self.match(Token.RSQB):
1493  value.append(self.parse_primitive())
1494  while self.match(Token.COMMA):
1495  self.consume(Token.COMMA)
1496  value.append(self.parse_primitive())
1497  self.consume(Token.RSQB)
1498  else:
1499  value = self.parse_identifier()
1500  if not self.is_at_end() and self.match(Token.LESS):
1501  self.consume(Token.LESS)
1502  if self.match(Token.GREATER):
1503  value += "<%s>" % (-1) # Signifies no input
1504  else:
1505  num1 = self.parse_number()
1506  if self.match(Token.COMMA):
1507  self.consume(Token.COMMA)
1508  num2 = self.parse_number()
1509  value += "<%s,%s>" % (num1, num2)
1510  else:
1511  value += "<%s>" % (num1)
1512  self.consume(Token.GREATER)
1513  return AnnotationNode(key, value)
1514 
1515  def parse_identifier(self):
1516  """ fmt: off
1517 
1518  IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
1519 
1520  fmt: on
1521  """
1522  token = self.consume(Token.IDENTIFIER)
1523  return token.lexeme
1524 
1525  def parse_string(self):
1526  """ fmt: off
1527 
1528  STRING: \".*?\"
1529 
1530  fmt: on
1531  """
1532  token = self.consume(Token.STRING)
1533  return token.lexeme
1534 
1535  def parse_number(self):
1536  """ fmt: off
1537 
1538  NUMBER: [0-9]+
1539 
1540  fmt: on
1541  """
1542  token = self.consume(Token.NUMBER)
1543  return token.lexeme
1544 
1545  def parse(self):
1546  """fmt: off
1547 
1548  udtf: IDENTIFIER "(" (args)? ")" ("|" annotation)* "->" args ("," templates)? ("|" "output_row_size" "=" primitive)?
1549 
1550  args: arg ("," arg)*
1551 
1552  arg: type IDENTIFIER? ("|" annotation)*
1553 
1554  type: composed
1555  | primitive
1556 
1557  composed: "Cursor" "<" arg ("," arg)* ">"
1558  | IDENTIFIER "<" type ("," type)* ">"
1559 
1560  primitive: IDENTIFIER
1561  | NUMBER
1562  | STRING
1563 
1564  annotation: IDENTIFIER "=" IDENTIFIER ("<" NUMBER ("," NUMBER) ">")?
1565  | IDENTIFIER "=" "[" PRIMITIVE? ("," PRIMITIVE)* "]"
1566  | "require" "=" STRING
1567 
1568  templates: template ("," template)
1569  template: IDENTIFIER "=" "[" IDENTIFIER ("," IDENTIFIER)* "]"
1570 
1571  IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
1572  NUMBER: [0-9]+
1573  STRING: \".*?\"
1574 
1575  fmt: on
1576  """
1577  self._curr = 0
1578  udtf = self.parse_udtf()
1579 
1580  # set parent
1581  udtf.parent = None
1582  d = deque()
1583  d.append(udtf)
1584  while d:
1585  node = d.pop()
1586  if isinstance(node, Iterable):
1587  for child in node:
1588  child.parent = node
1589  d.append(child)
1590  return udtf
1591 
1592 
1593 # fmt: off
1594 def find_signatures(input_file):
1595  """Returns a list of parsed UDTF signatures."""
1596  signatures = []
1597 
1598  last_line = None
1599  for line in open(input_file).readlines():
1600  line = line.strip()
1601  if last_line is not None:
1602  line = last_line + ' ' + line
1603  last_line = None
1604  if not line.startswith('UDTF:'):
1605  continue
1606  if line_is_incomplete(line):
1607  last_line = line
1608  continue
1609  last_line = None
1610  line = line[5:].lstrip()
1611  i = line.find('(')
1612  j = line.find(')')
1613  if i == -1 or j == -1:
1614  sys.stderr.write('Invalid UDTF specification: `%s`. Skipping.\n' % (line))
1615  continue
1616 
1617  expected_result = None
1618  if separator in line:
1619  line, expected_result = line.split(separator, 1)
1620  expected_result = expected_result.strip().split(separator)
1621  expected_result = list(map(lambda s: s.strip(), expected_result))
1622 
1623  ast = Parser(line).parse()
1624 
1625  if expected_result is not None:
1626  # Template transformer expands templates into multiple lines
1627  skip_signature = False
1628  try:
1629  result = Pipeline(TemplateTransformer,
1630  FieldAnnotationTransformer,
1631  TextEncodingDictTransformer,
1632  SupportedAnnotationsTransformer,
1633  FixRowMultiplierPosArgTransformer,
1634  RenameNodesTransformer,
1635  AstPrinter)(ast)
1636  except TransformerException as msg:
1637  result = ['%s: %s' % (type(msg).__name__, msg)]
1638  skip_signature = True
1639  assert set(result) == set(expected_result), "\n\tresult: %s != \n\texpected: %s" % (
1640  result,
1641  expected_result,
1642  )
1643  if skip_signature:
1644  continue
1645 
1646  signature = Pipeline(TemplateTransformer,
1647  FieldAnnotationTransformer,
1648  TextEncodingDictTransformer,
1649  SupportedAnnotationsTransformer,
1650  FixRowMultiplierPosArgTransformer,
1651  RenameNodesTransformer,
1652  DeclBracketTransformer)(ast)
1653 
1654  signatures.extend(signature)
1655 
1656  return signatures
1657 
1658 
1659 def format_function_args(input_types, output_types, uses_manager, use_generic_arg_name, emit_output_args):
1660  cpp_args = []
1661  name_args = []
1662 
1663  if uses_manager:
1664  cpp_args.append('TableFunctionManager& mgr')
1665  name_args.append('mgr')
1666 
1667  for idx, typ in enumerate(input_types):
1668  cpp_arg, name = typ.format_cpp_type(idx,
1669  use_generic_arg_name=use_generic_arg_name,
1670  is_input=True)
1671  cpp_args.append(cpp_arg)
1672  name_args.append(name)
1673 
1674  if emit_output_args:
1675  for idx, typ in enumerate(output_types):
1676  cpp_arg, name = typ.format_cpp_type(idx,
1677  use_generic_arg_name=use_generic_arg_name,
1678  is_input=False)
1679  cpp_args.append(cpp_arg)
1680  name_args.append(name)
1681 
1682  cpp_args = ', '.join(cpp_args)
1683  name_args = ', '.join(name_args)
1684  return cpp_args, name_args
1685 
1686 
1687 def build_template_function_call(caller, called, input_types, output_types, uses_manager):
1688  cpp_args, name_args = format_function_args(input_types,
1689  output_types,
1690  uses_manager,
1691  use_generic_arg_name=True,
1692  emit_output_args=True)
1693 
1694  template = ("EXTENSION_NOINLINE int32_t\n"
1695  "%s(%s) {\n"
1696  " return %s(%s);\n"
1697  "}\n") % (caller, cpp_args, called, name_args)
1698  return template
1699 
1700 
1701 def build_preflight_function(fn_name, sizer, input_types, output_types, uses_manager):
1702 
1703  def format_error_msg(err_msg, uses_manager):
1704  if uses_manager:
1705  return " return mgr.error_message(%s);\n" % (err_msg,)
1706  else:
1707  return " return table_function_error(%s);\n" % (err_msg,)
1708 
1709  cpp_args, _ = format_function_args(input_types,
1710  output_types,
1711  uses_manager,
1712  use_generic_arg_name=False,
1713  emit_output_args=False)
1714 
1715  if uses_manager:
1716  fn = "EXTENSION_NOINLINE int32_t\n"
1717  fn += "%s(%s) {\n" % (fn_name.lower() + "__preflight", cpp_args)
1718  else:
1719  fn = "EXTENSION_NOINLINE int32_t\n"
1720  fn += "%s(%s) {\n" % (fn_name.lower() + "__preflight", cpp_args)
1721 
1722  for typ in input_types:
1723  ann = typ.annotations
1724  for key, value in ann:
1725  if key == 'require':
1726  err_msg = '"Constraint `%s` is not satisfied."' % (value[1:-1])
1727 
1728  fn += " if (!(%s)) {\n" % (value[1:-1].replace('\\', ''),)
1729  fn += format_error_msg(err_msg, uses_manager)
1730  fn += " }\n"
1731 
1732  if sizer.is_arg_sizer():
1733  precomputed_nrows = str(sizer.args[0])
1734  if '"' in precomputed_nrows:
1735  precomputed_nrows = precomputed_nrows[1:-1]
1736  # check to see if the precomputed number of rows > 0
1737  err_msg = '"Output size expression `%s` evaluated in a negative value."' % (precomputed_nrows)
1738  fn += " auto _output_size = %s;\n" % (precomputed_nrows)
1739  fn += " if (_output_size < 0) {\n"
1740  fn += format_error_msg(err_msg, uses_manager)
1741  fn += " }\n"
1742  fn += " return _output_size;\n"
1743  else:
1744  fn += " return 0;\n"
1745  fn += "}\n\n"
1746 
1747  return fn
1748 
1749 
1751  if sizer.is_arg_sizer():
1752  return True
1753  for arg_annotations in sig.input_annotations:
1754  d = dict(arg_annotations)
1755  if 'require' in d.keys():
1756  return True
1757  return False
1758 
1759 
1760 def format_annotations(annotations_):
1761  def fmt(k, v):
1762  # type(v) is not always 'str'
1763  if k == 'require':
1764  return v[1:-1]
1765  return v
1766 
1767  s = "std::vector<std::map<std::string, std::string>>{"
1768  s += ', '.join(('{' + ', '.join('{"%s", "%s"}' % (k, fmt(k, v)) for k, v in a) + '}') for a in annotations_)
1769  s += "}"
1770  return s
1771 
1772 
1774  i = sig.name.rfind('_template')
1775  return i >= 0 and '__' in sig.name[:i + 1]
1776 
1777 
1778 def uses_manager(sig):
1779  return sig.inputs and sig.inputs[0].name == 'TableFunctionManager'
1780 
1781 
1783  # Any function that does not have _gpu_ suffix is a cpu function.
1784  i = sig.name.rfind('_gpu_')
1785  if i >= 0 and '__' in sig.name[:i + 1]:
1786  if uses_manager(sig):
1787  raise ValueError('Table function {} with gpu execution target cannot have TableFunctionManager argument'.format(sig.name))
1788  return False
1789  return True
1790 
1791 
1793  # A function with TableFunctionManager argument is a cpu-only function
1794  if uses_manager(sig):
1795  return False
1796  # Any function that does not have _cpu_ suffix is a gpu function.
1797  i = sig.name.rfind('_cpu_')
1798  return not (i >= 0 and '__' in sig.name[:i + 1])
1799 
1800 
1801 def parse_annotations(input_files):
1802 
1803  counter = 0
1804 
1805  add_stmts = []
1806  cpu_template_functions = []
1807  gpu_template_functions = []
1808  cpu_function_address_expressions = []
1809  gpu_function_address_expressions = []
1810  cond_fns = []
1811 
1812  for input_file in input_files:
1813  for sig in find_signatures(input_file):
1814 
1815  # Compute sql_types, input_types, and sizer
1816  sql_types_ = []
1817  input_types_ = []
1818  input_annotations = []
1819 
1820  sizer = None
1821  if sig.sizer is not None:
1822  expr = sig.sizer.value
1823  sizer = Bracket('kPreFlightParameter', (expr,))
1824 
1825  uses_manager = False
1826  for i, (t, annot) in enumerate(zip(sig.inputs, sig.input_annotations)):
1827  if t.is_output_buffer_sizer():
1828  if t.is_user_specified():
1829  sql_types_.append(Bracket.parse('int32').normalize(kind='input'))
1830  input_types_.append(sql_types_[-1])
1831  input_annotations.append(annot)
1832  assert sizer is None # exactly one sizer argument is allowed
1833  assert len(t.args) == 1, t
1834  sizer = t
1835  elif t.name == 'Cursor':
1836  for t_ in t.args:
1837  input_types_.append(t_)
1838  input_annotations.append(annot)
1839  sql_types_.append(Bracket('Cursor', args=()))
1840  elif t.name == 'TableFunctionManager':
1841  if i != 0:
1842  raise ValueError('{} must appear as a first argument of {}, but found it at position {}.'.format(t, sig.name, i))
1843  uses_manager = True
1844  else:
1845  input_types_.append(t)
1846  input_annotations.append(annot)
1847  if t.is_column_any():
1848  # XXX: let Bracket handle mapping of column to cursor(column)
1849  sql_types_.append(Bracket('Cursor', args=()))
1850  else:
1851  sql_types_.append(t)
1852 
1853  if sizer is None:
1854  name = 'kTableFunctionSpecifiedParameter'
1855  idx = 1 # this sizer is not actually materialized in the UDTF
1856  sizer = Bracket(name, (idx,))
1857 
1858  assert sizer is not None
1859  ns_output_types = tuple([a.apply_namespace(ns='ExtArgumentType') for a in sig.outputs])
1860  ns_input_types = tuple([t.apply_namespace(ns='ExtArgumentType') for t in input_types_])
1861  ns_sql_types = tuple([t.apply_namespace(ns='ExtArgumentType') for t in sql_types_])
1862 
1863  sig.function_annotations.append(('uses_manager', str(uses_manager).lower()))
1864 
1865  input_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(tostring, ns_input_types)))
1866  output_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(tostring, ns_output_types)))
1867  sql_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(tostring, ns_sql_types)))
1868  annotations = format_annotations(input_annotations + sig.output_annotations + [sig.function_annotations])
1869 
1870  # Notice that input_types and sig.input_types, (and
1871  # similarly, input_annotations and sig.input_annotations)
1872  # have different lengths when the sizer argument is
1873  # Constant or TableFunctionSpecifiedParameter. That is,
1874  # input_types contains all the user-specified arguments
1875  # while sig.input_types contains all arguments of the
1876  # implementation of an UDTF.
1877 
1878  if must_emit_preflight_function(sig, sizer):
1879  fn_name = '%s_%s' % (sig.name, str(counter)) if is_template_function(sig) else sig.name
1880  check_fn = build_preflight_function(fn_name, sizer, input_types_, sig.outputs, uses_manager)
1881  cond_fns.append(check_fn)
1882 
1883  if is_template_function(sig):
1884  name = sig.name + '_' + str(counter)
1885  counter += 1
1886  t = build_template_function_call(name, sig.name, input_types_, sig.outputs, uses_manager)
1887  address_expression = ('avoid_opt_address(reinterpret_cast<void*>(%s))' % name)
1888  if is_cpu_function(sig):
1889  cpu_template_functions.append(t)
1890  cpu_function_address_expressions.append(address_expression)
1891  if is_gpu_function(sig):
1892  gpu_template_functions.append(t)
1893  gpu_function_address_expressions.append(address_expression)
1894  add = ('TableFunctionsFactory::add("%s", %s, %s, %s, %s, %s, /*is_runtime:*/false);'
1895  % (name, sizer.format_sizer(), input_types, output_types, sql_types, annotations))
1896  add_stmts.append(add)
1897 
1898  else:
1899  add = ('TableFunctionsFactory::add("%s", %s, %s, %s, %s, %s, /*is_runtime:*/false);'
1900  % (sig.name, sizer.format_sizer(), input_types, output_types, sql_types, annotations))
1901  add_stmts.append(add)
1902  address_expression = ('avoid_opt_address(reinterpret_cast<void*>(%s))' % sig.name)
1903 
1904  if is_cpu_function(sig):
1905  cpu_function_address_expressions.append(address_expression)
1906  if is_gpu_function(sig):
1907  gpu_function_address_expressions.append(address_expression)
1908 
1909  return add_stmts, cpu_template_functions, gpu_template_functions, cpu_function_address_expressions, gpu_function_address_expressions, cond_fns
1910 
1911 
1912 if len(sys.argv) < 3:
1913 
1914  input_files = [os.path.join(os.path.dirname(__file__), 'test_udtf_signatures.hpp')]
1915  print('Running tests from %s' % (', '.join(input_files)))
1916  add_stmts, _, _, _, _, _ = parse_annotations(input_files)
1917 
1918  print('Usage:\n %s %s input1.hpp input2.hpp ... output.hpp' % (sys.executable, sys.argv[0], ))
1919 
1920  sys.exit(1)
1921 
1922 input_files, output_filename = sys.argv[1:-1], sys.argv[-1]
1923 cpu_output_header = os.path.splitext(output_filename)[0] + '_cpu.hpp'
1924 gpu_output_header = os.path.splitext(output_filename)[0] + '_gpu.hpp'
1925 assert input_files, sys.argv
1926 
1927 add_stmts, cpu_template_functions, gpu_template_functions, cpu_address_expressions, gpu_address_expressions, cond_fns = parse_annotations(sys.argv[1:-1])
1928 
1929 canonical_input_files = [input_file[input_file.find("/QueryEngine/") + 1:] for input_file in input_files]
1930 header_includes = ['#include "' + canonical_input_file + '"' for canonical_input_file in canonical_input_files]
1931 
1932 # Split up calls to TableFunctionsFactory::add() into chunks
1933 ADD_FUNC_CHUNK_SIZE = 100
1934 
1935 def add_method(i, chunk):
1936  return '''
1937  NO_OPT_ATTRIBUTE void add_table_functions_%d() const {
1938  %s
1939  }
1940 ''' % (i, '\n '.join(chunk))
1941 
1942 def add_methods(add_stmts):
1943  chunks = [ add_stmts[n:n+ADD_FUNC_CHUNK_SIZE] for n in range(0, len(add_stmts), ADD_FUNC_CHUNK_SIZE) ]
1944  return [ add_method(i,chunk) for i,chunk in enumerate(chunks) ]
1945 
1946 def call_methods(add_stmts):
1947  quot, rem = divmod(len(add_stmts), ADD_FUNC_CHUNK_SIZE)
1948  return [ 'add_table_functions_%d();' % (i) for i in range(quot + int(0 < rem)) ]
1949 
1950 content = '''
1951 /*
1952  This file is generated by %s. Do no edit!
1953 */
1954 
1955 #include "QueryEngine/TableFunctions/TableFunctionsFactory.h"
1956 %s
1957 
1958 /*
1959  Include the UDTF template initiations:
1960 */
1961 #include "TableFunctionsFactory_init_cpu.hpp"
1962 
1963 // volatile+noinline prevents compiler optimization
1964 #ifdef _WIN32
1965 __declspec(noinline)
1966 #else
1967  __attribute__((noinline))
1968 #endif
1969 volatile
1970 bool avoid_opt_address(void *address) {
1971  return address != nullptr;
1972 }
1973 
1974 bool functions_exist() {
1975  bool ret = true;
1976 
1977  ret &= (%s);
1978 
1979  return ret;
1980 }
1981 
1982 extern bool g_enable_table_functions;
1983 
1984 namespace table_functions {
1985 
1986 std::once_flag init_flag;
1987 
1988 #if defined(__clang__)
1989 #define NO_OPT_ATTRIBUTE __attribute__((optnone))
1990 
1991 #elif defined(__GNUC__) || defined(__GNUG__)
1992 #define NO_OPT_ATTRIBUTE __attribute((optimize("O0")))
1993 
1994 #elif defined(_MSC_VER)
1995 #define NO_OPT_ATTRIBUTE
1996 
1997 #endif
1998 
1999 #if defined(_MSC_VER)
2000 #pragma optimize("", off)
2001 #endif
2002 
2003 struct AddTableFunctions {
2004 %s
2005  NO_OPT_ATTRIBUTE void operator()() {
2006  %s
2007  }
2008 };
2009 
2010 void TableFunctionsFactory::init() {
2011  if (!g_enable_table_functions) {
2012  return;
2013  }
2014 
2015  if (!functions_exist()) {
2016  UNREACHABLE();
2017  return;
2018  }
2019 
2020  std::call_once(init_flag, AddTableFunctions{});
2021 }
2022 #if defined(_MSC_VER)
2023 #pragma optimize("", on)
2024 #endif
2025 
2026 // conditional check functions
2027 %s
2028 
2029 } // namespace table_functions
2030 
2031 ''' % (sys.argv[0],
2032  '\n'.join(header_includes),
2033  ' &&\n'.join(cpu_address_expressions),
2034  ''.join(add_methods(add_stmts)),
2035  '\n '.join(call_methods(add_stmts)),
2036  ''.join(cond_fns))
2037 
2038 header_content = '''
2039 /*
2040  This file is generated by %s. Do no edit!
2041 */
2042 %s
2043 
2044 %s
2045 '''
2046 
2047 dirname = os.path.dirname(output_filename)
2048 
2049 if dirname and not os.path.exists(dirname):
2050  try:
2051  os.makedirs(dirname)
2052  except OSError as e:
2053  import errno
2054  if e.errno != errno.EEXIST:
2055  raise
2056 
2057 f = open(output_filename, 'w')
2058 f.write(content)
2059 f.close()
2060 
2061 f = open(cpu_output_header, 'w')
2062 f.write(header_content % (sys.argv[0], '\n'.join(header_includes), '\n'.join(cpu_template_functions)))
2063 f.close()
2064 
2065 f = open(gpu_output_header, 'w')
2066 f.write(header_content % (sys.argv[0], '\n'.join(header_includes), '\n'.join(gpu_template_functions)))
2067 f.close()
std::string strip(std::string_view str)
trim any whitespace from the left and right ends of a string
std::string join(T const &container, std::string const &delim)
std::vector< std::string > split(std::string_view str, std::string_view delim, std::optional< size_t > maxsplit)
split apart a string into a vector of substrings
int open(const char *path, int flags, int mode)
Definition: omnisci_fs.cpp:64