OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TableFunctionsFactory_transformers.py
Go to the documentation of this file.
1 __all__ = ['TransformerException', 'AstPrinter', 'TemplateTransformer',
2  'FixRowMultiplierPosArgTransformer', 'RenameNodesTransformer',
3  'TextEncodingDictTransformer', 'FieldAnnotationTransformer',
4  'SupportedAnnotationsTransformer', 'RangeAnnotationTransformer',
5  'CursorAnnotationTransformer', 'AmbiguousSignatureCheckTransformer',
6  'DefaultValueAnnotationTransformer',
7  'DeclBracketTransformer', 'Pipeline']
8 
9 
10 import sys
11 import copy
12 import warnings
13 import itertools
14 from ast import literal_eval
15 from abc import abstractmethod
16 
17 if sys.version_info > (3, 0):
18  from abc import ABC
19 else:
20  from abc import ABCMeta as ABC
21 
22 
23 import TableFunctionsFactory_util as util
24 import TableFunctionsFactory_node as tf_node
25 import TableFunctionsFactory_declbracket as declbracket
26 
27 
28 class TransformerException(Exception):
29  pass
30 
31 
32 class TransformerWarning(UserWarning):
33  pass
34 
35 
36 class AstVisitor(object):
37  __metaclass__ = ABC
38 
39  @abstractmethod
40  def visit_udtf_node(self, node):
41  raise NotImplementedError()
42 
43  @abstractmethod
44  def visit_composed_node(self, node):
45  raise NotImplementedError()
46 
47  @abstractmethod
48  def visit_arg_node(self, node):
49  raise NotImplementedError()
50 
51  @abstractmethod
52  def visit_primitive_node(self, node):
53  raise NotImplementedError()
54 
55  @abstractmethod
56  def visit_annotation_node(self, node):
57  raise NotImplementedError()
58 
59  @abstractmethod
60  def visit_template_node(self, node):
61  raise NotImplementedError()
62 
63 
65  """Only overload the methods you need"""
66 
67  def visit_udtf_node(self, udtf_node):
68  udtf = copy.copy(udtf_node)
69  udtf.inputs = [arg.accept(self) for arg in udtf.inputs]
70  udtf.outputs = [arg.accept(self) for arg in udtf.outputs]
71  if udtf.templates:
72  udtf.templates = [t.accept(self) for t in udtf.templates]
73  udtf.annotations = [annot.accept(self) for annot in udtf.annotations]
74  return udtf
75 
76  def visit_composed_node(self, composed_node):
77  c = copy.copy(composed_node)
78  c.inner = [i.accept(self) for i in c.inner]
79  return c
80 
81  def visit_arg_node(self, arg_node):
82  arg_node = copy.copy(arg_node)
83  arg_node.type = arg_node.type.accept(self)
84  if arg_node.annotations:
85  arg_node.annotations = [a.accept(self) for a in arg_node.annotations]
86  return arg_node
87 
88  def visit_primitive_node(self, primitive_node):
89  return copy.copy(primitive_node)
90 
91  def visit_template_node(self, template_node):
92  return copy.copy(template_node)
93 
94  def visit_annotation_node(self, annotation_node):
95  return copy.copy(annotation_node)
96 
97 
99  """Returns a line formatted. Useful for testing"""
100 
101  def visit_udtf_node(self, udtf_node):
102  name = udtf_node.name
103  inputs = ", ".join([arg.accept(self) for arg in udtf_node.inputs])
104  outputs = ", ".join([arg.accept(self) for arg in udtf_node.outputs])
105  annotations = "| ".join([annot.accept(self) for annot in udtf_node.annotations])
106  sizer = " | " + udtf_node.sizer.accept(self) if udtf_node.sizer else ""
107  if annotations:
108  annotations = ' | ' + annotations
109  if udtf_node.templates:
110  templates = ", ".join([t.accept(self) for t in udtf_node.templates])
111  return "%s(%s)%s -> %s, %s%s" % (name, inputs, annotations, outputs, templates, sizer)
112  else:
113  return "%s(%s)%s -> %s%s" % (name, inputs, annotations, outputs, sizer)
114 
115  def visit_template_node(self, template_node):
116  # T=[T1, T2, ..., TN]
117  key = template_node.key
118  types = ['"%s"' % typ for typ in template_node.types]
119  return "%s=[%s]" % (key, ", ".join(types))
120 
121  def visit_annotation_node(self, annotation_node):
122  # key=value
123  key = annotation_node.key
124  value = annotation_node.value
125  if isinstance(value, list):
126  return "%s=[%s]" % (key, ','.join([v.accept(self) for v in value]))
127  return "%s=%s" % (key, value)
128 
129  def visit_arg_node(self, arg_node):
130  # type | annotation
131  typ = arg_node.type.accept(self)
132  if arg_node.annotations:
133  ann = " | ".join([a.accept(self) for a in arg_node.annotations])
134  s = "%s | %s" % (typ, ann)
135  else:
136  s = "%s" % (typ,)
137  return s
138 
139  def visit_composed_node(self, composed_node):
140  T = composed_node.inner[0].accept(self)
141  if composed_node.is_array():
142  # Array<T>
143  assert len(composed_node.inner) == 1
144  return "Array" + T
145  if composed_node.is_column():
146  # Column<T>
147  assert len(composed_node.inner) == 1
148  return "Column" + T
149  if composed_node.is_column_list():
150  # ColumnList<T>
151  assert len(composed_node.inner) == 1
152  return "ColumnList" + T
153  if composed_node.is_output_buffer_sizer():
154  # kConstant<N>
155  N = T
156  assert len(composed_node.inner) == 1
157  return util.translate_map.get(composed_node.type) + "<%s>" % (N,)
158  if composed_node.is_cursor():
159  # Cursor<T1, T2, ..., TN>
160  Ts = ", ".join([i.accept(self) for i in composed_node.inner])
161  return "Cursor<%s>" % (Ts)
162  raise ValueError(composed_node)
163 
164  def visit_primitive_node(self, primitive_node):
165  t = primitive_node.type
166  if primitive_node.is_output_buffer_sizer():
167  # arg_pos is zero-based
168  return util.translate_map.get(t, t) + "<%d>" % (
169  primitive_node.get_parent(tf_node.ArgNode).arg_pos + 1,
170  )
171  return util.translate_map.get(t, t)
172 
173 
175  """Like AstPrinter but returns a node instead of a string
176  """
177 
178 
179 def product_dict(**kwargs):
180  keys = kwargs.keys()
181  vals = kwargs.values()
182  for instance in itertools.product(*vals):
183  yield dict(zip(keys, instance))
184 
185 
187  """Expand template definition into multiple inputs"""
188 
189  def visit_udtf_node(self, udtf_node):
190  if not udtf_node.templates:
191  return udtf_node
192 
193  udtfs = dict()
194 
195  d = dict([(node.key, node.types) for node in udtf_node.templates])
196  name = udtf_node.name
197 
198  for product in product_dict(**d):
199  self.mapping_dict = product
200  inputs = [input_arg.accept(self) for input_arg in udtf_node.inputs]
201  outputs = [output_arg.accept(self) for output_arg in udtf_node.outputs]
202  udtf = tf_node.UdtfNode(name, inputs, outputs, udtf_node.annotations, None, udtf_node.sizer, udtf_node.line)
203  udtfs[str(udtf)] = udtf
204  self.mapping_dict = {}
205 
206  udtfs = list(udtfs.values())
207 
208  if len(udtfs) == 1:
209  return udtfs[0]
210 
211  return udtfs
212 
213  def visit_composed_node(self, composed_node):
214  typ = composed_node.type
215  typ = self.mapping_dict.get(typ, typ)
216 
217  inner = [i.accept(self) for i in composed_node.inner]
218  return composed_node.copy(typ, inner)
219 
220  def visit_primitive_node(self, primitive_node):
221  typ = primitive_node.type
222  typ = self.mapping_dict.get(typ, typ)
223  return primitive_node.copy(typ)
224 
225 
227  def visit_primitive_node(self, primitive_node):
228  """
229  * Fix kUserSpecifiedRowMultiplier without a pos arg
230  """
231  t = primitive_node.type
232 
233  if primitive_node.is_output_buffer_sizer():
234  pos = tf_node.PrimitiveNode(str(primitive_node.get_parent(tf_node.ArgNode).arg_pos + 1))
235  node = tf_node.ComposedNode(t, inner=[pos])
236  return node
237 
238  return primitive_node
239 
240 
242  def visit_primitive_node(self, primitive_node):
243  """
244  * Rename nodes using util.translate_map as dictionary
245  int -> Int32
246  float -> Float
247  """
248  t = primitive_node.type
249  return primitive_node.copy(util.translate_map.get(t, t))
250 
251 
253  def visit_udtf_node(self, udtf_node):
254  """
255  * Add default_input_id to Column(List)<TextEncodingDict> without one
256  """
257  udtf_node = super(type(self), self).visit_udtf_node(udtf_node)
258  # add default input_id
259  default_input_id = None
260  for idx, t in enumerate(udtf_node.inputs):
261 
262  if not isinstance(t.type, tf_node.ComposedNode):
263  continue
264  if default_input_id is not None:
265  pass
266  elif t.type.is_column_text_encoding_dict() or t.type.is_column_array_text_encoding_dict():
267  default_input_id = tf_node.AnnotationNode('input_id', 'args<%s>' % (idx,))
268  elif t.type.is_column_list_text_encoding_dict():
269  default_input_id = tf_node.AnnotationNode('input_id', 'args<%s, 0>' % (idx,))
270 
271  for t in udtf_node.outputs:
272  if isinstance(t.type, tf_node.ComposedNode) and t.type.is_any_text_encoding_dict():
273  for a in t.annotations:
274  if a.key == 'input_id':
275  break
276  else:
277  if default_input_id is None:
278  raise TypeError('Cannot parse line "%s".\n'
279  'Missing TextEncodingDict input?' %
280  (udtf_node.line))
281  t.annotations.append(default_input_id)
282 
283  return udtf_node
284 
285 
287 
288  def visit_udtf_node(self, udtf_node):
289  """
290  * Generate fields annotation to Cursor if non-existing
291  """
292  udtf_node = super(type(self), self).visit_udtf_node(udtf_node)
293 
294  for t in udtf_node.inputs:
295 
296  if not isinstance(t.type, tf_node.ComposedNode):
297  continue
298 
299  if t.type.is_cursor() and t.get_annotation('fields') is None:
300  fields = list(tf_node.PrimitiveNode(a.get_annotation('name', 'field%s' % i)) for i, a in enumerate(t.type.inner))
301  t.annotations.append(tf_node.AnnotationNode('fields', fields))
302 
303  return udtf_node
304 
305 
307  def visit_udtf_node(self, udtf_node):
308  """
309  * Typechecks default value annotations.
310  """
311  udtf_node = super(type(self), self).visit_udtf_node(udtf_node)
312 
313  for t in udtf_node.inputs:
314  for a in filter(lambda x: x.key == "default", t.annotations):
315  if not t.type.is_scalar():
316  raise TransformerException(
317  'Error in function "%s", input annotation \'%s=%s\'. '
318  '\"default\" annotation is only supported for scalar types!'\
319  % (udtf_node.name, a.key, a.value)
320  )
321  literal = literal_eval(a.value)
322  lst = [(bool, 'is_boolean_scalar'), (int, 'is_integer_scalar'), (float, 'is_float_scalar'),
323  (str, 'is_string_scalar')]
324 
325  for (cls, mthd) in lst:
326  if type(literal) is cls:
327  assert isinstance(t, tf_node.ArgNode)
328  m = getattr(t.type, mthd)
329  if not m():
330  raise TransformerException(
331  'Error in function "%s", input annotation \'%s=%s\'. '
332  'Argument is of type "%s" but value type was inferred as "%s".'
333  % (udtf_node.name, a.key, a.value, t.type.type, type(literal).__name__))
334  break
335 
336  return udtf_node
337 
338 
340  """
341  * Checks for supported annotations in a UDTF
342  """
343  def visit_udtf_node(self, udtf_node):
344  for t in udtf_node.inputs:
345  for a in t.annotations:
346  if a.key not in util.SupportedAnnotations:
347  raise TransformerException('unknown input annotation: `%s`' % (a.key))
348  for t in udtf_node.outputs:
349  for a in t.annotations:
350  if a.key not in util.SupportedAnnotations:
351  raise TransformerException('unknown output annotation: `%s`' % (a.key))
352  for annot in udtf_node.annotations:
353  if annot.key not in util.SupportedFunctionAnnotations:
354  raise TransformerException('unknown function annotation: `%s`' % (annot.key))
355  if annot.value.lower() in ['enable', 'on', '1', 'true']:
356  annot.value = '1'
357  elif annot.value.lower() in ['disable', 'off', '0', 'false']:
358  annot.value = '0'
359  return udtf_node
360 
361 
363  """
364  * A UDTF declaration is ambiguous if two or more ColumnLists are adjacent
365  to each other:
366  func__0(ColumnList<T> X, ColumnList<T> Z) -> Column<U>
367  func__1(ColumnList<T> X, Column<T> Y, ColumnList<T> Z) -> Column<U>
368  The first ColumnList ends up consuming all of the arguments leaving a single
369  one for the last ColumnList. In other words, Z becomes a Column
370  """
371  def check_ambiguity(self, udtf_name, lst):
372  """
373  udtf_name: str
374  lst: list[list[Node]]
375  """
376  for l in lst:
377  for i in range(len(l)):
378  if not l[i].is_column_list():
379  i += 1
380  continue
381 
382  collist = l[i]
383  T = collist.inner[0]
384 
385  for j in range(i+1, len(l)):
386  # if lst[j] == Column<T>, just continue
387  if l[j].is_column() and l[j].is_column_of(T):
388  continue
389  elif l[j].is_column_list() and l[j].is_column_list_of(T):
390  msg = ('%s signature is ambiguous as there are two '
391  'ColumnList with the same subtype in the same '
392  'group.') % (udtf_name)
393  if udtf_name not in ['ct_overload_column_list_test2__cpu_template']:
394  # warn only when the function ought to be fixed
395  warnings.warn(msg, TransformerWarning)
396  else:
397  break
398 
399  def visit_udtf_node(self, udtf_node):
400  lst = []
401  cursor = False
402  for arg in udtf_node.inputs:
403  s = arg.accept(self)
404  if isinstance(s, list):
405  lst.append(s) # Cursor
406  cursor = True
407  else:
408  # Aggregate single arguments in a list
409  if cursor or len(lst) == 0:
410  lst.append([s])
411  else:
412  lst[-1].append(s)
413  cursor = False
414 
415  self.check_ambiguity(udtf_node.name, lst)
416 
417  return udtf_node
418 
419  def visit_composed_node(self, composed_node):
420  s = super(type(self), self).visit_composed_node(composed_node)
421  if composed_node.is_cursor():
422  return [i.accept(self) for i in composed_node.inner]
423  return s
424 
425  def visit_arg_node(self, arg_node):
426  # skip annotations
427  return arg_node.type.accept(self)
428 
429 
431  """
432  * Append require annotation if range is used
433  """
434  def visit_arg_node(self, arg_node):
435  for ann in arg_node.annotations:
436  if ann.key == 'range':
437  name = arg_node.get_annotation('name')
438  if name is None:
439  raise TransformerException('"range" requires a named argument')
440 
441  v = ann.value
442  if len(v) == 2:
443  lo, hi = ann.value
444  value = '"{lo} <= {name} && {name} <= {hi}"'.format(lo=lo, hi=hi, name=name)
445  else:
446  raise TransformerException('"range" requires an interval. Got {v}'.format(v=v))
447  arg_node.set_annotation('require', value)
448  return arg_node
449 
450 
452  """
453  * Move a "require" annotation from inside a cursor to the cursor
454  """
455 
456  def visit_arg_node(self, arg_node):
457  if arg_node.type.is_cursor():
458  for inner in arg_node.type.inner:
459  for ann in inner.annotations:
460  if ann.key == 'require':
461  arg_node.annotations.append(ann)
462  return arg_node
463 
464 
466 
467  def visit_udtf_node(self, udtf_node):
468  name = udtf_node.name
469  inputs = []
470  input_annotations = []
471  outputs = []
472  output_annotations = []
473  function_annotations = []
474  sizer = udtf_node.sizer
475 
476  for i in udtf_node.inputs:
477  decl = i.accept(self)
478  inputs.append(decl)
479  input_annotations.append(decl.annotations)
480 
481  for o in udtf_node.outputs:
482  decl = o.accept(self)
483  outputs.append(decl.type)
484  output_annotations.append(decl.annotations)
485 
486  for annot in udtf_node.annotations:
487  annot = annot.accept(self)
488  function_annotations.append(annot)
489 
490  return util.Signature(name, inputs, outputs, input_annotations, output_annotations, function_annotations, sizer)
491 
492  def visit_arg_node(self, arg_node):
493  t = arg_node.type.accept(self)
494  anns = [a.accept(self) for a in arg_node.annotations]
495  return declbracket.Declaration(t, anns)
496 
497  def visit_composed_node(self, composed_node):
498  typ = util.translate_map.get(composed_node.type, composed_node.type)
499  inner = [i.accept(self) for i in composed_node.inner]
500  if composed_node.is_cursor():
501  inner = list(map(lambda x: x.apply_column(), inner))
502  return declbracket.Bracket(typ, args=tuple(inner))
503  elif composed_node.is_output_buffer_sizer():
504  return declbracket.Bracket(typ, args=tuple(inner))
505  else:
506  return declbracket.Bracket(typ + str(inner[0]))
507 
508  def visit_primitive_node(self, primitive_node):
509  t = primitive_node.type
510  return declbracket.Bracket(t)
511 
512  def visit_annotation_node(self, annotation_node):
513  key = annotation_node.key
514  value = annotation_node.value
515  return (key, value)
516 
517 
518 class Pipeline(object):
519  def __init__(self, *passes):
520  self.passes = passes
521 
522  def __call__(self, ast_list):
523  if not isinstance(ast_list, list):
524  ast_list = [ast_list]
525 
526  for c in self.passes:
527  ast_list = [ast.accept(c()) for ast in ast_list]
528  ast_list = itertools.chain.from_iterable( # flatten the list
529  map(lambda x: x if isinstance(x, list) else [x], ast_list))
530 
531  return list(ast_list)
size_t append(FILE *f, const size_t size, const int8_t *buf)
Appends the specified number of bytes to the end of the file f from buf.
Definition: File.cpp:158
std::string join(T const &container, std::string const &delim)