1 """Given a list of input files, scan for lines containing UDTF
2 specification statements in the following form:
4 UDTF: function_name(<arguments>) -> <output column types> (, <template type specifications>)?
6 where <arguments> is a comma-separated list of argument types. The
7 argument types specifications are:
10 Int8, Int16, Int32, Int64, Float, Double, Bool, TextEncodingDict, etc
12 ColumnInt8, ColumnInt16, ColumnInt32, ColumnInt64, ColumnFloat, ColumnDouble, ColumnBool, etc
14 ColumnListInt8, ColumnListInt16, ColumnListInt32, ColumnListInt64, ColumnListFloat, ColumnListDouble, ColumnListBool, etc
17 where t0, t1 are column or column list types
18 - output buffer size parameter type:
19 RowMultiplier<i>, ConstantParameter<i>, Constant<i>, TableFunctionSpecifiedParameter<i>
20 where i is a literal integer.
22 The output column types is a comma-separated list of column types, see above.
24 In addition, the following equivalents are suppored:
27 ColumnList<T> == ColumnListT
28 Cursor<T, V, ...> == Cursor<ColumnT, ColumnV, ...>
29 int8 == int8_t == Int8, etc
30 float == Float, double == Double, bool == Bool
31 T == ColumnT for output column types
32 RowMultiplier == RowMultiplier<i> where i is the one-based position of the sizer argument
33 when no sizer argument is provided, Constant<1> is assumed
35 Argument types can be annotated using `|' (bar) symbol after an
36 argument type specification. An annotation is specified by a label and
37 a value separated by `=' (equal) symbol. Multiple annotations can be
38 specified by using `|` (bar) symbol as the annotations separator.
39 Supported annotation labels are:
41 - name: to specify argument name
42 - input_id: to specify the dict id mapping for output TextEncodingDict columns.
44 If argument type follows an identifier, it will be mapped to name
45 annotations. For example, the following argument type specifications
51 Template type specifications is a comma separated list of template
52 type assignments where values are lists of argument type names. For
55 T = [Int8, Int16, Int32, Float], V = [Float, Double]
66 from abc
import abstractmethod
68 from collections
import deque, namedtuple
70 if sys.version_info > (3, 0):
72 from collections.abc
import Iterable
74 from abc
import ABCMeta
as ABC
75 from collections
import Iterable
80 Signature = namedtuple(
'Signature', [
'name',
'inputs',
'outputs',
'input_annotations',
'output_annotations',
'function_annotations',
'sizer'])
82 OutputBufferSizeTypes =
'''
83 kConstant, kUserSpecifiedConstantParameter, kUserSpecifiedRowMultiplier, kTableFunctionSpecifiedParameter, kPreFlightParameter
86 SupportedAnnotations =
'''
87 input_id, name, fields, require, range
91 SupportedFunctionAnnotations =
'''
92 filter_table_function_transpose, uses_manager
97 PreFlight=
'kPreFlightParameter',
98 ConstantParameter=
'kUserSpecifiedConstantParameter',
99 RowMultiplier=
'kUserSpecifiedRowMultiplier',
100 UserSpecifiedConstantParameter=
'kUserSpecifiedConstantParameter',
101 UserSpecifiedRowMultiplier=
'kUserSpecifiedRowMultiplier',
102 TableFunctionSpecifiedParameter=
'kTableFunctionSpecifiedParameter',
107 for t
in [
'Int8',
'Int16',
'Int32',
'Int64',
'Float',
'Double',
'Bool',
108 'TextEncodingDict',
'TextEncodingNone']:
109 translate_map[t.lower()] = t
110 if t.startswith(
'Int'):
111 translate_map[t.lower() +
'_t'] = t
115 """Holds a `TYPE | ANNOTATIONS`-like structure.
123 return self.type.name
127 return self.type.args
130 return self.type.format_sizer()
137 return str(self.
type)
141 return self.type.tostring()
144 return self.__class__(self.type.apply_column(), self.
annotations)
147 return self.__class__(self.type.apply_namespace(ns), self.
annotations)
150 return self.type.get_cpp_type()
153 real_arg_name = dict(self.
annotations).get(
'name',
None)
154 return self.type.format_cpp_type(idx,
155 use_generic_arg_name=use_generic_arg_name,
156 real_arg_name=real_arg_name,
160 if name.startswith(
'is_'):
161 return getattr(self.
type, name)
162 raise AttributeError(name)
166 return obj.tostring()
170 """Holds a `NAME<ARGS>`-like structure.
174 assert isinstance(name, str)
175 assert isinstance(args, tuple)
or args
is None, args
180 return 'Bracket(%r, args=%r)' % (self.
name, self.
args)
185 return '%s<%s>' % (self.
name,
', '.
join(map(str, self.
args)))
190 return '%s<%s>' % (self.
name,
', '.
join(map(tostring, self.
args)))
193 """Normalize bracket for given kind
195 assert kind
in [
'input',
'output'], kind
199 if self.
name ==
'Cursor':
200 args = [(a
if a.is_column_any()
else Bracket(
'Column', args=(a,))).
normalize(kind=kind)
for a
in self.
args]
208 """Apply cursor to a non-cursor column argument type.
209 TODO: this method is currently unused but we should apply
210 cursor to all input column arguments in order to distingush
212 foo(Cursor(Column<int32>, Column<float>)) -> Column<int32>
213 foo(Cursor(Column<int32>), Cursor(Column<float>)) -> Column<int32>
214 that at the moment are treated as the same :(
217 return Bracket(
'Cursor', args=(self,))
226 if self.
name ==
'Cursor':
227 return Bracket(ns +
'::' + self.
name, args=tuple(a.apply_namespace(ns=ns)
for a
in self.
args))
228 if not self.name.startswith(ns +
'::'):
233 return self.name.rsplit(
"::", 1)[-1] ==
'Cursor'
236 return self.name.rsplit(
"::", 1)[-1].startswith(
'Array')
239 return self.name.rsplit(
"::", 1)[-1].startswith(
'Column')
242 return self.name.rsplit(
"::", 1)[-1].startswith(
'ColumnList')
245 return self.name.rsplit(
"::", 1)[-1].startswith(
'Column')
and not self.
is_column_list()
248 return self.name.rsplit(
"::", 1)[-1].endswith(
'TextEncodingDict')
251 return self.name.rsplit(
"::", 1)[-1] ==
'ArrayTextEncodingDict'
254 return self.name.rsplit(
"::", 1)[-1] ==
'ColumnTextEncodingDict'
257 return self.name.rsplit(
"::", 1)[-1] ==
'ColumnArrayTextEncodingDict'
260 return self.name.rsplit(
"::", 1)[-1] ==
'ColumnListTextEncodingDict'
263 return self.name.rsplit(
"::", 1)[-1]
in OutputBufferSizeTypes
266 return self.name.rsplit(
"::", 1)[-1] ==
'kUserSpecifiedRowMultiplier'
269 return self.name.rsplit(
"::", 1)[-1] ==
'kPreFlightParameter'
274 return self.name.rsplit(
"::", 1)[-1]
not in (
'kConstant',
'kTableFunctionSpecifiedParameter',
'kPreFlightParameter')
279 return 'TableFunctionOutputRowSizer{OutputBufferSizeType::%s, %s}' % (self.
name, val)
282 name = self.name.rsplit(
"::", 1)[-1]
285 if name.startswith(
'ColumnList'):
286 name = name.lstrip(
'ColumnList')
287 clsname =
'ColumnList'
288 elif name.startswith(
'Column'):
289 name = name.lstrip(
'Column')
291 if name.startswith(
'Array'):
292 name = name.lstrip(
'Array')
298 if name.startswith(
'Bool'):
300 elif name.startswith(
'Int'):
301 ctype = name.lower() +
'_t'
302 elif name
in [
'Double',
'Float']:
304 elif name ==
'TextEncodingDict':
306 elif name ==
'TextEncodingNone':
308 elif name ==
'Timestamp':
310 elif name ==
'DayTimeInterval':
312 elif name ==
'YearMonthTimeInterval':
315 raise NotImplementedError(self)
318 if subclsname
is None:
319 return '%s<%s>' % (clsname, ctype)
320 return '%s<%s<%s>>' % (clsname, subclsname, ctype)
322 def format_cpp_type(self, idx, use_generic_arg_name=False, real_arg_name=None, is_input=True):
323 col_typs = (
'Column',
'ColumnList')
324 literal_ref_typs = (
'TextEncodingNone',)
325 if use_generic_arg_name:
326 arg_name =
'input' + str(idx)
if is_input
else 'output' + str(idx)
327 elif real_arg_name
is not None:
328 arg_name = real_arg_name
331 arg_name =
'input' + str(idx)
if is_input
else 'output' + str(idx)
332 const =
'const ' if is_input
else ''
334 if any(cpp_type.startswith(t)
for t
in col_typs + literal_ref_typs):
335 return '%s%s& %s' % (const, cpp_type, arg_name), arg_name
337 return '%s %s' % (cpp_type, arg_name), arg_name
341 """typ is a string in format NAME<ARGS> or NAME
343 Returns Bracket instance.
350 assert typ.endswith(
'>'), typ
351 name = typ[:i].
strip()
353 rest = typ[i + 1:-1].
strip()
359 a, rest = rest[:i].rstrip(), rest[i + 1:].lstrip()
360 args.append(cls.parse(a))
363 name = translate_map.get(name, name)
364 return cls(name, args)
369 for i, c
in enumerate(line):
374 elif d == 0
and c ==
',':
382 return line.endswith(
',')
or line.endswith(
'->')
or line.endswith(separator)
or line.endswith(
'|')
386 return identifier.lower() ==
'cursor'
396 class ParserException(Exception):
426 One of the tokens in the list above
428 Corresponding string in the text
437 Token.GREATER:
"GREATER",
438 Token.COMMA:
"COMMA",
439 Token.EQUAL:
"EQUAL",
440 Token.RARROW:
"RARROW",
441 Token.STRING:
"STRING",
442 Token.NUMBER:
"NUMBER",
449 Token.IDENTIFIER:
"IDENTIFIER",
450 Token.COLON:
"COLON",
452 return names.get(token)
455 return 'Token(%s, "%s")' % (Token.tok_name(self.
type), self.
lexeme)
501 self._tokens.append(
Token(type, lexeme))
516 return char
in (
"-",)
565 if char ==
'"' and curr !=
'\\':
578 if char
and char.isdigit():
587 IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
591 if char
and char.isalnum()
or char ==
"_":
599 return self.
peek().isalpha()
or self.
peek() ==
"_"
602 return self.
peek() ==
'"'
605 return self.
peek().isdigit()
608 return self.
peek().isalpha()
611 return self.
peek().isspace()
617 'Could not match char "%s" at pos %d on line\n %s' % (char, curr, self.
line)
649 class AstTransformer(AstVisitor):
650 """Only overload the methods you need"""
653 udtf = copy.copy(udtf_node)
654 udtf.inputs = [arg.accept(self)
for arg
in udtf.inputs]
655 udtf.outputs = [arg.accept(self)
for arg
in udtf.outputs]
657 udtf.templates = [t.accept(self)
for t
in udtf.templates]
658 udtf.annotations = [annot.accept(self)
for annot
in udtf.annotations]
662 c = copy.copy(composed_node)
663 c.inner = [i.accept(self)
for i
in c.inner]
667 arg_node = copy.copy(arg_node)
668 arg_node.type = arg_node.type.accept(self)
669 if arg_node.annotations:
670 arg_node.annotations = [a.accept(self)
for a
in arg_node.annotations]
674 return copy.copy(primitive_node)
677 return copy.copy(template_node)
680 return copy.copy(annotation_node)
684 """Returns a line formatted. Useful for testing"""
687 name = udtf_node.name
688 inputs =
", ".
join([arg.accept(self)
for arg
in udtf_node.inputs])
689 outputs =
", ".
join([arg.accept(self)
for arg
in udtf_node.outputs])
690 annotations =
"| ".
join([annot.accept(self)
for annot
in udtf_node.annotations])
691 sizer =
" | " + udtf_node.sizer.accept(self)
if udtf_node.sizer
else ""
693 annotations =
' | ' + annotations
694 if udtf_node.templates:
695 templates =
", ".
join([t.accept(self)
for t
in udtf_node.templates])
696 return "%s(%s)%s -> %s, %s%s" % (name, inputs, annotations, outputs, templates, sizer)
698 return "%s(%s)%s -> %s%s" % (name, inputs, annotations, outputs, sizer)
702 key = template_node.key
703 types = [
'"%s"' % typ
for typ
in template_node.types]
704 return "%s=[%s]" % (key,
", ".
join(types))
708 key = annotation_node.key
709 value = annotation_node.value
710 if isinstance(value, list):
711 return "%s=[%s]" % (key,
','.
join([v.accept(self)
for v
in value]))
712 return "%s=%s" % (key, value)
716 typ = arg_node.type.accept(self)
717 if arg_node.annotations:
718 ann =
" | ".
join([a.accept(self)
for a
in arg_node.annotations])
719 s =
"%s | %s" % (typ, ann)
723 if s ==
"ColumnTextEncodingDict" and arg_node.kind ==
"output":
724 return s +
" | input_id=args<0>"
728 T = composed_node.inner[0].accept(self)
729 if composed_node.is_array():
731 assert len(composed_node.inner) == 1
733 if composed_node.is_column():
735 assert len(composed_node.inner) == 1
737 if composed_node.is_column_list():
739 assert len(composed_node.inner) == 1
740 return "ColumnList" + T
741 if composed_node.is_output_buffer_sizer():
744 assert len(composed_node.inner) == 1
745 return translate_map.get(composed_node.type) +
"<%s>" % (N,)
746 if composed_node.is_cursor():
748 Ts =
", ".
join([i.accept(self)
for i
in composed_node.inner])
749 return "Cursor<%s>" % (Ts)
750 raise ValueError(composed_node)
753 t = primitive_node.type
754 if primitive_node.is_output_buffer_sizer():
756 return translate_map.get(t, t) +
"<%d>" % (
757 primitive_node.get_parent(ArgNode).arg_pos + 1,
759 return translate_map.get(t, t)
763 """Like AstPrinter but returns a node instead of a string
769 vals = kwargs.values()
770 for instance
in itertools.product(*vals):
771 yield dict(zip(keys, instance))
775 """Expand template definition into multiple inputs"""
778 if not udtf_node.templates:
783 d = dict([(node.key, node.types)
for node
in udtf_node.templates])
784 name = udtf_node.name
788 inputs = [input_arg.accept(self)
for input_arg
in udtf_node.inputs]
789 outputs = [output_arg.accept(self)
for output_arg
in udtf_node.outputs]
790 udtf =
UdtfNode(name, inputs, outputs, udtf_node.annotations,
None, udtf_node.sizer, udtf_node.line)
791 udtfs[str(udtf)] = udtf
794 udtfs = list(udtfs.values())
802 typ = composed_node.type
803 typ = self.mapping_dict.get(typ, typ)
805 inner = [i.accept(self)
for i
in composed_node.inner]
806 return composed_node.copy(typ, inner)
809 typ = primitive_node.type
810 typ = self.mapping_dict.get(typ, typ)
811 return primitive_node.copy(typ)
817 * Fix kUserSpecifiedRowMultiplier without a pos arg
819 t = primitive_node.type
821 if primitive_node.is_output_buffer_sizer():
822 pos =
PrimitiveNode(str(primitive_node.get_parent(ArgNode).arg_pos + 1))
826 return primitive_node
832 * Rename nodes using translate_map as dictionary
836 t = primitive_node.type
837 return primitive_node.copy(translate_map.get(t, t))
843 * Add default_input_id to Column(List)<TextEncodingDict> without one
847 default_input_id =
None
848 for idx, t
in enumerate(udtf_node.inputs):
850 if not isinstance(t.type, ComposedNode):
852 if default_input_id
is not None:
854 elif t.type.is_column_text_encoding_dict()
or t.type.is_column_array_text_encoding_dict():
855 default_input_id =
AnnotationNode(
'input_id',
'args<%s>' % (idx,))
856 elif t.type.is_column_list_text_encoding_dict():
857 default_input_id =
AnnotationNode(
'input_id',
'args<%s, 0>' % (idx,))
859 for t
in udtf_node.outputs:
860 if isinstance(t.type, ComposedNode)
and t.type.is_any_text_encoding_dict():
861 for a
in t.annotations:
862 if a.key ==
'input_id':
865 if default_input_id
is None:
866 raise TypeError(
'Cannot parse line "%s".\n'
867 'Missing TextEncodingDict input?' %
869 t.annotations.append(default_input_id)
878 * Generate fields annotation to Cursor if non-existing
882 for t
in udtf_node.inputs:
884 if not isinstance(t.type, ComposedNode):
887 if t.type.is_cursor()
and t.get_annotation(
'fields')
is None:
888 fields = list(
PrimitiveNode(a.get_annotation(
'name',
'field%s' % i))
for i, a
in enumerate(t.type.inner))
895 * Checks for supported annotations in a UDTF
898 for t
in udtf_node.inputs:
899 for a
in t.annotations:
900 if a.key
not in SupportedAnnotations:
902 for t
in udtf_node.outputs:
903 for a
in t.annotations:
904 if a.key
not in SupportedAnnotations:
906 for annot
in udtf_node.annotations:
907 if annot.key
not in SupportedFunctionAnnotations:
909 if annot.value.lower()
in [
'enable',
'on',
'1',
'true']:
911 elif annot.value.lower()
in [
'disable',
'off',
'0',
'false']:
918 * Append require annotation if range is used
921 for ann
in arg_node.annotations:
922 if ann.key ==
'range':
923 name = arg_node.get_annotation(
'name')
930 value =
'"{lo} <= {name} && {name} <= {hi}"'.format(lo=lo, hi=hi, name=name)
933 arg_node.set_annotation(
'require', value)
940 name = udtf_node.name
942 input_annotations = []
944 output_annotations = []
945 function_annotations = []
946 sizer = udtf_node.sizer
948 for i
in udtf_node.inputs:
949 decl = i.accept(self)
951 input_annotations.append(decl.annotations)
953 for o
in udtf_node.outputs:
954 decl = o.accept(self)
955 outputs.append(decl.type)
956 output_annotations.append(decl.annotations)
958 for annot
in udtf_node.annotations:
959 annot = annot.accept(self)
960 function_annotations.append(annot)
962 return Signature(name, inputs, outputs, input_annotations, output_annotations, function_annotations, sizer)
965 t = arg_node.type.accept(self)
966 anns = [a.accept(self)
for a
in arg_node.annotations]
970 typ = translate_map.get(composed_node.type, composed_node.type)
971 inner = [i.accept(self)
for i
in composed_node.inner]
972 if composed_node.is_cursor():
973 inner = list(map(
lambda x: x.apply_column(), inner))
974 return Bracket(typ, args=tuple(inner))
975 elif composed_node.is_output_buffer_sizer():
976 return Bracket(typ, args=tuple(inner))
978 return Bracket(typ + str(inner[0]))
981 t = primitive_node.type
985 key = annotation_node.key
986 value = annotation_node.value
1003 if isinstance(self, cls):
1006 if self.parent
is not None:
1007 return self.parent.get_parent(cls)
1009 raise ValueError(
"could not find parent with given class %s" % (cls))
1012 other = self.__class__(*args)
1015 for attr
in [
'parent',
'arg_pos']:
1016 if attr
in self.__dict__:
1017 setattr(other, attr, getattr(self, attr))
1026 class UdtfNode(Node, IterableNode):
1028 def __init__(self, name, inputs, outputs, annotations, templates, sizer, line):
1033 inputs : list[ArgNode]
1034 outputs : list[ArgNode]
1035 annotations : Optional[List[AnnotationNode]]
1036 templates : Optional[list[TemplateNode]]
1037 sizer : Optional[str]
1049 return visitor.visit_udtf_node(self)
1053 inputs = [str(i)
for i
in self.
inputs]
1054 outputs = [str(o)
for o
in self.
outputs]
1056 sizer =
"| %s" % str(self.
sizer)
if self.
sizer else ""
1058 templates = [str(t)
for t
in self.
templates]
1060 return "UDTF: %s (%s) | %s -> %s, %s %s" % (name, inputs, annotations, outputs, templates, sizer)
1062 return "UDTF: %s (%s) -> %s, %s %s" % (name, inputs, outputs, templates, sizer)
1065 return "UDTF: %s (%s) | %s -> %s %s" % (name, inputs, annotations, outputs, sizer)
1067 return "UDTF: %s (%s) -> %s %s" % (name, inputs, outputs, sizer)
1090 annotations : List[AnnotationNode]
1097 return visitor.visit_arg_node(self)
1104 return "ArgNode(%s | %s)" % (t, anns)
1105 return "ArgNode(%s)" % (t)
1124 assert not found, (i, a)
1139 return self.
type ==
"Column"
1142 return self.
type ==
"ColumnList"
1145 return self.
type ==
"Cursor"
1149 return translate_map.get(t, t)
in OutputBufferSizeTypes
1163 return visitor.visit_primitive_node(self)
1169 return self.
type ==
'TextEncodingDict'
1172 return self.
type ==
'ArrayTextEncodingDict'
1184 inner : list[TypeNode]
1190 return visitor.visit_composed_node(self)
1194 return len(self.
inner)
1197 i =
", ".
join([str(i)
for i
in self.
inner])
1198 return "Composed(%s<%s>)" % (self.
type, i)
1201 for i
in self.
inner:
1238 return visitor.visit_annotation_node(self)
1242 return self.
accept(printer)
1260 return visitor.visit_template_node(self)
1264 return self.
accept(printer)
1274 if not isinstance(ast_list, list):
1275 ast_list = [ast_list]
1278 ast_list = [ast.accept(c())
for ast
in ast_list]
1279 ast_list = itertools.chain.from_iterable(
1280 map(
lambda x: x
if isinstance(x, list)
else [x], ast_list))
1282 return list(ast_list)
1306 msg =
"Expected token %s but got %s at pos %d.\n Tokens: %s" % (
1308 Token.tok_name(expected_type),
1312 assert curr_token.type == expected_type, msg
1316 """consumes the current token iff its type matches the
1317 expected_type. Otherwise, an error is raised
1320 if curr_token.type == expected_type:
1324 expected_token = Token.tok_name(expected_type)
1326 'Token mismatch at function consume. '
1327 'Expected type "%s" but got token "%s"\n\n'
1328 'Tokens: %s\n' % (expected_token, curr_token, self.
_tokens)
1339 msg =
"\n\nError while trying to parse token %s at pos %d.\n" "Tokens: %s" % (
1348 return curr_token.type == expected_type
1356 udtf: IDENTIFIER "(" (args)? ")" ("|" annotation)* "->" args ("," templates)? ("|" "output_row_size" "=" primitive)?
1363 if not self.
match(Token.RPAR):
1370 self.
expect(Token.RARROW)
1382 assert idtn ==
"output_row_size", idtn
1385 key =
"kPreFlightParameter"
1390 for arg
in input_args:
1393 i += arg.type.cursor_length()
if arg.type.is_cursor()
else 1
1395 for i, arg
in enumerate(output_args):
1399 return UdtfNode(name, input_args, output_args, annotations, templates, sizer, self.
line)
1404 args: arg IDENTIFIER ("," arg)*
1419 self.
_curr = curr + 1
1426 arg: type IDENTIFIER? ("|" annotation)*
1440 if ahead.type == Token.IDENTIFIER
and ahead.lexeme ==
'output_row_size':
1445 return ArgNode(typ, annotations)
1460 if not self.
match(Token.LESS):
1470 composed: "Cursor" "<" arg ("," arg)* ">"
1471 | IDENTIFIER "<" type ("," type)* ">"
1479 while self.
match(Token.COMMA):
1484 while self.
match(Token.COMMA):
1493 primitive: IDENTIFIER
1499 if self.
match(Token.IDENTIFIER):
1501 elif self.
match(Token.NUMBER):
1503 elif self.
match(Token.STRING):
1512 templates: template ("," template)*
1526 template: IDENTIFIER "=" "[" IDENTIFIER ("," IDENTIFIER)* "]"
1535 while self.
match(Token.COMMA):
1544 annotation: IDENTIFIER "=" IDENTIFIER ("<" NUMBER ("," NUMBER) ">")?
1545 | IDENTIFIER "=" "[" PRIMITIVE? ("," PRIMITIVE)* "]"
1546 | "require" "=" STRING
1553 if key ==
"require":
1558 if not self.
match(Token.RSQB):
1560 while self.
match(Token.COMMA):
1568 if self.
match(Token.GREATER):
1569 value +=
"<%s>" % (-1)
1572 if self.
match(Token.COMMA):
1575 value +=
"<%s,%s>" % (num1, num2)
1577 value +=
"<%s>" % (num1)
1584 IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
1588 token = self.
consume(Token.IDENTIFIER)
1598 token = self.
consume(Token.STRING)
1608 token = self.
consume(Token.NUMBER)
1614 udtf: IDENTIFIER "(" (args)? ")" ("|" annotation)* "->" args ("," templates)? ("|" "output_row_size" "=" primitive)?
1616 args: arg ("," arg)*
1618 arg: type IDENTIFIER? ("|" annotation)*
1623 composed: "Cursor" "<" arg ("," arg)* ">"
1624 | IDENTIFIER "<" type ("," type)* ">"
1626 primitive: IDENTIFIER
1630 annotation: IDENTIFIER "=" IDENTIFIER ("<" NUMBER ("," NUMBER) ">")?
1631 | IDENTIFIER "=" "[" PRIMITIVE? ("," PRIMITIVE)* "]"
1632 | "require" "=" STRING
1634 templates: template ("," template)
1635 template: IDENTIFIER "=" "[" IDENTIFIER ("," IDENTIFIER)* "]"
1637 IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
1652 if isinstance(node, Iterable):
1661 """Returns a list of parsed UDTF signatures."""
1665 for line
in open(input_file).readlines():
1667 if last_line
is not None:
1668 line = last_line +
' ' + line
1670 if not line.startswith(
'UDTF:'):
1676 line = line[5:].lstrip()
1679 if i == -1
or j == -1:
1680 sys.stderr.write(
'Invalid UDTF specification: `%s`. Skipping.\n' % (line))
1683 expected_result =
None
1684 if separator
in line:
1685 line, expected_result = line.split(separator, 1)
1686 expected_result = expected_result.strip().
split(separator)
1687 expected_result = list(map(
lambda s: s.strip(), expected_result))
1689 ast =
Parser(line).parse()
1691 if expected_result
is not None:
1694 result =
Pipeline(TemplateTransformer,
1695 FieldAnnotationTransformer,
1696 TextEncodingDictTransformer,
1697 SupportedAnnotationsTransformer,
1698 RangeAnnotationTransformer,
1699 FixRowMultiplierPosArgTransformer,
1700 RenameNodesTransformer,
1702 except TransformerException
as msg:
1703 result = [
'%s: %s' % (
type(msg).__name__, msg)]
1704 assert len(result) == len(expected_result),
"\n\tresult: %s \n!= \n\texpected: %s" % (
1705 '\n\t\t '.
join(result),
1706 '\n\t\t '.
join(expected_result)
1708 assert set(result) == set(expected_result),
"\n\tresult: %s != \n\texpected: %s" % (
1709 '\n\t\t '.
join(result),
1710 '\n\t\t '.
join(expected_result),
1714 signature =
Pipeline(TemplateTransformer,
1715 FieldAnnotationTransformer,
1716 TextEncodingDictTransformer,
1717 SupportedAnnotationsTransformer,
1718 RangeAnnotationTransformer,
1719 FixRowMultiplierPosArgTransformer,
1720 RenameNodesTransformer,
1721 DeclBracketTransformer)(ast)
1723 signatures.extend(signature)
1733 cpp_args.append(
'TableFunctionManager& mgr')
1734 name_args.append(
'mgr')
1736 for idx, typ
in enumerate(input_types):
1737 cpp_arg, name = typ.format_cpp_type(idx,
1738 use_generic_arg_name=use_generic_arg_name,
1740 cpp_args.append(cpp_arg)
1741 name_args.append(name)
1743 if emit_output_args:
1744 for idx, typ
in enumerate(output_types):
1745 cpp_arg, name = typ.format_cpp_type(idx,
1746 use_generic_arg_name=use_generic_arg_name,
1748 cpp_args.append(cpp_arg)
1749 name_args.append(name)
1751 cpp_args =
', '.
join(cpp_args)
1752 name_args =
', '.
join(name_args)
1753 return cpp_args, name_args
1760 use_generic_arg_name=
True,
1761 emit_output_args=
True)
1763 template = (
"EXTENSION_NOINLINE int32_t\n"
1766 "}\n") % (caller, cpp_args, called, name_args)
1772 def format_error_msg(err_msg, uses_manager):
1774 return " return mgr.error_message(%s);\n" % (err_msg,)
1776 return " return table_function_error(%s);\n" % (err_msg,)
1781 use_generic_arg_name=
False,
1782 emit_output_args=
False)
1785 fn =
"EXTENSION_NOINLINE int32_t\n"
1786 fn +=
"%s(%s) {\n" % (fn_name.lower() +
"__preflight", cpp_args)
1788 fn =
"EXTENSION_NOINLINE int32_t\n"
1789 fn +=
"%s(%s) {\n" % (fn_name.lower() +
"__preflight", cpp_args)
1791 for typ
in input_types:
1792 if isinstance(typ, Declaration):
1793 ann = typ.annotations
1794 for key, value
in ann:
1795 if key ==
'require':
1796 err_msg =
'"Constraint `%s` is not satisfied."' % (value[1:-1])
1798 fn +=
" if (!(%s)) {\n" % (value[1:-1].replace(
'\\',
''),)
1799 fn += format_error_msg(err_msg, uses_manager)
1802 if sizer.is_arg_sizer():
1803 precomputed_nrows = str(sizer.args[0])
1804 if '"' in precomputed_nrows:
1805 precomputed_nrows = precomputed_nrows[1:-1]
1807 err_msg =
'"Output size expression `%s` evaluated in a negative value."' % (precomputed_nrows)
1808 fn +=
" auto _output_size = %s;\n" % (precomputed_nrows)
1809 fn +=
" if (_output_size < 0) {\n"
1810 fn += format_error_msg(err_msg, uses_manager)
1812 fn +=
" return _output_size;\n"
1814 fn +=
" return 0;\n"
1821 if sizer.is_arg_sizer():
1823 for arg_annotations
in sig.input_annotations:
1824 d = dict(arg_annotations)
1825 if 'require' in d.keys():
1837 s =
"std::vector<std::map<std::string, std::string>>{"
1838 s +=
', '.
join((
'{' +
', '.
join(
'{"%s", "%s"}' % (k, fmt(k, v))
for k, v
in a) +
'}')
for a
in annotations_)
1844 i = sig.name.rfind(
'_template')
1845 return i >= 0
and '__' in sig.name[:i + 1]
1849 return sig.inputs
and sig.inputs[0].name ==
'TableFunctionManager'
1854 i = sig.name.rfind(
'_gpu_')
1855 if i >= 0
and '__' in sig.name[:i + 1]:
1857 raise ValueError(
'Table function {} with gpu execution target cannot have TableFunctionManager argument'.format(sig.name))
1867 i = sig.name.rfind(
'_cpu_')
1868 return not (i >= 0
and '__' in sig.name[:i + 1])
1876 cpu_template_functions = []
1877 gpu_template_functions = []
1878 cpu_function_address_expressions = []
1879 gpu_function_address_expressions = []
1882 for input_file
in input_files:
1888 input_annotations = []
1891 if sig.sizer
is not None:
1892 expr = sig.sizer.value
1893 sizer =
Bracket(
'kPreFlightParameter', (expr,))
1895 uses_manager =
False
1896 for i, (t, annot)
in enumerate(zip(sig.inputs, sig.input_annotations)):
1897 if t.is_output_buffer_sizer():
1898 if t.is_user_specified():
1899 sql_types_.append(Bracket.parse(
'int32').normalize(kind=
'input'))
1900 input_types_.append(sql_types_[-1])
1901 input_annotations.append(annot)
1902 assert sizer
is None
1903 assert len(t.args) == 1, t
1905 elif t.name ==
'Cursor':
1907 input_types_.append(t_)
1908 input_annotations.append(annot)
1909 sql_types_.append(
Bracket(
'Cursor', args=()))
1910 elif t.name ==
'TableFunctionManager':
1912 raise ValueError(
'{} must appear as a first argument of {}, but found it at position {}.'.format(t, sig.name, i))
1915 input_types_.append(t)
1916 input_annotations.append(annot)
1917 if t.is_column_any():
1919 sql_types_.append(
Bracket(
'Cursor', args=()))
1921 sql_types_.append(t)
1924 name =
'kTableFunctionSpecifiedParameter'
1928 assert sizer
is not None
1929 ns_output_types = tuple([a.apply_namespace(ns=
'ExtArgumentType')
for a
in sig.outputs])
1930 ns_input_types = tuple([t.apply_namespace(ns=
'ExtArgumentType')
for t
in input_types_])
1931 ns_sql_types = tuple([t.apply_namespace(ns=
'ExtArgumentType')
for t
in sql_types_])
1933 sig.function_annotations.append((
'uses_manager', str(uses_manager).lower()))
1935 input_types =
'std::vector<ExtArgumentType>{%s}' % (
', '.
join(map(tostring, ns_input_types)))
1936 output_types =
'std::vector<ExtArgumentType>{%s}' % (
', '.
join(map(tostring, ns_output_types)))
1937 sql_types =
'std::vector<ExtArgumentType>{%s}' % (
', '.
join(map(tostring, ns_sql_types)))
1938 annotations =
format_annotations(input_annotations + sig.output_annotations + [sig.function_annotations])
1951 cond_fns.append(check_fn)
1954 name = sig.name +
'_' + str(counter)
1957 address_expression = (
'avoid_opt_address(reinterpret_cast<void*>(%s))' % name)
1959 cpu_template_functions.append(t)
1960 cpu_function_address_expressions.append(address_expression)
1962 gpu_template_functions.append(t)
1963 gpu_function_address_expressions.append(address_expression)
1964 add = (
'TableFunctionsFactory::add("%s", %s, %s, %s, %s, %s, /*is_runtime:*/false);'
1965 % (name, sizer.format_sizer(), input_types, output_types, sql_types, annotations))
1966 add_stmts.append(add)
1969 add = (
'TableFunctionsFactory::add("%s", %s, %s, %s, %s, %s, /*is_runtime:*/false);'
1970 % (sig.name, sizer.format_sizer(), input_types, output_types, sql_types, annotations))
1971 add_stmts.append(add)
1972 address_expression = (
'avoid_opt_address(reinterpret_cast<void*>(%s))' % sig.name)
1975 cpu_function_address_expressions.append(address_expression)
1977 gpu_function_address_expressions.append(address_expression)
1979 return add_stmts, cpu_template_functions, gpu_template_functions, cpu_function_address_expressions, gpu_function_address_expressions, cond_fns
1982 if len(sys.argv) < 3:
1984 input_files = [os.path.join(os.path.dirname(__file__),
'test_udtf_signatures.hpp')]
1985 print(
'Running tests from %s' % (
', '.
join(input_files)))
1988 print(
'Usage:\n %s %s input1.hpp input2.hpp ... output.hpp' % (sys.executable, sys.argv[0], ))
1992 input_files, output_filename = sys.argv[1:-1], sys.argv[-1]
1993 cpu_output_header = os.path.splitext(output_filename)[0] +
'_cpu.hpp'
1994 gpu_output_header = os.path.splitext(output_filename)[0] +
'_gpu.hpp'
1995 assert input_files, sys.argv
1997 add_stmts, cpu_template_functions, gpu_template_functions, cpu_address_expressions, gpu_address_expressions, cond_fns =
parse_annotations(sys.argv[1:-1])
1999 canonical_input_files = [input_file[input_file.find(
"/QueryEngine/") + 1:]
for input_file
in input_files]
2000 header_includes = [
'#include "' + canonical_input_file +
'"' for canonical_input_file
in canonical_input_files]
2003 ADD_FUNC_CHUNK_SIZE = 100
2007 NO_OPT_ATTRIBUTE void add_table_functions_%d() const {
2010 ''' % (i,
'\n '.
join(chunk))
2013 chunks = [ add_stmts[n:n+ADD_FUNC_CHUNK_SIZE]
for n
in range(0, len(add_stmts), ADD_FUNC_CHUNK_SIZE) ]
2014 return [
add_method(i,chunk)
for i,chunk
in enumerate(chunks) ]
2017 quot, rem = divmod(len(add_stmts), ADD_FUNC_CHUNK_SIZE)
2018 return [
'add_table_functions_%d();' % (i)
for i
in range(quot + int(0 < rem)) ]
2022 This file is generated by %s. Do no edit!
2025 #include "QueryEngine/TableFunctions/TableFunctionsFactory.h"
2029 Include the UDTF template initiations:
2031 #include "TableFunctionsFactory_init_cpu.hpp"
2033 // volatile+noinline prevents compiler optimization
2035 __declspec(noinline)
2037 __attribute__((noinline))
2040 bool avoid_opt_address(void *address) {
2041 return address != nullptr;
2044 bool functions_exist() {
2052 extern bool g_enable_table_functions;
2054 namespace table_functions {
2056 std::once_flag init_flag;
2058 #if defined(__clang__)
2059 #define NO_OPT_ATTRIBUTE __attribute__((optnone))
2061 #elif defined(__GNUC__) || defined(__GNUG__)
2062 #define NO_OPT_ATTRIBUTE __attribute((optimize("O0")))
2064 #elif defined(_MSC_VER)
2065 #define NO_OPT_ATTRIBUTE
2069 #if defined(_MSC_VER)
2070 #pragma optimize("", off)
2073 struct AddTableFunctions {
2075 NO_OPT_ATTRIBUTE void operator()() {
2080 void TableFunctionsFactory::init() {
2081 if (!g_enable_table_functions) {
2085 if (!functions_exist()) {
2090 std::call_once(init_flag, AddTableFunctions{});
2092 #if defined(_MSC_VER)
2093 #pragma optimize("", on)
2096 // conditional check functions
2099 } // namespace table_functions
2102 '\n'.
join(header_includes),
2103 ' &&\n'.
join(cpu_address_expressions),
2108 header_content =
'''
2110 This file is generated by %s. Do no edit!
2117 dirname = os.path.dirname(output_filename)
2119 if dirname
and not os.path.exists(dirname):
2121 os.makedirs(dirname)
2122 except OSError
as e:
2124 if e.errno != errno.EEXIST:
2131 f =
open(cpu_output_header,
'w')
2132 f.write(header_content % (sys.argv[0],
'\n'.
join(header_includes),
'\n'.
join(cpu_template_functions)))
2135 f =
open(gpu_output_header,
'w')
2136 f.write(header_content % (sys.argv[0],
'\n'.
join(header_includes),
'\n'.
join(gpu_template_functions)))
def is_column_array_text_encoding_dict
def build_template_function_call
def is_output_buffer_sizer
def is_array_text_encoding_dict
def is_column_list_text_encoding_dict
def is_text_encoding_dict
def is_text_encoding_dict
def visit_annotation_node
def is_any_text_encoding_dict
def is_column_list_text_encoding_dict
int open(const char *path, int flags, int mode)
def can_token_be_double_char
def build_preflight_function
def is_array_text_encoding_dict
def is_column_text_encoding_dict
def is_array_text_encoding_dict
def is_any_text_encoding_dict
def visit_annotation_node
def is_output_buffer_sizer
def is_column_array_text_encoding_dict
def is_column_text_encoding_dict
def must_emit_preflight_function