OmniSciDB  04ee39c94c
MapDParser.java
Go to the documentation of this file.
1 /*
2  * Copyright 2019 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.mapd.calcite.parser;
17 
18 import com.google.common.collect.ImmutableList;
21 
22 import org.apache.calcite.avatica.util.Casing;
23 import org.apache.calcite.config.CalciteConnectionConfig;
24 import org.apache.calcite.config.CalciteConnectionConfigImpl;
25 import org.apache.calcite.plan.Context;
26 import org.apache.calcite.plan.RelOptLattice;
27 import org.apache.calcite.plan.RelOptMaterialization;
28 import org.apache.calcite.plan.RelOptUtil;
31 import org.apache.calcite.rel.RelNode;
32 import org.apache.calcite.rel.RelRoot;
33 import org.apache.calcite.rel.RelShuttleImpl;
34 import org.apache.calcite.rel.core.RelFactories;
35 import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider;
36 import org.apache.calcite.rel.rules.FilterMergeRule;
37 import org.apache.calcite.rel.rules.FilterProjectTransposeRule;
38 import org.apache.calcite.rel.rules.JoinProjectTransposeRule;
39 import org.apache.calcite.rel.rules.ProjectMergeRule;
40 import org.apache.calcite.rel.type.RelDataTypeFactory;
41 import org.apache.calcite.rel.type.RelDataTypeSystem;
42 import org.apache.calcite.rex.RexBuilder;
43 import org.apache.calcite.rex.RexCall;
44 import org.apache.calcite.rex.RexNode;
45 import org.apache.calcite.rex.RexShuttle;
46 import org.apache.calcite.schema.SchemaPlus;
47 import org.apache.calcite.sql.SqlAsOperator;
48 import org.apache.calcite.sql.SqlBasicCall;
49 import org.apache.calcite.sql.SqlCall;
50 import org.apache.calcite.sql.SqlDialect;
51 import org.apache.calcite.sql.SqlIdentifier;
52 import org.apache.calcite.sql.SqlKind;
53 import org.apache.calcite.sql.SqlLiteral;
54 import org.apache.calcite.sql.SqlNode;
55 import org.apache.calcite.sql.SqlNodeList;
56 import org.apache.calcite.sql.SqlNumericLiteral;
57 import org.apache.calcite.sql.SqlOperatorTable;
58 import org.apache.calcite.sql.SqlOrderBy;
59 import org.apache.calcite.sql.SqlSelect;
60 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
61 import org.apache.calcite.sql.parser.SqlParseException;
62 import org.apache.calcite.sql.parser.SqlParser;
63 import org.apache.calcite.sql.parser.SqlParserPos;
64 import org.apache.calcite.sql.type.SqlTypeName;
65 import org.apache.calcite.sql.type.SqlTypeUtil;
66 import org.apache.calcite.sql.validate.SqlConformanceEnum;
68 import org.apache.calcite.tools.FrameworkConfig;
69 import org.apache.calcite.tools.Frameworks;
70 import org.apache.calcite.tools.Planner;
71 import org.apache.calcite.tools.Program;
72 import org.apache.calcite.tools.Programs;
73 import org.apache.calcite.tools.RelConversionException;
74 import org.apache.calcite.tools.ValidationException;
75 import org.apache.calcite.util.ConversionUtil;
76 import org.slf4j.Logger;
77 import org.slf4j.LoggerFactory;
78 
79 import java.util.HashSet;
80 import java.util.List;
81 import java.util.Map;
82 import java.util.Properties;
83 import java.util.Set;
84 
89 public final class MapDParser {
90  public static final ThreadLocal<MapDParser> CURRENT_PARSER = new ThreadLocal<>();
91 
92  final static Logger MAPDLOGGER = LoggerFactory.getLogger(MapDParser.class);
93 
94  // private SqlTypeFactoryImpl typeFactory;
95  // private MapDCatalogReader catalogReader;
96  // private SqlValidatorImpl validator;
97  // private SqlToRelConverter converter;
98  private final Map<String, ExtensionFunction> extSigs;
99  private final String dataDir;
100 
101  private int callCount = 0;
102  private final int mapdPort;
104  SqlNode sqlNode_;
106 
107  public MapDParser(String dataDir,
108  final Map<String, ExtensionFunction> extSigs,
109  int mapdPort,
111  System.setProperty(
112  "saffron.default.charset", ConversionUtil.NATIVE_UTF16_CHARSET_NAME);
113  System.setProperty(
114  "saffron.default.nationalcharset", ConversionUtil.NATIVE_UTF16_CHARSET_NAME);
115  System.setProperty("saffron.default.collation.name",
116  ConversionUtil.NATIVE_UTF16_CHARSET_NAME + "$en_US");
117  this.dataDir = dataDir;
118  this.extSigs = extSigs;
119  this.mapdPort = mapdPort;
120  this.sock_transport_properties = skT;
121  }
122 
123  private static final Context MAPD_CONNECTION_CONTEXT = new Context() {
125  CalciteConnectionConfig config = new CalciteConnectionConfigImpl(new Properties()) {
126  @SuppressWarnings("unchecked")
127  public <T extends Object> T typeSystem(
128  java.lang.Class<T> typeSystemClass, T defaultTypeSystem) {
129  return (T) myTypeSystem;
130  };
131 
132  public boolean caseSensitive() {
133  return false;
134  };
135 
136  public org.apache.calcite.sql.validate.SqlConformance conformance() {
137  return SqlConformanceEnum.LENIENT;
138  };
139  };
140 
141  @Override
142  public <C> C unwrap(Class<C> aClass) {
143  if (aClass.isInstance(config)) {
144  return aClass.cast(config);
145  }
146  return null;
147  }
148  };
149 
151  MapDSchema mapd =
152  new MapDSchema(dataDir, this, mapdPort, mapdUser, sock_transport_properties);
153  final SchemaPlus rootSchema = Frameworks.createRootSchema(true);
154  final FrameworkConfig config =
155  Frameworks.newConfigBuilder()
156  .defaultSchema(rootSchema.add(mapdUser.getDB(), mapd))
157  .operatorTable(createOperatorTable(extSigs))
158  .parserConfig(SqlParser.configBuilder()
159  .setConformance(SqlConformanceEnum.LENIENT)
160  .setUnquotedCasing(Casing.UNCHANGED)
161  .setCaseSensitive(false)
162  .build())
163  .sqlToRelConverterConfig(
165  .configBuilder()
166  // disable sub-query expansion (in-lining)
167  .withExpand(false)
168  // allow as many as possible IN operator values
169  .withInSubQueryThreshold(Integer.MAX_VALUE)
170  .build())
171  .typeSystem(createTypeSystem())
172  .context(MAPD_CONNECTION_CONTEXT)
173  .build();
174  return new MapDPlanner(config);
175  }
176 
177  public void setUser(MapDUser mapdUser) {
178  this.mapdUser = mapdUser;
179  }
180 
182  String sql, final MapDParserOptions parserOptions, final MapDUser mapDUser)
183  throws SqlParseException, ValidationException, RelConversionException {
184  callCount++;
185  final RelRoot sqlRel = queryToSqlNode(sql, parserOptions);
186 
187  RelNode project = sqlRel.project();
188 
189  if (parserOptions.isExplain()) {
190  return RelOptUtil.toString(sqlRel.project());
191  }
192 
193  String res = MapDSerializer.toString(project);
194 
195  return res;
196  }
197 
199  String sql, int cursor, List<String> visible_tables) {
200  return getPlanner().getCompletionHints(sql, cursor, visible_tables);
201  }
202 
203  public Set<String> resolveSelectIdentifiers(SqlIdentifierCapturer capturer) {
204  MapDSchema schema =
205  new MapDSchema(dataDir, this, mapdPort, mapdUser, sock_transport_properties);
206  HashSet<String> resolved = new HashSet<>();
207 
208  for (String name : capturer.selects) {
209  MapDTable table = (MapDTable) schema.getTable(name);
210  if (null == table) {
211  throw new RuntimeException("table/view not found: " + name);
212  }
213 
214  if (table instanceof MapDView) {
215  MapDView view = (MapDView) table;
216  resolved.addAll(resolveSelectIdentifiers(view.getAccessedObjects()));
217  } else {
218  resolved.add(name);
219  }
220  }
221 
222  return resolved;
223  }
224 
225  RelRoot queryToSqlNode(final String sql, final MapDParserOptions parserOptions)
226  throws SqlParseException, ValidationException, RelConversionException {
227  MapDPlanner planner = getPlanner();
228 
229  SqlNode node = processSQL(sql, parserOptions.isLegacySyntax(), planner);
230 
231  if (parserOptions.isLegacySyntax()) {
232  // close original planner
233  planner.close();
234  // create a new one
235  planner = getPlanner();
236  node = processSQL(node.toSqlString(SqlDialect.CALCITE).toString(), false, planner);
237  }
238 
239  boolean is_select_star = isSelectStar(node);
240 
241  SqlNode validateR = planner.validate(node);
242  SqlSelect validate_select = getSelectChild(validateR);
243 
244  // Hide rowid from select * queries
245  if (parserOptions.isLegacySyntax() && is_select_star && validate_select != null) {
246  SqlNodeList proj_exprs = ((SqlSelect) validateR).getSelectList();
247  SqlNodeList new_proj_exprs = new SqlNodeList(proj_exprs.getParserPosition());
248  for (SqlNode proj_expr : proj_exprs) {
249  final SqlNode unaliased_proj_expr = getUnaliasedExpression(proj_expr);
250 
251  if (unaliased_proj_expr instanceof SqlIdentifier) {
252  if ((((SqlIdentifier) unaliased_proj_expr).toString().toLowerCase())
253  .endsWith(".rowid")) {
254  continue;
255  }
256  }
257  new_proj_exprs.add(proj_expr);
258  }
259  validate_select.setSelectList(new_proj_exprs);
260 
261  // trick planner back into correct state for validate
262  planner.close();
263  // create a new one
264  planner = getPlanner();
265  processSQL(validateR.toSqlString(SqlDialect.CALCITE).toString(), false, planner);
266  // now validate the new modified SqlNode;
267  validateR = planner.validate(validateR);
268  }
269 
270  planner.setFilterPushDownInfo(parserOptions.getFilterPushDownInfo());
271  RelRoot relR = planner.rel(validateR);
272  relR = replaceIsTrue(planner.getTypeFactory(), relR);
273  planner.close();
274 
275  if (!parserOptions.isViewOptimizeEnabled()) {
276  return relR;
277  } else {
278  // check to see if a view is involved in the query
279  boolean foundView = false;
280  MapDSchema schema = new MapDSchema(
281  dataDir, this, mapdPort, mapdUser, sock_transport_properties);
282  SqlIdentifierCapturer capturer =
283  captureIdentifiers(sql, parserOptions.isLegacySyntax());
284  for (String name : capturer.selects) {
285  MapDTable table = (MapDTable) schema.getTable(name);
286  if (null == table) {
287  throw new RuntimeException("table/view not found: " + name);
288  }
289  if (table instanceof MapDView) {
290  foundView = true;
291  }
292  }
293 
294  if (!foundView) {
295  return relR;
296  }
297 
298  // do some calcite based optimization
299  // will allow duplicate projects to merge
300  ProjectMergeRule projectMergeRule =
301  new ProjectMergeRule(true, RelFactories.LOGICAL_BUILDER);
302  final Program program =
303  Programs.hep(ImmutableList.of(FilterProjectTransposeRule.INSTANCE,
304  projectMergeRule,
306  FilterMergeRule.INSTANCE,
307  JoinProjectTransposeRule.LEFT_PROJECT_INCLUDE_OUTER,
308  JoinProjectTransposeRule.RIGHT_PROJECT_INCLUDE_OUTER,
309  JoinProjectTransposeRule.BOTH_PROJECT_INCLUDE_OUTER),
310  true,
311  DefaultRelMetadataProvider.INSTANCE);
312 
313  RelNode oldRel;
314  RelNode newRel = relR.project();
315 
316  do {
317  oldRel = newRel;
318  newRel = program.run(null,
319  oldRel,
320  null,
321  ImmutableList.<RelOptMaterialization>of(),
322  ImmutableList.<RelOptLattice>of());
323  // there must be a better way to compare these
324  } while (!RelOptUtil.toString(oldRel).equals(RelOptUtil.toString(newRel)));
325  RelRoot optRel = RelRoot.of(newRel, relR.kind);
326  return optRel;
327  }
328  }
329 
330  private RelRoot replaceIsTrue(final RelDataTypeFactory typeFactory, RelRoot root) {
331  final RexShuttle callShuttle = new RexShuttle() {
332  RexBuilder builder = new RexBuilder(typeFactory);
333 
334  public RexNode visitCall(RexCall call) {
335  call = (RexCall) super.visitCall(call);
336  if (call.getKind() == SqlKind.IS_TRUE) {
337  return builder.makeCall(SqlStdOperatorTable.AND,
338  builder.makeCall(
339  SqlStdOperatorTable.IS_NOT_NULL, call.getOperands().get(0)),
340  call.getOperands().get(0));
341  } else if (call.getKind() == SqlKind.IS_NOT_TRUE) {
342  return builder.makeCall(SqlStdOperatorTable.OR,
343  builder.makeCall(
344  SqlStdOperatorTable.IS_NULL, call.getOperands().get(0)),
345  builder.makeCall(SqlStdOperatorTable.NOT, call.getOperands().get(0)));
346  } else if (call.getKind() == SqlKind.IS_FALSE) {
347  return builder.makeCall(SqlStdOperatorTable.AND,
348  builder.makeCall(
349  SqlStdOperatorTable.IS_NOT_NULL, call.getOperands().get(0)),
350  builder.makeCall(SqlStdOperatorTable.NOT, call.getOperands().get(0)));
351  } else if (call.getKind() == SqlKind.IS_NOT_FALSE) {
352  return builder.makeCall(SqlStdOperatorTable.OR,
353  builder.makeCall(
354  SqlStdOperatorTable.IS_NULL, call.getOperands().get(0)),
355  call.getOperands().get(0));
356  }
357 
358  return call;
359  }
360  };
361 
362  RelNode node = root.rel.accept(new RelShuttleImpl() {
363  @Override
364  protected RelNode visitChild(RelNode parent, int i, RelNode child) {
365  RelNode node = super.visitChild(parent, i, child);
366  return node.accept(callShuttle);
367  }
368  });
369 
370  return new RelRoot(
371  node, root.validatedRowType, root.kind, root.fields, root.collation);
372  }
373 
374  private static SqlNode getUnaliasedExpression(final SqlNode node) {
375  if (node instanceof SqlBasicCall
376  && ((SqlBasicCall) node).getOperator() instanceof SqlAsOperator) {
377  SqlNode[] operands = ((SqlBasicCall) node).getOperands();
378  return operands[0];
379  }
380  return node;
381  }
382 
383  private static boolean isSelectStar(SqlNode node) {
384  SqlSelect select_node = getSelectChild(node);
385  if (select_node == null) {
386  return false;
387  }
388  SqlNode from = getUnaliasedExpression(select_node.getFrom());
389  if (from instanceof SqlCall) {
390  return false;
391  }
392  SqlNodeList proj_exprs = select_node.getSelectList();
393  if (proj_exprs.size() != 1) {
394  return false;
395  }
396  SqlNode proj_expr = proj_exprs.get(0);
397  if (!(proj_expr instanceof SqlIdentifier)) {
398  return false;
399  }
400  return ((SqlIdentifier) proj_expr).isStar();
401  }
402 
403  private static SqlSelect getSelectChild(SqlNode node) {
404  if (node instanceof SqlSelect) {
405  return (SqlSelect) node;
406  }
407  if (node instanceof SqlOrderBy) {
408  SqlOrderBy order_by_node = (SqlOrderBy) node;
409  if (order_by_node.query instanceof SqlSelect) {
410  return (SqlSelect) order_by_node.query;
411  }
412  }
413  return null;
414  }
415 
416  private SqlNode processSQL(String sql, final boolean legacy_syntax, Planner planner)
417  throws SqlParseException {
418  SqlNode parseR = null;
419  try {
420  parseR = planner.parse(sql);
421  MAPDLOGGER.debug(" node is \n" + parseR.toString());
422  } catch (SqlParseException ex) {
423  MAPDLOGGER.error("failed to process SQL '" + sql + "' \n" + ex.toString());
424  throw ex;
425  }
426 
427  if (!legacy_syntax) {
428  return parseR;
429  }
430  RelDataTypeFactory typeFactory = planner.getTypeFactory();
431  SqlSelect select_node = null;
432  if (parseR instanceof SqlSelect) {
433  select_node = (SqlSelect) parseR;
434  desugar(select_node, typeFactory);
435  } else if (parseR instanceof SqlOrderBy) {
436  SqlOrderBy order_by_node = (SqlOrderBy) parseR;
437  if (order_by_node.query instanceof SqlSelect) {
438  select_node = (SqlSelect) order_by_node.query;
439  SqlOrderBy new_order_by_node = desugar(select_node, order_by_node, typeFactory);
440  if (new_order_by_node != null) {
441  return new_order_by_node;
442  }
443  }
444  }
445  return parseR;
446  }
447 
448  private void desugar(SqlSelect select_node, RelDataTypeFactory typeFactory) {
449  desugar(select_node, null, typeFactory);
450  }
451 
452  private SqlOrderBy desugar(SqlSelect select_node,
453  SqlOrderBy order_by_node,
454  RelDataTypeFactory typeFactory) {
455  MAPDLOGGER.debug("desugar: before: " + select_node.toString());
456  desugarExpression(select_node.getFrom(), typeFactory);
457  desugarExpression(select_node.getWhere(), typeFactory);
458  SqlNodeList select_list = select_node.getSelectList();
459  SqlNodeList new_select_list = new SqlNodeList(select_list.getParserPosition());
460  java.util.Map<String, SqlNode> id_to_expr = new java.util.HashMap<String, SqlNode>();
461  for (SqlNode proj : select_list) {
462  if (!(proj instanceof SqlBasicCall)) {
463  new_select_list.add(proj);
464  continue;
465  }
466  SqlBasicCall proj_call = (SqlBasicCall) proj;
467  new_select_list.add(expand(proj_call, id_to_expr, typeFactory));
468  }
469  select_node.setSelectList(new_select_list);
470  SqlNodeList group_by_list = select_node.getGroup();
471  if (group_by_list != null) {
472  select_node.setGroupBy(expand(group_by_list, id_to_expr, typeFactory));
473  }
474  SqlNode having = select_node.getHaving();
475  if (having != null) {
476  expand(having, id_to_expr, typeFactory);
477  }
478  SqlOrderBy new_order_by_node = null;
479  if (order_by_node != null && order_by_node.orderList != null
480  && order_by_node.orderList.size() > 0) {
481  SqlNodeList new_order_by_list =
482  expand(order_by_node.orderList, id_to_expr, typeFactory);
483  new_order_by_node = new SqlOrderBy(order_by_node.getParserPosition(),
484  select_node,
485  new_order_by_list,
486  order_by_node.offset,
487  order_by_node.fetch);
488  }
489 
490  MAPDLOGGER.debug("desugar: after: " + select_node.toString());
491  return new_order_by_node;
492  }
493 
494  private void desugarExpression(SqlNode node, RelDataTypeFactory typeFactory) {
495  if (node instanceof SqlSelect) {
496  desugar((SqlSelect) node, typeFactory);
497  return;
498  }
499  if (!(node instanceof SqlBasicCall)) {
500  return;
501  }
502  SqlBasicCall basic_call = (SqlBasicCall) node;
503  for (SqlNode operator : basic_call.getOperands()) {
504  if (operator instanceof SqlOrderBy) {
505  desugarExpression(((SqlOrderBy) operator).query, typeFactory);
506  } else {
507  desugarExpression(operator, typeFactory);
508  }
509  }
510  }
511 
512  private SqlNode expand(final SqlNode node,
513  final java.util.Map<String, SqlNode> id_to_expr,
514  RelDataTypeFactory typeFactory) {
515  MAPDLOGGER.debug("expand: " + node.toString());
516  if (node instanceof SqlBasicCall) {
517  SqlBasicCall node_call = (SqlBasicCall) node;
518  SqlNode[] operands = node_call.getOperands();
519  for (int i = 0; i < operands.length; ++i) {
520  node_call.setOperand(i, expand(operands[i], id_to_expr, typeFactory));
521  }
522  SqlNode expanded_variance = expandVariance(node_call, typeFactory);
523  if (expanded_variance != null) {
524  return expanded_variance;
525  }
526  SqlNode expanded_covariance = expandCovariance(node_call, typeFactory);
527  if (expanded_covariance != null) {
528  return expanded_covariance;
529  }
530  SqlNode expanded_correlation = expandCorrelation(node_call, typeFactory);
531  if (expanded_correlation != null) {
532  return expanded_correlation;
533  }
534  }
535  if (node instanceof SqlSelect) {
536  SqlSelect select_node = (SqlSelect) node;
537  desugar(select_node, typeFactory);
538  }
539  return node;
540  }
541 
542  private SqlNodeList expand(final SqlNodeList group_by_list,
543  final java.util.Map<String, SqlNode> id_to_expr,
544  RelDataTypeFactory typeFactory) {
545  SqlNodeList new_group_by_list = new SqlNodeList(new SqlParserPos(-1, -1));
546  for (SqlNode group_by : group_by_list) {
547  if (!(group_by instanceof SqlIdentifier)) {
548  new_group_by_list.add(expand(group_by, id_to_expr, typeFactory));
549  continue;
550  }
551  SqlIdentifier group_by_id = ((SqlIdentifier) group_by);
552  if (id_to_expr.containsKey(group_by_id.toString())) {
553  new_group_by_list.add(id_to_expr.get(group_by_id.toString()));
554  } else {
555  new_group_by_list.add(group_by);
556  }
557  }
558  return new_group_by_list;
559  }
560 
561  private SqlNode expandVariance(
562  final SqlBasicCall proj_call, RelDataTypeFactory typeFactory) {
563  // Expand variance aggregates that are not supported natively
564  if (proj_call.operandCount() != 1) {
565  return null;
566  }
567  boolean biased;
568  boolean sqrt;
569  boolean flt;
570  if (proj_call.getOperator().isName("STDDEV_POP", false)) {
571  biased = true;
572  sqrt = true;
573  flt = false;
574  } else if (proj_call.getOperator().getName().equalsIgnoreCase("STDDEV_POP_FLOAT")) {
575  biased = true;
576  sqrt = true;
577  flt = true;
578  } else if (proj_call.getOperator().isName("STDDEV_SAMP", false)
579  || proj_call.getOperator().getName().equalsIgnoreCase("STDDEV")) {
580  biased = false;
581  sqrt = true;
582  flt = false;
583  } else if (proj_call.getOperator().getName().equalsIgnoreCase("STDDEV_SAMP_FLOAT")
584  || proj_call.getOperator().getName().equalsIgnoreCase("STDDEV_FLOAT")) {
585  biased = false;
586  sqrt = true;
587  flt = true;
588  } else if (proj_call.getOperator().isName("VAR_POP", false)) {
589  biased = true;
590  sqrt = false;
591  flt = false;
592  } else if (proj_call.getOperator().getName().equalsIgnoreCase("VAR_POP_FLOAT")) {
593  biased = true;
594  sqrt = false;
595  flt = true;
596  } else if (proj_call.getOperator().isName("VAR_SAMP", false)
597  || proj_call.getOperator().getName().equalsIgnoreCase("VARIANCE")) {
598  biased = false;
599  sqrt = false;
600  flt = false;
601  } else if (proj_call.getOperator().getName().equalsIgnoreCase("VAR_SAMP_FLOAT")
602  || proj_call.getOperator().getName().equalsIgnoreCase("VARIANCE_FLOAT")) {
603  biased = false;
604  sqrt = false;
605  flt = true;
606  } else {
607  return null;
608  }
609  final SqlNode operand = proj_call.operand(0);
610  final SqlParserPos pos = proj_call.getParserPosition();
611  SqlNode expanded_proj_call =
612  expandVariance(pos, operand, biased, sqrt, flt, typeFactory);
613  MAPDLOGGER.debug("Expanded select_list SqlCall: " + proj_call.toString());
614  MAPDLOGGER.debug("to : " + expanded_proj_call.toString());
615  return expanded_proj_call;
616  }
617 
618  private SqlNode expandVariance(final SqlParserPos pos,
619  final SqlNode operand,
620  boolean biased,
621  boolean sqrt,
622  boolean flt,
623  RelDataTypeFactory typeFactory) {
624  // stddev_pop(x) ==>
625  // power(
626  // (sum(x * x) - sum(x) * sum(x) / (case count(x) when 0 then NULL else count(x)
627  // end)) / (case count(x) when 0 then NULL else count(x) end), .5)
628  //
629  // stddev_samp(x) ==>
630  // power(
631  // (sum(x * x) - sum(x) * sum(x) / (case count(x) when 0 then NULL else count(x)
632  // )) / ((case count(x) when 1 then NULL else count(x) - 1 end)), .5)
633  //
634  // var_pop(x) ==>
635  // (sum(x * x) - sum(x) * sum(x) / ((case count(x) when 0 then NULL else
636  // count(x)
637  // end))) / ((case count(x) when 0 then NULL else count(x) end))
638  //
639  // var_samp(x) ==>
640  // (sum(x * x) - sum(x) * sum(x) / ((case count(x) when 0 then NULL else
641  // count(x)
642  // end))) / ((case count(x) when 1 then NULL else count(x) - 1 end))
643  //
644  final SqlNode arg = SqlStdOperatorTable.CAST.createCall(pos,
645  operand,
646  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
647  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
648  final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
649  final SqlNode sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared);
650  final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg);
651  final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, sum, sum);
652  final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg);
653  final SqlLiteral nul = SqlLiteral.createNull(pos);
654  final SqlNumericLiteral zero = SqlLiteral.createExactNumeric("0", pos);
655  final SqlNode countEqZero = SqlStdOperatorTable.EQUALS.createCall(pos, count, zero);
656  SqlNodeList whenList = new SqlNodeList(pos);
657  SqlNodeList thenList = new SqlNodeList(pos);
658  whenList.add(countEqZero);
659  thenList.add(nul);
660  final SqlNode int_denominator = SqlStdOperatorTable.CASE.createCall(
661  null, pos, null, whenList, thenList, count);
662  final SqlNode denominator = SqlStdOperatorTable.CAST.createCall(pos,
663  int_denominator,
664  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
665  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
666  final SqlNode avgSumSquared =
667  SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, denominator);
668  final SqlNode diff =
669  SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquared, avgSumSquared);
670  final SqlNode denominator1;
671  if (biased) {
672  denominator1 = denominator;
673  } else {
674  final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
675  final SqlNode countEqOne = SqlStdOperatorTable.EQUALS.createCall(pos, count, one);
676  final SqlNode countMinusOne = SqlStdOperatorTable.MINUS.createCall(pos, count, one);
677  SqlNodeList whenList1 = new SqlNodeList(pos);
678  SqlNodeList thenList1 = new SqlNodeList(pos);
679  whenList1.add(countEqOne);
680  thenList1.add(nul);
681  final SqlNode int_denominator1 = SqlStdOperatorTable.CASE.createCall(
682  null, pos, null, whenList1, thenList1, countMinusOne);
683  denominator1 = SqlStdOperatorTable.CAST.createCall(pos,
684  int_denominator1,
685  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
686  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
687  }
688  final SqlNode div = SqlStdOperatorTable.DIVIDE.createCall(pos, diff, denominator1);
689  SqlNode result = div;
690  if (sqrt) {
691  final SqlNumericLiteral half = SqlLiteral.createExactNumeric("0.5", pos);
692  result = SqlStdOperatorTable.POWER.createCall(pos, div, half);
693  }
694  return SqlStdOperatorTable.CAST.createCall(pos,
695  result,
696  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
697  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
698  }
699 
700  private SqlNode expandCovariance(
701  final SqlBasicCall proj_call, RelDataTypeFactory typeFactory) {
702  // Expand covariance aggregates
703  if (proj_call.operandCount() != 2) {
704  return null;
705  }
706  boolean pop;
707  boolean flt;
708  if (proj_call.getOperator().isName("COVAR_POP", false)) {
709  pop = true;
710  flt = false;
711  } else if (proj_call.getOperator().isName("COVAR_SAMP", false)) {
712  pop = false;
713  flt = false;
714  } else if (proj_call.getOperator().getName().equalsIgnoreCase("COVAR_POP_FLOAT")) {
715  pop = true;
716  flt = true;
717  } else if (proj_call.getOperator().getName().equalsIgnoreCase("COVAR_SAMP_FLOAT")) {
718  pop = false;
719  flt = true;
720  } else {
721  return null;
722  }
723  final SqlNode operand0 = proj_call.operand(0);
724  final SqlNode operand1 = proj_call.operand(1);
725  final SqlParserPos pos = proj_call.getParserPosition();
726  SqlNode expanded_proj_call =
727  expandCovariance(pos, operand0, operand1, pop, flt, typeFactory);
728  MAPDLOGGER.debug("Expanded select_list SqlCall: " + proj_call.toString());
729  MAPDLOGGER.debug("to : " + expanded_proj_call.toString());
730  return expanded_proj_call;
731  }
732 
733  private SqlNode expandCovariance(SqlParserPos pos,
734  final SqlNode operand0,
735  final SqlNode operand1,
736  boolean pop,
737  boolean flt,
738  RelDataTypeFactory typeFactory) {
739  // covar_pop(x, y) ==> avg(x * y) - avg(x) * avg(y)
740  // covar_samp(x, y) ==> (sum(x * y) - sum(x) * avg(y))
741  // ((case count(x) when 1 then NULL else count(x) - 1 end))
742  final SqlNode arg0 = SqlStdOperatorTable.CAST.createCall(operand0.getParserPosition(),
743  operand0,
744  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
745  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
746  final SqlNode arg1 = SqlStdOperatorTable.CAST.createCall(operand1.getParserPosition(),
747  operand1,
748  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
749  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
750  final SqlNode mulArg = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg0, arg1);
751  final SqlNode avgArg1 = SqlStdOperatorTable.AVG.createCall(pos, arg1);
752  if (pop) {
753  final SqlNode avgMulArg = SqlStdOperatorTable.AVG.createCall(pos, mulArg);
754  final SqlNode avgArg0 = SqlStdOperatorTable.AVG.createCall(pos, arg0);
755  final SqlNode mulAvgAvg =
756  SqlStdOperatorTable.MULTIPLY.createCall(pos, avgArg0, avgArg1);
757  final SqlNode covarPop =
758  SqlStdOperatorTable.MINUS.createCall(pos, avgMulArg, mulAvgAvg);
759  return SqlStdOperatorTable.CAST.createCall(pos,
760  covarPop,
761  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
762  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
763  }
764  final SqlNode sumMulArg = SqlStdOperatorTable.SUM.createCall(pos, mulArg);
765  final SqlNode sumArg0 = SqlStdOperatorTable.SUM.createCall(pos, arg0);
766  final SqlNode mulSumAvg =
767  SqlStdOperatorTable.MULTIPLY.createCall(pos, sumArg0, avgArg1);
768  final SqlNode sub = SqlStdOperatorTable.MINUS.createCall(pos, sumMulArg, mulSumAvg);
769  final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, operand0);
770  final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
771  final SqlNode countEqOne = SqlStdOperatorTable.EQUALS.createCall(pos, count, one);
772  final SqlNode countMinusOne = SqlStdOperatorTable.MINUS.createCall(pos, count, one);
773  final SqlLiteral nul = SqlLiteral.createNull(pos);
774  SqlNodeList whenList1 = new SqlNodeList(pos);
775  SqlNodeList thenList1 = new SqlNodeList(pos);
776  whenList1.add(countEqOne);
777  thenList1.add(nul);
778  final SqlNode int_denominator = SqlStdOperatorTable.CASE.createCall(
779  null, pos, null, whenList1, thenList1, countMinusOne);
780  final SqlNode denominator = SqlStdOperatorTable.CAST.createCall(pos,
781  int_denominator,
782  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
783  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
784  final SqlNode covarSamp =
785  SqlStdOperatorTable.DIVIDE.createCall(pos, sub, denominator);
786  return SqlStdOperatorTable.CAST.createCall(pos,
787  covarSamp,
788  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
789  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
790  }
791 
792  private SqlNode expandCorrelation(
793  final SqlBasicCall proj_call, RelDataTypeFactory typeFactory) {
794  // Expand correlation coefficient
795  if (proj_call.operandCount() != 2) {
796  return null;
797  }
798  boolean flt;
799  if (proj_call.getOperator().isName("CORR", false)
800  || proj_call.getOperator().getName().equalsIgnoreCase("CORRELATION")) {
801  // expand correlation coefficient
802  flt = false;
803  } else if (proj_call.getOperator().getName().equalsIgnoreCase("CORR_FLOAT")
804  || proj_call.getOperator().getName().equalsIgnoreCase("CORRELATION_FLOAT")) {
805  // expand correlation coefficient
806  flt = true;
807  } else {
808  return null;
809  }
810  // corr(x, y) ==> (avg(x * y) - avg(x) * avg(y)) / (stddev_pop(x) *
811  // stddev_pop(y))
812  // ==> covar_pop(x, y) / (stddev_pop(x) * stddev_pop(y))
813  final SqlNode operand0 = proj_call.operand(0);
814  final SqlNode operand1 = proj_call.operand(1);
815  final SqlParserPos pos = proj_call.getParserPosition();
816  SqlNode covariance =
817  expandCovariance(pos, operand0, operand1, true, flt, typeFactory);
818  SqlNode stddev0 = expandVariance(pos, operand0, true, true, flt, typeFactory);
819  SqlNode stddev1 = expandVariance(pos, operand1, true, true, flt, typeFactory);
820  final SqlNode mulStddev =
821  SqlStdOperatorTable.MULTIPLY.createCall(pos, stddev0, stddev1);
822  final SqlNumericLiteral zero = SqlLiteral.createExactNumeric("0.0", pos);
823  final SqlNode mulStddevEqZero =
824  SqlStdOperatorTable.EQUALS.createCall(pos, mulStddev, zero);
825  final SqlLiteral nul = SqlLiteral.createNull(pos);
826  SqlNodeList whenList1 = new SqlNodeList(pos);
827  SqlNodeList thenList1 = new SqlNodeList(pos);
828  whenList1.add(mulStddevEqZero);
829  thenList1.add(nul);
830  final SqlNode denominator = SqlStdOperatorTable.CASE.createCall(
831  null, pos, null, whenList1, thenList1, mulStddev);
832  final SqlNode expanded_proj_call =
833  SqlStdOperatorTable.DIVIDE.createCall(pos, covariance, denominator);
834  MAPDLOGGER.debug("Expanded select_list SqlCall: " + proj_call.toString());
835  MAPDLOGGER.debug("to : " + expanded_proj_call.toString());
836  return expanded_proj_call;
837  }
838 
845  protected SqlOperatorTable createOperatorTable(
846  final Map<String, ExtensionFunction> extSigs) {
847  final MapDSqlOperatorTable tempOpTab =
848  new MapDSqlOperatorTable(SqlStdOperatorTable.instance());
849  // MAT 11 Nov 2015
850  // Example of how to add custom function
851  MapDSqlOperatorTable.addUDF(tempOpTab, extSigs);
852  return tempOpTab;
853  }
854 
855  public SqlIdentifierCapturer captureIdentifiers(String sql, boolean legacy_syntax)
856  throws SqlParseException {
857  try {
858  Planner planner = getPlanner();
859  SqlNode node = processSQL(sql, legacy_syntax, planner);
861  capturer.scan(node);
862  return capturer;
863  } catch (Exception | Error e) {
864  MAPDLOGGER.error("Error parsing sql: " + sql, e);
865  return new SqlIdentifierCapturer();
866  }
867  }
868 
869  public int getCallCount() {
870  return callCount;
871  }
872 
873  public void updateMetaData(String schema, String table) {
874  MAPDLOGGER.debug("schema :" + schema + " table :" + table);
875  MapDSchema mapd =
876  new MapDSchema(dataDir, this, mapdPort, null, sock_transport_properties);
877  mapd.updateMetaData(schema, table);
878  }
879 
880  protected RelDataTypeSystem createTypeSystem() {
881  final MapDTypeSystem typeSystem = new MapDTypeSystem();
882  return typeSystem;
883  }
884 }
SqlNode expand(final SqlNode node, final java.util.Map< String, SqlNode > id_to_expr, RelDataTypeFactory typeFactory)
static boolean isSelectStar(SqlNode node)
RelRoot replaceIsTrue(final RelDataTypeFactory typeFactory, RelRoot root)
void desugarExpression(SqlNode node, RelDataTypeFactory typeFactory)
SqlNode expandCovariance(final SqlBasicCall proj_call, RelDataTypeFactory typeFactory)
SqlNode expandCorrelation(final SqlBasicCall proj_call, RelDataTypeFactory typeFactory)
static String toString(final RelNode rel)
auto sql(const std::string &sql_stmts)
Definition: DataGen.cpp:60
void updateMetaData(String schema, String table)
Definition: MapDSchema.java:95
SqlOperatorTable createOperatorTable(final Map< String, ExtensionFunction > extSigs)
SqlNode expandVariance(final SqlParserPos pos, final SqlNode operand, boolean biased, boolean sqrt, boolean flt, RelDataTypeFactory typeFactory)
SqlIdentifierCapturer getAccessedObjects()
Definition: MapDView.java:66
void updateMetaData(String schema, String table)
SockTransportProperties sock_transport_properties
static SqlNode getUnaliasedExpression(final SqlNode node)
SqlIdentifierCapturer captureIdentifiers(String sql, boolean legacy_syntax)
final Map< String, ExtensionFunction > extSigs
Definition: MapDParser.java:98
SqlNode processSQL(String sql, final boolean legacy_syntax, Planner planner)
static final Logger MAPDLOGGER
Definition: MapDParser.java:92
Set< String > resolveSelectIdentifiers(SqlIdentifierCapturer capturer)
static SqlSelect getSelectChild(SqlNode node)
Table getTable(String string)
Definition: MapDSchema.java:51
void setUser(MapDUser mapdUser)
SqlNodeList expand(final SqlNodeList group_by_list, final java.util.Map< String, SqlNode > id_to_expr, RelDataTypeFactory typeFactory)
SqlNode expandCovariance(SqlParserPos pos, final SqlNode operand0, final SqlNode operand1, boolean pop, boolean flt, RelDataTypeFactory typeFactory)
void desugar(SqlSelect select_node, RelDataTypeFactory typeFactory)
RelRoot queryToSqlNode(final String sql, final MapDParserOptions parserOptions)
MapDParser(String dataDir, final Map< String, ExtensionFunction > extSigs, int mapdPort, SockTransportProperties skT)
void setFilterPushDownInfo(final List< MapDParserOptions.FilterPushDownInfo > filterPushDownInfo)
static void addUDF(MapDSqlOperatorTable opTab, final Map< String, ExtensionFunction > extSigs)
CompletionResult getCompletionHints(final String sql, final int cursor, final List< String > visibleTables)
SqlOrderBy desugar(SqlSelect select_node, SqlOrderBy order_by_node, RelDataTypeFactory typeFactory)
static final ThreadLocal< MapDParser > CURRENT_PARSER
Definition: MapDParser.java:90
SqlNode expandVariance(final SqlBasicCall proj_call, RelDataTypeFactory typeFactory)
MapDPlanner.CompletionResult getCompletionHints(String sql, int cursor, List< String > visible_tables)
String getRelAlgebra(String sql, final MapDParserOptions parserOptions, final MapDUser mapDUser)
static final Context MAPD_CONNECTION_CONTEXT
RelDataTypeSystem createTypeSystem()