OmniSciDB  94e8789169
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
generate_TableFunctionsFactory_init.py
Go to the documentation of this file.
1 """Given a list of input files, scan for lines containing UDTF
2 specification statements in the following form:
3 
4  UDTF: function_name(<arguments>) -> <output column types>
5 
6 where <arguments> is a comma-separated list of argument types. The
7 argument types specifications are:
8 
9 - scalar types:
10  Int8, Int16, Int32, Int64, Float, Double, Bool, etc
11 - column types:
12  ColumnInt8, ColumnInt16, ColumnInt32, ColumnInt64, ColumnFloat, ColumnDouble, ColumnBool, etc
13 - cursor type:
14  Cursor<t0, t1, ...>
15  where t0, t1 are column types
16 - output buffer size parameter type:
17  RowMultiplier<i>, ConstantParameter<i>, Constant<i>
18  where i is literal integer
19 
20 The output column types is a comma-separated list of column types, see above.
21 
22 In addition, the following equivalents are suppored:
23  Column<T> == ColumnT
24  Cursor<T, V, ...> == Cursor<ColumnT, ColumnV, ...>
25  int8 == int8_t == Int8, etc
26  float == Float, double == Double, bool == Bool
27  T == ColumnT for output column types
28  RowMultiplier == RowMultiplier<i> where i is the one-based position of the sizer argument
29  when no sizer argument is provided, Constant<1> is assumed
30 """
31 # Author: Pearu Peterson
32 # Created: January 2021
33 
34 import os
35 import re
36 import sys
37 
38 ExtArgumentTypes = '''
39 Int8, Int16, Int32, Int64, Float, Double, Void, PInt8, PInt16, PInt32,
40 PInt64, PFloat, PDouble, PBool, Bool, ArrayInt8, ArrayInt16,
41 ArrayInt32, ArrayInt64, ArrayFloat, ArrayDouble, ArrayBool, GeoPoint,
42 GeoLineString, Cursor, GeoPolygon, GeoMultiPolygon, ColumnInt8,
43 ColumnInt16, ColumnInt32, ColumnInt64, ColumnFloat, ColumnDouble,
44 ColumnBool, TextEncodingNone, TextEncodingDict8, TextEncodingDict16,
45 TextEncodingDict32
46 '''.strip().replace(' ', '').split(',')
47 
48 OutputBufferSizeTypes = '''
49 kConstant, kUserSpecifiedConstantParameter, kUserSpecifiedRowMultiplier
50 '''.strip().replace(' ', '').split(',')
51 
52 translate_map = dict(
53  Constant = 'kConstant',
54  ConstantParameter = 'kUserSpecifiedConstantParameter',
55  RowMultiplier = 'kUserSpecifiedRowMultiplier',
56  UserSpecifiedConstantParameter = 'kUserSpecifiedConstantParameter',
57  UserSpecifiedRowMultiplier = 'kUserSpecifiedRowMultiplier',
58  short = 'Int16',
59  int = 'Int32',
60  long = 'Int64',
61 )
62 for t in ['Int8', 'Int16', 'Int32', 'Int64', 'Float', 'Double', 'Bool']:
63  translate_map[t.lower()] = t
64  if t.startswith('Int'):
65  translate_map[t.lower() + '_t'] = t
66 
67 
68 _is_int = re.compile(r'\d+').match
69 
70 def type_parse(a):
71  i = a.find('<')
72  if i >= 0:
73  assert a.endswith('>')
74  n = a[:i]
75  n = translate_map.get(n, n)
76  if n in OutputBufferSizeTypes:
77  v = a[i+1:-1]
78  assert _is_int(v)
79  return n, v
80  if n == 'Cursor':
81  lst = []
82  for t in map(type_parse, a[i+1:-1].split(',')):
83  if 'Column' + t in ExtArgumentTypes:
84  lst.append('Column' + t)
85  else:
86  lst.append(t)
87  return n, tuple(lst)
88  if n == 'Column':
89  return n + type_parse(a[i+1:-1])
90  else:
91  a = translate_map.get(a, a)
92  if a in ExtArgumentTypes:
93  return a
94  if a in OutputBufferSizeTypes:
95  return a, None
96  raise ValueError('Cannot parse `%s` to ExtArgumentTypes or OutputBufferSizeTypes' % (a,))
97 
98 
99 add_stmts = []
100 
101 for input_file in sys.argv[1:-1]:
102  for line in open(input_file).readlines():
103  line = line.replace(' ', '').strip()
104  if not line.startswith('UDTF:'):
105  continue
106  line = line[5:]
107  i = line.find('(')
108  j = line.find(')')
109  if i == -1 or j == -1:
110  sys.stderr.write('Invalid UDTF specification: `%s`. Skipping.\n' % (line))
111  continue
112  name = line[:i]
113  args_line = line[i+1:j]
114  outputs = line[j+1:]
115  if outputs.startswith('->'):
116  outputs = outputs[2:]
117  outputs = outputs.split(',')
118 
119  args = []
120  while args_line:
121  i = args_line.find(',')
122  if i == -1:
123  args.append(args_line)
124  break
125  j = args_line.find('<')
126  k = args_line.find('>')
127  if j == -1 or i < j:
128  args.append(args_line[:i])
129  args_line = args_line[i+1:]
130  else:
131  assert k != -1
132  args.append(args_line[:k+1])
133  args_line = args_line[k+1:].lstrip(',')
134 
135  input_types = []
136  output_types = []
137  sql_types = []
138  sizer = None
139  for i, a in enumerate(args):
140  try:
141  r = type_parse(a)
142  except ValueError as msg:
143  raise ValueError('`%s`: %s' % (line, msg))
144  if isinstance(r, str) and r.startswith('Column'):
145  r = 'Cursor', (r,)
146  if isinstance(r, str):
147  input_types.append(r)
148  sql_types.append(r)
149  else:
150  n, t = r
151  if n in OutputBufferSizeTypes:
152  if n != 'kConstant':
153  input_types.append('ExtArgumentType::Int32')
154  sql_types.append('ExtArgumentType::Int32')
155  if n == 'kUserSpecifiedRowMultiplier':
156  if not t:
157  t = str(i + 1)
158  assert t == str(i+1), 'Expected %s<%s> got %s<%s> from %s' % (n, i+1, n, t, a)
159  assert sizer is None # exactly one sizer argument is allowed
160  sizer = 'TableFunctionOutputRowSizer{OutputBufferSizeType::%s, %s}' % (n, t)
161  else:
162  assert n == 'Cursor', (a, r)
163  for t_ in t:
164  input_types.append('ExtArgumentType::%s' % (t_))
165  sql_types.append('ExtArgumentType::%s' % (n))
166 
167  for a in outputs:
168  try:
169  r = type_parse(a)
170  except ValueError as msg:
171  raise ValueError('`%s`: %s' % (line, msg))
172  assert isinstance(r, str), (a, r)
173  if 'Column' + r in ExtArgumentTypes:
174  r = 'Column' + r
175  output_types.append('ExtArgumentType::%s' % (r))
176 
177  if sizer is None:
178  sizer = 'TableFunctionOutputRowSizer{OutputBufferSizeType::kConstant, 1}'
179 
180  input_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(input_types))
181  output_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(output_types))
182  sql_types = 'std::vector<ExtArgumentType>{%s}' % (', '.join(sql_types))
183  add = 'TableFunctionsFactory::add("%s", %s, %s, %s, %s);' % (name, sizer, input_types, output_types, sql_types)
184  add_stmts.append(add)
185 
186 
187 content = '''
188 /*
189  This file is generated by %s. Do no edit!
190 */
191 
192 #include "QueryEngine/TableFunctions/TableFunctionsFactory.h"
193 
194 extern bool g_enable_table_functions;
195 
196 namespace table_functions {
197 
198 std::once_flag init_flag;
199 
200 void TableFunctionsFactory::init() {
201  if (!g_enable_table_functions) {
202  return;
203  }
204  std::call_once(init_flag, []() {
205  %s
206  });
207 }
208 
209 } // namespace table_functions
210 ''' % (sys.argv[0], '\n '.join(add_stmts))
211 
212 output_filename = sys.argv[-1]
213 dirname = os.path.dirname(output_filename)
214 if not os.path.exists(dirname):
215  os.makedirs(dirname)
216 
217 f = open(output_filename, 'w')
218 f.write(content)
219 f.close()
int open(const char *path, int flags, int mode)
Definition: omnisci_fs.cpp:64
std::string strip(std::string_view str)
trim any whitespace from the left and right ends of a string
std::string join(T const &container, std::string const &delim)
std::vector< std::string > split(std::string_view str, std::string_view delim, std::optional< size_t > maxsplit)
split apart a string into a vector of substrings