OmniSciDB  c0231cc57d
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
generate_TableFunctionsFactory_init.py
Go to the documentation of this file.
1 """Given a list of input files, scan for lines containing UDTF
2 specification statements in the following form:
3 
4  UDTF: function_name(<arguments>) -> <output column types> (, <template type specifications>)?
5 
6 where <arguments> is a comma-separated list of argument types. The
7 argument types specifications are:
8 
9 - scalar types:
10  Int8, Int16, Int32, Int64, Float, Double, Bool, TextEncodingDict, etc
11 - column types:
12  ColumnInt8, ColumnInt16, ColumnInt32, ColumnInt64, ColumnFloat, ColumnDouble, ColumnBool, etc
13 - column list types:
14  ColumnListInt8, ColumnListInt16, ColumnListInt32, ColumnListInt64, ColumnListFloat, ColumnListDouble, ColumnListBool, etc
15 - cursor type:
16  Cursor<t0, t1, ...>
17  where t0, t1 are column or column list types
18 - output buffer size parameter type:
19  RowMultiplier<i>, ConstantParameter<i>, Constant<i>, TableFunctionSpecifiedParameter<i>
20  where i is a literal integer.
21 
22 The output column types is a comma-separated list of column types, see above.
23 
24 In addition, the following equivalents are suppored:
25 
26  Column<T> == ColumnT
27  ColumnList<T> == ColumnListT
28  Cursor<T, V, ...> == Cursor<ColumnT, ColumnV, ...>
29  int8 == int8_t == Int8, etc
30  float == Float, double == Double, bool == Bool
31  T == ColumnT for output column types
32  RowMultiplier == RowMultiplier<i> where i is the one-based position of the sizer argument
33  when no sizer argument is provided, Constant<1> is assumed
34 
35 Argument types can be annotated using `|' (bar) symbol after an
36 argument type specification. An annotation is specified by a label and
37 a value separated by `=' (equal) symbol. Multiple annotations can be
38 specified by using `|` (bar) symbol as the annotations separator.
39 Supported annotation labels are:
40 
41 - name: to specify argument name
42 - input_id: to specify the dict id mapping for output TextEncodingDict columns.
43 
44 If argument type follows an identifier, it will be mapped to name
45 annotations. For example, the following argument type specifications
46 are equivalent:
47 
48  Int8 a
49  Int8 | name=a
50 
51 Template type specifications is a comma separated list of template
52 type assignments where values are lists of argument type names. For
53 instance:
54 
55  T = [Int8, Int16, Int32, Float], V = [Float, Double]
56 
57 """
58 # Author: Pearu Peterson
59 # Created: January 2021
60 
61 
62 import os
63 import sys
64 import itertools
65 import copy
66 from abc import abstractmethod
67 
68 from collections import deque, namedtuple
69 
70 if sys.version_info > (3, 0):
71  from abc import ABC
72  from collections.abc import Iterable
73 else:
74  from abc import ABCMeta as ABC
75  from collections import Iterable
76 
77 # fmt: off
78 separator = '$=>$'
79 
80 Signature = namedtuple('Signature', ['name', 'inputs', 'outputs', 'input_annotations', 'output_annotations', 'function_annotations', 'sizer'])
81 
82 OutputBufferSizeTypes = '''
83 kConstant, kUserSpecifiedConstantParameter, kUserSpecifiedRowMultiplier, kTableFunctionSpecifiedParameter, kPreFlightParameter
84 '''.strip().replace(' ', '').split(',')
85 
86 SupportedAnnotations = '''
87 input_id, name, fields, require, range
88 '''.strip().replace(' ', '').split(',')
89 
90 # TODO: support `gpu`, `cpu`, `template` as function annotations
91 SupportedFunctionAnnotations = '''
92 filter_table_function_transpose, uses_manager
93 '''.strip().replace(' ', '').split(',')
94 
95 translate_map = dict(
96  Constant='kConstant',
97  PreFlight='kPreFlightParameter',
98  ConstantParameter='kUserSpecifiedConstantParameter',
99  RowMultiplier='kUserSpecifiedRowMultiplier',
100  UserSpecifiedConstantParameter='kUserSpecifiedConstantParameter',
101  UserSpecifiedRowMultiplier='kUserSpecifiedRowMultiplier',
102  TableFunctionSpecifiedParameter='kTableFunctionSpecifiedParameter',
103  short='Int16',
104  int='Int32',
105  long='Int64',
106 )
107 for t in ['Int8', 'Int16', 'Int32', 'Int64', 'Float', 'Double', 'Bool',
108  'TextEncodingDict', 'TextEncodingNone']:
109  translate_map[t.lower()] = t
110  if t.startswith('Int'):
111  translate_map[t.lower() + '_t'] = t
112 
113 
115  """Holds a `TYPE | ANNOTATIONS`-like structure.
116  """
117  def __init__(self, type, annotations=[]):
118  self.type = type
119  self.annotations = annotations
120 
121  @property
122  def name(self):
123  return self.type.name
124 
125  @property
126  def args(self):
127  return self.type.args
128 
129  def format_sizer(self):
130  return self.type.format_sizer()
131 
132  def __repr__(self):
133  return 'Declaration(%r, ann=%r)' % (self.type, self.annotations)
134 
135  def __str__(self):
136  if not self.annotations:
137  return str(self.type)
138  return '%s | %s' % (self.type, ' | '.join(map(str, self.annotations)))
139 
140  def tostring(self):
141  return self.type.tostring()
142 
143  def apply_column(self):
144  return self.__class__(self.type.apply_column(), self.annotations)
145 
146  def apply_namespace(self, ns='ExtArgumentType'):
147  return self.__class__(self.type.apply_namespace(ns), self.annotations)
148 
149  def get_cpp_type(self):
150  return self.type.get_cpp_type()
151 
152  def format_cpp_type(self, idx, use_generic_arg_name=False, is_input=True):
153  real_arg_name = dict(self.annotations).get('name', None)
154  return self.type.format_cpp_type(idx,
155  use_generic_arg_name=use_generic_arg_name,
156  real_arg_name=real_arg_name,
157  is_input=is_input)
158 
159  def __getattr__(self, name):
160  if name.startswith('is_'):
161  return getattr(self.type, name)
162  raise AttributeError(name)
163 
164 
165 def tostring(obj):
166  return obj.tostring()
167 
168 
169 class Bracket:
170  """Holds a `NAME<ARGS>`-like structure.
171  """
172 
173  def __init__(self, name, args=None):
174  assert isinstance(name, str)
175  assert isinstance(args, tuple) or args is None, args
176  self.name = name
177  self.args = args
178 
179  def __repr__(self):
180  return 'Bracket(%r, args=%r)' % (self.name, self.args)
181 
182  def __str__(self):
183  if not self.args:
184  return self.name
185  return '%s<%s>' % (self.name, ', '.join(map(str, self.args)))
186 
187  def tostring(self):
188  if not self.args:
189  return self.name
190  return '%s<%s>' % (self.name, ', '.join(map(tostring, self.args)))
191 
192  def normalize(self, kind='input'):
193  """Normalize bracket for given kind
194  """
195  assert kind in ['input', 'output'], kind
196  if self.is_column_any() and self.args:
197  return Bracket(self.name + ''.join(map(str, self.args)))
198  if kind == 'input':
199  if self.name == 'Cursor':
200  args = [(a if a.is_column_any() else Bracket('Column', args=(a,))).normalize(kind=kind) for a in self.args]
201  return Bracket(self.name, tuple(args))
202  if kind == 'output':
203  if not self.is_column_any():
204  return Bracket('Column', args=(self,)).normalize(kind=kind)
205  return self
206 
207  def apply_cursor(self):
208  """Apply cursor to a non-cursor column argument type.
209  TODO: this method is currently unused but we should apply
210  cursor to all input column arguments in order to distingush
211  signatures like:
212  foo(Cursor(Column<int32>, Column<float>)) -> Column<int32>
213  foo(Cursor(Column<int32>), Cursor(Column<float>)) -> Column<int32>
214  that at the moment are treated as the same :(
215  """
216  if self.is_column():
217  return Bracket('Cursor', args=(self,))
218  return self
219 
220  def apply_column(self):
221  if not self.is_column() and not self.is_column_list():
222  return Bracket('Column' + self.name)
223  return self
224 
225  def apply_namespace(self, ns='ExtArgumentType'):
226  if self.name == 'Cursor':
227  return Bracket(ns + '::' + self.name, args=tuple(a.apply_namespace(ns=ns) for a in self.args))
228  if not self.name.startswith(ns + '::'):
229  return Bracket(ns + '::' + self.name)
230  return self
231 
232  def is_cursor(self):
233  return self.name.rsplit("::", 1)[-1] == 'Cursor'
234 
235  def is_array(self):
236  return self.name.rsplit("::", 1)[-1].startswith('Array')
237 
238  def is_column_any(self):
239  return self.name.rsplit("::", 1)[-1].startswith('Column')
240 
241  def is_column_list(self):
242  return self.name.rsplit("::", 1)[-1].startswith('ColumnList')
243 
244  def is_column(self):
245  return self.name.rsplit("::", 1)[-1].startswith('Column') and not self.is_column_list()
246 
248  return self.name.rsplit("::", 1)[-1].endswith('TextEncodingDict')
249 
251  return self.name.rsplit("::", 1)[-1] == 'ArrayTextEncodingDict'
252 
254  return self.name.rsplit("::", 1)[-1] == 'ColumnTextEncodingDict'
255 
257  return self.name.rsplit("::", 1)[-1] == 'ColumnArrayTextEncodingDict'
258 
260  return self.name.rsplit("::", 1)[-1] == 'ColumnListTextEncodingDict'
261 
263  return self.name.rsplit("::", 1)[-1] in OutputBufferSizeTypes
264 
265  def is_row_multiplier(self):
266  return self.name.rsplit("::", 1)[-1] == 'kUserSpecifiedRowMultiplier'
267 
268  def is_arg_sizer(self):
269  return self.name.rsplit("::", 1)[-1] == 'kPreFlightParameter'
270 
271  def is_user_specified(self):
272  # Return True if given argument cannot specified by user
273  if self.is_output_buffer_sizer():
274  return self.name.rsplit("::", 1)[-1] not in ('kConstant', 'kTableFunctionSpecifiedParameter', 'kPreFlightParameter')
275  return True
276 
277  def format_sizer(self):
278  val = 0 if self.is_arg_sizer() else self.args[0]
279  return 'TableFunctionOutputRowSizer{OutputBufferSizeType::%s, %s}' % (self.name, val)
280 
281  def get_cpp_type(self):
282  name = self.name.rsplit("::", 1)[-1]
283  clsname = None
284  subclsname = None
285  if name.startswith('ColumnList'):
286  name = name.lstrip('ColumnList')
287  clsname = 'ColumnList'
288  elif name.startswith('Column'):
289  name = name.lstrip('Column')
290  clsname = 'Column'
291  if name.startswith('Array'):
292  name = name.lstrip('Array')
293  if clsname is None:
294  clsname = 'Array'
295  else:
296  subclsname = 'Array'
297 
298  if name.startswith('Bool'):
299  ctype = name.lower()
300  elif name.startswith('Int'):
301  ctype = name.lower() + '_t'
302  elif name in ['Double', 'Float']:
303  ctype = name.lower()
304  elif name == 'TextEncodingDict':
305  ctype = name
306  elif name == 'TextEncodingNone':
307  ctype = name
308  elif name == 'Timestamp':
309  ctype = name
310  else:
311  raise NotImplementedError(self)
312  if clsname is None:
313  return ctype
314  if subclsname is None:
315  return '%s<%s>' % (clsname, ctype)
316  return '%s<%s<%s>>' % (clsname, subclsname, ctype)
317 
318  def format_cpp_type(self, idx, use_generic_arg_name=False, real_arg_name=None, is_input=True):
319  col_typs = ('Column', 'ColumnList')
320  literal_ref_typs = ('TextEncodingNone',)
321  if use_generic_arg_name:
322  arg_name = 'input' + str(idx) if is_input else 'output' + str(idx)
323  elif real_arg_name is not None:
324  arg_name = real_arg_name
325  else:
326  # in some cases, the real arg name is not specified
327  arg_name = 'input' + str(idx) if is_input else 'output' + str(idx)
328  const = 'const ' if is_input else ''
329  cpp_type = self.get_cpp_type()
330  if any(cpp_type.startswith(t) for t in col_typs + literal_ref_typs):
331  return '%s%s& %s' % (const, cpp_type, arg_name), arg_name
332  else:
333  return '%s %s' % (cpp_type, arg_name), arg_name
334 
335  @classmethod
336  def parse(cls, typ):
337  """typ is a string in format NAME<ARGS> or NAME
338 
339  Returns Bracket instance.
340  """
341  i = typ.find('<')
342  if i == -1:
343  name = typ.strip()
344  args = None
345  else:
346  assert typ.endswith('>'), typ
347  name = typ[:i].strip()
348  args = []
349  rest = typ[i + 1:-1].strip()
350  while rest:
351  i = find_comma(rest)
352  if i == -1:
353  a, rest = rest, ''
354  else:
355  a, rest = rest[:i].rstrip(), rest[i + 1:].lstrip()
356  args.append(cls.parse(a))
357  args = tuple(args)
358 
359  name = translate_map.get(name, name)
360  return cls(name, args)
361 
362 
363 def find_comma(line):
364  d = 0
365  for i, c in enumerate(line):
366  if c in '<([{':
367  d += 1
368  elif c in '>)]{':
369  d -= 1
370  elif d == 0 and c == ',':
371  return i
372  return -1
373 
374 
376  # TODO: try to parse the line to be certain about completeness.
377  # `$=>$' is used to separate the UDTF signature and the expected result
378  return line.endswith(',') or line.endswith('->') or line.endswith(separator) or line.endswith('|')
379 
380 
381 def is_identifier_cursor(identifier):
382  return identifier.lower() == 'cursor'
383 
384 
385 # fmt: on
386 
387 
388 class TokenizeException(Exception):
389  pass
390 
391 
392 class ParserException(Exception):
393  pass
394 
395 
396 class TransformerException(Exception):
397  pass
398 
399 
400 class Token:
401  LESS = 1 # <
402  GREATER = 2 # >
403  COMMA = 3 # ,
404  EQUAL = 4 # =
405  RARROW = 5 # ->
406  STRING = 6 # reserved for string constants
407  NUMBER = 7 #
408  VBAR = 8 # |
409  BANG = 9 # !
410  LPAR = 10 # (
411  RPAR = 11 # )
412  LSQB = 12 # [
413  RSQB = 13 # ]
414  IDENTIFIER = 14 #
415  COLON = 15 # :
416 
417  def __init__(self, type, lexeme):
418  """
419  Parameters
420  ----------
421  type : int
422  One of the tokens in the list above
423  lexeme : str
424  Corresponding string in the text
425  """
426  self.type = type
427  self.lexeme = lexeme
428 
429  @classmethod
430  def tok_name(cls, token):
431  names = {
432  Token.LESS: "LESS",
433  Token.GREATER: "GREATER",
434  Token.COMMA: "COMMA",
435  Token.EQUAL: "EQUAL",
436  Token.RARROW: "RARROW",
437  Token.STRING: "STRING",
438  Token.NUMBER: "NUMBER",
439  Token.VBAR: "VBAR",
440  Token.BANG: "BANG",
441  Token.LPAR: "LPAR",
442  Token.RPAR: "RPAR",
443  Token.LSQB: "LSQB",
444  Token.RSQB: "RSQB",
445  Token.IDENTIFIER: "IDENTIFIER",
446  Token.COLON: "COLON",
447  }
448  return names.get(token)
449 
450  def __str__(self):
451  return 'Token(%s, "%s")' % (Token.tok_name(self.type), self.lexeme)
452 
453  __repr__ = __str__
454 
455 
456 class Tokenize:
457  def __init__(self, line):
458  self._line = line
459  self._tokens = []
460  self.start = 0
461  self.curr = 0
462  self.tokenize()
463 
464  @property
465  def line(self):
466  return self._line
467 
468  @property
469  def tokens(self):
470  return self._tokens
471 
472  def tokenize(self):
473  while not self.is_at_end():
474  self.start = self.curr
475 
476  if self.is_token_whitespace():
477  self.consume_whitespace()
478  elif self.is_digit():
479  self.consume_number()
480  elif self.is_token_string():
481  self.consume_string()
482  elif self.is_token_identifier():
483  self.consume_identifier()
484  elif self.can_token_be_double_char():
485  self.consume_double_char()
486  else:
487  self.consume_single_char()
488 
489  def is_at_end(self):
490  return len(self.line) == self.curr
491 
492  def current_token(self):
493  return self.line[self.start:self.curr + 1]
494 
495  def add_token(self, type):
496  lexeme = self.line[self.start:self.curr + 1]
497  self._tokens.append(Token(type, lexeme))
498 
499  def lookahead(self):
500  if self.curr + 1 >= len(self.line):
501  return None
502  return self.line[self.curr + 1]
503 
504  def advance(self):
505  self.curr += 1
506 
507  def peek(self):
508  return self.line[self.curr]
509 
511  char = self.peek()
512  return char in ("-",)
513 
515  ahead = self.lookahead()
516  if ahead == ">":
517  self.advance()
518  self.add_token(Token.RARROW) # ->
519  self.advance()
520  else:
521  self.raise_tokenize_error()
522 
524  char = self.peek()
525  if char == "(":
526  self.add_token(Token.LPAR)
527  elif char == ")":
528  self.add_token(Token.RPAR)
529  elif char == "<":
530  self.add_token(Token.LESS)
531  elif char == ">":
532  self.add_token(Token.GREATER)
533  elif char == ",":
534  self.add_token(Token.COMMA)
535  elif char == "=":
536  self.add_token(Token.EQUAL)
537  elif char == "|":
538  self.add_token(Token.VBAR)
539  elif char == "!":
540  self.add_token(Token.BANG)
541  elif char == "[":
542  self.add_token(Token.LSQB)
543  elif char == "]":
544  self.add_token(Token.RSQB)
545  elif char == ":":
546  self.add_token(Token.COLON)
547  else:
548  self.raise_tokenize_error()
549  self.advance()
550 
552  self.advance()
553 
554  def consume_string(self):
555  """
556  STRING: \".*?\"
557  """
558  while True:
559  char = self.lookahead()
560  curr = self.peek()
561  if char == '"' and curr != '\\':
562  self.advance()
563  break
564  self.advance()
565  self.add_token(Token.STRING)
566  self.advance()
567 
568  def consume_number(self):
569  """
570  NUMBER: [0-9]+
571  """
572  while True:
573  char = self.lookahead()
574  if char and char.isdigit():
575  self.advance()
576  else:
577  break
578  self.add_token(Token.NUMBER)
579  self.advance()
580 
582  """
583  IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
584  """
585  while True:
586  char = self.lookahead()
587  if char and char.isalnum() or char == "_":
588  self.advance()
589  else:
590  break
591  self.add_token(Token.IDENTIFIER)
592  self.advance()
593 
595  return self.peek().isalpha() or self.peek() == "_"
596 
597  def is_token_string(self):
598  return self.peek() == '"'
599 
600  def is_digit(self):
601  return self.peek().isdigit()
602 
603  def is_alpha(self):
604  return self.peek().isalpha()
605 
607  return self.peek().isspace()
608 
610  curr = self.curr
611  char = self.peek()
612  raise TokenizeException(
613  'Could not match char "%s" at pos %d on line\n %s' % (char, curr, self.line)
614  )
615 
616 
617 class AstVisitor(object):
618  __metaclass__ = ABC
619 
620  @abstractmethod
621  def visit_udtf_node(self, node):
622  pass
623 
624  @abstractmethod
625  def visit_composed_node(self, node):
626  pass
627 
628  @abstractmethod
629  def visit_arg_node(self, node):
630  pass
631 
632  @abstractmethod
633  def visit_primitive_node(self, node):
634  pass
635 
636  @abstractmethod
637  def visit_annotation_node(self, node):
638  pass
639 
640  @abstractmethod
641  def visit_template_node(self, node):
642  pass
643 
644 
645 class AstTransformer(AstVisitor):
646  """Only overload the methods you need"""
647 
648  def visit_udtf_node(self, udtf_node):
649  udtf = copy.copy(udtf_node)
650  udtf.inputs = [arg.accept(self) for arg in udtf.inputs]
651  udtf.outputs = [arg.accept(self) for arg in udtf.outputs]
652  if udtf.templates:
653  udtf.templates = [t.accept(self) for t in udtf.templates]
654  udtf.annotations = [annot.accept(self) for annot in udtf.annotations]
655  return udtf
656 
657  def visit_composed_node(self, composed_node):
658  c = copy.copy(composed_node)
659  c.inner = [i.accept(self) for i in c.inner]
660  return c
661 
662  def visit_arg_node(self, arg_node):
663  arg_node = copy.copy(arg_node)
664  arg_node.type = arg_node.type.accept(self)
665  if arg_node.annotations:
666  arg_node.annotations = [a.accept(self) for a in arg_node.annotations]
667  return arg_node
668 
669  def visit_primitive_node(self, primitive_node):
670  return copy.copy(primitive_node)
671 
672  def visit_template_node(self, template_node):
673  return copy.copy(template_node)
674 
675  def visit_annotation_node(self, annotation_node):
676  return copy.copy(annotation_node)
677 
678 
680  """Returns a line formatted. Useful for testing"""
681 
682  def visit_udtf_node(self, udtf_node):
683  name = udtf_node.name
684  inputs = ", ".join([arg.accept(self) for arg in udtf_node.inputs])
685  outputs = ", ".join([arg.accept(self) for arg in udtf_node.outputs])
686  annotations = "| ".join([annot.accept(self) for annot in udtf_node.annotations])
687  sizer = " | " + udtf_node.sizer.accept(self) if udtf_node.sizer else ""
688  if annotations:
689  annotations = ' | ' + annotations
690  if udtf_node.templates:
691  templates = ", ".join([t.accept(self) for t in udtf_node.templates])
692  return "%s(%s)%s -> %s, %s%s" % (name, inputs, annotations, outputs, templates, sizer)
693  else:
694  return "%s(%s)%s -> %s%s" % (name, inputs, annotations, outputs, sizer)
695 
696  def visit_template_node(self, template_node):
697  # T=[T1, T2, ..., TN]
698  key = template_node.key
699  types = ['"%s"' % typ for typ in template_node.types]
700  return "%s=[%s]" % (key, ", ".join(types))
701 
702  def visit_annotation_node(self, annotation_node):
703  # key=value
704  key = annotation_node.key
705  value = annotation_node.value
706  if isinstance(value, list):
707  return "%s=[%s]" % (key, ','.join([v.accept(self) for v in value]))
708  return "%s=%s" % (key, value)
709 
710  def visit_arg_node(self, arg_node):
711  # type | annotation
712  typ = arg_node.type.accept(self)
713  if arg_node.annotations:
714  ann = " | ".join([a.accept(self) for a in arg_node.annotations])
715  s = "%s | %s" % (typ, ann)
716  else:
717  s = "%s" % (typ,)
718  # insert input_id=args<0> if input_id is not specified
719  if s == "ColumnTextEncodingDict" and arg_node.kind == "output":
720  return s + " | input_id=args<0>"
721  return s
722 
723  def visit_composed_node(self, composed_node):
724  T = composed_node.inner[0].accept(self)
725  if composed_node.is_array():
726  # Array<T>
727  assert len(composed_node.inner) == 1
728  return "Array" + T
729  if composed_node.is_column():
730  # Column<T>
731  assert len(composed_node.inner) == 1
732  return "Column" + T
733  if composed_node.is_column_list():
734  # ColumnList<T>
735  assert len(composed_node.inner) == 1
736  return "ColumnList" + T
737  if composed_node.is_output_buffer_sizer():
738  # kConstant<N>
739  N = T
740  assert len(composed_node.inner) == 1
741  return translate_map.get(composed_node.type) + "<%s>" % (N,)
742  if composed_node.is_cursor():
743  # Cursor<T1, T2, ..., TN>
744  Ts = ", ".join([i.accept(self) for i in composed_node.inner])
745  return "Cursor<%s>" % (Ts)
746  raise ValueError(composed_node)
747 
748  def visit_primitive_node(self, primitive_node):
749  t = primitive_node.type
750  if primitive_node.is_output_buffer_sizer():
751  # arg_pos is zero-based
752  return translate_map.get(t, t) + "<%d>" % (
753  primitive_node.get_parent(ArgNode).arg_pos + 1,
754  )
755  return translate_map.get(t, t)
756 
757 
759  """Like AstPrinter but returns a node instead of a string
760  """
761 
762 
763 def product_dict(**kwargs):
764  keys = kwargs.keys()
765  vals = kwargs.values()
766  for instance in itertools.product(*vals):
767  yield dict(zip(keys, instance))
768 
769 
771  """Expand template definition into multiple inputs"""
772 
773  def visit_udtf_node(self, udtf_node):
774  if not udtf_node.templates:
775  return udtf_node
776 
777  udtfs = dict()
778 
779  d = dict([(node.key, node.types) for node in udtf_node.templates])
780  name = udtf_node.name
781 
782  for product in product_dict(**d):
783  self.mapping_dict = product
784  inputs = [input_arg.accept(self) for input_arg in udtf_node.inputs]
785  outputs = [output_arg.accept(self) for output_arg in udtf_node.outputs]
786  udtf = UdtfNode(name, inputs, outputs, udtf_node.annotations, None, udtf_node.sizer, udtf_node.line)
787  udtfs[str(udtf)] = udtf
788  self.mapping_dict = {}
789 
790  udtfs = list(udtfs.values())
791 
792  if len(udtfs) == 1:
793  return udtfs[0]
794 
795  return udtfs
796 
797  def visit_composed_node(self, composed_node):
798  typ = composed_node.type
799  typ = self.mapping_dict.get(typ, typ)
800 
801  inner = [i.accept(self) for i in composed_node.inner]
802  return composed_node.copy(typ, inner)
803 
804  def visit_primitive_node(self, primitive_node):
805  typ = primitive_node.type
806  typ = self.mapping_dict.get(typ, typ)
807  return primitive_node.copy(typ)
808 
809 
811  def visit_primitive_node(self, primitive_node):
812  """
813  * Fix kUserSpecifiedRowMultiplier without a pos arg
814  """
815  t = primitive_node.type
816 
817  if primitive_node.is_output_buffer_sizer():
818  pos = PrimitiveNode(str(primitive_node.get_parent(ArgNode).arg_pos + 1))
819  node = ComposedNode(t, inner=[pos])
820  return node
821 
822  return primitive_node
823 
824 
826  def visit_primitive_node(self, primitive_node):
827  """
828  * Rename nodes using translate_map as dictionary
829  int -> Int32
830  float -> Float
831  """
832  t = primitive_node.type
833  return primitive_node.copy(translate_map.get(t, t))
834 
835 
837  def visit_udtf_node(self, udtf_node):
838  """
839  * Add default_input_id to Column(List)<TextEncodingDict> without one
840  """
841  udtf_node = super(type(self), self).visit_udtf_node(udtf_node)
842  # add default input_id
843  default_input_id = None
844  for idx, t in enumerate(udtf_node.inputs):
845 
846  if not isinstance(t.type, ComposedNode):
847  continue
848  if default_input_id is not None:
849  pass
850  elif t.type.is_column_text_encoding_dict() or t.type.is_column_array_text_encoding_dict():
851  default_input_id = AnnotationNode('input_id', 'args<%s>' % (idx,))
852  elif t.type.is_column_list_text_encoding_dict():
853  default_input_id = AnnotationNode('input_id', 'args<%s, 0>' % (idx,))
854 
855  for t in udtf_node.outputs:
856  if isinstance(t.type, ComposedNode) and t.type.is_any_text_encoding_dict():
857  for a in t.annotations:
858  if a.key == 'input_id':
859  break
860  else:
861  if default_input_id is None:
862  raise TypeError('Cannot parse line "%s".\n'
863  'Missing TextEncodingDict input?' %
864  (udtf_node.line))
865  t.annotations.append(default_input_id)
866 
867  return udtf_node
868 
869 
871 
872  def visit_udtf_node(self, udtf_node):
873  """
874  * Generate fields annotation to Cursor if non-existing
875  """
876  udtf_node = super(type(self), self).visit_udtf_node(udtf_node)
877 
878  for t in udtf_node.inputs:
879 
880  if not isinstance(t.type, ComposedNode):
881  continue
882 
883  if t.type.is_cursor() and t.get_annotation('fields') is None:
884  fields = list(PrimitiveNode(a.get_annotation('name', 'field%s' % i)) for i, a in enumerate(t.type.inner))
885  t.annotations.append(AnnotationNode('fields', fields))
886 
887  return udtf_node
888 
890  """
891  * Checks for supported annotations in a UDTF
892  """
893  def visit_udtf_node(self, udtf_node):
894  for t in udtf_node.inputs:
895  for a in t.annotations:
896  if a.key not in SupportedAnnotations:
897  raise TransformerException('unknown input annotation: `%s`' % (a.key))
898  for t in udtf_node.outputs:
899  for a in t.annotations:
900  if a.key not in SupportedAnnotations:
901  raise TransformerException('unknown output annotation: `%s`' % (a.key))
902  for annot in udtf_node.annotations:
903  if annot.key not in SupportedFunctionAnnotations:
904  raise TransformerException('unknown function annotation: `%s`' % (annot.key))
905  if annot.value.lower() in ['enable', 'on', '1', 'true']:
906  annot.value = '1'
907  elif annot.value.lower() in ['disable', 'off', '0', 'false']:
908  annot.value = '0'
909  return udtf_node
910 
911 
913  """
914  * Append require annotation if range is used
915  """
916  def visit_arg_node(self, arg_node):
917  for ann in arg_node.annotations:
918  if ann.key == 'range':
919  name = arg_node.get_annotation('name')
920  if name is None:
921  raise TransformerException('"range" requires a named argument')
922 
923  l = ann.value
924  if len(l) == 2:
925  lo, hi = ann.value
926  value = '"{lo} <= {name} && {name} <= {hi}"'.format(lo=lo, hi=hi, name=name)
927  else:
928  raise TransformerException('"range" requires an interval. Got {l}'.format(l=l))
929  arg_node.set_annotation('require', value)
930  return arg_node
931 
932 
934 
935  def visit_udtf_node(self, udtf_node):
936  name = udtf_node.name
937  inputs = []
938  input_annotations = []
939  outputs = []
940  output_annotations = []
941  function_annotations = []
942  sizer = udtf_node.sizer
943 
944  for i in udtf_node.inputs:
945  decl = i.accept(self)
946  inputs.append(decl)
947  input_annotations.append(decl.annotations)
948 
949  for o in udtf_node.outputs:
950  decl = o.accept(self)
951  outputs.append(decl.type)
952  output_annotations.append(decl.annotations)
953 
954  for annot in udtf_node.annotations:
955  annot = annot.accept(self)
956  function_annotations.append(annot)
957 
958  return Signature(name, inputs, outputs, input_annotations, output_annotations, function_annotations, sizer)
959 
960  def visit_arg_node(self, arg_node):
961  t = arg_node.type.accept(self)
962  anns = [a.accept(self) for a in arg_node.annotations]
963  return Declaration(t, anns)
964 
965  def visit_composed_node(self, composed_node):
966  typ = translate_map.get(composed_node.type, composed_node.type)
967  inner = [i.accept(self) for i in composed_node.inner]
968  if composed_node.is_cursor():
969  inner = list(map(lambda x: x.apply_column(), inner))
970  return Bracket(typ, args=tuple(inner))
971  elif composed_node.is_output_buffer_sizer():
972  return Bracket(typ, args=tuple(inner))
973  else:
974  return Bracket(typ + str(inner[0]))
975 
976  def visit_primitive_node(self, primitive_node):
977  t = primitive_node.type
978  return Bracket(t)
979 
980  def visit_annotation_node(self, annotation_node):
981  key = annotation_node.key
982  value = annotation_node.value
983  return (key, value)
984 
985 
986 class Node(object):
987 
988  __metaclass__ = ABC
989 
990  @abstractmethod
991  def accept(self, visitor):
992  pass
993 
994  @abstractmethod
995  def __str__(self):
996  pass
997 
998  def get_parent(self, cls):
999  if isinstance(self, cls):
1000  return self
1001 
1002  if self.parent is not None:
1003  return self.parent.get_parent(cls)
1004 
1005  raise ValueError("could not find parent with given class %s" % (cls))
1006 
1007  def copy(self, *args):
1008  other = self.__class__(*args)
1009 
1010  # copy parent and arg_pos
1011  for attr in ['parent', 'arg_pos']:
1012  if attr in self.__dict__:
1013  setattr(other, attr, getattr(self, attr))
1014 
1015  return other
1016 
1017 
1018 class IterableNode(Iterable):
1019  pass
1020 
1021 
1022 class UdtfNode(Node, IterableNode):
1023 
1024  def __init__(self, name, inputs, outputs, annotations, templates, sizer, line):
1025  """
1026  Parameters
1027  ----------
1028  name : str
1029  inputs : list[ArgNode]
1030  outputs : list[ArgNode]
1031  annotations : Optional[List[AnnotationNode]]
1032  templates : Optional[list[TemplateNode]]
1033  sizer : Optional[str]
1034  line: str
1035  """
1036  self.name = name
1037  self.inputs = inputs
1038  self.outputs = outputs
1039  self.annotations = annotations
1040  self.templates = templates
1041  self.sizer = sizer
1042  self.line = line
1043 
1044  def accept(self, visitor):
1045  return visitor.visit_udtf_node(self)
1046 
1047  def __str__(self):
1048  name = self.name
1049  inputs = [str(i) for i in self.inputs]
1050  outputs = [str(o) for o in self.outputs]
1051  annotations = [str(a) for a in self.annotations]
1052  sizer = "| %s" % str(self.sizer) if self.sizer else ""
1053  if self.templates:
1054  templates = [str(t) for t in self.templates]
1055  if annotations:
1056  return "UDTF: %s (%s) | %s -> %s, %s %s" % (name, inputs, annotations, outputs, templates, sizer)
1057  else:
1058  return "UDTF: %s (%s) -> %s, %s %s" % (name, inputs, outputs, templates, sizer)
1059  else:
1060  if annotations:
1061  return "UDTF: %s (%s) | %s -> %s %s" % (name, inputs, annotations, outputs, sizer)
1062  else:
1063  return "UDTF: %s (%s) -> %s %s" % (name, inputs, outputs, sizer)
1064 
1065  def __iter__(self):
1066  for i in self.inputs:
1067  yield i
1068  for o in self.outputs:
1069  yield o
1070  for a in self.annotations:
1071  yield a
1072  if self.templates:
1073  for t in self.templates:
1074  yield t
1075 
1076  __repr__ = __str__
1077 
1078 
1080 
1081  def __init__(self, type, annotations):
1082  """
1083  Parameters
1084  ----------
1085  type : TypeNode
1086  annotations : List[AnnotationNode]
1087  """
1088  self.type = type
1089  self.annotations = annotations
1090  self.arg_pos = None
1091 
1092  def accept(self, visitor):
1093  return visitor.visit_arg_node(self)
1094 
1095  def __str__(self):
1096  t = str(self.type)
1097  anns = ""
1098  if self.annotations:
1099  anns = " | ".join([str(a) for a in self.annotations])
1100  return "ArgNode(%s | %s)" % (t, anns)
1101  return "ArgNode(%s)" % (t)
1102 
1103  def __iter__(self):
1104  yield self.type
1105  for a in self.annotations:
1106  yield a
1107 
1108  __repr__ = __str__
1109 
1110  def get_annotation(self, key, default=None):
1111  for a in self.annotations:
1112  if a.key == key:
1113  return a.value
1114  return default
1115 
1116  def set_annotation(self, key, value):
1117  found = False
1118  for i, a in enumerate(self.annotations):
1119  if a.key == key:
1120  assert not found, (i, a) # annotations with the same key not supported
1121  self.annotations[i] = AnnotationNode(key, value)
1122  found = True
1123  if not found:
1124  self.annotations.append(AnnotationNode(key, value))
1125 
1126 
1128  def is_array(self):
1129  return self.type == "Array"
1130 
1131  def is_column_any(self):
1132  return self.is_column() or self.is_column_list()
1133 
1134  def is_column(self):
1135  return self.type == "Column"
1136 
1137  def is_column_list(self):
1138  return self.type == "ColumnList"
1139 
1140  def is_cursor(self):
1141  return self.type == "Cursor"
1142 
1144  t = self.type
1145  return translate_map.get(t, t) in OutputBufferSizeTypes
1146 
1147 
1149 
1150  def __init__(self, type):
1151  """
1152  Parameters
1153  ----------
1154  type : str
1155  """
1156  self.type = type
1157 
1158  def accept(self, visitor):
1159  return visitor.visit_primitive_node(self)
1160 
1161  def __str__(self):
1162  return self.accept(AstPrinter())
1163 
1165  return self.type == 'TextEncodingDict'
1166 
1168  return self.type == 'ArrayTextEncodingDict'
1169 
1170  __repr__ = __str__
1171 
1172 
1174 
1175  def __init__(self, type, inner):
1176  """
1177  Parameters
1178  ----------
1179  type : str
1180  inner : list[TypeNode]
1181  """
1182  self.type = type
1183  self.inner = inner
1184 
1185  def accept(self, visitor):
1186  return visitor.visit_composed_node(self)
1187 
1188  def cursor_length(self):
1189  assert self.is_cursor()
1190  return len(self.inner)
1191 
1192  def __str__(self):
1193  i = ", ".join([str(i) for i in self.inner])
1194  return "Composed(%s<%s>)" % (self.type, i)
1195 
1196  def __iter__(self):
1197  for i in self.inner:
1198  yield i
1199 
1201  return False
1202 
1204  return self.is_array() and self.inner[0].is_text_encoding_dict()
1205 
1207  return self.is_column() and self.inner[0].is_text_encoding_dict()
1208 
1210  return self.is_column_list() and self.inner[0].is_text_encoding_dict()
1211 
1213  return self.is_column() and self.inner[0].is_array_text_encoding_dict()
1214 
1216  return self.inner[0].is_text_encoding_dict() or self.inner[0].is_array_text_encoding_dict()
1217 
1218  __repr__ = __str__
1219 
1220 
1222 
1223  def __init__(self, key, value):
1224  """
1225  Parameters
1226  ----------
1227  key : str
1228  value : {str, list}
1229  """
1230  self.key = key
1231  self.value = value
1232 
1233  def accept(self, visitor):
1234  return visitor.visit_annotation_node(self)
1235 
1236  def __str__(self):
1237  printer = AstPrinter()
1238  return self.accept(printer)
1239 
1240  __repr__ = __str__
1241 
1242 
1244 
1245  def __init__(self, key, types):
1246  """
1247  Parameters
1248  ----------
1249  key : str
1250  types : tuple[str]
1251  """
1252  self.key = key
1253  self.types = types
1254 
1255  def accept(self, visitor):
1256  return visitor.visit_template_node(self)
1257 
1258  def __str__(self):
1259  printer = AstPrinter()
1260  return self.accept(printer)
1261 
1262  __repr__ = __str__
1263 
1264 
1265 class Pipeline(object):
1266  def __init__(self, *passes):
1267  self.passes = passes
1268 
1269  def __call__(self, ast_list):
1270  if not isinstance(ast_list, list):
1271  ast_list = [ast_list]
1272 
1273  for c in self.passes:
1274  ast_list = [ast.accept(c()) for ast in ast_list]
1275  ast_list = itertools.chain.from_iterable( # flatten the list
1276  map(lambda x: x if isinstance(x, list) else [x], ast_list))
1277 
1278  return list(ast_list)
1279 
1280 
1281 class Parser:
1282  def __init__(self, line):
1283  self._tokens = Tokenize(line).tokens
1284  self._curr = 0
1285  self.line = line
1286 
1287  @property
1288  def tokens(self):
1289  return self._tokens
1290 
1291  def is_at_end(self):
1292  return self._curr >= len(self._tokens)
1293 
1294  def current_token(self):
1295  return self._tokens[self._curr]
1296 
1297  def advance(self):
1298  self._curr += 1
1299 
1300  def expect(self, expected_type):
1301  curr_token = self.current_token()
1302  msg = "Expected token %s but got %s at pos %d.\n Tokens: %s" % (
1303  curr_token,
1304  Token.tok_name(expected_type),
1305  self._curr,
1306  self._tokens,
1307  )
1308  assert curr_token.type == expected_type, msg
1309  self.advance()
1310 
1311  def consume(self, expected_type):
1312  """consumes the current token iff its type matches the
1313  expected_type. Otherwise, an error is raised
1314  """
1315  curr_token = self.current_token()
1316  if curr_token.type == expected_type:
1317  self.advance()
1318  return curr_token
1319  else:
1320  expected_token = Token.tok_name(expected_type)
1321  self.raise_parser_error(
1322  'Token mismatch at function consume. '
1323  'Expected type "%s" but got token "%s"\n\n'
1324  'Tokens: %s\n' % (expected_token, curr_token, self._tokens)
1325  )
1326 
1327  def current_pos(self):
1328  return self._curr
1329 
1330  def raise_parser_error(self, msg=None):
1331  if not msg:
1332  token = self.current_token()
1333  pos = self.current_pos()
1334  tokens = self.tokens
1335  msg = "\n\nError while trying to parse token %s at pos %d.\n" "Tokens: %s" % (
1336  token,
1337  pos,
1338  tokens,
1339  )
1340  raise ParserException(msg)
1341 
1342  def match(self, expected_type):
1343  curr_token = self.current_token()
1344  return curr_token.type == expected_type
1345 
1346  def lookahead(self):
1347  return self._tokens[self._curr + 1]
1348 
1349  def parse_udtf(self):
1350  """fmt: off
1351 
1352  udtf: IDENTIFIER "(" (args)? ")" ("|" annotation)* "->" args ("," templates)? ("|" "output_row_size" "=" primitive)?
1353 
1354  fmt: on
1355  """
1356  name = self.parse_identifier()
1357  self.expect(Token.LPAR) # (
1358  input_args = []
1359  if not self.match(Token.RPAR):
1360  input_args = self.parse_args()
1361  self.expect(Token.RPAR) # )
1362  annotations = []
1363  while not self.is_at_end() and self.match(Token.VBAR): # |
1364  self.consume(Token.VBAR)
1365  annotations.append(self.parse_annotation())
1366  self.expect(Token.RARROW) # ->
1367  output_args = self.parse_args()
1368 
1369  templates = None
1370  if not self.is_at_end() and self.match(Token.COMMA):
1371  self.consume(Token.COMMA)
1372  templates = self.parse_templates()
1373 
1374  sizer = None
1375  if not self.is_at_end() and self.match(Token.VBAR):
1376  self.consume(Token.VBAR)
1377  idtn = self.parse_identifier()
1378  assert idtn == "output_row_size", idtn
1379  self.consume(Token.EQUAL)
1380  node = self.parse_primitive()
1381  key = "kPreFlightParameter"
1382  sizer = AnnotationNode(key, value=node.type)
1383 
1384  # set arg_pos
1385  i = 0
1386  for arg in input_args:
1387  arg.arg_pos = i
1388  arg.kind = "input"
1389  i += arg.type.cursor_length() if arg.type.is_cursor() else 1
1390 
1391  for i, arg in enumerate(output_args):
1392  arg.arg_pos = i
1393  arg.kind = "output"
1394 
1395  return UdtfNode(name, input_args, output_args, annotations, templates, sizer, self.line)
1396 
1397  def parse_args(self):
1398  """fmt: off
1399 
1400  args: arg IDENTIFIER ("," arg)*
1401 
1402  fmt: on
1403  """
1404  args = []
1405  args.append(self.parse_arg())
1406  while not self.is_at_end() and self.match(Token.COMMA):
1407  curr = self._curr
1408  self.consume(Token.COMMA)
1409  self.parse_type() # assuming that we are not ending with COMMA
1410  if not self.is_at_end() and self.match(Token.EQUAL):
1411  # arg type cannot be assigned, so this must be a template specification
1412  self._curr = curr # step back and let the code below parse the templates
1413  break
1414  else:
1415  self._curr = curr + 1 # step back from self.parse_type(), parse_arg will parse it again
1416  args.append(self.parse_arg())
1417  return args
1418 
1419  def parse_arg(self):
1420  """fmt: off
1421 
1422  arg: type IDENTIFIER? ("|" annotation)*
1423 
1424  fmt: on
1425  """
1426  typ = self.parse_type()
1427 
1428  annotations = []
1429 
1430  if not self.is_at_end() and self.match(Token.IDENTIFIER):
1431  name = self.parse_identifier()
1432  annotations.append(AnnotationNode('name', name))
1433 
1434  while not self.is_at_end() and self.match(Token.VBAR):
1435  ahead = self.lookahead()
1436  if ahead.type == Token.IDENTIFIER and ahead.lexeme == 'output_row_size':
1437  break
1438  self.consume(Token.VBAR)
1439  annotations.append(self.parse_annotation())
1440 
1441  return ArgNode(typ, annotations)
1442 
1443  def parse_type(self):
1444  """fmt: off
1445 
1446  type: composed
1447  | primitive
1448 
1449  fmt: on
1450  """
1451  curr = self._curr # save state
1452  primitive = self.parse_primitive()
1453  if self.is_at_end():
1454  return primitive
1455 
1456  if not self.match(Token.LESS):
1457  return primitive
1458 
1459  self._curr = curr # return state
1460 
1461  return self.parse_composed()
1462 
1463  def parse_composed(self):
1464  """fmt: off
1465 
1466  composed: "Cursor" "<" arg ("," arg)* ">"
1467  | IDENTIFIER "<" type ("," type)* ">"
1468 
1469  fmt: on
1470  """
1471  idtn = self.parse_identifier()
1472  self.consume(Token.LESS)
1473  if is_identifier_cursor(idtn):
1474  inner = [self.parse_arg()]
1475  while self.match(Token.COMMA):
1476  self.consume(Token.COMMA)
1477  inner.append(self.parse_arg())
1478  else:
1479  inner = [self.parse_type()]
1480  while self.match(Token.COMMA):
1481  self.consume(Token.COMMA)
1482  inner.append(self.parse_type())
1483  self.consume(Token.GREATER)
1484  return ComposedNode(idtn, inner)
1485 
1486  def parse_primitive(self):
1487  """fmt: off
1488 
1489  primitive: IDENTIFIER
1490  | NUMBER
1491  | STRING
1492 
1493  fmt: on
1494  """
1495  if self.match(Token.IDENTIFIER):
1496  lexeme = self.parse_identifier()
1497  elif self.match(Token.NUMBER):
1498  lexeme = self.parse_number()
1499  elif self.match(Token.STRING):
1500  lexeme = self.parse_string()
1501  else:
1502  raise self.raise_parser_error()
1503  return PrimitiveNode(lexeme)
1504 
1505  def parse_templates(self):
1506  """fmt: off
1507 
1508  templates: template ("," template)*
1509 
1510  fmt: on
1511  """
1512  T = []
1513  T.append(self.parse_template())
1514  while not self.is_at_end() and self.match(Token.COMMA):
1515  self.consume(Token.COMMA)
1516  T.append(self.parse_template())
1517  return T
1518 
1519  def parse_template(self):
1520  """fmt: off
1521 
1522  template: IDENTIFIER "=" "[" IDENTIFIER ("," IDENTIFIER)* "]"
1523 
1524  fmt: on
1525  """
1526  key = self.parse_identifier()
1527  types = []
1528  self.consume(Token.EQUAL)
1529  self.consume(Token.LSQB)
1530  types.append(self.parse_identifier())
1531  while self.match(Token.COMMA):
1532  self.consume(Token.COMMA)
1533  types.append(self.parse_identifier())
1534  self.consume(Token.RSQB)
1535  return TemplateNode(key, tuple(types))
1536 
1537  def parse_annotation(self):
1538  """fmt: off
1539 
1540  annotation: IDENTIFIER "=" IDENTIFIER ("<" NUMBER ("," NUMBER) ">")?
1541  | IDENTIFIER "=" "[" PRIMITIVE? ("," PRIMITIVE)* "]"
1542  | "require" "=" STRING
1543 
1544  fmt: on
1545  """
1546  key = self.parse_identifier()
1547  self.consume(Token.EQUAL)
1548 
1549  if key == "require":
1550  value = self.parse_string()
1551  elif not self.is_at_end() and self.match(Token.LSQB):
1552  value = []
1553  self.consume(Token.LSQB)
1554  if not self.match(Token.RSQB):
1555  value.append(self.parse_primitive())
1556  while self.match(Token.COMMA):
1557  self.consume(Token.COMMA)
1558  value.append(self.parse_primitive())
1559  self.consume(Token.RSQB)
1560  else:
1561  value = self.parse_identifier()
1562  if not self.is_at_end() and self.match(Token.LESS):
1563  self.consume(Token.LESS)
1564  if self.match(Token.GREATER):
1565  value += "<%s>" % (-1) # Signifies no input
1566  else:
1567  num1 = self.parse_number()
1568  if self.match(Token.COMMA):
1569  self.consume(Token.COMMA)
1570  num2 = self.parse_number()
1571  value += "<%s,%s>" % (num1, num2)
1572  else:
1573  value += "<%s>" % (num1)
1574  self.consume(Token.GREATER)
1575  return AnnotationNode(key, value)
1576 
1577  def parse_identifier(self):
1578  """ fmt: off
1579 
1580  IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
1581 
1582  fmt: on
1583  """
1584  token = self.consume(Token.IDENTIFIER)
1585  return token.lexeme
1586 
1587  def parse_string(self):
1588  """ fmt: off
1589 
1590  STRING: \".*?\"
1591 
1592  fmt: on
1593  """
1594  token = self.consume(Token.STRING)
1595  return token.lexeme
1596 
1597  def parse_number(self):
1598  """ fmt: off
1599 
1600  NUMBER: [0-9]+
1601 
1602  fmt: on
1603  """
1604  token = self.consume(Token.NUMBER)
1605  return token.lexeme
1606 
1607  def parse(self):
1608  """fmt: off
1609 
1610  udtf: IDENTIFIER "(" (args)? ")" ("|" annotation)* "->" args ("," templates)? ("|" "output_row_size" "=" primitive)?
1611 
1612  args: arg ("," arg)*
1613 
1614  arg: type IDENTIFIER? ("|" annotation)*
1615 
1616  type: composed
1617  | primitive
1618 
1619  composed: "Cursor" "<" arg ("," arg)* ">"
1620  | IDENTIFIER "<" type ("," type)* ">"
1621 
1622  primitive: IDENTIFIER
1623  | NUMBER
1624  | STRING
1625 
1626  annotation: IDENTIFIER "=" IDENTIFIER ("<" NUMBER ("," NUMBER) ">")?
1627  | IDENTIFIER "=" "[" PRIMITIVE? ("," PRIMITIVE)* "]"
1628  | "require" "=" STRING
1629 
1630  templates: template ("," template)
1631  template: IDENTIFIER "=" "[" IDENTIFIER ("," IDENTIFIER)* "]"
1632 
1633  IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
1634  NUMBER: [0-9]+
1635  STRING: \".*?\"
1636 
1637  fmt: on
1638  """
1639  self._curr = 0
1640  udtf = self.parse_udtf()
1641 
1642  # set parent
1643  udtf.parent = None
1644  d = deque()
1645  d.append(udtf)
1646  while d:
1647  node = d.pop()
1648  if isinstance(node, Iterable):
1649  for child in node:
1650  child.parent = node
1651  d.append(child)
1652  return udtf
1653 
1654 
1655 # fmt: off
1656 def find_signatures(input_file):
1657  """Returns a list of parsed UDTF signatures."""
1658  signatures = []
1659 
1660  last_line = None
1661  for line in open(input_file).readlines():
1662  line = line.strip()
1663  if last_line is not None:
1664  line = last_line + ' ' + line
1665  last_line = None
1666  if not line.startswith('UDTF:'):
1667  continue
1668  if line_is_incomplete(line):
1669  last_line = line
1670  continue
1671  last_line = None
1672  line = line[5:].lstrip()
1673  i = line.find('(')
1674  j = line.find(')')
1675  if i == -1 or j == -1:
1676  sys.stderr.write('Invalid UDTF specification: `%s`. Skipping.\n' % (line))
1677  continue
1678 
1679  expected_result = None
1680  if separator in line:
1681  line, expected_result = line.split(separator, 1)
1682  expected_result = expected_result.strip().split(separator)
1683  expected_result = list(map(lambda s: s.strip(), expected_result))
1684 
1685  ast = Parser(line).parse()
1686 
1687  if expected_result is not None:
1688  # Template transformer expands templates into multiple lines
1689  try:
1690  result = Pipeline(TemplateTransformer,
1691  FieldAnnotationTransformer,
1692  TextEncodingDictTransformer,
1693  SupportedAnnotationsTransformer,
1694  RangeAnnotationTransformer,
1695  FixRowMultiplierPosArgTransformer,
1696  RenameNodesTransformer,
1697  AstPrinter)(ast)
1698  except TransformerException as msg:
1699  result = ['%s: %s' % (type(msg).__name__, msg)]
1700  assert len(result) == len(expected_result), "\n\tresult: %s \n!= \n\texpected: %s" % (
1701  '\n\t\t '.join(result),
1702  '\n\t\t '.join(expected_result)
1703  )
1704  assert set(result) == set(expected_result), "\n\tresult: %s != \n\texpected: %s" % (
1705  '\n\t\t '.join(result),
1706  '\n\t\t '.join(expected_result),
1707  )
1708 
1709  else:
1710  signature = Pipeline(TemplateTransformer,
1711  FieldAnnotationTransformer,
1712  TextEncodingDictTransformer,
1713  SupportedAnnotationsTransformer,
1714  RangeAnnotationTransformer,
1715  FixRowMultiplierPosArgTransformer,
1716  RenameNodesTransformer,
1717  DeclBracketTransformer)(ast)
1718 
1719  signatures.extend(signature)
1720 
1721  return signatures
1722 
1723 
1724 def format_function_args(input_types, output_types, uses_manager, use_generic_arg_name, emit_output_args):
1725  cpp_args = []
1726  name_args = []
1727 
1728  if uses_manager:
1729  cpp_args.append('TableFunctionManager& mgr')
1730  name_args.append('mgr')
1731 
1732  for idx, typ in enumerate(input_types):
1733  cpp_arg, name = typ.format_cpp_type(idx,
1734  use_generic_arg_name=use_generic_arg_name,
1735  is_input=True)
1736  cpp_args.append(cpp_arg)
1737  name_args.append(name)
1738 
1739  if emit_output_args:
1740  for idx, typ in enumerate(output_types):
1741  cpp_arg, name = typ.format_cpp_type(idx,
1742  use_generic_arg_name=use_generic_arg_name,
1743  is_input=False)
1744  cpp_args.append(cpp_arg)
1745  name_args.append(name)
1746 
1747  cpp_args = ', '.join(cpp_args)
1748  name_args = ', '.join(name_args)
1749  return cpp_args, name_args
1750 
1751 
1752 def build_template_function_call(caller, called, input_types, output_types, uses_manager):
1753  cpp_args, name_args = format_function_args(input_types,
1754  output_types,
1755  uses_manager,
1756  use_generic_arg_name=True,
1757  emit_output_args=True)
1758 
1759  template = ("EXTENSION_NOINLINE int32_t\n"
1760  "%s(%s) {\n"
1761  " return %s(%s);\n"
1762  "}\n") % (caller, cpp_args, called, name_args)
1763  return template
1764 
1765 
1766 def build_preflight_function(fn_name, sizer, input_types, output_types, uses_manager):
1767 
1768  def format_error_msg(err_msg, uses_manager):
1769  if uses_manager:
1770  return " return mgr.error_message(%s);\n" % (err_msg,)
1771  else:
1772  return " return table_function_error(%s);\n" % (err_msg,)
1773 
1774  cpp_args, _ = format_function_args(input_types,
1775  output_types,
1776  uses_manager,
1777  use_generic_arg_name=False,
1778  emit_output_args=False)
1779 
1780  if uses_manager:
1781  fn = "EXTENSION_NOINLINE int32_t\n"
1782  fn += "%s(%s) {\n" % (fn_name.lower() + "__preflight", cpp_args)
1783  else:
1784  fn = "EXTENSION_NOINLINE int32_t\n"
1785  fn += "%s(%s) {\n" % (fn_name.lower() + "__preflight", cpp_args)
1786 
1787  for typ in input_types:
1788  if isinstance(typ, Declaration):
1789  ann = typ.annotations
1790  for key, value in ann:
1791  if key == 'require':
1792  err_msg = '"Constraint `%s` is not satisfied."' % (value[1:-1])
1793 
1794  fn += " if (!(%s)) {\n" % (value[1:-1].replace('\\', ''),)
1795  fn += format_error_msg(err_msg, uses_manager)
1796  fn += " }\n"
1797 
1798  if sizer.is_arg_sizer():
1799  precomputed_nrows = str(sizer.args[0])
1800  if '"' in precomputed_nrows:
1801  precomputed_nrows = precomputed_nrows[1:-1]
1802  # check to see if the precomputed number of rows > 0
1803  err_msg = '"Output size expression `%s` evaluated in a negative value."' % (precomputed_nrows)
1804  fn += " auto _output_size = %s;\n" % (precomputed_nrows)
1805  fn += " if (_output_size < 0) {\n"
1806  fn += format_error_msg(err_msg, uses_manager)
1807  fn += " }\n"
1808  fn += " return _output_size;\n"
1809  else:
1810  fn += " return 0;\n"
1811  fn += "}\n\n"
1812 
1813  return fn
1814 
1815 
1817  if sizer.is_arg_sizer():
1818  return True
1819  for arg_annotations in sig.input_annotations:
1820  d = dict(arg_annotations)
1821  if 'require' in d.keys():
1822  return True
1823  return False
1824 
1825 
1826 def format_annotations(annotations_):
1827  def fmt(k, v):
1828  # type(v) is not always 'str'
1829  if k == 'require':
1830  return v[1:-1]
1831  return v
1832 
1833  s = "std::vector<std::map<std::string, std::string>>{"
1834  s += ', '.join(('{' + ', '.join('{"%s", "%s"}' % (k, fmt(k, v)) for k, v in a) + '}') for a in annotations_)
1835  s += "}"
1836  return s
1837 
1838 
1840  i = sig.name.rfind('_template')
1841  return i >= 0 and '__' in sig.name[:i + 1]
1842 
1843 
1844 def uses_manager(sig):
1845  return sig.inputs and sig.inputs[0].name == 'TableFunctionManager'
1846 
1847 
1849  # Any function that does not have _gpu_ suffix is a cpu function.
1850  i = sig.name.rfind('_gpu_')
1851  if i >= 0 and '__' in sig.name[:i + 1]:
1852  if uses_manager(sig):
1853  raise ValueError('Table function {} with gpu execution target cannot have TableFunctionManager argument'.format(sig.name))
1854  return False
1855  return True
1856 
1857 
1859  # A function with TableFunctionManager argument is a cpu-only function
1860  if uses_manager(sig):
1861  return False
1862  # Any function that does not have _cpu_ suffix is a gpu function.
1863  i = sig.name.rfind('_cpu_')
1864  return not (i >= 0 and '__' in sig.name[:i + 1])
1865 
1866 
1867 def parse_annotations(input_files):
1868 
1869  counter = 0
1870 
1871  add_stmts = []
1872  cpu_template_functions = []
1873  gpu_template_functions = []
1874  cpu_function_address_expressions = []
1875  gpu_function_address_expressions = []
1876  cond_fns = []
1877 
1878  for input_file in input_files:
1879  for sig in find_signatures(input_file):
1880 
1881  # Compute sql_types, input_types, and sizer
1882  sql_types_ = []
1883  input_types_ = []
1884  input_annotations = []
1885 
1886  sizer = None
1887  if sig.sizer is not None:
1888  expr = sig.sizer.value
1889  sizer = Bracket('kPreFlightParameter', (expr,))
1890 
1891  uses_manager = False
1892  for i, (t, annot) in enumerate(zip(sig.inputs, sig.input_annotations)):
1893  if t.is_output_buffer_sizer():
1894  if t.is_user_specified():
1895  sql_types_.append(Bracket.parse('int32').normalize(kind='input'))
1896  input_types_.append(sql_types_[-1])
1897  input_annotations.append(annot)
1898  assert sizer is None # exactly one sizer argument is allowed
1899  assert len(t.args) == 1, t
1900  sizer = t
1901  elif t.name == 'Cursor':
1902  for t_ in t.args:
1903  input_types_.append(t_)
1904  input_annotations.append(annot)
1905  sql_types_.append(Bracket('Cursor', args=()))
1906  elif t.name == 'TableFunctionManager':
1907  if i != 0:
1908  raise ValueError('{} must appear as a first argument of {}, but found it at position {}.'.format(t, sig.name, i))
1909  uses_manager = True
1910  else:
1911  input_types_.append(t)
1912  input_annotations.append(annot)
1913  if t.is_column_any():
1914  # XXX: let Bracket handle mapping of column to cursor(column)
1915  sql_types_.append(Bracket('Cursor', args=()))
1916  else:
1917  sql_types_.append(t)
1918 
1919  if sizer is None:
1920  name = 'kTableFunctionSpecifiedParameter'
1921  idx = 1 # this sizer is not actually materialized in the UDTF
1922  sizer = Bracket(name, (idx,))
1923 
1924  assert sizer is not None
1925  ns_output_types = tuple([a.apply_namespace(ns='ExtArgumentType') for a in sig.outputs])
1926  ns_input_types = tuple([t.apply_namespace(ns='ExtArgumentType') for t in input_types_])
1927  ns_sql_types = tuple([t.apply_namespace(ns='ExtArgumentType') for t in sql_types_])
1928 
1929  sig.function_annotations.append(('uses_manager', str(uses_manager).lower()))
1930 
1931  input_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(tostring, ns_input_types)))
1932  output_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(tostring, ns_output_types)))
1933  sql_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(tostring, ns_sql_types)))
1934  annotations = format_annotations(input_annotations + sig.output_annotations + [sig.function_annotations])
1935 
1936  # Notice that input_types and sig.input_types, (and
1937  # similarly, input_annotations and sig.input_annotations)
1938  # have different lengths when the sizer argument is
1939  # Constant or TableFunctionSpecifiedParameter. That is,
1940  # input_types contains all the user-specified arguments
1941  # while sig.input_types contains all arguments of the
1942  # implementation of an UDTF.
1943 
1944  if must_emit_preflight_function(sig, sizer):
1945  fn_name = '%s_%s' % (sig.name, str(counter)) if is_template_function(sig) else sig.name
1946  check_fn = build_preflight_function(fn_name, sizer, input_types_, sig.outputs, uses_manager)
1947  cond_fns.append(check_fn)
1948 
1949  if is_template_function(sig):
1950  name = sig.name + '_' + str(counter)
1951  counter += 1
1952  t = build_template_function_call(name, sig.name, input_types_, sig.outputs, uses_manager)
1953  address_expression = ('avoid_opt_address(reinterpret_cast<void*>(%s))' % name)
1954  if is_cpu_function(sig):
1955  cpu_template_functions.append(t)
1956  cpu_function_address_expressions.append(address_expression)
1957  if is_gpu_function(sig):
1958  gpu_template_functions.append(t)
1959  gpu_function_address_expressions.append(address_expression)
1960  add = ('TableFunctionsFactory::add("%s", %s, %s, %s, %s, %s, /*is_runtime:*/false);'
1961  % (name, sizer.format_sizer(), input_types, output_types, sql_types, annotations))
1962  add_stmts.append(add)
1963 
1964  else:
1965  add = ('TableFunctionsFactory::add("%s", %s, %s, %s, %s, %s, /*is_runtime:*/false);'
1966  % (sig.name, sizer.format_sizer(), input_types, output_types, sql_types, annotations))
1967  add_stmts.append(add)
1968  address_expression = ('avoid_opt_address(reinterpret_cast<void*>(%s))' % sig.name)
1969 
1970  if is_cpu_function(sig):
1971  cpu_function_address_expressions.append(address_expression)
1972  if is_gpu_function(sig):
1973  gpu_function_address_expressions.append(address_expression)
1974 
1975  return add_stmts, cpu_template_functions, gpu_template_functions, cpu_function_address_expressions, gpu_function_address_expressions, cond_fns
1976 
1977 
1978 if len(sys.argv) < 3:
1979 
1980  input_files = [os.path.join(os.path.dirname(__file__), 'test_udtf_signatures.hpp')]
1981  print('Running tests from %s' % (', '.join(input_files)))
1982  add_stmts, _, _, _, _, _ = parse_annotations(input_files)
1983 
1984  print('Usage:\n %s %s input1.hpp input2.hpp ... output.hpp' % (sys.executable, sys.argv[0], ))
1985 
1986  sys.exit(1)
1987 
1988 input_files, output_filename = sys.argv[1:-1], sys.argv[-1]
1989 cpu_output_header = os.path.splitext(output_filename)[0] + '_cpu.hpp'
1990 gpu_output_header = os.path.splitext(output_filename)[0] + '_gpu.hpp'
1991 assert input_files, sys.argv
1992 
1993 add_stmts, cpu_template_functions, gpu_template_functions, cpu_address_expressions, gpu_address_expressions, cond_fns = parse_annotations(sys.argv[1:-1])
1994 
1995 canonical_input_files = [input_file[input_file.find("/QueryEngine/") + 1:] for input_file in input_files]
1996 header_includes = ['#include "' + canonical_input_file + '"' for canonical_input_file in canonical_input_files]
1997 
1998 # Split up calls to TableFunctionsFactory::add() into chunks
1999 ADD_FUNC_CHUNK_SIZE = 100
2000 
2001 def add_method(i, chunk):
2002  return '''
2003  NO_OPT_ATTRIBUTE void add_table_functions_%d() const {
2004  %s
2005  }
2006 ''' % (i, '\n '.join(chunk))
2007 
2008 def add_methods(add_stmts):
2009  chunks = [ add_stmts[n:n+ADD_FUNC_CHUNK_SIZE] for n in range(0, len(add_stmts), ADD_FUNC_CHUNK_SIZE) ]
2010  return [ add_method(i,chunk) for i,chunk in enumerate(chunks) ]
2011 
2012 def call_methods(add_stmts):
2013  quot, rem = divmod(len(add_stmts), ADD_FUNC_CHUNK_SIZE)
2014  return [ 'add_table_functions_%d();' % (i) for i in range(quot + int(0 < rem)) ]
2015 
2016 content = '''
2017 /*
2018  This file is generated by %s. Do no edit!
2019 */
2020 
2021 #include "QueryEngine/TableFunctions/TableFunctionsFactory.h"
2022 %s
2023 
2024 /*
2025  Include the UDTF template initiations:
2026 */
2027 #include "TableFunctionsFactory_init_cpu.hpp"
2028 
2029 // volatile+noinline prevents compiler optimization
2030 #ifdef _WIN32
2031 __declspec(noinline)
2032 #else
2033  __attribute__((noinline))
2034 #endif
2035 volatile
2036 bool avoid_opt_address(void *address) {
2037  return address != nullptr;
2038 }
2039 
2040 bool functions_exist() {
2041  bool ret = true;
2042 
2043  ret &= (%s);
2044 
2045  return ret;
2046 }
2047 
2048 extern bool g_enable_table_functions;
2049 
2050 namespace table_functions {
2051 
2052 std::once_flag init_flag;
2053 
2054 #if defined(__clang__)
2055 #define NO_OPT_ATTRIBUTE __attribute__((optnone))
2056 
2057 #elif defined(__GNUC__) || defined(__GNUG__)
2058 #define NO_OPT_ATTRIBUTE __attribute((optimize("O0")))
2059 
2060 #elif defined(_MSC_VER)
2061 #define NO_OPT_ATTRIBUTE
2062 
2063 #endif
2064 
2065 #if defined(_MSC_VER)
2066 #pragma optimize("", off)
2067 #endif
2068 
2069 struct AddTableFunctions {
2070 %s
2071  NO_OPT_ATTRIBUTE void operator()() {
2072  %s
2073  }
2074 };
2075 
2076 void TableFunctionsFactory::init() {
2077  if (!g_enable_table_functions) {
2078  return;
2079  }
2080 
2081  if (!functions_exist()) {
2082  UNREACHABLE();
2083  return;
2084  }
2085 
2086  std::call_once(init_flag, AddTableFunctions{});
2087 }
2088 #if defined(_MSC_VER)
2089 #pragma optimize("", on)
2090 #endif
2091 
2092 // conditional check functions
2093 %s
2094 
2095 } // namespace table_functions
2096 
2097 ''' % (sys.argv[0],
2098  '\n'.join(header_includes),
2099  ' &&\n'.join(cpu_address_expressions),
2100  ''.join(add_methods(add_stmts)),
2101  '\n '.join(call_methods(add_stmts)),
2102  ''.join(cond_fns))
2103 
2104 header_content = '''
2105 /*
2106  This file is generated by %s. Do no edit!
2107 */
2108 %s
2109 
2110 %s
2111 '''
2112 
2113 dirname = os.path.dirname(output_filename)
2114 
2115 if dirname and not os.path.exists(dirname):
2116  try:
2117  os.makedirs(dirname)
2118  except OSError as e:
2119  import errno
2120  if e.errno != errno.EEXIST:
2121  raise
2122 
2123 f = open(output_filename, 'w')
2124 f.write(content)
2125 f.close()
2126 
2127 f = open(cpu_output_header, 'w')
2128 f.write(header_content % (sys.argv[0], '\n'.join(header_includes), '\n'.join(cpu_template_functions)))
2129 f.close()
2130 
2131 f = open(gpu_output_header, 'w')
2132 f.write(header_content % (sys.argv[0], '\n'.join(header_includes), '\n'.join(gpu_template_functions)))
2133 f.close()
std::string strip(std::string_view str)
trim any whitespace from the left and right ends of a string
std::string join(T const &container, std::string const &delim)
std::vector< std::string > split(std::string_view str, std::string_view delim, std::optional< size_t > maxsplit)
split apart a string into a vector of substrings
int open(const char *path, int flags, int mode)
Definition: heavyai_fs.cpp:66