OmniSciDB  a987f07e93
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
generate_TableFunctionsFactory_init.py
Go to the documentation of this file.
1 """Given a list of input files, scan for lines containing UDTF
2 specification statements in the following form:
3 
4  UDTF: function_name(<arguments>) -> <output column types> (, <template type specifications>)?
5 
6 where <arguments> is a comma-separated list of argument types. The
7 argument types specifications are:
8 
9 - scalar types:
10  Int8, Int16, Int32, Int64, Float, Double, Bool, TextEncodingDict, etc
11 - column types:
12  ColumnInt8, ColumnInt16, ColumnInt32, ColumnInt64, ColumnFloat, ColumnDouble, ColumnBool, etc
13 - column list types:
14  ColumnListInt8, ColumnListInt16, ColumnListInt32, ColumnListInt64, ColumnListFloat, ColumnListDouble, ColumnListBool, etc
15 - cursor type:
16  Cursor<t0, t1, ...>
17  where t0, t1 are column or column list types
18 - output buffer size parameter type:
19  RowMultiplier<i>, ConstantParameter<i>, Constant<i>, TableFunctionSpecifiedParameter<i>
20  where i is a literal integer.
21 
22 The output column types is a comma-separated list of column types, see above.
23 
24 In addition, the following equivalents are suppored:
25 
26  Column<T> == ColumnT
27  ColumnList<T> == ColumnListT
28  Cursor<T, V, ...> == Cursor<ColumnT, ColumnV, ...>
29  int8 == int8_t == Int8, etc
30  float == Float, double == Double, bool == Bool
31  T == ColumnT for output column types
32  RowMultiplier == RowMultiplier<i> where i is the one-based position of the sizer argument
33  when no sizer argument is provided, Constant<1> is assumed
34 
35 Argument types can be annotated using `|' (bar) symbol after an
36 argument type specification. An annotation is specified by a label and
37 a value separated by `=' (equal) symbol. Multiple annotations can be
38 specified by using `|` (bar) symbol as the annotations separator.
39 Supported annotation labels are:
40 
41 - name: to specify argument name
42 - input_id: to specify the dict id mapping for output TextEncodingDict columns.
43 
44 If argument type follows an identifier, it will be mapped to name
45 annotations. For example, the following argument type specifications
46 are equivalent:
47 
48  Int8 a
49  Int8 | name=a
50 
51 Template type specifications is a comma separated list of template
52 type assignments where values are lists of argument type names. For
53 instance:
54 
55  T = [Int8, Int16, Int32, Float], V = [Float, Double]
56 
57 """
58 # Author: Pearu Peterson
59 # Created: January 2021
60 
61 
62 import os
63 import sys
64 import itertools
65 import copy
66 from abc import abstractmethod
67 
68 from collections import deque, namedtuple
69 
70 if sys.version_info > (3, 0):
71  from abc import ABC
72  from collections.abc import Iterable
73 else:
74  from abc import ABCMeta as ABC
75  from collections import Iterable
76 
77 # fmt: off
78 separator = '$=>$'
79 
80 Signature = namedtuple('Signature', ['name', 'inputs', 'outputs', 'input_annotations', 'output_annotations', 'function_annotations', 'sizer'])
81 
82 OutputBufferSizeTypes = '''
83 kConstant, kUserSpecifiedConstantParameter, kUserSpecifiedRowMultiplier, kTableFunctionSpecifiedParameter, kPreFlightParameter
84 '''.strip().replace(' ', '').split(',')
85 
86 SupportedAnnotations = '''
87 input_id, name, fields, require, range
88 '''.strip().replace(' ', '').split(',')
89 
90 # TODO: support `gpu`, `cpu`, `template` as function annotations
91 SupportedFunctionAnnotations = '''
92 filter_table_function_transpose, uses_manager
93 '''.strip().replace(' ', '').split(',')
94 
95 translate_map = dict(
96  Constant='kConstant',
97  PreFlight='kPreFlightParameter',
98  ConstantParameter='kUserSpecifiedConstantParameter',
99  RowMultiplier='kUserSpecifiedRowMultiplier',
100  UserSpecifiedConstantParameter='kUserSpecifiedConstantParameter',
101  UserSpecifiedRowMultiplier='kUserSpecifiedRowMultiplier',
102  TableFunctionSpecifiedParameter='kTableFunctionSpecifiedParameter',
103  short='Int16',
104  int='Int32',
105  long='Int64',
106 )
107 for t in ['Int8', 'Int16', 'Int32', 'Int64', 'Float', 'Double', 'Bool',
108  'TextEncodingDict', 'TextEncodingNone']:
109  translate_map[t.lower()] = t
110  if t.startswith('Int'):
111  translate_map[t.lower() + '_t'] = t
112 
113 
115  """Holds a `TYPE | ANNOTATIONS`-like structure.
116  """
117  def __init__(self, type, annotations=[]):
118  self.type = type
119  self.annotations = annotations
120 
121  @property
122  def name(self):
123  return self.type.name
124 
125  @property
126  def args(self):
127  return self.type.args
128 
129  def format_sizer(self):
130  return self.type.format_sizer()
131 
132  def __repr__(self):
133  return 'Declaration(%r, ann=%r)' % (self.type, self.annotations)
134 
135  def __str__(self):
136  if not self.annotations:
137  return str(self.type)
138  return '%s | %s' % (self.type, ' | '.join(map(str, self.annotations)))
139 
140  def tostring(self):
141  return self.type.tostring()
142 
143  def apply_column(self):
144  return self.__class__(self.type.apply_column(), self.annotations)
145 
146  def apply_namespace(self, ns='ExtArgumentType'):
147  return self.__class__(self.type.apply_namespace(ns), self.annotations)
148 
149  def get_cpp_type(self):
150  return self.type.get_cpp_type()
151 
152  def format_cpp_type(self, idx, use_generic_arg_name=False, is_input=True):
153  real_arg_name = dict(self.annotations).get('name', None)
154  return self.type.format_cpp_type(idx,
155  use_generic_arg_name=use_generic_arg_name,
156  real_arg_name=real_arg_name,
157  is_input=is_input)
158 
159  def __getattr__(self, name):
160  if name.startswith('is_'):
161  return getattr(self.type, name)
162  raise AttributeError(name)
163 
164 
165 def tostring(obj):
166  return obj.tostring()
167 
168 
169 class Bracket:
170  """Holds a `NAME<ARGS>`-like structure.
171  """
172 
173  def __init__(self, name, args=None):
174  assert isinstance(name, str)
175  assert isinstance(args, tuple) or args is None, args
176  self.name = name
177  self.args = args
178 
179  def __repr__(self):
180  return 'Bracket(%r, args=%r)' % (self.name, self.args)
181 
182  def __str__(self):
183  if not self.args:
184  return self.name
185  return '%s<%s>' % (self.name, ', '.join(map(str, self.args)))
186 
187  def tostring(self):
188  if not self.args:
189  return self.name
190  return '%s<%s>' % (self.name, ', '.join(map(tostring, self.args)))
191 
192  def normalize(self, kind='input'):
193  """Normalize bracket for given kind
194  """
195  assert kind in ['input', 'output'], kind
196  if self.is_column_any() and self.args:
197  return Bracket(self.name + ''.join(map(str, self.args)))
198  if kind == 'input':
199  if self.name == 'Cursor':
200  args = [(a if a.is_column_any() else Bracket('Column', args=(a,))).normalize(kind=kind) for a in self.args]
201  return Bracket(self.name, tuple(args))
202  if kind == 'output':
203  if not self.is_column_any():
204  return Bracket('Column', args=(self,)).normalize(kind=kind)
205  return self
206 
207  def apply_cursor(self):
208  """Apply cursor to a non-cursor column argument type.
209  TODO: this method is currently unused but we should apply
210  cursor to all input column arguments in order to distingush
211  signatures like:
212  foo(Cursor(Column<int32>, Column<float>)) -> Column<int32>
213  foo(Cursor(Column<int32>), Cursor(Column<float>)) -> Column<int32>
214  that at the moment are treated as the same :(
215  """
216  if self.is_column():
217  return Bracket('Cursor', args=(self,))
218  return self
219 
220  def apply_column(self):
221  if not self.is_column() and not self.is_column_list():
222  return Bracket('Column' + self.name)
223  return self
224 
225  def apply_namespace(self, ns='ExtArgumentType'):
226  if self.name == 'Cursor':
227  return Bracket(ns + '::' + self.name, args=tuple(a.apply_namespace(ns=ns) for a in self.args))
228  if not self.name.startswith(ns + '::'):
229  return Bracket(ns + '::' + self.name)
230  return self
231 
232  def is_cursor(self):
233  return self.name.rsplit("::", 1)[-1] == 'Cursor'
234 
235  def is_array(self):
236  return self.name.rsplit("::", 1)[-1].startswith('Array')
237 
238  def is_column_any(self):
239  return self.name.rsplit("::", 1)[-1].startswith('Column')
240 
241  def is_column_list(self):
242  return self.name.rsplit("::", 1)[-1].startswith('ColumnList')
243 
244  def is_column(self):
245  return self.name.rsplit("::", 1)[-1].startswith('Column') and not self.is_column_list()
246 
248  return self.name.rsplit("::", 1)[-1].endswith('TextEncodingDict')
249 
251  return self.name.rsplit("::", 1)[-1] == 'ArrayTextEncodingDict'
252 
254  return self.name.rsplit("::", 1)[-1] == 'ColumnTextEncodingDict'
255 
257  return self.name.rsplit("::", 1)[-1] == 'ColumnArrayTextEncodingDict'
258 
260  return self.name.rsplit("::", 1)[-1] == 'ColumnListTextEncodingDict'
261 
263  return self.name.rsplit("::", 1)[-1] in OutputBufferSizeTypes
264 
265  def is_row_multiplier(self):
266  return self.name.rsplit("::", 1)[-1] == 'kUserSpecifiedRowMultiplier'
267 
268  def is_arg_sizer(self):
269  return self.name.rsplit("::", 1)[-1] == 'kPreFlightParameter'
270 
271  def is_user_specified(self):
272  # Return True if given argument cannot specified by user
273  if self.is_output_buffer_sizer():
274  return self.name.rsplit("::", 1)[-1] not in ('kConstant', 'kTableFunctionSpecifiedParameter', 'kPreFlightParameter')
275  return True
276 
277  def format_sizer(self):
278  val = 0 if self.is_arg_sizer() else self.args[0]
279  return 'TableFunctionOutputRowSizer{OutputBufferSizeType::%s, %s}' % (self.name, val)
280 
281  def get_cpp_type(self):
282  name = self.name.rsplit("::", 1)[-1]
283  clsname = None
284  subclsname = None
285  if name.startswith('ColumnList'):
286  name = name.lstrip('ColumnList')
287  clsname = 'ColumnList'
288  elif name.startswith('Column'):
289  name = name.lstrip('Column')
290  clsname = 'Column'
291  if name.startswith('Array'):
292  name = name.lstrip('Array')
293  if clsname is None:
294  clsname = 'Array'
295  else:
296  subclsname = 'Array'
297 
298  if name.startswith('Bool'):
299  ctype = name.lower()
300  elif name.startswith('Int'):
301  ctype = name.lower() + '_t'
302  elif name in ['Double', 'Float']:
303  ctype = name.lower()
304  elif name == 'TextEncodingDict':
305  ctype = name
306  elif name == 'TextEncodingNone':
307  ctype = name
308  elif name == 'Timestamp':
309  ctype = name
310  elif name == 'DayTimeInterval':
311  ctype = name
312  elif name == 'YearMonthTimeInterval':
313  ctype = name
314  else:
315  raise NotImplementedError(self)
316  if clsname is None:
317  return ctype
318  if subclsname is None:
319  return '%s<%s>' % (clsname, ctype)
320  return '%s<%s<%s>>' % (clsname, subclsname, ctype)
321 
322  def format_cpp_type(self, idx, use_generic_arg_name=False, real_arg_name=None, is_input=True):
323  col_typs = ('Column', 'ColumnList')
324  literal_ref_typs = ('TextEncodingNone',)
325  if use_generic_arg_name:
326  arg_name = 'input' + str(idx) if is_input else 'output' + str(idx)
327  elif real_arg_name is not None:
328  arg_name = real_arg_name
329  else:
330  # in some cases, the real arg name is not specified
331  arg_name = 'input' + str(idx) if is_input else 'output' + str(idx)
332  const = 'const ' if is_input else ''
333  cpp_type = self.get_cpp_type()
334  if any(cpp_type.startswith(t) for t in col_typs + literal_ref_typs):
335  return '%s%s& %s' % (const, cpp_type, arg_name), arg_name
336  else:
337  return '%s %s' % (cpp_type, arg_name), arg_name
338 
339  @classmethod
340  def parse(cls, typ):
341  """typ is a string in format NAME<ARGS> or NAME
342 
343  Returns Bracket instance.
344  """
345  i = typ.find('<')
346  if i == -1:
347  name = typ.strip()
348  args = None
349  else:
350  assert typ.endswith('>'), typ
351  name = typ[:i].strip()
352  args = []
353  rest = typ[i + 1:-1].strip()
354  while rest:
355  i = find_comma(rest)
356  if i == -1:
357  a, rest = rest, ''
358  else:
359  a, rest = rest[:i].rstrip(), rest[i + 1:].lstrip()
360  args.append(cls.parse(a))
361  args = tuple(args)
362 
363  name = translate_map.get(name, name)
364  return cls(name, args)
365 
366 
367 def find_comma(line):
368  d = 0
369  for i, c in enumerate(line):
370  if c in '<([{':
371  d += 1
372  elif c in '>)]{':
373  d -= 1
374  elif d == 0 and c == ',':
375  return i
376  return -1
377 
378 
380  # TODO: try to parse the line to be certain about completeness.
381  # `$=>$' is used to separate the UDTF signature and the expected result
382  return line.endswith(',') or line.endswith('->') or line.endswith(separator) or line.endswith('|')
383 
384 
385 def is_identifier_cursor(identifier):
386  return identifier.lower() == 'cursor'
387 
388 
389 # fmt: on
390 
391 
392 class TokenizeException(Exception):
393  pass
394 
395 
396 class ParserException(Exception):
397  pass
398 
399 
400 class TransformerException(Exception):
401  pass
402 
403 
404 class Token:
405  LESS = 1 # <
406  GREATER = 2 # >
407  COMMA = 3 # ,
408  EQUAL = 4 # =
409  RARROW = 5 # ->
410  STRING = 6 # reserved for string constants
411  NUMBER = 7 #
412  VBAR = 8 # |
413  BANG = 9 # !
414  LPAR = 10 # (
415  RPAR = 11 # )
416  LSQB = 12 # [
417  RSQB = 13 # ]
418  IDENTIFIER = 14 #
419  COLON = 15 # :
420 
421  def __init__(self, type, lexeme):
422  """
423  Parameters
424  ----------
425  type : int
426  One of the tokens in the list above
427  lexeme : str
428  Corresponding string in the text
429  """
430  self.type = type
431  self.lexeme = lexeme
432 
433  @classmethod
434  def tok_name(cls, token):
435  names = {
436  Token.LESS: "LESS",
437  Token.GREATER: "GREATER",
438  Token.COMMA: "COMMA",
439  Token.EQUAL: "EQUAL",
440  Token.RARROW: "RARROW",
441  Token.STRING: "STRING",
442  Token.NUMBER: "NUMBER",
443  Token.VBAR: "VBAR",
444  Token.BANG: "BANG",
445  Token.LPAR: "LPAR",
446  Token.RPAR: "RPAR",
447  Token.LSQB: "LSQB",
448  Token.RSQB: "RSQB",
449  Token.IDENTIFIER: "IDENTIFIER",
450  Token.COLON: "COLON",
451  }
452  return names.get(token)
453 
454  def __str__(self):
455  return 'Token(%s, "%s")' % (Token.tok_name(self.type), self.lexeme)
456 
457  __repr__ = __str__
458 
459 
460 class Tokenize:
461  def __init__(self, line):
462  self._line = line
463  self._tokens = []
464  self.start = 0
465  self.curr = 0
466  self.tokenize()
467 
468  @property
469  def line(self):
470  return self._line
471 
472  @property
473  def tokens(self):
474  return self._tokens
475 
476  def tokenize(self):
477  while not self.is_at_end():
478  self.start = self.curr
479 
480  if self.is_token_whitespace():
481  self.consume_whitespace()
482  elif self.is_digit():
483  self.consume_number()
484  elif self.is_token_string():
485  self.consume_string()
486  elif self.is_token_identifier():
487  self.consume_identifier()
488  elif self.can_token_be_double_char():
489  self.consume_double_char()
490  else:
491  self.consume_single_char()
492 
493  def is_at_end(self):
494  return len(self.line) == self.curr
495 
496  def current_token(self):
497  return self.line[self.start:self.curr + 1]
498 
499  def add_token(self, type):
500  lexeme = self.line[self.start:self.curr + 1]
501  self._tokens.append(Token(type, lexeme))
502 
503  def lookahead(self):
504  if self.curr + 1 >= len(self.line):
505  return None
506  return self.line[self.curr + 1]
507 
508  def advance(self):
509  self.curr += 1
510 
511  def peek(self):
512  return self.line[self.curr]
513 
515  char = self.peek()
516  return char in ("-",)
517 
519  ahead = self.lookahead()
520  if ahead == ">":
521  self.advance()
522  self.add_token(Token.RARROW) # ->
523  self.advance()
524  else:
525  self.raise_tokenize_error()
526 
528  char = self.peek()
529  if char == "(":
530  self.add_token(Token.LPAR)
531  elif char == ")":
532  self.add_token(Token.RPAR)
533  elif char == "<":
534  self.add_token(Token.LESS)
535  elif char == ">":
536  self.add_token(Token.GREATER)
537  elif char == ",":
538  self.add_token(Token.COMMA)
539  elif char == "=":
540  self.add_token(Token.EQUAL)
541  elif char == "|":
542  self.add_token(Token.VBAR)
543  elif char == "!":
544  self.add_token(Token.BANG)
545  elif char == "[":
546  self.add_token(Token.LSQB)
547  elif char == "]":
548  self.add_token(Token.RSQB)
549  elif char == ":":
550  self.add_token(Token.COLON)
551  else:
552  self.raise_tokenize_error()
553  self.advance()
554 
556  self.advance()
557 
558  def consume_string(self):
559  """
560  STRING: \".*?\"
561  """
562  while True:
563  char = self.lookahead()
564  curr = self.peek()
565  if char == '"' and curr != '\\':
566  self.advance()
567  break
568  self.advance()
569  self.add_token(Token.STRING)
570  self.advance()
571 
572  def consume_number(self):
573  """
574  NUMBER: [0-9]+
575  """
576  while True:
577  char = self.lookahead()
578  if char and char.isdigit():
579  self.advance()
580  else:
581  break
582  self.add_token(Token.NUMBER)
583  self.advance()
584 
586  """
587  IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
588  """
589  while True:
590  char = self.lookahead()
591  if char and char.isalnum() or char == "_":
592  self.advance()
593  else:
594  break
595  self.add_token(Token.IDENTIFIER)
596  self.advance()
597 
599  return self.peek().isalpha() or self.peek() == "_"
600 
601  def is_token_string(self):
602  return self.peek() == '"'
603 
604  def is_digit(self):
605  return self.peek().isdigit()
606 
607  def is_alpha(self):
608  return self.peek().isalpha()
609 
611  return self.peek().isspace()
612 
614  curr = self.curr
615  char = self.peek()
616  raise TokenizeException(
617  'Could not match char "%s" at pos %d on line\n %s' % (char, curr, self.line)
618  )
619 
620 
621 class AstVisitor(object):
622  __metaclass__ = ABC
623 
624  @abstractmethod
625  def visit_udtf_node(self, node):
626  pass
627 
628  @abstractmethod
629  def visit_composed_node(self, node):
630  pass
631 
632  @abstractmethod
633  def visit_arg_node(self, node):
634  pass
635 
636  @abstractmethod
637  def visit_primitive_node(self, node):
638  pass
639 
640  @abstractmethod
641  def visit_annotation_node(self, node):
642  pass
643 
644  @abstractmethod
645  def visit_template_node(self, node):
646  pass
647 
648 
649 class AstTransformer(AstVisitor):
650  """Only overload the methods you need"""
651 
652  def visit_udtf_node(self, udtf_node):
653  udtf = copy.copy(udtf_node)
654  udtf.inputs = [arg.accept(self) for arg in udtf.inputs]
655  udtf.outputs = [arg.accept(self) for arg in udtf.outputs]
656  if udtf.templates:
657  udtf.templates = [t.accept(self) for t in udtf.templates]
658  udtf.annotations = [annot.accept(self) for annot in udtf.annotations]
659  return udtf
660 
661  def visit_composed_node(self, composed_node):
662  c = copy.copy(composed_node)
663  c.inner = [i.accept(self) for i in c.inner]
664  return c
665 
666  def visit_arg_node(self, arg_node):
667  arg_node = copy.copy(arg_node)
668  arg_node.type = arg_node.type.accept(self)
669  if arg_node.annotations:
670  arg_node.annotations = [a.accept(self) for a in arg_node.annotations]
671  return arg_node
672 
673  def visit_primitive_node(self, primitive_node):
674  return copy.copy(primitive_node)
675 
676  def visit_template_node(self, template_node):
677  return copy.copy(template_node)
678 
679  def visit_annotation_node(self, annotation_node):
680  return copy.copy(annotation_node)
681 
682 
684  """Returns a line formatted. Useful for testing"""
685 
686  def visit_udtf_node(self, udtf_node):
687  name = udtf_node.name
688  inputs = ", ".join([arg.accept(self) for arg in udtf_node.inputs])
689  outputs = ", ".join([arg.accept(self) for arg in udtf_node.outputs])
690  annotations = "| ".join([annot.accept(self) for annot in udtf_node.annotations])
691  sizer = " | " + udtf_node.sizer.accept(self) if udtf_node.sizer else ""
692  if annotations:
693  annotations = ' | ' + annotations
694  if udtf_node.templates:
695  templates = ", ".join([t.accept(self) for t in udtf_node.templates])
696  return "%s(%s)%s -> %s, %s%s" % (name, inputs, annotations, outputs, templates, sizer)
697  else:
698  return "%s(%s)%s -> %s%s" % (name, inputs, annotations, outputs, sizer)
699 
700  def visit_template_node(self, template_node):
701  # T=[T1, T2, ..., TN]
702  key = template_node.key
703  types = ['"%s"' % typ for typ in template_node.types]
704  return "%s=[%s]" % (key, ", ".join(types))
705 
706  def visit_annotation_node(self, annotation_node):
707  # key=value
708  key = annotation_node.key
709  value = annotation_node.value
710  if isinstance(value, list):
711  return "%s=[%s]" % (key, ','.join([v.accept(self) for v in value]))
712  return "%s=%s" % (key, value)
713 
714  def visit_arg_node(self, arg_node):
715  # type | annotation
716  typ = arg_node.type.accept(self)
717  if arg_node.annotations:
718  ann = " | ".join([a.accept(self) for a in arg_node.annotations])
719  s = "%s | %s" % (typ, ann)
720  else:
721  s = "%s" % (typ,)
722  # insert input_id=args<0> if input_id is not specified
723  if s == "ColumnTextEncodingDict" and arg_node.kind == "output":
724  return s + " | input_id=args<0>"
725  return s
726 
727  def visit_composed_node(self, composed_node):
728  T = composed_node.inner[0].accept(self)
729  if composed_node.is_array():
730  # Array<T>
731  assert len(composed_node.inner) == 1
732  return "Array" + T
733  if composed_node.is_column():
734  # Column<T>
735  assert len(composed_node.inner) == 1
736  return "Column" + T
737  if composed_node.is_column_list():
738  # ColumnList<T>
739  assert len(composed_node.inner) == 1
740  return "ColumnList" + T
741  if composed_node.is_output_buffer_sizer():
742  # kConstant<N>
743  N = T
744  assert len(composed_node.inner) == 1
745  return translate_map.get(composed_node.type) + "<%s>" % (N,)
746  if composed_node.is_cursor():
747  # Cursor<T1, T2, ..., TN>
748  Ts = ", ".join([i.accept(self) for i in composed_node.inner])
749  return "Cursor<%s>" % (Ts)
750  raise ValueError(composed_node)
751 
752  def visit_primitive_node(self, primitive_node):
753  t = primitive_node.type
754  if primitive_node.is_output_buffer_sizer():
755  # arg_pos is zero-based
756  return translate_map.get(t, t) + "<%d>" % (
757  primitive_node.get_parent(ArgNode).arg_pos + 1,
758  )
759  return translate_map.get(t, t)
760 
761 
763  """Like AstPrinter but returns a node instead of a string
764  """
765 
766 
767 def product_dict(**kwargs):
768  keys = kwargs.keys()
769  vals = kwargs.values()
770  for instance in itertools.product(*vals):
771  yield dict(zip(keys, instance))
772 
773 
775  """Expand template definition into multiple inputs"""
776 
777  def visit_udtf_node(self, udtf_node):
778  if not udtf_node.templates:
779  return udtf_node
780 
781  udtfs = dict()
782 
783  d = dict([(node.key, node.types) for node in udtf_node.templates])
784  name = udtf_node.name
785 
786  for product in product_dict(**d):
787  self.mapping_dict = product
788  inputs = [input_arg.accept(self) for input_arg in udtf_node.inputs]
789  outputs = [output_arg.accept(self) for output_arg in udtf_node.outputs]
790  udtf = UdtfNode(name, inputs, outputs, udtf_node.annotations, None, udtf_node.sizer, udtf_node.line)
791  udtfs[str(udtf)] = udtf
792  self.mapping_dict = {}
793 
794  udtfs = list(udtfs.values())
795 
796  if len(udtfs) == 1:
797  return udtfs[0]
798 
799  return udtfs
800 
801  def visit_composed_node(self, composed_node):
802  typ = composed_node.type
803  typ = self.mapping_dict.get(typ, typ)
804 
805  inner = [i.accept(self) for i in composed_node.inner]
806  return composed_node.copy(typ, inner)
807 
808  def visit_primitive_node(self, primitive_node):
809  typ = primitive_node.type
810  typ = self.mapping_dict.get(typ, typ)
811  return primitive_node.copy(typ)
812 
813 
815  def visit_primitive_node(self, primitive_node):
816  """
817  * Fix kUserSpecifiedRowMultiplier without a pos arg
818  """
819  t = primitive_node.type
820 
821  if primitive_node.is_output_buffer_sizer():
822  pos = PrimitiveNode(str(primitive_node.get_parent(ArgNode).arg_pos + 1))
823  node = ComposedNode(t, inner=[pos])
824  return node
825 
826  return primitive_node
827 
828 
830  def visit_primitive_node(self, primitive_node):
831  """
832  * Rename nodes using translate_map as dictionary
833  int -> Int32
834  float -> Float
835  """
836  t = primitive_node.type
837  return primitive_node.copy(translate_map.get(t, t))
838 
839 
841  def visit_udtf_node(self, udtf_node):
842  """
843  * Add default_input_id to Column(List)<TextEncodingDict> without one
844  """
845  udtf_node = super(type(self), self).visit_udtf_node(udtf_node)
846  # add default input_id
847  default_input_id = None
848  for idx, t in enumerate(udtf_node.inputs):
849 
850  if not isinstance(t.type, ComposedNode):
851  continue
852  if default_input_id is not None:
853  pass
854  elif t.type.is_column_text_encoding_dict() or t.type.is_column_array_text_encoding_dict():
855  default_input_id = AnnotationNode('input_id', 'args<%s>' % (idx,))
856  elif t.type.is_column_list_text_encoding_dict():
857  default_input_id = AnnotationNode('input_id', 'args<%s, 0>' % (idx,))
858 
859  for t in udtf_node.outputs:
860  if isinstance(t.type, ComposedNode) and t.type.is_any_text_encoding_dict():
861  for a in t.annotations:
862  if a.key == 'input_id':
863  break
864  else:
865  if default_input_id is None:
866  raise TypeError('Cannot parse line "%s".\n'
867  'Missing TextEncodingDict input?' %
868  (udtf_node.line))
869  t.annotations.append(default_input_id)
870 
871  return udtf_node
872 
873 
875 
876  def visit_udtf_node(self, udtf_node):
877  """
878  * Generate fields annotation to Cursor if non-existing
879  """
880  udtf_node = super(type(self), self).visit_udtf_node(udtf_node)
881 
882  for t in udtf_node.inputs:
883 
884  if not isinstance(t.type, ComposedNode):
885  continue
886 
887  if t.type.is_cursor() and t.get_annotation('fields') is None:
888  fields = list(PrimitiveNode(a.get_annotation('name', 'field%s' % i)) for i, a in enumerate(t.type.inner))
889  t.annotations.append(AnnotationNode('fields', fields))
890 
891  return udtf_node
892 
894  """
895  * Checks for supported annotations in a UDTF
896  """
897  def visit_udtf_node(self, udtf_node):
898  for t in udtf_node.inputs:
899  for a in t.annotations:
900  if a.key not in SupportedAnnotations:
901  raise TransformerException('unknown input annotation: `%s`' % (a.key))
902  for t in udtf_node.outputs:
903  for a in t.annotations:
904  if a.key not in SupportedAnnotations:
905  raise TransformerException('unknown output annotation: `%s`' % (a.key))
906  for annot in udtf_node.annotations:
907  if annot.key not in SupportedFunctionAnnotations:
908  raise TransformerException('unknown function annotation: `%s`' % (annot.key))
909  if annot.value.lower() in ['enable', 'on', '1', 'true']:
910  annot.value = '1'
911  elif annot.value.lower() in ['disable', 'off', '0', 'false']:
912  annot.value = '0'
913  return udtf_node
914 
915 
917  """
918  * Append require annotation if range is used
919  """
920  def visit_arg_node(self, arg_node):
921  for ann in arg_node.annotations:
922  if ann.key == 'range':
923  name = arg_node.get_annotation('name')
924  if name is None:
925  raise TransformerException('"range" requires a named argument')
926 
927  l = ann.value
928  if len(l) == 2:
929  lo, hi = ann.value
930  value = '"{lo} <= {name} && {name} <= {hi}"'.format(lo=lo, hi=hi, name=name)
931  else:
932  raise TransformerException('"range" requires an interval. Got {l}'.format(l=l))
933  arg_node.set_annotation('require', value)
934  return arg_node
935 
936 
938 
939  def visit_udtf_node(self, udtf_node):
940  name = udtf_node.name
941  inputs = []
942  input_annotations = []
943  outputs = []
944  output_annotations = []
945  function_annotations = []
946  sizer = udtf_node.sizer
947 
948  for i in udtf_node.inputs:
949  decl = i.accept(self)
950  inputs.append(decl)
951  input_annotations.append(decl.annotations)
952 
953  for o in udtf_node.outputs:
954  decl = o.accept(self)
955  outputs.append(decl.type)
956  output_annotations.append(decl.annotations)
957 
958  for annot in udtf_node.annotations:
959  annot = annot.accept(self)
960  function_annotations.append(annot)
961 
962  return Signature(name, inputs, outputs, input_annotations, output_annotations, function_annotations, sizer)
963 
964  def visit_arg_node(self, arg_node):
965  t = arg_node.type.accept(self)
966  anns = [a.accept(self) for a in arg_node.annotations]
967  return Declaration(t, anns)
968 
969  def visit_composed_node(self, composed_node):
970  typ = translate_map.get(composed_node.type, composed_node.type)
971  inner = [i.accept(self) for i in composed_node.inner]
972  if composed_node.is_cursor():
973  inner = list(map(lambda x: x.apply_column(), inner))
974  return Bracket(typ, args=tuple(inner))
975  elif composed_node.is_output_buffer_sizer():
976  return Bracket(typ, args=tuple(inner))
977  else:
978  return Bracket(typ + str(inner[0]))
979 
980  def visit_primitive_node(self, primitive_node):
981  t = primitive_node.type
982  return Bracket(t)
983 
984  def visit_annotation_node(self, annotation_node):
985  key = annotation_node.key
986  value = annotation_node.value
987  return (key, value)
988 
989 
990 class Node(object):
991 
992  __metaclass__ = ABC
993 
994  @abstractmethod
995  def accept(self, visitor):
996  pass
997 
998  @abstractmethod
999  def __str__(self):
1000  pass
1001 
1002  def get_parent(self, cls):
1003  if isinstance(self, cls):
1004  return self
1005 
1006  if self.parent is not None:
1007  return self.parent.get_parent(cls)
1008 
1009  raise ValueError("could not find parent with given class %s" % (cls))
1010 
1011  def copy(self, *args):
1012  other = self.__class__(*args)
1013 
1014  # copy parent and arg_pos
1015  for attr in ['parent', 'arg_pos']:
1016  if attr in self.__dict__:
1017  setattr(other, attr, getattr(self, attr))
1018 
1019  return other
1020 
1021 
1022 class IterableNode(Iterable):
1023  pass
1024 
1025 
1026 class UdtfNode(Node, IterableNode):
1027 
1028  def __init__(self, name, inputs, outputs, annotations, templates, sizer, line):
1029  """
1030  Parameters
1031  ----------
1032  name : str
1033  inputs : list[ArgNode]
1034  outputs : list[ArgNode]
1035  annotations : Optional[List[AnnotationNode]]
1036  templates : Optional[list[TemplateNode]]
1037  sizer : Optional[str]
1038  line: str
1039  """
1040  self.name = name
1041  self.inputs = inputs
1042  self.outputs = outputs
1043  self.annotations = annotations
1044  self.templates = templates
1045  self.sizer = sizer
1046  self.line = line
1047 
1048  def accept(self, visitor):
1049  return visitor.visit_udtf_node(self)
1050 
1051  def __str__(self):
1052  name = self.name
1053  inputs = [str(i) for i in self.inputs]
1054  outputs = [str(o) for o in self.outputs]
1055  annotations = [str(a) for a in self.annotations]
1056  sizer = "| %s" % str(self.sizer) if self.sizer else ""
1057  if self.templates:
1058  templates = [str(t) for t in self.templates]
1059  if annotations:
1060  return "UDTF: %s (%s) | %s -> %s, %s %s" % (name, inputs, annotations, outputs, templates, sizer)
1061  else:
1062  return "UDTF: %s (%s) -> %s, %s %s" % (name, inputs, outputs, templates, sizer)
1063  else:
1064  if annotations:
1065  return "UDTF: %s (%s) | %s -> %s %s" % (name, inputs, annotations, outputs, sizer)
1066  else:
1067  return "UDTF: %s (%s) -> %s %s" % (name, inputs, outputs, sizer)
1068 
1069  def __iter__(self):
1070  for i in self.inputs:
1071  yield i
1072  for o in self.outputs:
1073  yield o
1074  for a in self.annotations:
1075  yield a
1076  if self.templates:
1077  for t in self.templates:
1078  yield t
1079 
1080  __repr__ = __str__
1081 
1082 
1084 
1085  def __init__(self, type, annotations):
1086  """
1087  Parameters
1088  ----------
1089  type : TypeNode
1090  annotations : List[AnnotationNode]
1091  """
1092  self.type = type
1093  self.annotations = annotations
1094  self.arg_pos = None
1095 
1096  def accept(self, visitor):
1097  return visitor.visit_arg_node(self)
1098 
1099  def __str__(self):
1100  t = str(self.type)
1101  anns = ""
1102  if self.annotations:
1103  anns = " | ".join([str(a) for a in self.annotations])
1104  return "ArgNode(%s | %s)" % (t, anns)
1105  return "ArgNode(%s)" % (t)
1106 
1107  def __iter__(self):
1108  yield self.type
1109  for a in self.annotations:
1110  yield a
1111 
1112  __repr__ = __str__
1113 
1114  def get_annotation(self, key, default=None):
1115  for a in self.annotations:
1116  if a.key == key:
1117  return a.value
1118  return default
1119 
1120  def set_annotation(self, key, value):
1121  found = False
1122  for i, a in enumerate(self.annotations):
1123  if a.key == key:
1124  assert not found, (i, a) # annotations with the same key not supported
1125  self.annotations[i] = AnnotationNode(key, value)
1126  found = True
1127  if not found:
1128  self.annotations.append(AnnotationNode(key, value))
1129 
1130 
1132  def is_array(self):
1133  return self.type == "Array"
1134 
1135  def is_column_any(self):
1136  return self.is_column() or self.is_column_list()
1137 
1138  def is_column(self):
1139  return self.type == "Column"
1140 
1141  def is_column_list(self):
1142  return self.type == "ColumnList"
1143 
1144  def is_cursor(self):
1145  return self.type == "Cursor"
1146 
1148  t = self.type
1149  return translate_map.get(t, t) in OutputBufferSizeTypes
1150 
1151 
1153 
1154  def __init__(self, type):
1155  """
1156  Parameters
1157  ----------
1158  type : str
1159  """
1160  self.type = type
1161 
1162  def accept(self, visitor):
1163  return visitor.visit_primitive_node(self)
1164 
1165  def __str__(self):
1166  return self.accept(AstPrinter())
1167 
1169  return self.type == 'TextEncodingDict'
1170 
1172  return self.type == 'ArrayTextEncodingDict'
1173 
1174  __repr__ = __str__
1175 
1176 
1178 
1179  def __init__(self, type, inner):
1180  """
1181  Parameters
1182  ----------
1183  type : str
1184  inner : list[TypeNode]
1185  """
1186  self.type = type
1187  self.inner = inner
1188 
1189  def accept(self, visitor):
1190  return visitor.visit_composed_node(self)
1191 
1192  def cursor_length(self):
1193  assert self.is_cursor()
1194  return len(self.inner)
1195 
1196  def __str__(self):
1197  i = ", ".join([str(i) for i in self.inner])
1198  return "Composed(%s<%s>)" % (self.type, i)
1199 
1200  def __iter__(self):
1201  for i in self.inner:
1202  yield i
1203 
1205  return False
1206 
1208  return self.is_array() and self.inner[0].is_text_encoding_dict()
1209 
1211  return self.is_column() and self.inner[0].is_text_encoding_dict()
1212 
1214  return self.is_column_list() and self.inner[0].is_text_encoding_dict()
1215 
1217  return self.is_column() and self.inner[0].is_array_text_encoding_dict()
1218 
1220  return self.inner[0].is_text_encoding_dict() or self.inner[0].is_array_text_encoding_dict()
1221 
1222  __repr__ = __str__
1223 
1224 
1226 
1227  def __init__(self, key, value):
1228  """
1229  Parameters
1230  ----------
1231  key : str
1232  value : {str, list}
1233  """
1234  self.key = key
1235  self.value = value
1236 
1237  def accept(self, visitor):
1238  return visitor.visit_annotation_node(self)
1239 
1240  def __str__(self):
1241  printer = AstPrinter()
1242  return self.accept(printer)
1243 
1244  __repr__ = __str__
1245 
1246 
1248 
1249  def __init__(self, key, types):
1250  """
1251  Parameters
1252  ----------
1253  key : str
1254  types : tuple[str]
1255  """
1256  self.key = key
1257  self.types = types
1258 
1259  def accept(self, visitor):
1260  return visitor.visit_template_node(self)
1261 
1262  def __str__(self):
1263  printer = AstPrinter()
1264  return self.accept(printer)
1265 
1266  __repr__ = __str__
1267 
1268 
1269 class Pipeline(object):
1270  def __init__(self, *passes):
1271  self.passes = passes
1272 
1273  def __call__(self, ast_list):
1274  if not isinstance(ast_list, list):
1275  ast_list = [ast_list]
1276 
1277  for c in self.passes:
1278  ast_list = [ast.accept(c()) for ast in ast_list]
1279  ast_list = itertools.chain.from_iterable( # flatten the list
1280  map(lambda x: x if isinstance(x, list) else [x], ast_list))
1281 
1282  return list(ast_list)
1283 
1284 
1285 class Parser:
1286  def __init__(self, line):
1287  self._tokens = Tokenize(line).tokens
1288  self._curr = 0
1289  self.line = line
1290 
1291  @property
1292  def tokens(self):
1293  return self._tokens
1294 
1295  def is_at_end(self):
1296  return self._curr >= len(self._tokens)
1297 
1298  def current_token(self):
1299  return self._tokens[self._curr]
1300 
1301  def advance(self):
1302  self._curr += 1
1303 
1304  def expect(self, expected_type):
1305  curr_token = self.current_token()
1306  msg = "Expected token %s but got %s at pos %d.\n Tokens: %s" % (
1307  curr_token,
1308  Token.tok_name(expected_type),
1309  self._curr,
1310  self._tokens,
1311  )
1312  assert curr_token.type == expected_type, msg
1313  self.advance()
1314 
1315  def consume(self, expected_type):
1316  """consumes the current token iff its type matches the
1317  expected_type. Otherwise, an error is raised
1318  """
1319  curr_token = self.current_token()
1320  if curr_token.type == expected_type:
1321  self.advance()
1322  return curr_token
1323  else:
1324  expected_token = Token.tok_name(expected_type)
1325  self.raise_parser_error(
1326  'Token mismatch at function consume. '
1327  'Expected type "%s" but got token "%s"\n\n'
1328  'Tokens: %s\n' % (expected_token, curr_token, self._tokens)
1329  )
1330 
1331  def current_pos(self):
1332  return self._curr
1333 
1334  def raise_parser_error(self, msg=None):
1335  if not msg:
1336  token = self.current_token()
1337  pos = self.current_pos()
1338  tokens = self.tokens
1339  msg = "\n\nError while trying to parse token %s at pos %d.\n" "Tokens: %s" % (
1340  token,
1341  pos,
1342  tokens,
1343  )
1344  raise ParserException(msg)
1345 
1346  def match(self, expected_type):
1347  curr_token = self.current_token()
1348  return curr_token.type == expected_type
1349 
1350  def lookahead(self):
1351  return self._tokens[self._curr + 1]
1352 
1353  def parse_udtf(self):
1354  """fmt: off
1355 
1356  udtf: IDENTIFIER "(" (args)? ")" ("|" annotation)* "->" args ("," templates)? ("|" "output_row_size" "=" primitive)?
1357 
1358  fmt: on
1359  """
1360  name = self.parse_identifier()
1361  self.expect(Token.LPAR) # (
1362  input_args = []
1363  if not self.match(Token.RPAR):
1364  input_args = self.parse_args()
1365  self.expect(Token.RPAR) # )
1366  annotations = []
1367  while not self.is_at_end() and self.match(Token.VBAR): # |
1368  self.consume(Token.VBAR)
1369  annotations.append(self.parse_annotation())
1370  self.expect(Token.RARROW) # ->
1371  output_args = self.parse_args()
1372 
1373  templates = None
1374  if not self.is_at_end() and self.match(Token.COMMA):
1375  self.consume(Token.COMMA)
1376  templates = self.parse_templates()
1377 
1378  sizer = None
1379  if not self.is_at_end() and self.match(Token.VBAR):
1380  self.consume(Token.VBAR)
1381  idtn = self.parse_identifier()
1382  assert idtn == "output_row_size", idtn
1383  self.consume(Token.EQUAL)
1384  node = self.parse_primitive()
1385  key = "kPreFlightParameter"
1386  sizer = AnnotationNode(key, value=node.type)
1387 
1388  # set arg_pos
1389  i = 0
1390  for arg in input_args:
1391  arg.arg_pos = i
1392  arg.kind = "input"
1393  i += arg.type.cursor_length() if arg.type.is_cursor() else 1
1394 
1395  for i, arg in enumerate(output_args):
1396  arg.arg_pos = i
1397  arg.kind = "output"
1398 
1399  return UdtfNode(name, input_args, output_args, annotations, templates, sizer, self.line)
1400 
1401  def parse_args(self):
1402  """fmt: off
1403 
1404  args: arg IDENTIFIER ("," arg)*
1405 
1406  fmt: on
1407  """
1408  args = []
1409  args.append(self.parse_arg())
1410  while not self.is_at_end() and self.match(Token.COMMA):
1411  curr = self._curr
1412  self.consume(Token.COMMA)
1413  self.parse_type() # assuming that we are not ending with COMMA
1414  if not self.is_at_end() and self.match(Token.EQUAL):
1415  # arg type cannot be assigned, so this must be a template specification
1416  self._curr = curr # step back and let the code below parse the templates
1417  break
1418  else:
1419  self._curr = curr + 1 # step back from self.parse_type(), parse_arg will parse it again
1420  args.append(self.parse_arg())
1421  return args
1422 
1423  def parse_arg(self):
1424  """fmt: off
1425 
1426  arg: type IDENTIFIER? ("|" annotation)*
1427 
1428  fmt: on
1429  """
1430  typ = self.parse_type()
1431 
1432  annotations = []
1433 
1434  if not self.is_at_end() and self.match(Token.IDENTIFIER):
1435  name = self.parse_identifier()
1436  annotations.append(AnnotationNode('name', name))
1437 
1438  while not self.is_at_end() and self.match(Token.VBAR):
1439  ahead = self.lookahead()
1440  if ahead.type == Token.IDENTIFIER and ahead.lexeme == 'output_row_size':
1441  break
1442  self.consume(Token.VBAR)
1443  annotations.append(self.parse_annotation())
1444 
1445  return ArgNode(typ, annotations)
1446 
1447  def parse_type(self):
1448  """fmt: off
1449 
1450  type: composed
1451  | primitive
1452 
1453  fmt: on
1454  """
1455  curr = self._curr # save state
1456  primitive = self.parse_primitive()
1457  if self.is_at_end():
1458  return primitive
1459 
1460  if not self.match(Token.LESS):
1461  return primitive
1462 
1463  self._curr = curr # return state
1464 
1465  return self.parse_composed()
1466 
1467  def parse_composed(self):
1468  """fmt: off
1469 
1470  composed: "Cursor" "<" arg ("," arg)* ">"
1471  | IDENTIFIER "<" type ("," type)* ">"
1472 
1473  fmt: on
1474  """
1475  idtn = self.parse_identifier()
1476  self.consume(Token.LESS)
1477  if is_identifier_cursor(idtn):
1478  inner = [self.parse_arg()]
1479  while self.match(Token.COMMA):
1480  self.consume(Token.COMMA)
1481  inner.append(self.parse_arg())
1482  else:
1483  inner = [self.parse_type()]
1484  while self.match(Token.COMMA):
1485  self.consume(Token.COMMA)
1486  inner.append(self.parse_type())
1487  self.consume(Token.GREATER)
1488  return ComposedNode(idtn, inner)
1489 
1490  def parse_primitive(self):
1491  """fmt: off
1492 
1493  primitive: IDENTIFIER
1494  | NUMBER
1495  | STRING
1496 
1497  fmt: on
1498  """
1499  if self.match(Token.IDENTIFIER):
1500  lexeme = self.parse_identifier()
1501  elif self.match(Token.NUMBER):
1502  lexeme = self.parse_number()
1503  elif self.match(Token.STRING):
1504  lexeme = self.parse_string()
1505  else:
1506  raise self.raise_parser_error()
1507  return PrimitiveNode(lexeme)
1508 
1509  def parse_templates(self):
1510  """fmt: off
1511 
1512  templates: template ("," template)*
1513 
1514  fmt: on
1515  """
1516  T = []
1517  T.append(self.parse_template())
1518  while not self.is_at_end() and self.match(Token.COMMA):
1519  self.consume(Token.COMMA)
1520  T.append(self.parse_template())
1521  return T
1522 
1523  def parse_template(self):
1524  """fmt: off
1525 
1526  template: IDENTIFIER "=" "[" IDENTIFIER ("," IDENTIFIER)* "]"
1527 
1528  fmt: on
1529  """
1530  key = self.parse_identifier()
1531  types = []
1532  self.consume(Token.EQUAL)
1533  self.consume(Token.LSQB)
1534  types.append(self.parse_identifier())
1535  while self.match(Token.COMMA):
1536  self.consume(Token.COMMA)
1537  types.append(self.parse_identifier())
1538  self.consume(Token.RSQB)
1539  return TemplateNode(key, tuple(types))
1540 
1541  def parse_annotation(self):
1542  """fmt: off
1543 
1544  annotation: IDENTIFIER "=" IDENTIFIER ("<" NUMBER ("," NUMBER) ">")?
1545  | IDENTIFIER "=" "[" PRIMITIVE? ("," PRIMITIVE)* "]"
1546  | "require" "=" STRING
1547 
1548  fmt: on
1549  """
1550  key = self.parse_identifier()
1551  self.consume(Token.EQUAL)
1552 
1553  if key == "require":
1554  value = self.parse_string()
1555  elif not self.is_at_end() and self.match(Token.LSQB):
1556  value = []
1557  self.consume(Token.LSQB)
1558  if not self.match(Token.RSQB):
1559  value.append(self.parse_primitive())
1560  while self.match(Token.COMMA):
1561  self.consume(Token.COMMA)
1562  value.append(self.parse_primitive())
1563  self.consume(Token.RSQB)
1564  else:
1565  value = self.parse_identifier()
1566  if not self.is_at_end() and self.match(Token.LESS):
1567  self.consume(Token.LESS)
1568  if self.match(Token.GREATER):
1569  value += "<%s>" % (-1) # Signifies no input
1570  else:
1571  num1 = self.parse_number()
1572  if self.match(Token.COMMA):
1573  self.consume(Token.COMMA)
1574  num2 = self.parse_number()
1575  value += "<%s,%s>" % (num1, num2)
1576  else:
1577  value += "<%s>" % (num1)
1578  self.consume(Token.GREATER)
1579  return AnnotationNode(key, value)
1580 
1581  def parse_identifier(self):
1582  """ fmt: off
1583 
1584  IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
1585 
1586  fmt: on
1587  """
1588  token = self.consume(Token.IDENTIFIER)
1589  return token.lexeme
1590 
1591  def parse_string(self):
1592  """ fmt: off
1593 
1594  STRING: \".*?\"
1595 
1596  fmt: on
1597  """
1598  token = self.consume(Token.STRING)
1599  return token.lexeme
1600 
1601  def parse_number(self):
1602  """ fmt: off
1603 
1604  NUMBER: [0-9]+
1605 
1606  fmt: on
1607  """
1608  token = self.consume(Token.NUMBER)
1609  return token.lexeme
1610 
1611  def parse(self):
1612  """fmt: off
1613 
1614  udtf: IDENTIFIER "(" (args)? ")" ("|" annotation)* "->" args ("," templates)? ("|" "output_row_size" "=" primitive)?
1615 
1616  args: arg ("," arg)*
1617 
1618  arg: type IDENTIFIER? ("|" annotation)*
1619 
1620  type: composed
1621  | primitive
1622 
1623  composed: "Cursor" "<" arg ("," arg)* ">"
1624  | IDENTIFIER "<" type ("," type)* ">"
1625 
1626  primitive: IDENTIFIER
1627  | NUMBER
1628  | STRING
1629 
1630  annotation: IDENTIFIER "=" IDENTIFIER ("<" NUMBER ("," NUMBER) ">")?
1631  | IDENTIFIER "=" "[" PRIMITIVE? ("," PRIMITIVE)* "]"
1632  | "require" "=" STRING
1633 
1634  templates: template ("," template)
1635  template: IDENTIFIER "=" "[" IDENTIFIER ("," IDENTIFIER)* "]"
1636 
1637  IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
1638  NUMBER: [0-9]+
1639  STRING: \".*?\"
1640 
1641  fmt: on
1642  """
1643  self._curr = 0
1644  udtf = self.parse_udtf()
1645 
1646  # set parent
1647  udtf.parent = None
1648  d = deque()
1649  d.append(udtf)
1650  while d:
1651  node = d.pop()
1652  if isinstance(node, Iterable):
1653  for child in node:
1654  child.parent = node
1655  d.append(child)
1656  return udtf
1657 
1658 
1659 # fmt: off
1660 def find_signatures(input_file):
1661  """Returns a list of parsed UDTF signatures."""
1662  signatures = []
1663 
1664  last_line = None
1665  for line in open(input_file).readlines():
1666  line = line.strip()
1667  if last_line is not None:
1668  line = last_line + ' ' + line
1669  last_line = None
1670  if not line.startswith('UDTF:'):
1671  continue
1672  if line_is_incomplete(line):
1673  last_line = line
1674  continue
1675  last_line = None
1676  line = line[5:].lstrip()
1677  i = line.find('(')
1678  j = line.find(')')
1679  if i == -1 or j == -1:
1680  sys.stderr.write('Invalid UDTF specification: `%s`. Skipping.\n' % (line))
1681  continue
1682 
1683  expected_result = None
1684  if separator in line:
1685  line, expected_result = line.split(separator, 1)
1686  expected_result = expected_result.strip().split(separator)
1687  expected_result = list(map(lambda s: s.strip(), expected_result))
1688 
1689  ast = Parser(line).parse()
1690 
1691  if expected_result is not None:
1692  # Template transformer expands templates into multiple lines
1693  try:
1694  result = Pipeline(TemplateTransformer,
1695  FieldAnnotationTransformer,
1696  TextEncodingDictTransformer,
1697  SupportedAnnotationsTransformer,
1698  RangeAnnotationTransformer,
1699  FixRowMultiplierPosArgTransformer,
1700  RenameNodesTransformer,
1701  AstPrinter)(ast)
1702  except TransformerException as msg:
1703  result = ['%s: %s' % (type(msg).__name__, msg)]
1704  assert len(result) == len(expected_result), "\n\tresult: %s \n!= \n\texpected: %s" % (
1705  '\n\t\t '.join(result),
1706  '\n\t\t '.join(expected_result)
1707  )
1708  assert set(result) == set(expected_result), "\n\tresult: %s != \n\texpected: %s" % (
1709  '\n\t\t '.join(result),
1710  '\n\t\t '.join(expected_result),
1711  )
1712 
1713  else:
1714  signature = Pipeline(TemplateTransformer,
1715  FieldAnnotationTransformer,
1716  TextEncodingDictTransformer,
1717  SupportedAnnotationsTransformer,
1718  RangeAnnotationTransformer,
1719  FixRowMultiplierPosArgTransformer,
1720  RenameNodesTransformer,
1721  DeclBracketTransformer)(ast)
1722 
1723  signatures.extend(signature)
1724 
1725  return signatures
1726 
1727 
1728 def format_function_args(input_types, output_types, uses_manager, use_generic_arg_name, emit_output_args):
1729  cpp_args = []
1730  name_args = []
1731 
1732  if uses_manager:
1733  cpp_args.append('TableFunctionManager& mgr')
1734  name_args.append('mgr')
1735 
1736  for idx, typ in enumerate(input_types):
1737  cpp_arg, name = typ.format_cpp_type(idx,
1738  use_generic_arg_name=use_generic_arg_name,
1739  is_input=True)
1740  cpp_args.append(cpp_arg)
1741  name_args.append(name)
1742 
1743  if emit_output_args:
1744  for idx, typ in enumerate(output_types):
1745  cpp_arg, name = typ.format_cpp_type(idx,
1746  use_generic_arg_name=use_generic_arg_name,
1747  is_input=False)
1748  cpp_args.append(cpp_arg)
1749  name_args.append(name)
1750 
1751  cpp_args = ', '.join(cpp_args)
1752  name_args = ', '.join(name_args)
1753  return cpp_args, name_args
1754 
1755 
1756 def build_template_function_call(caller, called, input_types, output_types, uses_manager):
1757  cpp_args, name_args = format_function_args(input_types,
1758  output_types,
1759  uses_manager,
1760  use_generic_arg_name=True,
1761  emit_output_args=True)
1762 
1763  template = ("EXTENSION_NOINLINE int32_t\n"
1764  "%s(%s) {\n"
1765  " return %s(%s);\n"
1766  "}\n") % (caller, cpp_args, called, name_args)
1767  return template
1768 
1769 
1770 def build_preflight_function(fn_name, sizer, input_types, output_types, uses_manager):
1771 
1772  def format_error_msg(err_msg, uses_manager):
1773  if uses_manager:
1774  return " return mgr.error_message(%s);\n" % (err_msg,)
1775  else:
1776  return " return table_function_error(%s);\n" % (err_msg,)
1777 
1778  cpp_args, _ = format_function_args(input_types,
1779  output_types,
1780  uses_manager,
1781  use_generic_arg_name=False,
1782  emit_output_args=False)
1783 
1784  if uses_manager:
1785  fn = "EXTENSION_NOINLINE int32_t\n"
1786  fn += "%s(%s) {\n" % (fn_name.lower() + "__preflight", cpp_args)
1787  else:
1788  fn = "EXTENSION_NOINLINE int32_t\n"
1789  fn += "%s(%s) {\n" % (fn_name.lower() + "__preflight", cpp_args)
1790 
1791  for typ in input_types:
1792  if isinstance(typ, Declaration):
1793  ann = typ.annotations
1794  for key, value in ann:
1795  if key == 'require':
1796  err_msg = '"Constraint `%s` is not satisfied."' % (value[1:-1])
1797 
1798  fn += " if (!(%s)) {\n" % (value[1:-1].replace('\\', ''),)
1799  fn += format_error_msg(err_msg, uses_manager)
1800  fn += " }\n"
1801 
1802  if sizer.is_arg_sizer():
1803  precomputed_nrows = str(sizer.args[0])
1804  if '"' in precomputed_nrows:
1805  precomputed_nrows = precomputed_nrows[1:-1]
1806  # check to see if the precomputed number of rows > 0
1807  err_msg = '"Output size expression `%s` evaluated in a negative value."' % (precomputed_nrows)
1808  fn += " auto _output_size = %s;\n" % (precomputed_nrows)
1809  fn += " if (_output_size < 0) {\n"
1810  fn += format_error_msg(err_msg, uses_manager)
1811  fn += " }\n"
1812  fn += " return _output_size;\n"
1813  else:
1814  fn += " return 0;\n"
1815  fn += "}\n\n"
1816 
1817  return fn
1818 
1819 
1821  if sizer.is_arg_sizer():
1822  return True
1823  for arg_annotations in sig.input_annotations:
1824  d = dict(arg_annotations)
1825  if 'require' in d.keys():
1826  return True
1827  return False
1828 
1829 
1830 def format_annotations(annotations_):
1831  def fmt(k, v):
1832  # type(v) is not always 'str'
1833  if k == 'require':
1834  return v[1:-1]
1835  return v
1836 
1837  s = "std::vector<std::map<std::string, std::string>>{"
1838  s += ', '.join(('{' + ', '.join('{"%s", "%s"}' % (k, fmt(k, v)) for k, v in a) + '}') for a in annotations_)
1839  s += "}"
1840  return s
1841 
1842 
1844  i = sig.name.rfind('_template')
1845  return i >= 0 and '__' in sig.name[:i + 1]
1846 
1847 
1848 def uses_manager(sig):
1849  return sig.inputs and sig.inputs[0].name == 'TableFunctionManager'
1850 
1851 
1853  # Any function that does not have _gpu_ suffix is a cpu function.
1854  i = sig.name.rfind('_gpu_')
1855  if i >= 0 and '__' in sig.name[:i + 1]:
1856  if uses_manager(sig):
1857  raise ValueError('Table function {} with gpu execution target cannot have TableFunctionManager argument'.format(sig.name))
1858  return False
1859  return True
1860 
1861 
1863  # A function with TableFunctionManager argument is a cpu-only function
1864  if uses_manager(sig):
1865  return False
1866  # Any function that does not have _cpu_ suffix is a gpu function.
1867  i = sig.name.rfind('_cpu_')
1868  return not (i >= 0 and '__' in sig.name[:i + 1])
1869 
1870 
1871 def parse_annotations(input_files):
1872 
1873  counter = 0
1874 
1875  add_stmts = []
1876  cpu_template_functions = []
1877  gpu_template_functions = []
1878  cpu_function_address_expressions = []
1879  gpu_function_address_expressions = []
1880  cond_fns = []
1881 
1882  for input_file in input_files:
1883  for sig in find_signatures(input_file):
1884 
1885  # Compute sql_types, input_types, and sizer
1886  sql_types_ = []
1887  input_types_ = []
1888  input_annotations = []
1889 
1890  sizer = None
1891  if sig.sizer is not None:
1892  expr = sig.sizer.value
1893  sizer = Bracket('kPreFlightParameter', (expr,))
1894 
1895  uses_manager = False
1896  for i, (t, annot) in enumerate(zip(sig.inputs, sig.input_annotations)):
1897  if t.is_output_buffer_sizer():
1898  if t.is_user_specified():
1899  sql_types_.append(Bracket.parse('int32').normalize(kind='input'))
1900  input_types_.append(sql_types_[-1])
1901  input_annotations.append(annot)
1902  assert sizer is None # exactly one sizer argument is allowed
1903  assert len(t.args) == 1, t
1904  sizer = t
1905  elif t.name == 'Cursor':
1906  for t_ in t.args:
1907  input_types_.append(t_)
1908  input_annotations.append(annot)
1909  sql_types_.append(Bracket('Cursor', args=()))
1910  elif t.name == 'TableFunctionManager':
1911  if i != 0:
1912  raise ValueError('{} must appear as a first argument of {}, but found it at position {}.'.format(t, sig.name, i))
1913  uses_manager = True
1914  else:
1915  input_types_.append(t)
1916  input_annotations.append(annot)
1917  if t.is_column_any():
1918  # XXX: let Bracket handle mapping of column to cursor(column)
1919  sql_types_.append(Bracket('Cursor', args=()))
1920  else:
1921  sql_types_.append(t)
1922 
1923  if sizer is None:
1924  name = 'kTableFunctionSpecifiedParameter'
1925  idx = 1 # this sizer is not actually materialized in the UDTF
1926  sizer = Bracket(name, (idx,))
1927 
1928  assert sizer is not None
1929  ns_output_types = tuple([a.apply_namespace(ns='ExtArgumentType') for a in sig.outputs])
1930  ns_input_types = tuple([t.apply_namespace(ns='ExtArgumentType') for t in input_types_])
1931  ns_sql_types = tuple([t.apply_namespace(ns='ExtArgumentType') for t in sql_types_])
1932 
1933  sig.function_annotations.append(('uses_manager', str(uses_manager).lower()))
1934 
1935  input_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(tostring, ns_input_types)))
1936  output_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(tostring, ns_output_types)))
1937  sql_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(map(tostring, ns_sql_types)))
1938  annotations = format_annotations(input_annotations + sig.output_annotations + [sig.function_annotations])
1939 
1940  # Notice that input_types and sig.input_types, (and
1941  # similarly, input_annotations and sig.input_annotations)
1942  # have different lengths when the sizer argument is
1943  # Constant or TableFunctionSpecifiedParameter. That is,
1944  # input_types contains all the user-specified arguments
1945  # while sig.input_types contains all arguments of the
1946  # implementation of an UDTF.
1947 
1948  if must_emit_preflight_function(sig, sizer):
1949  fn_name = '%s_%s' % (sig.name, str(counter)) if is_template_function(sig) else sig.name
1950  check_fn = build_preflight_function(fn_name, sizer, input_types_, sig.outputs, uses_manager)
1951  cond_fns.append(check_fn)
1952 
1953  if is_template_function(sig):
1954  name = sig.name + '_' + str(counter)
1955  counter += 1
1956  t = build_template_function_call(name, sig.name, input_types_, sig.outputs, uses_manager)
1957  address_expression = ('avoid_opt_address(reinterpret_cast<void*>(%s))' % name)
1958  if is_cpu_function(sig):
1959  cpu_template_functions.append(t)
1960  cpu_function_address_expressions.append(address_expression)
1961  if is_gpu_function(sig):
1962  gpu_template_functions.append(t)
1963  gpu_function_address_expressions.append(address_expression)
1964  add = ('TableFunctionsFactory::add("%s", %s, %s, %s, %s, %s, /*is_runtime:*/false);'
1965  % (name, sizer.format_sizer(), input_types, output_types, sql_types, annotations))
1966  add_stmts.append(add)
1967 
1968  else:
1969  add = ('TableFunctionsFactory::add("%s", %s, %s, %s, %s, %s, /*is_runtime:*/false);'
1970  % (sig.name, sizer.format_sizer(), input_types, output_types, sql_types, annotations))
1971  add_stmts.append(add)
1972  address_expression = ('avoid_opt_address(reinterpret_cast<void*>(%s))' % sig.name)
1973 
1974  if is_cpu_function(sig):
1975  cpu_function_address_expressions.append(address_expression)
1976  if is_gpu_function(sig):
1977  gpu_function_address_expressions.append(address_expression)
1978 
1979  return add_stmts, cpu_template_functions, gpu_template_functions, cpu_function_address_expressions, gpu_function_address_expressions, cond_fns
1980 
1981 
1982 if len(sys.argv) < 3:
1983 
1984  input_files = [os.path.join(os.path.dirname(__file__), 'test_udtf_signatures.hpp')]
1985  print('Running tests from %s' % (', '.join(input_files)))
1986  add_stmts, _, _, _, _, _ = parse_annotations(input_files)
1987 
1988  print('Usage:\n %s %s input1.hpp input2.hpp ... output.hpp' % (sys.executable, sys.argv[0], ))
1989 
1990  sys.exit(1)
1991 
1992 input_files, output_filename = sys.argv[1:-1], sys.argv[-1]
1993 cpu_output_header = os.path.splitext(output_filename)[0] + '_cpu.hpp'
1994 gpu_output_header = os.path.splitext(output_filename)[0] + '_gpu.hpp'
1995 assert input_files, sys.argv
1996 
1997 add_stmts, cpu_template_functions, gpu_template_functions, cpu_address_expressions, gpu_address_expressions, cond_fns = parse_annotations(sys.argv[1:-1])
1998 
1999 canonical_input_files = [input_file[input_file.find("/QueryEngine/") + 1:] for input_file in input_files]
2000 header_includes = ['#include "' + canonical_input_file + '"' for canonical_input_file in canonical_input_files]
2001 
2002 # Split up calls to TableFunctionsFactory::add() into chunks
2003 ADD_FUNC_CHUNK_SIZE = 100
2004 
2005 def add_method(i, chunk):
2006  return '''
2007  NO_OPT_ATTRIBUTE void add_table_functions_%d() const {
2008  %s
2009  }
2010 ''' % (i, '\n '.join(chunk))
2011 
2012 def add_methods(add_stmts):
2013  chunks = [ add_stmts[n:n+ADD_FUNC_CHUNK_SIZE] for n in range(0, len(add_stmts), ADD_FUNC_CHUNK_SIZE) ]
2014  return [ add_method(i,chunk) for i,chunk in enumerate(chunks) ]
2015 
2016 def call_methods(add_stmts):
2017  quot, rem = divmod(len(add_stmts), ADD_FUNC_CHUNK_SIZE)
2018  return [ 'add_table_functions_%d();' % (i) for i in range(quot + int(0 < rem)) ]
2019 
2020 content = '''
2021 /*
2022  This file is generated by %s. Do no edit!
2023 */
2024 
2025 #include "QueryEngine/TableFunctions/TableFunctionsFactory.h"
2026 %s
2027 
2028 /*
2029  Include the UDTF template initiations:
2030 */
2031 #include "TableFunctionsFactory_init_cpu.hpp"
2032 
2033 // volatile+noinline prevents compiler optimization
2034 #ifdef _WIN32
2035 __declspec(noinline)
2036 #else
2037  __attribute__((noinline))
2038 #endif
2039 volatile
2040 bool avoid_opt_address(void *address) {
2041  return address != nullptr;
2042 }
2043 
2044 bool functions_exist() {
2045  bool ret = true;
2046 
2047  ret &= (%s);
2048 
2049  return ret;
2050 }
2051 
2052 extern bool g_enable_table_functions;
2053 
2054 namespace table_functions {
2055 
2056 std::once_flag init_flag;
2057 
2058 #if defined(__clang__)
2059 #define NO_OPT_ATTRIBUTE __attribute__((optnone))
2060 
2061 #elif defined(__GNUC__) || defined(__GNUG__)
2062 #define NO_OPT_ATTRIBUTE __attribute((optimize("O0")))
2063 
2064 #elif defined(_MSC_VER)
2065 #define NO_OPT_ATTRIBUTE
2066 
2067 #endif
2068 
2069 #if defined(_MSC_VER)
2070 #pragma optimize("", off)
2071 #endif
2072 
2073 struct AddTableFunctions {
2074 %s
2075  NO_OPT_ATTRIBUTE void operator()() {
2076  %s
2077  }
2078 };
2079 
2080 void TableFunctionsFactory::init() {
2081  if (!g_enable_table_functions) {
2082  return;
2083  }
2084 
2085  if (!functions_exist()) {
2086  UNREACHABLE();
2087  return;
2088  }
2089 
2090  std::call_once(init_flag, AddTableFunctions{});
2091 }
2092 #if defined(_MSC_VER)
2093 #pragma optimize("", on)
2094 #endif
2095 
2096 // conditional check functions
2097 %s
2098 
2099 } // namespace table_functions
2100 
2101 ''' % (sys.argv[0],
2102  '\n'.join(header_includes),
2103  ' &&\n'.join(cpu_address_expressions),
2104  ''.join(add_methods(add_stmts)),
2105  '\n '.join(call_methods(add_stmts)),
2106  ''.join(cond_fns))
2107 
2108 header_content = '''
2109 /*
2110  This file is generated by %s. Do no edit!
2111 */
2112 %s
2113 
2114 %s
2115 '''
2116 
2117 dirname = os.path.dirname(output_filename)
2118 
2119 if dirname and not os.path.exists(dirname):
2120  try:
2121  os.makedirs(dirname)
2122  except OSError as e:
2123  import errno
2124  if e.errno != errno.EEXIST:
2125  raise
2126 
2127 f = open(output_filename, 'w')
2128 f.write(content)
2129 f.close()
2130 
2131 f = open(cpu_output_header, 'w')
2132 f.write(header_content % (sys.argv[0], '\n'.join(header_includes), '\n'.join(cpu_template_functions)))
2133 f.close()
2134 
2135 f = open(gpu_output_header, 'w')
2136 f.write(header_content % (sys.argv[0], '\n'.join(header_includes), '\n'.join(gpu_template_functions)))
2137 f.close()
std::string strip(std::string_view str)
trim any whitespace from the left and right ends of a string
std::string join(T const &container, std::string const &delim)
std::vector< std::string > split(std::string_view str, std::string_view delim, std::optional< size_t > maxsplit)
split apart a string into a vector of substrings
int open(const char *path, int flags, int mode)
Definition: heavyai_fs.cpp:66