1 """Given a list of input files, scan for lines containing UDTF
2 specification statements in the following form:
4 UDTF: function_name(<arguments>) -> <output column types> (, <template type specifications>)?
6 where <arguments> is a comma-separated list of argument types. The
7 argument types specifications are:
10 Int8, Int16, Int32, Int64, Float, Double, Bool, TextEncodingDict, etc
12 ColumnInt8, ColumnInt16, ColumnInt32, ColumnInt64, ColumnFloat, ColumnDouble, ColumnBool, etc
14 ColumnListInt8, ColumnListInt16, ColumnListInt32, ColumnListInt64, ColumnListFloat, ColumnListDouble, ColumnListBool, etc
17 where t0, t1 are column or column list types
18 - output buffer size parameter type:
19 RowMultiplier<i>, ConstantParameter<i>, Constant<i>, TableFunctionSpecifiedParameter<i>
20 where i is a literal integer.
22 The output column types is a comma-separated list of column types, see above.
24 In addition, the following equivalents are suppored:
27 ColumnList<T> == ColumnListT
28 Cursor<T, V, ...> == Cursor<ColumnT, ColumnV, ...>
29 int8 == int8_t == Int8, etc
30 float == Float, double == Double, bool == Bool
31 T == ColumnT for output column types
32 RowMultiplier == RowMultiplier<i> where i is the one-based position of the sizer argument
33 when no sizer argument is provided, Constant<1> is assumed
35 Argument types can be annotated using `|' (bar) symbol after an
36 argument type specification. An annotation is specified by a label and
37 a value separated by `=' (equal) symbol. Multiple annotations can be
38 specified by using `|` (bar) symbol as the annotations separator.
39 Supported annotation labels are:
41 - name: to specify argument name
42 - input_id: to specify the dict id mapping for output TextEncodingDict columns.
44 If argument type follows an identifier, it will be mapped to name
45 annotations. For example, the following argument type specifications
51 Template type specifications is a comma separated list of template
52 type assignments where values are lists of argument type names. For
55 T = [Int8, Int16, Int32, Float], V = [Float, Double]
66 from abc
import abstractmethod
68 from collections
import deque, namedtuple
70 if sys.version_info > (3, 0):
72 from collections.abc
import Iterable
74 from abc
import ABCMeta
as ABC
75 from collections
import Iterable
80 Signature = namedtuple(
'Signature', [
'name',
'inputs',
'outputs',
'input_annotations',
'output_annotations',
'function_annotations',
'sizer'])
82 ExtArgumentTypes =
''' Int8, Int16, Int32, Int64, Float, Double, Void, PInt8, PInt16,
83 PInt32, PInt64, PFloat, PDouble, PBool, Bool, ArrayInt8, ArrayInt16,
84 ArrayInt32, ArrayInt64, ArrayFloat, ArrayDouble, ArrayBool, GeoPoint,
85 GeoLineString, Cursor, GeoPolygon, GeoMultiPolygon, ColumnInt8,
86 ColumnInt16, ColumnInt32, ColumnInt64, ColumnFloat, ColumnDouble,
87 ColumnBool, ColumnTextEncodingDict, ColumnTimestamp, TextEncodingNone,
88 TextEncodingDict, Timestamp, ColumnListInt8, ColumnListInt16, ColumnListInt32,
89 ColumnListInt64, ColumnListFloat, ColumnListDouble, ColumnListBool,
90 ColumnListTextEncodingDict'''.
strip().replace(
' ',
'').replace(
'\n',
'').
split(
',')
92 OutputBufferSizeTypes =
'''
93 kConstant, kUserSpecifiedConstantParameter, kUserSpecifiedRowMultiplier, kTableFunctionSpecifiedParameter, kPreFlightParameter
96 SupportedAnnotations =
'''
97 input_id, name, fields, require
101 SupportedFunctionAnnotations =
'''
102 filter_table_function_transpose, uses_manager
105 translate_map = dict(
106 Constant=
'kConstant',
107 PreFlight=
'kPreFlightParameter',
108 ConstantParameter=
'kUserSpecifiedConstantParameter',
109 RowMultiplier=
'kUserSpecifiedRowMultiplier',
110 UserSpecifiedConstantParameter=
'kUserSpecifiedConstantParameter',
111 UserSpecifiedRowMultiplier=
'kUserSpecifiedRowMultiplier',
112 TableFunctionSpecifiedParameter=
'kTableFunctionSpecifiedParameter',
117 for t
in [
'Int8',
'Int16',
'Int32',
'Int64',
'Float',
'Double',
'Bool',
118 'TextEncodingDict',
'TextEncodingNone']:
119 translate_map[t.lower()] = t
120 if t.startswith(
'Int'):
121 translate_map[t.lower() +
'_t'] = t
125 """Holds a `TYPE | ANNOTATIONS`-like structure.
133 return self.type.name
137 return self.type.args
140 return self.type.format_sizer()
147 return str(self.
type)
151 return self.type.tostring()
154 return self.__class__(self.type.apply_column(), self.
annotations)
157 return self.__class__(self.type.apply_namespace(ns), self.
annotations)
160 return self.type.get_cpp_type()
163 real_arg_name = dict(self.
annotations).get(
'name',
None)
164 return self.type.format_cpp_type(idx,
165 use_generic_arg_name=use_generic_arg_name,
166 real_arg_name=real_arg_name,
170 if name.startswith(
'is_'):
171 return getattr(self.
type, name)
172 raise AttributeError(name)
176 return obj.tostring()
180 """Holds a `NAME<ARGS>`-like structure.
184 assert isinstance(name, str)
185 assert isinstance(args, tuple)
or args
is None, args
190 return 'Bracket(%r, args=%r)' % (self.
name, self.
args)
195 return '%s<%s>' % (self.
name,
', '.
join(map(str, self.
args)))
200 return '%s<%s>' % (self.
name,
', '.
join(map(tostring, self.
args)))
203 """Normalize bracket for given kind
205 assert kind
in [
'input',
'output'], kind
209 if self.
name ==
'Cursor':
210 args = [(a
if a.is_column_any()
else Bracket(
'Column', args=(a,))).
normalize(kind=kind)
for a
in self.
args]
218 """Apply cursor to a non-cursor column argument type.
219 TODO: this method is currently unused but we should apply
220 cursor to all input column arguments in order to distingush
222 foo(Cursor(Column<int32>, Column<float>)) -> Column<int32>
223 foo(Cursor(Column<int32>), Cursor(Column<float>)) -> Column<int32>
224 that at the moment are treated as the same :(
227 return Bracket(
'Cursor', args=(self,))
236 if self.
name ==
'Cursor':
237 return Bracket(ns +
'::' + self.
name, args=tuple(a.apply_namespace(ns=ns)
for a
in self.
args))
238 if not self.name.startswith(ns +
'::'):
243 return self.name.rsplit(
"::", 1)[-1] ==
'Cursor'
246 return self.name.rsplit(
"::", 1)[-1].startswith(
'Column')
249 return self.name.rsplit(
"::", 1)[-1].startswith(
'ColumnList')
252 return self.name.rsplit(
"::", 1)[-1].startswith(
'Column')
and not self.
is_column_list()
255 return self.name.rsplit(
"::", 1)[-1].endswith(
'TextEncodingDict')
258 return self.name.rsplit(
"::", 1)[-1] ==
'ColumnTextEncodingDict'
261 return self.name.rsplit(
"::", 1)[-1] ==
'ColumnListTextEncodingDict'
264 return self.name.rsplit(
"::", 1)[-1]
in OutputBufferSizeTypes
267 return self.name.rsplit(
"::", 1)[-1] ==
'kUserSpecifiedRowMultiplier'
270 return self.name.rsplit(
"::", 1)[-1] ==
'kPreFlightParameter'
275 return self.name.rsplit(
"::", 1)[-1]
not in (
'kConstant',
'kTableFunctionSpecifiedParameter',
'kPreFlightParameter')
280 return 'TableFunctionOutputRowSizer{OutputBufferSizeType::%s, %s}' % (self.
name, val)
283 name = self.name.rsplit(
"::", 1)[-1]
285 if name.startswith(
'ColumnList'):
286 name = name.lstrip(
'ColumnList')
287 clsname =
'ColumnList'
288 elif name.startswith(
'Column'):
289 name = name.lstrip(
'Column')
291 if name.startswith(
'Bool'):
293 elif name.startswith(
'Int'):
294 ctype = name.lower() +
'_t'
295 elif name
in [
'Double',
'Float']:
297 elif name ==
'TextEncodingDict':
299 elif name ==
'TextEncodingNone':
301 elif name ==
'Timestamp':
304 raise NotImplementedError(self)
307 return '%s<%s>' % (clsname, ctype)
309 def format_cpp_type(self, idx, use_generic_arg_name=False, real_arg_name=None, is_input=True):
310 col_typs = (
'Column',
'ColumnList')
311 literal_ref_typs = (
'TextEncodingNone',)
312 if use_generic_arg_name:
313 arg_name =
'input' + str(idx)
if is_input
else 'output' + str(idx)
314 elif real_arg_name
is not None:
315 arg_name = real_arg_name
318 arg_name =
'input' + str(idx)
if is_input
else 'output' + str(idx)
319 const =
'const ' if is_input
else ''
322 if any(cpp_type.startswith(t)
for t
in col_typs + literal_ref_typs):
323 return '%s%s& %s' % (const, cpp_type, arg_name), arg_name
325 return '%s %s' % (cpp_type, arg_name), arg_name
329 """typ is a string in format NAME<ARGS> or NAME
331 Returns Bracket instance.
338 assert typ.endswith(
'>'), typ
339 name = typ[:i].
strip()
341 rest = typ[i + 1:-1].
strip()
347 a, rest = rest[:i].rstrip(), rest[i + 1:].lstrip()
348 args.append(cls.parse(a))
351 name = translate_map.get(name, name)
352 return cls(name, args)
357 for i, c
in enumerate(line):
362 elif d == 0
and c ==
',':
370 return line.endswith(
',')
or line.endswith(
'->')
or line.endswith(separator)
or line.endswith(
'|')
374 return identifier.lower() ==
'cursor'
384 class ParserException(Exception):
414 One of the tokens in the list above
416 Corresponding string in the text
425 Token.GREATER:
"GREATER",
426 Token.COMMA:
"COMMA",
427 Token.EQUAL:
"EQUAL",
428 Token.RARROW:
"RARROW",
429 Token.STRING:
"STRING",
430 Token.NUMBER:
"NUMBER",
437 Token.IDENTIFIER:
"IDENTIFIER",
438 Token.COLON:
"COLON",
440 return names.get(token)
443 return 'Token(%s, "%s")' % (Token.tok_name(self.
type), self.
lexeme)
489 self._tokens.append(
Token(type, lexeme))
504 return char
in (
"-",)
553 if char ==
'"' and curr !=
'\\':
566 if char
and char.isdigit():
575 IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
579 if char
and char.isalnum()
or char ==
"_":
587 return self.
peek().isalpha()
or self.
peek() ==
"_"
590 return self.
peek() ==
'"'
593 return self.
peek().isdigit()
596 return self.
peek().isalpha()
599 return self.
peek().isspace()
605 'Could not match char "%s" at pos %d on line\n %s' % (char, curr, self.
line)
637 class AstTransformer(AstVisitor):
638 """Only overload the methods you need"""
641 udtf = copy.copy(udtf_node)
642 udtf.inputs = [arg.accept(self)
for arg
in udtf.inputs]
643 udtf.outputs = [arg.accept(self)
for arg
in udtf.outputs]
645 udtf.templates = [t.accept(self)
for t
in udtf.templates]
646 udtf.annotations = [annot.accept(self)
for annot
in udtf.annotations]
650 c = copy.copy(composed_node)
651 c.inner = [i.accept(self)
for i
in c.inner]
655 arg_node = copy.copy(arg_node)
656 arg_node.type = arg_node.type.accept(self)
657 if arg_node.annotations:
658 arg_node.annotations = [a.accept(self)
for a
in arg_node.annotations]
662 return copy.copy(primitive_node)
665 return copy.copy(template_node)
668 return copy.copy(annotation_node)
672 """Returns a line formatted. Useful for testing"""
675 name = udtf_node.name
676 inputs =
", ".
join([arg.accept(self)
for arg
in udtf_node.inputs])
677 outputs =
", ".
join([arg.accept(self)
for arg
in udtf_node.outputs])
678 annotations =
"| ".
join([annot.accept(self)
for annot
in udtf_node.annotations])
679 sizer =
" | " + udtf_node.sizer.accept(self)
if udtf_node.sizer
else ""
681 annotations =
' | ' + annotations
682 if udtf_node.templates:
683 templates =
", ".
join([t.accept(self)
for t
in udtf_node.templates])
684 return "%s(%s)%s -> %s, %s%s" % (name, inputs, annotations, outputs, templates, sizer)
686 return "%s(%s)%s -> %s%s" % (name, inputs, annotations, outputs, sizer)
690 key = template_node.key
691 types = [
'"%s"' % typ
for typ
in template_node.types]
692 return "%s=[%s]" % (key,
", ".
join(types))
696 key = annotation_node.key
697 value = annotation_node.value
698 if isinstance(value, list):
699 return "%s=[%s]" % (key,
','.
join([v.accept(self)
for v
in value]))
700 return "%s=%s" % (key, value)
704 typ = arg_node.type.accept(self)
705 if arg_node.annotations:
706 ann =
" | ".
join([a.accept(self)
for a
in arg_node.annotations])
707 s =
"%s | %s" % (typ, ann)
711 if s ==
"ColumnTextEncodingDict" and arg_node.kind ==
"output":
712 return s +
" | input_id=args<0>"
716 T = composed_node.inner[0].accept(self)
717 if composed_node.is_column():
719 assert len(composed_node.inner) == 1
721 if composed_node.is_column_list():
723 assert len(composed_node.inner) == 1
724 return "ColumnList" + T
725 if composed_node.is_output_buffer_sizer():
728 assert len(composed_node.inner) == 1
729 return translate_map.get(composed_node.type) +
"<%s>" % (N,)
730 if composed_node.is_cursor():
732 Ts =
", ".
join([i.accept(self)
for i
in composed_node.inner])
733 return "Cursor<%s>" % (Ts)
734 raise ValueError(composed_node)
737 t = primitive_node.type
738 if primitive_node.is_output_buffer_sizer():
740 return translate_map.get(t, t) +
"<%d>" % (
741 primitive_node.get_parent(ArgNode).arg_pos + 1,
743 return translate_map.get(t, t)
747 """Like AstPrinter but returns a node instead of a string
753 vals = kwargs.values()
754 for instance
in itertools.product(*vals):
755 yield dict(zip(keys, instance))
759 """Expand template definition into multiple inputs"""
762 if not udtf_node.templates:
767 d = dict([(node.key, node.types)
for node
in udtf_node.templates])
768 name = udtf_node.name
772 inputs = [input_arg.accept(self)
for input_arg
in udtf_node.inputs]
773 outputs = [output_arg.accept(self)
for output_arg
in udtf_node.outputs]
774 udtfs.append(
UdtfNode(name, inputs, outputs, udtf_node.annotations,
None, udtf_node.sizer, udtf_node.line))
783 typ = composed_node.type
784 typ = self.mapping_dict.get(typ, typ)
786 inner = [i.accept(self)
for i
in composed_node.inner]
787 return composed_node.copy(typ, inner)
790 typ = primitive_node.type
791 typ = self.mapping_dict.get(typ, typ)
792 return primitive_node.copy(typ)
798 * Fix kUserSpecifiedRowMultiplier without a pos arg
800 t = primitive_node.type
802 if primitive_node.is_output_buffer_sizer():
803 pos =
PrimitiveNode(str(primitive_node.get_parent(ArgNode).arg_pos + 1))
807 return primitive_node
813 * Rename nodes using translate_map as dictionary
817 t = primitive_node.type
818 return primitive_node.copy(translate_map.get(t, t))
824 * Add default_input_id to Column(List)<TextEncodingDict> without one
829 default_input_id =
None
830 for idx, t
in enumerate(udtf_node.inputs):
832 if not isinstance(t.type, ComposedNode):
834 if default_input_id
is not None:
836 elif t.type.is_column_text_encoding_dict():
837 default_input_id =
AnnotationNode(
'input_id',
'args<%s>' % (idx,))
838 elif t.type.is_column_list_text_encoding_dict():
839 default_input_id =
AnnotationNode(
'input_id',
'args<%s, 0>' % (idx,))
841 for t
in udtf_node.outputs:
842 if isinstance(t.type, ComposedNode)
and t.type.is_any_text_encoding_dict():
843 for a
in t.annotations:
844 if a.key ==
'input_id':
847 if default_input_id
is None:
848 raise TypeError(
'Cannot parse line "%s".\n'
849 'Missing TextEncodingDict input?' %
851 t.annotations.append(default_input_id)
860 * Generate fields annotation to Cursor if non-existing
864 for _, t
in enumerate(udtf_node.inputs):
866 if not isinstance(t.type, ComposedNode):
869 if t.type.is_cursor()
and t.get_annotation(
'fields')
is None:
870 fields = list(
PrimitiveNode(a.get_annotation(
'name',
'field%s' % i))
for i, a
in enumerate(t.type.inner))
878 * Checks for supported annotations in a UDTF
881 for idx, t
in enumerate(udtf_node.inputs):
882 for a
in t.annotations:
883 if a.key
not in SupportedAnnotations:
885 for t
in udtf_node.outputs:
886 for a
in t.annotations:
887 if a.key
not in SupportedAnnotations:
889 for annot
in udtf_node.annotations:
890 if annot.key
not in SupportedFunctionAnnotations:
892 if annot.value.lower()
in [
'enable',
'on',
'1',
'true']:
894 elif annot.value.lower()
in [
'disable',
'off',
'0',
'false']:
902 name = udtf_node.name
904 input_annotations = []
906 output_annotations = []
907 function_annotations = []
908 sizer = udtf_node.sizer
910 for i
in udtf_node.inputs:
911 decl = i.accept(self)
913 input_annotations.append(decl.annotations)
915 for o
in udtf_node.outputs:
916 decl = o.accept(self)
917 outputs.append(decl.type)
918 output_annotations.append(decl.annotations)
920 for annot
in udtf_node.annotations:
921 annot = annot.accept(self)
922 function_annotations.append(annot)
924 return Signature(name, inputs, outputs, input_annotations, output_annotations, function_annotations, sizer)
927 t = arg_node.type.accept(self)
928 anns = [a.accept(self)
for a
in arg_node.annotations]
932 typ = translate_map.get(composed_node.type, composed_node.type)
933 inner = [i.accept(self)
for i
in composed_node.inner]
934 if composed_node.is_cursor():
935 inner = list(map(
lambda x: x.apply_column(), inner))
936 return Bracket(typ, args=tuple(inner))
937 elif composed_node.is_output_buffer_sizer():
938 return Bracket(typ, args=tuple(inner))
940 return Bracket(typ + str(inner[0]))
943 t = primitive_node.type
947 key = annotation_node.key
948 value = annotation_node.value
965 if isinstance(self, cls):
968 if self.parent
is not None:
969 return self.parent.get_parent(cls)
971 raise ValueError(
"could not find parent with given class %s" % (cls))
974 other = self.__class__(*args)
977 for attr
in [
'parent',
'arg_pos']:
978 if attr
in self.__dict__:
979 setattr(other, attr, getattr(self, attr))
988 class UdtfNode(Node, IterableNode):
990 def __init__(self, name, inputs, outputs, annotations, templates, sizer, line):
995 inputs : list[ArgNode]
996 outputs : list[ArgNode]
997 annotations : Optional[List[AnnotationNode]]
998 templates : Optional[list[TemplateNode]]
999 sizer : Optional[str]
1011 return visitor.visit_udtf_node(self)
1015 inputs = [str(i)
for i
in self.
inputs]
1016 outputs = [str(o)
for o
in self.
outputs]
1018 sizer =
"| %s" % str(self.
sizer)
if self.
sizer else ""
1020 templates = [str(t)
for t
in self.
templates]
1022 return "UDTF: %s (%s) | %s -> %s, %s %s" % (name, inputs, annotations, outputs, templates, sizer)
1024 return "UDTF: %s (%s) -> %s, %s %s" % (name, inputs, outputs, templates, sizer)
1027 return "UDTF: %s (%s) | %s -> %s %s" % (name, inputs, annotations, outputs, sizer)
1029 return "UDTF: %s (%s) -> %s %s" % (name, inputs, outputs, sizer)
1052 annotations : List[AnnotationNode]
1059 return visitor.visit_arg_node(self)
1066 return "ArgNode(%s %s)" % (t, anns)
1067 return "ArgNode(%s)" % (t)
1088 return self.
type ==
"ColumnList"
1091 return self.
type ==
"Cursor"
1095 return translate_map.get(t, t)
in OutputBufferSizeTypes
1109 return visitor.visit_primitive_node(self)
1115 return self.
type ==
'TextEncodingDict'
1127 inner : list[TypeNode]
1133 return visitor.visit_composed_node(self)
1137 return len(self.
inner)
1140 i =
", ".
join([str(i)
for i
in self.
inner])
1141 return "Composed(%s<%s>)" % (self.
type, i)
1144 for i
in self.
inner:
1148 return self.
is_column()
and self.
inner[0].is_text_encoding_dict()
1154 return self.
inner[0].is_text_encoding_dict()
1172 return visitor.visit_annotation_node(self)
1176 return self.
accept(printer)
1194 return visitor.visit_template_node(self)
1198 return self.
accept(printer)
1208 if not isinstance(ast_list, list):
1209 ast_list = [ast_list]
1212 ast_list = [ast.accept(c())
for ast
in ast_list]
1213 ast_list = itertools.chain.from_iterable(
1214 map(
lambda x: x
if isinstance(x, list)
else [x], ast_list))
1216 return list(ast_list)
1240 msg =
"Expected token %s but got %s at pos %d.\n Tokens: %s" % (
1242 Token.tok_name(expected_type),
1246 assert curr_token.type == expected_type, msg
1250 """consumes the current token iff its type matches the
1251 expected_type. Otherwise, an error is raised
1254 if curr_token.type == expected_type:
1258 expected_token = Token.tok_name(expected_type)
1260 'Token mismatch at function consume. '
1261 'Expected type "%s" but got token "%s"\n\n'
1262 'Tokens: %s\n' % (expected_token, curr_token, self.
_tokens)
1273 msg =
"\n\nError while trying to parse token %s at pos %d.\n" "Tokens: %s" % (
1282 return curr_token.type == expected_type
1290 udtf: IDENTIFIER "(" (args)? ")" ("|" annotation)* "->" args ("," templates)? ("|" "output_row_size" "=" primitive)?
1297 if not self.
match(Token.RPAR):
1304 self.
expect(Token.RARROW)
1316 assert idtn ==
"output_row_size"
1319 key =
"kPreFlightParameter"
1324 for arg
in input_args:
1327 i += arg.type.cursor_length()
if arg.type.is_cursor()
else 1
1329 for i, arg
in enumerate(output_args):
1333 return UdtfNode(name, input_args, output_args, annotations, templates, sizer, self.
line)
1338 args: arg IDENTIFIER ("," arg)*
1353 self.
_curr = curr + 1
1360 arg: type IDENTIFIER? ("|" annotation)*
1374 if ahead.type == Token.IDENTIFIER
and ahead.lexeme ==
'output_row_size':
1379 return ArgNode(typ, annotations)
1394 if not self.
match(Token.LESS):
1404 composed: "Cursor" "<" arg ("," arg)* ">"
1405 | IDENTIFIER "<" type ("," type)* ">"
1413 while self.
match(Token.COMMA):
1418 while self.
match(Token.COMMA):
1427 primitive: IDENTIFIER
1433 if self.
match(Token.IDENTIFIER):
1435 elif self.
match(Token.NUMBER):
1437 elif self.
match(Token.STRING):
1446 templates: template ("," template)*
1460 template: IDENTIFIER "=" "[" IDENTIFIER ("," IDENTIFIER)* "]"
1469 while self.
match(Token.COMMA):
1478 annotation: IDENTIFIER "=" IDENTIFIER ("<" NUMBER ("," NUMBER) ">")?
1479 | IDENTIFIER "=" "[" PRIMITIVE? ("," PRIMITIVE)* "]"
1480 | "require" "=" STRING
1487 if key ==
"require":
1492 if not self.
match(Token.RSQB):
1494 while self.
match(Token.COMMA):
1502 if self.
match(Token.GREATER):
1503 value +=
"<%s>" % (-1)
1506 if self.
match(Token.COMMA):
1509 value +=
"<%s,%s>" % (num1, num2)
1511 value +=
"<%s>" % (num1)
1518 IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
1522 token = self.
consume(Token.IDENTIFIER)
1532 token = self.
consume(Token.STRING)
1542 token = self.
consume(Token.NUMBER)
1548 udtf: IDENTIFIER "(" (args)? ")" ("|" annotation)* "->" args ("," templates)? ("|" "output_row_size" "=" primitive)?
1550 args: arg ("," arg)*
1552 arg: type IDENTIFIER? ("|" annotation)*
1557 composed: "Cursor" "<" arg ("," arg)* ">"
1558 | IDENTIFIER "<" type ("," type)* ">"
1560 primitive: IDENTIFIER
1564 annotation: IDENTIFIER "=" IDENTIFIER ("<" NUMBER ("," NUMBER) ">")?
1565 | IDENTIFIER "=" "[" PRIMITIVE? ("," PRIMITIVE)* "]"
1566 | "require" "=" STRING
1568 templates: template ("," template)
1569 template: IDENTIFIER "=" "[" IDENTIFIER ("," IDENTIFIER)* "]"
1571 IDENTIFIER: [A-Za-z_][A-Za-z0-9_]*
1586 if isinstance(node, Iterable):
1595 """Returns a list of parsed UDTF signatures."""
1599 for line
in open(input_file).readlines():
1601 if last_line
is not None:
1602 line = last_line +
' ' + line
1604 if not line.startswith(
'UDTF:'):
1610 line = line[5:].lstrip()
1613 if i == -1
or j == -1:
1614 sys.stderr.write(
'Invalid UDTF specification: `%s`. Skipping.\n' % (line))
1617 expected_result =
None
1618 if separator
in line:
1619 line, expected_result = line.split(separator, 1)
1620 expected_result = expected_result.strip().
split(separator)
1621 expected_result = list(map(
lambda s: s.strip(), expected_result))
1623 ast =
Parser(line).parse()
1625 if expected_result
is not None:
1627 skip_signature =
False
1629 result =
Pipeline(TemplateTransformer,
1630 FieldAnnotationTransformer,
1631 TextEncodingDictTransformer,
1632 SupportedAnnotationsTransformer,
1633 FixRowMultiplierPosArgTransformer,
1634 RenameNodesTransformer,
1636 except TransformerException
as msg:
1637 result = [
'%s: %s' % (
type(msg).__name__, msg)]
1638 skip_signature =
True
1639 assert set(result) == set(expected_result),
"\n\tresult: %s != \n\texpected: %s" % (
1646 signature =
Pipeline(TemplateTransformer,
1647 FieldAnnotationTransformer,
1648 TextEncodingDictTransformer,
1649 SupportedAnnotationsTransformer,
1650 FixRowMultiplierPosArgTransformer,
1651 RenameNodesTransformer,
1652 DeclBracketTransformer)(ast)
1654 signatures.extend(signature)
1664 cpp_args.append(
'TableFunctionManager& mgr')
1665 name_args.append(
'mgr')
1667 for idx, typ
in enumerate(input_types):
1668 cpp_arg, name = typ.format_cpp_type(idx,
1669 use_generic_arg_name=use_generic_arg_name,
1671 cpp_args.append(cpp_arg)
1672 name_args.append(name)
1674 if emit_output_args:
1675 for idx, typ
in enumerate(output_types):
1676 cpp_arg, name = typ.format_cpp_type(idx,
1677 use_generic_arg_name=use_generic_arg_name,
1679 cpp_args.append(cpp_arg)
1680 name_args.append(name)
1682 cpp_args =
', '.
join(cpp_args)
1683 name_args =
', '.
join(name_args)
1684 return cpp_args, name_args
1691 use_generic_arg_name=
True,
1692 emit_output_args=
True)
1694 template = (
"EXTENSION_NOINLINE int32_t\n"
1697 "}\n") % (caller, cpp_args, called, name_args)
1703 def format_error_msg(err_msg, uses_manager):
1705 return " return mgr.error_message(%s);\n" % (err_msg,)
1707 return " return table_function_error(%s);\n" % (err_msg,)
1712 use_generic_arg_name=
False,
1713 emit_output_args=
False)
1716 fn =
"EXTENSION_NOINLINE int32_t\n"
1717 fn +=
"%s(%s) {\n" % (fn_name.lower() +
"__preflight", cpp_args)
1719 fn =
"EXTENSION_NOINLINE int32_t\n"
1720 fn +=
"%s(%s) {\n" % (fn_name.lower() +
"__preflight", cpp_args)
1722 for typ
in input_types:
1723 ann = typ.annotations
1724 for key, value
in ann:
1725 if key ==
'require':
1726 err_msg =
'"Constraint `%s` is not satisfied."' % (value[1:-1])
1728 fn +=
" if (!(%s)) {\n" % (value[1:-1].replace(
'\\',
''),)
1729 fn += format_error_msg(err_msg, uses_manager)
1732 if sizer.is_arg_sizer():
1733 precomputed_nrows = str(sizer.args[0])
1734 if '"' in precomputed_nrows:
1735 precomputed_nrows = precomputed_nrows[1:-1]
1737 err_msg =
'"Output size expression `%s` evaluated in a negative value."' % (precomputed_nrows)
1738 fn +=
" auto _output_size = %s;\n" % (precomputed_nrows)
1739 fn +=
" if (_output_size < 0) {\n"
1740 fn += format_error_msg(err_msg, uses_manager)
1742 fn +=
" return _output_size;\n"
1744 fn +=
" return 0;\n"
1751 if sizer.is_arg_sizer():
1753 for arg_annotations
in sig.input_annotations:
1754 d = dict(arg_annotations)
1755 if 'require' in d.keys():
1767 s =
"std::vector<std::map<std::string, std::string>>{"
1768 s +=
', '.
join((
'{' +
', '.
join(
'{"%s", "%s"}' % (k, fmt(k, v))
for k, v
in a) +
'}')
for a
in annotations_)
1774 i = sig.name.rfind(
'_template')
1775 return i >= 0
and '__' in sig.name[:i + 1]
1779 return sig.inputs
and sig.inputs[0].name ==
'TableFunctionManager'
1784 i = sig.name.rfind(
'_gpu_')
1785 if i >= 0
and '__' in sig.name[:i + 1]:
1787 raise ValueError(
'Table function {} with gpu execution target cannot have TableFunctionManager argument'.format(sig.name))
1797 i = sig.name.rfind(
'_cpu_')
1798 return not (i >= 0
and '__' in sig.name[:i + 1])
1806 cpu_template_functions = []
1807 gpu_template_functions = []
1808 cpu_function_address_expressions = []
1809 gpu_function_address_expressions = []
1812 for input_file
in input_files:
1818 input_annotations = []
1821 if sig.sizer
is not None:
1822 expr = sig.sizer.value
1823 sizer =
Bracket(
'kPreFlightParameter', (expr,))
1825 uses_manager =
False
1826 for i, (t, annot)
in enumerate(zip(sig.inputs, sig.input_annotations)):
1827 if t.is_output_buffer_sizer():
1828 if t.is_user_specified():
1829 sql_types_.append(Bracket.parse(
'int32').normalize(kind=
'input'))
1830 input_types_.append(sql_types_[-1])
1831 input_annotations.append(annot)
1832 assert sizer
is None
1833 assert len(t.args) == 1, t
1835 elif t.name ==
'Cursor':
1837 input_types_.append(t_)
1838 input_annotations.append(annot)
1839 sql_types_.append(
Bracket(
'Cursor', args=()))
1840 elif t.name ==
'TableFunctionManager':
1842 raise ValueError(
'{} must appear as a first argument of {}, but found it at position {}.'.format(t, sig.name, i))
1845 input_types_.append(t)
1846 input_annotations.append(annot)
1847 if t.is_column_any():
1849 sql_types_.append(
Bracket(
'Cursor', args=()))
1851 sql_types_.append(t)
1854 name =
'kTableFunctionSpecifiedParameter'
1858 assert sizer
is not None
1859 ns_output_types = tuple([a.apply_namespace(ns=
'ExtArgumentType')
for a
in sig.outputs])
1860 ns_input_types = tuple([t.apply_namespace(ns=
'ExtArgumentType')
for t
in input_types_])
1861 ns_sql_types = tuple([t.apply_namespace(ns=
'ExtArgumentType')
for t
in sql_types_])
1863 sig.function_annotations.append((
'uses_manager', str(uses_manager).lower()))
1865 input_types =
'std::vector<ExtArgumentType>{%s}' % (
', '.
join(map(tostring, ns_input_types)))
1866 output_types =
'std::vector<ExtArgumentType>{%s}' % (
', '.
join(map(tostring, ns_output_types)))
1867 sql_types =
'std::vector<ExtArgumentType>{%s}' % (
', '.
join(map(tostring, ns_sql_types)))
1868 annotations =
format_annotations(input_annotations + sig.output_annotations + [sig.function_annotations])
1881 cond_fns.append(check_fn)
1884 name = sig.name +
'_' + str(counter)
1887 address_expression = (
'avoid_opt_address(reinterpret_cast<void*>(%s))' % name)
1889 cpu_template_functions.append(t)
1890 cpu_function_address_expressions.append(address_expression)
1892 gpu_template_functions.append(t)
1893 gpu_function_address_expressions.append(address_expression)
1894 add = (
'TableFunctionsFactory::add("%s", %s, %s, %s, %s, %s, /*is_runtime:*/false);'
1895 % (name, sizer.format_sizer(), input_types, output_types, sql_types, annotations))
1896 add_stmts.append(add)
1899 add = (
'TableFunctionsFactory::add("%s", %s, %s, %s, %s, %s, /*is_runtime:*/false);'
1900 % (sig.name, sizer.format_sizer(), input_types, output_types, sql_types, annotations))
1901 add_stmts.append(add)
1902 address_expression = (
'avoid_opt_address(reinterpret_cast<void*>(%s))' % sig.name)
1905 cpu_function_address_expressions.append(address_expression)
1907 gpu_function_address_expressions.append(address_expression)
1909 return add_stmts, cpu_template_functions, gpu_template_functions, cpu_function_address_expressions, gpu_function_address_expressions, cond_fns
1912 if len(sys.argv) < 3:
1914 input_files = [os.path.join(os.path.dirname(__file__),
'test_udtf_signatures.hpp')]
1915 print(
'Running tests from %s' % (
', '.
join(input_files)))
1918 print(
'Usage:\n %s %s input1.hpp input2.hpp ... output.hpp' % (sys.executable, sys.argv[0], ))
1922 input_files, output_filename = sys.argv[1:-1], sys.argv[-1]
1923 cpu_output_header = os.path.splitext(output_filename)[0] +
'_cpu.hpp'
1924 gpu_output_header = os.path.splitext(output_filename)[0] +
'_gpu.hpp'
1925 assert input_files, sys.argv
1927 add_stmts, cpu_template_functions, gpu_template_functions, cpu_address_expressions, gpu_address_expressions, cond_fns =
parse_annotations(sys.argv[1:-1])
1929 canonical_input_files = [input_file[input_file.find(
"/QueryEngine/") + 1:]
for input_file
in input_files]
1930 header_includes = [
'#include "' + canonical_input_file +
'"' for canonical_input_file
in canonical_input_files]
1933 ADD_FUNC_CHUNK_SIZE = 100
1937 NO_OPT_ATTRIBUTE void add_table_functions_%d() const {
1940 ''' % (i,
'\n '.
join(chunk))
1943 chunks = [ add_stmts[n:n+ADD_FUNC_CHUNK_SIZE]
for n
in range(0, len(add_stmts), ADD_FUNC_CHUNK_SIZE) ]
1944 return [
add_method(i,chunk)
for i,chunk
in enumerate(chunks) ]
1947 quot, rem = divmod(len(add_stmts), ADD_FUNC_CHUNK_SIZE)
1948 return [
'add_table_functions_%d();' % (i)
for i
in range(quot + int(0 < rem)) ]
1952 This file is generated by %s. Do no edit!
1955 #include "QueryEngine/TableFunctions/TableFunctionsFactory.h"
1959 Include the UDTF template initiations:
1961 #include "TableFunctionsFactory_init_cpu.hpp"
1963 // volatile+noinline prevents compiler optimization
1965 __declspec(noinline)
1967 __attribute__((noinline))
1970 bool avoid_opt_address(void *address) {
1971 return address != nullptr;
1974 bool functions_exist() {
1982 extern bool g_enable_table_functions;
1984 namespace table_functions {
1986 std::once_flag init_flag;
1988 #if defined(__clang__)
1989 #define NO_OPT_ATTRIBUTE __attribute__((optnone))
1991 #elif defined(__GNUC__) || defined(__GNUG__)
1992 #define NO_OPT_ATTRIBUTE __attribute((optimize("O0")))
1994 #elif defined(_MSC_VER)
1995 #define NO_OPT_ATTRIBUTE
1999 #if defined(_MSC_VER)
2000 #pragma optimize("", off)
2003 struct AddTableFunctions {
2005 NO_OPT_ATTRIBUTE void operator()() {
2010 void TableFunctionsFactory::init() {
2011 if (!g_enable_table_functions) {
2015 if (!functions_exist()) {
2020 std::call_once(init_flag, AddTableFunctions{});
2022 #if defined(_MSC_VER)
2023 #pragma optimize("", on)
2026 // conditional check functions
2029 } // namespace table_functions
2032 '\n'.
join(header_includes),
2033 ' &&\n'.
join(cpu_address_expressions),
2038 header_content =
'''
2040 This file is generated by %s. Do no edit!
2047 dirname = os.path.dirname(output_filename)
2049 if dirname
and not os.path.exists(dirname):
2051 os.makedirs(dirname)
2052 except OSError
as e:
2054 if e.errno != errno.EEXIST:
2061 f =
open(cpu_output_header,
'w')
2062 f.write(header_content % (sys.argv[0],
'\n'.
join(header_includes),
'\n'.
join(cpu_template_functions)))
2065 f =
open(gpu_output_header,
'w')
2066 f.write(header_content % (sys.argv[0],
'\n'.
join(header_includes),
'\n'.
join(gpu_template_functions)))
def build_template_function_call
def is_output_buffer_sizer
def is_column_list_text_encoding_dict
def is_text_encoding_dict
def visit_annotation_node
def is_any_text_encoding_dict
def is_column_list_text_encoding_dict
int open(const char *path, int flags, int mode)
def can_token_be_double_char
def build_preflight_function
def is_column_text_encoding_dict
def is_any_text_encoding_dict
def visit_annotation_node
def is_output_buffer_sizer
def is_column_text_encoding_dict
def must_emit_preflight_function