OmniSciDB  06b3bd477c
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
MapDParser.java
Go to the documentation of this file.
1 /*
2  * Copyright 2019 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.mapd.calcite.parser;
17 
18 import static org.apache.calcite.sql.parser.SqlParserPos.ZERO;
19 
20 import com.google.common.collect.ImmutableList;
22 import com.mapd.parser.extension.ddl.ExtendedSqlParser;
26 
27 import org.apache.calcite.avatica.util.Casing;
28 import org.apache.calcite.config.CalciteConnectionConfig;
29 import org.apache.calcite.config.CalciteConnectionConfigImpl;
30 import org.apache.calcite.config.CalciteConnectionProperty;
31 import org.apache.calcite.plan.Context;
32 import org.apache.calcite.plan.RelOptLattice;
33 import org.apache.calcite.plan.RelOptMaterialization;
34 import org.apache.calcite.plan.RelOptTable;
35 import org.apache.calcite.plan.RelOptUtil;
38 import org.apache.calcite.rel.RelNode;
39 import org.apache.calcite.rel.RelRoot;
40 import org.apache.calcite.rel.RelShuttleImpl;
41 import org.apache.calcite.rel.core.Join;
42 import org.apache.calcite.rel.core.RelFactories;
43 import org.apache.calcite.rel.core.TableModify;
44 import org.apache.calcite.rel.core.TableModify.Operation;
45 import org.apache.calcite.rel.hint.HintStrategyTable;
46 import org.apache.calcite.rel.logical.LogicalProject;
47 import org.apache.calcite.rel.logical.LogicalTableModify;
48 import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider;
49 import org.apache.calcite.rel.rules.FilterMergeRule;
50 import org.apache.calcite.rel.rules.FilterProjectTransposeRule;
51 import org.apache.calcite.rel.rules.JoinProjectTransposeRule;
52 import org.apache.calcite.rel.rules.ProjectMergeRule;
53 import org.apache.calcite.rel.type.RelDataType;
54 import org.apache.calcite.rel.type.RelDataTypeFactory;
55 import org.apache.calcite.rel.type.RelDataTypeField;
56 import org.apache.calcite.rel.type.RelDataTypeSystem;
57 import org.apache.calcite.rex.RexBuilder;
58 import org.apache.calcite.rex.RexCall;
59 import org.apache.calcite.rex.RexInputRef;
60 import org.apache.calcite.rex.RexNode;
61 import org.apache.calcite.rex.RexShuttle;
62 import org.apache.calcite.runtime.CalciteException;
63 import org.apache.calcite.schema.SchemaPlus;
64 import org.apache.calcite.sql.JoinConditionType;
66 import org.apache.calcite.sql.SqlBasicCall;
67 import org.apache.calcite.sql.SqlBasicTypeNameSpec;
68 import org.apache.calcite.sql.SqlCall;
69 import org.apache.calcite.sql.SqlDataTypeSpec;
70 import org.apache.calcite.sql.SqlDelete;
71 import org.apache.calcite.sql.SqlDialect;
72 import org.apache.calcite.sql.SqlFunctionCategory;
73 import org.apache.calcite.sql.SqlIdentifier;
74 import org.apache.calcite.sql.SqlJoin;
75 import org.apache.calcite.sql.SqlKind;
76 import org.apache.calcite.sql.SqlLiteral;
77 import org.apache.calcite.sql.SqlNode;
78 import org.apache.calcite.sql.SqlNodeList;
79 import org.apache.calcite.sql.SqlNumericLiteral;
80 import org.apache.calcite.sql.SqlOrderBy;
81 import org.apache.calcite.sql.SqlSelect;
82 import org.apache.calcite.sql.SqlUnresolvedFunction;
83 import org.apache.calcite.sql.SqlUpdate;
84 import org.apache.calcite.sql.SqlWith;
85 import org.apache.calcite.sql.fun.SqlCase;
86 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
87 import org.apache.calcite.sql.parser.SqlParseException;
88 import org.apache.calcite.sql.parser.SqlParser;
89 import org.apache.calcite.sql.parser.SqlParserPos;
90 import org.apache.calcite.sql.type.SqlTypeName;
91 import org.apache.calcite.sql.type.SqlTypeUtil;
92 import org.apache.calcite.sql.util.SqlBasicVisitor;
93 import org.apache.calcite.sql.util.SqlShuttle;
94 import org.apache.calcite.sql.validate.SqlConformanceEnum;
96 import org.apache.calcite.tools.FrameworkConfig;
97 import org.apache.calcite.tools.Frameworks;
98 import org.apache.calcite.tools.Planner;
99 import org.apache.calcite.tools.Program;
100 import org.apache.calcite.tools.Programs;
101 import org.apache.calcite.tools.RelConversionException;
102 import org.apache.calcite.tools.ValidationException;
103 import org.apache.calcite.util.Pair;
104 import org.apache.calcite.util.Util;
105 import org.slf4j.Logger;
106 import org.slf4j.LoggerFactory;
107 
108 import java.lang.reflect.Field;
109 import java.util.ArrayList;
110 import java.util.Arrays;
111 import java.util.Collections;
112 import java.util.EnumSet;
113 import java.util.HashSet;
114 import java.util.List;
115 import java.util.Map;
116 import java.util.Properties;
117 import java.util.Set;
118 import java.util.concurrent.ConcurrentHashMap;
119 import java.util.function.BiPredicate;
120 import java.util.function.Supplier;
121 
126 public final class MapDParser {
127  public static final ThreadLocal<MapDParser> CURRENT_PARSER = new ThreadLocal<>();
128  private static final EnumSet<SqlKind> SCALAR =
129  EnumSet.of(SqlKind.SCALAR_QUERY, SqlKind.SELECT);
130  private static final EnumSet<SqlKind> EXISTS = EnumSet.of(SqlKind.EXISTS);
131  private static final EnumSet<SqlKind> DELETE = EnumSet.of(SqlKind.DELETE);
132  private static final EnumSet<SqlKind> UPDATE = EnumSet.of(SqlKind.UPDATE);
133  private static final EnumSet<SqlKind> IN = EnumSet.of(SqlKind.IN);
134  private static final EnumSet<SqlKind> ARRAY_VALUE =
135  EnumSet.of(SqlKind.ARRAY_VALUE_CONSTRUCTOR);
136 
137  final static Logger MAPDLOGGER = LoggerFactory.getLogger(MapDParser.class);
138 
139  // private SqlTypeFactoryImpl typeFactory;
140  // private MapDCatalogReader catalogReader;
141  // private SqlValidatorImpl validator;
142  // private SqlToRelConverter converter;
143  private final Supplier<MapDSqlOperatorTable> mapDSqlOperatorTable;
144  private final String dataDir;
145 
146  private int callCount = 0;
147  private final int mapdPort;
149  SqlNode sqlNode_;
151 
152  private static Map<String, Boolean> SubqueryCorrMemo = new ConcurrentHashMap<>();
153 
154  public MapDParser(String dataDir,
155  final Supplier<MapDSqlOperatorTable> mapDSqlOperatorTable,
156  int mapdPort,
158  this.dataDir = dataDir;
159  this.mapDSqlOperatorTable = mapDSqlOperatorTable;
160  this.mapdPort = mapdPort;
161  this.sock_transport_properties = skT;
162  }
163 
164  public void clearMemo() {
165  SubqueryCorrMemo.clear();
166  }
167 
168  private static final Context MAPD_CONNECTION_CONTEXT = new Context() {
169  MapDTypeSystem myTypeSystem = new MapDTypeSystem();
170  CalciteConnectionConfig config = new CalciteConnectionConfigImpl(new Properties()) {
171  {
172  properties.put(CalciteConnectionProperty.CASE_SENSITIVE.camelName(),
173  String.valueOf(false));
174  properties.put(CalciteConnectionProperty.CONFORMANCE.camelName(),
175  String.valueOf(SqlConformanceEnum.LENIENT));
176  }
177 
178  @SuppressWarnings("unchecked")
179  public <T extends Object> T typeSystem(
180  java.lang.Class<T> typeSystemClass, T defaultTypeSystem) {
181  return (T) myTypeSystem;
182  };
183 
184  public boolean caseSensitive() {
185  return false;
186  };
187 
188  public org.apache.calcite.sql.validate.SqlConformance conformance() {
189  return SqlConformanceEnum.LENIENT;
190  };
191  };
192 
193  @Override
194  public <C> C unwrap(Class<C> aClass) {
195  if (aClass.isInstance(config)) {
196  return aClass.cast(config);
197  }
198  return null;
199  }
200  };
201 
203  return getPlanner(true, true);
204  }
205 
206  private boolean isCorrelated(SqlNode expression) {
207  String queryString = expression.toSqlString(SqlDialect.CALCITE).getSql();
208  Boolean isCorrelatedSubquery = SubqueryCorrMemo.get(queryString);
209  if (null != isCorrelatedSubquery) {
210  return isCorrelatedSubquery;
211  }
212 
213  try {
216  MapDParserOptions options = new MapDParserOptions();
217  parser.setUser(mapdUser);
218  parser.processSql(expression, options);
219  } catch (Exception e) {
220  // if we are not able to parse, then assume correlated
221  SubqueryCorrMemo.put(queryString, true);
222  return true;
223  }
224  SubqueryCorrMemo.put(queryString, false);
225  return false;
226  }
227 
228  private MapDPlanner getPlanner(final boolean allowSubQueryExpansion,
229  final boolean allowPushdownJoinCondition) {
230  final MapDSchema mapd =
232  final SchemaPlus rootSchema = Frameworks.createRootSchema(true);
233 
234  BiPredicate<SqlNode, SqlNode> expandPredicate = new BiPredicate<SqlNode, SqlNode>() {
235  @Override
236  public boolean test(SqlNode root, SqlNode expression) {
237  if (!allowSubQueryExpansion) {
238  return false;
239  }
240 
241  // special handling of sub-queries
242  if (expression.isA(SCALAR) || expression.isA(EXISTS) || expression.isA(IN)) {
243  // only expand if it is correlated.
244 
245  if (expression.isA(EXISTS)) {
246  // always expand subquery by EXISTS clause
247  return true;
248  }
249 
250  if (expression.isA(IN)) {
251  // expand subquery by IN clause
252  // but correlated subquery by NOT_IN clause is not available
253  // currently due to a lack of supporting in Calcite
254  boolean found_expression = false;
255  if (expression instanceof SqlCall) {
256  SqlCall call = (SqlCall) expression;
257  if (call.getOperandList().size() == 2) {
258  // if IN clause is correlated, its second operand of corresponding
259  // expression is SELECT clause which indicates a correlated subquery.
260  // Here, an expression "f.val IN (SELECT ...)" has two operands.
261  // Since we have interest in its subquery, so try to check whether
262  // the second operand, i.e., call.getOperandList().get(1)
263  // is a type of SqlSelect and also is correlated.
264  // Note that the second operand of non-correlated IN clause
265  // does not have SqlSelect as its second operand
266  if (call.getOperandList().get(1) instanceof SqlSelect) {
267  expression = call.getOperandList().get(1);
268  SqlSelect select_call = (SqlSelect) expression;
269  if (select_call.hasWhere()) {
270  found_expression = true;
271  }
272  }
273  }
274  }
275  if (!found_expression) {
276  return false;
277  }
278  }
279 
280  if (isCorrelated(expression)) {
281  SqlSelect select = null;
282  if (expression instanceof SqlCall) {
283  SqlCall call = (SqlCall) expression;
284  if (call.getOperator().equals(SqlStdOperatorTable.SCALAR_QUERY)) {
285  expression = call.getOperandList().get(0);
286  }
287  }
288 
289  if (expression instanceof SqlSelect) {
290  select = (SqlSelect) expression;
291  }
292 
293  if (null != select) {
294  if (null != select.getFetch() || null != select.getOffset()
295  || (null != select.getOrderList()
296  && select.getOrderList().size() != 0)) {
297  throw new CalciteException(
298  "Correlated sub-queries with ordering not supported.", null);
299  }
300  }
301  return true;
302  }
303  }
304 
305  // per default we do not want to expand
306  return false;
307  }
308  };
309 
310  BiPredicate<SqlNode, Join> pushdownJoinPredicate = new BiPredicate<SqlNode, Join>() {
311  @Override
312  public boolean test(SqlNode t, Join u) {
313  if (!allowPushdownJoinCondition) {
314  return false;
315  }
316 
317  return !hasGeoColumns(u.getRowType());
318  }
319 
320  private boolean hasGeoColumns(RelDataType type) {
321  for (RelDataTypeField f : type.getFieldList()) {
322  if ("any".equalsIgnoreCase(f.getType().getFamily().toString())) {
323  // any indicates geo types at the moment
324  return true;
325  }
326  }
327 
328  return false;
329  }
330  };
331 
332  final FrameworkConfig config =
333  Frameworks.newConfigBuilder()
334  .defaultSchema(rootSchema.add(mapdUser.getDB(), mapd))
335  .operatorTable(mapDSqlOperatorTable.get())
336  .parserConfig(SqlParser.configBuilder()
337  .setConformance(SqlConformanceEnum.LENIENT)
338  .setUnquotedCasing(Casing.UNCHANGED)
339  .setCaseSensitive(false)
340  // allow identifiers of up to 512 chars
341  .setIdentifierMaxLength(512)
342  .setParserFactory(ExtendedSqlParser.FACTORY)
343  .build())
344  .sqlToRelConverterConfig(
345  SqlToRelConverter
346  .configBuilder()
347  // enable sub-query expansion (de-correlation)
348  .withExpandPredicate(expandPredicate)
349  // allow as many as possible IN operator values
350  .withInSubQueryThreshold(Integer.MAX_VALUE)
351  .withPushdownJoinCondition(pushdownJoinPredicate)
352  .withHintStrategyTable(
354  .build())
355  .typeSystem(createTypeSystem())
356  .context(MAPD_CONNECTION_CONTEXT)
357  .build();
358  return new MapDPlanner(config);
359  }
360 
361  public void setUser(MapDUser mapdUser) {
362  this.mapdUser = mapdUser;
363  }
364 
365  public Pair<String, SqlIdentifierCapturer> process(
366  String sql, final MapDParserOptions parserOptions)
367  throws SqlParseException, ValidationException, RelConversionException {
368  final MapDPlanner planner = getPlanner(true, true);
369  final SqlNode sqlNode = parseSql(sql, parserOptions.isLegacySyntax(), planner);
370  String res = processSql(sqlNode, parserOptions);
371  SqlIdentifierCapturer capture = captureIdentifiers(sqlNode);
372 
373  return new Pair<String, SqlIdentifierCapturer>(res, capture);
374  }
375 
376  public String processSql(String sql, final MapDParserOptions parserOptions)
377  throws SqlParseException, ValidationException, RelConversionException {
378  callCount++;
379 
380  final MapDPlanner planner = getPlanner(true, true);
381  final SqlNode sqlNode = parseSql(sql, parserOptions.isLegacySyntax(), planner);
382 
383  return processSql(sqlNode, parserOptions);
384  }
385 
386  public String processSql(final SqlNode sqlNode, final MapDParserOptions parserOptions)
387  throws SqlParseException, ValidationException, RelConversionException {
388  callCount++;
389 
390  if (sqlNode instanceof JsonSerializableDdl) {
391  return ((JsonSerializableDdl) sqlNode).toJsonString();
392  }
393 
394  final MapDPlanner planner = getPlanner(true, true);
395  planner.advanceToValidate();
396 
397  final RelRoot sqlRel = convertSqlToRelNode(sqlNode, planner, parserOptions);
398  RelNode project = sqlRel.project();
399 
400  if (parserOptions.isExplain()) {
401  return RelOptUtil.toString(sqlRel.project());
402  }
403 
404  String res = MapDSerializer.toString(project);
405 
406  return res;
407  }
408 
409  public MapDPlanner.CompletionResult getCompletionHints(
410  String sql, int cursor, List<String> visible_tables) {
411  return getPlanner().getCompletionHints(sql, cursor, visible_tables);
412  }
413 
414  public Set<String> resolveSelectIdentifiers(SqlIdentifierCapturer capturer) {
415  MapDSchema schema =
417  HashSet<String> resolved = new HashSet<>();
418 
419  for (String name : capturer.selects) {
420  MapDTable table = (MapDTable) schema.getTable(name);
421  if (null == table) {
422  throw new RuntimeException("table/view not found: " + name);
423  }
424 
425  if (table instanceof MapDView) {
426  MapDView view = (MapDView) table;
427  resolved.addAll(resolveSelectIdentifiers(view.getAccessedObjects()));
428  } else {
429  resolved.add(name);
430  }
431  }
432 
433  return resolved;
434  }
435 
436  private String getTableName(SqlNode node) {
437  if (node.isA(EnumSet.of(SqlKind.AS))) {
438  node = ((SqlCall) node).getOperandList().get(1);
439  }
440  if (node instanceof SqlIdentifier) {
441  SqlIdentifier id = (SqlIdentifier) node;
442  return id.names.get(id.names.size() - 1);
443  }
444  return null;
445  }
446 
447  private SqlSelect rewriteSimpleUpdateAsSelect(final SqlUpdate update) {
448  SqlNode where = update.getCondition();
449 
450  if (update.getSourceExpressionList().size() != 1) {
451  return null;
452  }
453 
454  if (!(update.getSourceExpressionList().get(0) instanceof SqlSelect)) {
455  return null;
456  }
457 
458  final SqlSelect inner = (SqlSelect) update.getSourceExpressionList().get(0);
459 
460  if (null != inner.getGroup() || null != inner.getFetch() || null != inner.getOffset()
461  || (null != inner.getOrderList() && inner.getOrderList().size() != 0)
462  || (null != inner.getGroup() && inner.getGroup().size() != 0)
463  || null == getTableName(inner.getFrom())) {
464  return null;
465  }
466 
467  if (!isCorrelated(inner)) {
468  return null;
469  }
470 
471  final String updateTableName = getTableName(update.getTargetTable());
472 
473  if (null != where) {
474  where = where.accept(new SqlShuttle() {
475  @Override
476  public SqlNode visit(SqlIdentifier id) {
477  if (id.isSimple()) {
478  id = new SqlIdentifier(Arrays.asList(updateTableName, id.getSimple()),
479  id.getParserPosition());
480  }
481 
482  return id;
483  }
484  });
485  }
486 
487  SqlJoin join = new SqlJoin(ZERO,
488  update.getTargetTable(),
489  SqlLiteral.createBoolean(false, ZERO),
490  SqlLiteral.createSymbol(JoinType.LEFT, ZERO),
491  inner.getFrom(),
492  SqlLiteral.createSymbol(JoinConditionType.ON, ZERO),
493  inner.getWhere());
494 
495  SqlNode select0 = inner.getSelectList().get(0);
496 
497  boolean wrapInSingleValue = true;
498  if (select0 instanceof SqlCall) {
499  SqlCall selectExprCall = (SqlCall) select0;
500  if (Util.isSingleValue(selectExprCall)) {
501  wrapInSingleValue = false;
502  }
503  }
504 
505  if (wrapInSingleValue) {
506  select0 = new SqlBasicCall(
507  SqlStdOperatorTable.SINGLE_VALUE, new SqlNode[] {select0}, ZERO);
508  }
509 
510  SqlNodeList selectList = new SqlNodeList(ZERO);
511  selectList.add(select0);
512  selectList.add(new SqlBasicCall(SqlStdOperatorTable.AS,
513  new SqlNode[] {new SqlBasicCall(
514  new SqlUnresolvedFunction(
515  new SqlIdentifier("OFFSET_IN_FRAGMENT", ZERO),
516  null,
517  null,
518  null,
519  null,
520  SqlFunctionCategory.USER_DEFINED_FUNCTION),
521  new SqlNode[0],
522  SqlParserPos.ZERO),
523  new SqlIdentifier("EXPR$DELETE_OFFSET_IN_FRAGMENT", ZERO)},
524  ZERO));
525 
526  SqlNodeList groupBy = new SqlNodeList(ZERO);
527  groupBy.add(new SqlIdentifier("EXPR$DELETE_OFFSET_IN_FRAGMENT", ZERO));
528 
529  SqlSelect select = new SqlSelect(ZERO,
530  null,
531  selectList,
532  join,
533  where,
534  groupBy,
535  null,
536  null,
537  null,
538  null,
539  null,
540  null);
541  return select;
542  }
543 
544  private LogicalTableModify getDummyUpdate(SqlUpdate update)
545  throws SqlParseException, ValidationException, RelConversionException {
546  SqlIdentifier targetTable = (SqlIdentifier) update.getTargetTable();
547  String targetTableName = targetTable.names.get(targetTable.names.size() - 1);
548  MapDPlanner planner = getPlanner();
549  String dummySql = "DELETE FROM " + targetTableName;
550  SqlNode dummyNode = planner.parse(dummySql);
551  dummyNode = planner.validate(dummyNode);
552  RelRoot dummyRoot = planner.rel(dummyNode);
553  LogicalTableModify dummyModify = (LogicalTableModify) dummyRoot.rel;
554  return dummyModify;
555  }
556 
557  private RelRoot rewriteUpdateAsSelect(SqlUpdate update, MapDParserOptions parserOptions)
558  throws SqlParseException, ValidationException, RelConversionException {
559  int correlatedQueriesCount[] = new int[1];
560  SqlBasicVisitor<Void> correlatedQueriesCounter = new SqlBasicVisitor<Void>() {
561  @Override
562  public Void visit(SqlCall call) {
563  if (call.isA(SCALAR)
564  && ((call instanceof SqlBasicCall && call.operandCount() == 1
565  && !call.operand(0).isA(SCALAR))
566  || !(call instanceof SqlBasicCall))) {
567  if (isCorrelated(call)) {
568  correlatedQueriesCount[0]++;
569  }
570  }
571  return super.visit(call);
572  }
573  };
574 
575  update.accept(correlatedQueriesCounter);
576  if (correlatedQueriesCount[0] > 1) {
577  throw new CalciteException(
578  "table modifications with multiple correlated sub-queries not supported.",
579  null);
580  }
581 
582  boolean allowPushdownJoinCondition = false;
583  SqlNodeList sourceExpression = new SqlNodeList(SqlParserPos.ZERO);
584  LogicalTableModify dummyModify = getDummyUpdate(update);
585  RelOptTable targetTable = dummyModify.getTable();
586  RelDataType targetTableType = targetTable.getRowType();
587 
588  SqlSelect select = rewriteSimpleUpdateAsSelect(update);
589  boolean applyRexCast = null == select;
590 
591  if (null == select) {
592  for (int i = 0; i < update.getSourceExpressionList().size(); i++) {
593  SqlNode targetColumn = update.getTargetColumnList().get(i);
594  SqlNode expression = update.getSourceExpressionList().get(i);
595 
596  if (!(targetColumn instanceof SqlIdentifier)) {
597  throw new RuntimeException("Unknown identifier type!");
598  }
599  SqlIdentifier id = (SqlIdentifier) targetColumn;
600  RelDataType fieldType =
601  targetTableType.getField(id.names.get(id.names.size() - 1), false, false)
602  .getType();
603 
604  if (expression.isA(ARRAY_VALUE) && null != fieldType.getComponentType()) {
605  // apply a cast to all array value elements
606 
607  SqlDataTypeSpec elementType = new SqlDataTypeSpec(
608  new SqlBasicTypeNameSpec(fieldType.getComponentType().getSqlTypeName(),
609  fieldType.getPrecision(),
610  fieldType.getScale(),
611  null == fieldType.getCharset() ? null
612  : fieldType.getCharset().name(),
613  SqlParserPos.ZERO),
614  SqlParserPos.ZERO);
615  SqlCall array_expression = (SqlCall) expression;
616  ArrayList<SqlNode> values = new ArrayList<>();
617 
618  for (SqlNode value : array_expression.getOperandList()) {
619  if (value.isA(EnumSet.of(SqlKind.LITERAL))) {
620  SqlNode casted_value = new SqlBasicCall(SqlStdOperatorTable.CAST,
621  new SqlNode[] {value, elementType},
622  value.getParserPosition());
623  values.add(casted_value);
624  } else {
625  values.add(value);
626  }
627  }
628 
629  expression = new SqlBasicCall(MapDSqlOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
630  values.toArray(new SqlNode[0]),
631  expression.getParserPosition());
632  }
633  sourceExpression.add(expression);
634  }
635 
636  sourceExpression.add(new SqlBasicCall(SqlStdOperatorTable.AS,
637  new SqlNode[] {
638  new SqlBasicCall(new SqlUnresolvedFunction(
639  new SqlIdentifier("OFFSET_IN_FRAGMENT",
640  SqlParserPos.ZERO),
641  null,
642  null,
643  null,
644  null,
645  SqlFunctionCategory.USER_DEFINED_FUNCTION),
646  new SqlNode[0],
647  SqlParserPos.ZERO),
648  new SqlIdentifier("EXPR$DELETE_OFFSET_IN_FRAGMENT", ZERO)},
649  ZERO));
650 
651  select = new SqlSelect(SqlParserPos.ZERO,
652  null,
653  sourceExpression,
654  update.getTargetTable(),
655  update.getCondition(),
656  null,
657  null,
658  null,
659  null,
660  null,
661  null,
662  null);
663  }
664 
665  MapDPlanner planner = getPlanner(true, allowPushdownJoinCondition);
666  SqlNode node = planner.parse(select.toSqlString(SqlDialect.CALCITE).getSql());
667  node = planner.validate(node);
668  RelRoot root = planner.rel(node);
669  LogicalProject project = (LogicalProject) root.project();
670 
671  ArrayList<String> fields = new ArrayList<String>();
672  ArrayList<RexNode> nodes = new ArrayList<RexNode>();
673  final RexBuilder builder = new RexBuilder(planner.getTypeFactory());
674 
675  for (SqlNode n : update.getTargetColumnList()) {
676  if (n instanceof SqlIdentifier) {
677  SqlIdentifier id = (SqlIdentifier) n;
678  fields.add(id.names.get(id.names.size() - 1));
679  } else {
680  throw new RuntimeException("Unknown identifier type!");
681  }
682  }
683 
684  int idx = 0;
685  for (RexNode exp : project.getChildExps()) {
686  if (applyRexCast && idx + 1 < project.getChildExps().size()) {
687  RelDataType expectedFieldType =
688  targetTableType.getField(fields.get(idx), false, false).getType();
689  if (!exp.getType().equals(expectedFieldType) && !exp.isA(ARRAY_VALUE)) {
690  exp = builder.makeCast(expectedFieldType, exp);
691  }
692  }
693 
694  nodes.add(exp);
695  idx++;
696  }
697 
698  ArrayList<RexNode> inputs = new ArrayList<RexNode>();
699  int n = 0;
700  for (int i = 0; i < fields.size(); i++) {
701  inputs.add(
702  new RexInputRef(n, project.getRowType().getFieldList().get(n).getType()));
703  n++;
704  }
705 
706  fields.add("EXPR$DELETE_OFFSET_IN_FRAGMENT");
707  inputs.add(new RexInputRef(n, project.getRowType().getFieldList().get(n).getType()));
708 
709  project = project.copy(
710  project.getTraitSet(), project.getInput(), nodes, project.getRowType());
711 
712  LogicalTableModify modify = LogicalTableModify.create(targetTable,
713  dummyModify.getCatalogReader(),
714  project,
715  Operation.UPDATE,
716  fields,
717  inputs,
718  true);
719  return RelRoot.of(modify, SqlKind.UPDATE);
720  }
721 
722  RelRoot queryToRelNode(final String sql, final MapDParserOptions parserOptions)
723  throws SqlParseException, ValidationException, RelConversionException {
724  final MapDPlanner planner = getPlanner(true, true);
725  final SqlNode sqlNode = parseSql(sql, parserOptions.isLegacySyntax(), planner);
726  return convertSqlToRelNode(sqlNode, planner, parserOptions);
727  }
728 
729  RelRoot convertSqlToRelNode(final SqlNode sqlNode,
730  final MapDPlanner mapDPlanner,
731  final MapDParserOptions parserOptions)
732  throws SqlParseException, ValidationException, RelConversionException {
733  SqlNode node = sqlNode;
734  MapDPlanner planner = mapDPlanner;
735  boolean allowCorrelatedSubQueryExpansion = true;
736  boolean allowPushdownJoinCondition = true;
737  boolean patchUpdateToDelete = false;
738 
739  if (node.isA(DELETE)) {
740  SqlDelete sqlDelete = (SqlDelete) node;
741  node = new SqlUpdate(node.getParserPosition(),
742  sqlDelete.getTargetTable(),
743  SqlNodeList.EMPTY,
744  SqlNodeList.EMPTY,
745  sqlDelete.getCondition(),
746  sqlDelete.getSourceSelect(),
747  sqlDelete.getAlias());
748 
749  patchUpdateToDelete = true;
750  }
751 
752  if (node.isA(UPDATE)) {
753  SqlUpdate update = (SqlUpdate) node;
754  update = (SqlUpdate) planner.validate(update);
755  RelRoot root = rewriteUpdateAsSelect(update, parserOptions);
756 
757  if (patchUpdateToDelete) {
758  LogicalTableModify modify = (LogicalTableModify) root.rel;
759 
760  try {
761  Field f = TableModify.class.getDeclaredField("operation");
762  f.setAccessible(true);
763  f.set(modify, Operation.DELETE);
764  } catch (Throwable e) {
765  throw new RuntimeException(e);
766  }
767 
768  root = RelRoot.of(modify, SqlKind.DELETE);
769  }
770 
771  return root;
772  }
773 
774  if (parserOptions.isLegacySyntax()) {
775  // close original planner
776  planner.close();
777  // create a new one
778  planner = getPlanner(allowCorrelatedSubQueryExpansion, allowPushdownJoinCondition);
779  node = parseSql(node.toSqlString(SqlDialect.CALCITE).toString(), false, planner);
780  }
781 
782  SqlNode validateR = planner.validate(node);
783  planner.setFilterPushDownInfo(parserOptions.getFilterPushDownInfo());
784  RelRoot relR = planner.rel(validateR);
785  relR = replaceIsTrue(planner.getTypeFactory(), relR);
786  planner.close();
787 
788  if (!parserOptions.isViewOptimizeEnabled()) {
789  return relR;
790  } else {
791  // check to see if a view is involved in the query
792  boolean foundView = false;
793  MapDSchema schema = new MapDSchema(
794  dataDir, this, mapdPort, mapdUser, sock_transport_properties);
795  SqlIdentifierCapturer capturer = captureIdentifiers(sqlNode);
796  for (String name : capturer.selects) {
797  MapDTable table = (MapDTable) schema.getTable(name);
798  if (null == table) {
799  throw new RuntimeException("table/view not found: " + name);
800  }
801  if (table instanceof MapDView) {
802  foundView = true;
803  }
804  }
805 
806  if (!foundView) {
807  return relR;
808  }
809 
810  // do some calcite based optimization
811  // will allow duplicate projects to merge
812  ProjectMergeRule projectMergeRule =
813  new ProjectMergeRule(true, RelFactories.LOGICAL_BUILDER);
814  final Program program =
815  Programs.hep(ImmutableList.of(FilterProjectTransposeRule.INSTANCE,
816  projectMergeRule,
817  ProjectProjectRemoveRule.INSTANCE,
818  FilterMergeRule.INSTANCE,
819  JoinProjectTransposeRule.LEFT_PROJECT_INCLUDE_OUTER,
820  JoinProjectTransposeRule.RIGHT_PROJECT_INCLUDE_OUTER,
821  JoinProjectTransposeRule.BOTH_PROJECT_INCLUDE_OUTER),
822  true,
823  DefaultRelMetadataProvider.INSTANCE);
824 
825  RelNode oldRel;
826  RelNode newRel = relR.project();
827 
828  do {
829  oldRel = newRel;
830  newRel = program.run(null,
831  oldRel,
832  null,
833  ImmutableList.<RelOptMaterialization>of(),
834  ImmutableList.<RelOptLattice>of());
835  // there must be a better way to compare these
836  } while (!RelOptUtil.toString(oldRel).equals(RelOptUtil.toString(newRel)));
837  RelRoot optRel = RelRoot.of(newRel, relR.kind);
838  return optRel;
839  }
840  }
841 
842  private RelRoot replaceIsTrue(final RelDataTypeFactory typeFactory, RelRoot root) {
843  final RexShuttle callShuttle = new RexShuttle() {
844  RexBuilder builder = new RexBuilder(typeFactory);
845 
846  public RexNode visitCall(RexCall call) {
847  call = (RexCall) super.visitCall(call);
848  if (call.getKind() == SqlKind.IS_TRUE) {
849  return builder.makeCall(SqlStdOperatorTable.AND,
850  builder.makeCall(
851  SqlStdOperatorTable.IS_NOT_NULL, call.getOperands().get(0)),
852  call.getOperands().get(0));
853  } else if (call.getKind() == SqlKind.IS_NOT_TRUE) {
854  return builder.makeCall(SqlStdOperatorTable.OR,
855  builder.makeCall(
856  SqlStdOperatorTable.IS_NULL, call.getOperands().get(0)),
857  builder.makeCall(SqlStdOperatorTable.NOT, call.getOperands().get(0)));
858  } else if (call.getKind() == SqlKind.IS_FALSE) {
859  return builder.makeCall(SqlStdOperatorTable.AND,
860  builder.makeCall(
861  SqlStdOperatorTable.IS_NOT_NULL, call.getOperands().get(0)),
862  builder.makeCall(SqlStdOperatorTable.NOT, call.getOperands().get(0)));
863  } else if (call.getKind() == SqlKind.IS_NOT_FALSE) {
864  return builder.makeCall(SqlStdOperatorTable.OR,
865  builder.makeCall(
866  SqlStdOperatorTable.IS_NULL, call.getOperands().get(0)),
867  call.getOperands().get(0));
868  }
869 
870  return call;
871  }
872  };
873 
874  RelNode node = root.rel.accept(new RelShuttleImpl() {
875  @Override
876  protected RelNode visitChild(RelNode parent, int i, RelNode child) {
877  RelNode node = super.visitChild(parent, i, child);
878  return node.accept(callShuttle);
879  }
880  });
881 
882  return new RelRoot(node,
883  root.validatedRowType,
884  root.kind,
885  root.fields,
886  root.collation,
887  Collections.emptyList());
888  }
889 
890  private SqlNode parseSql(String sql, final boolean legacy_syntax, Planner planner)
891  throws SqlParseException {
892  SqlNode parseR = null;
893  try {
894  parseR = planner.parse(sql);
895  MAPDLOGGER.debug(" node is \n" + parseR.toString());
896  } catch (SqlParseException ex) {
897  MAPDLOGGER.error("failed to parse SQL '" + sql + "' \n" + ex.toString());
898  throw ex;
899  }
900 
901  if (!legacy_syntax) {
902  return parseR;
903  }
904 
905  RelDataTypeFactory typeFactory = planner.getTypeFactory();
906  SqlSelect select_node = null;
907  if (parseR instanceof SqlSelect) {
908  select_node = (SqlSelect) parseR;
909  desugar(select_node, typeFactory);
910  } else if (parseR instanceof SqlOrderBy) {
911  SqlOrderBy order_by_node = (SqlOrderBy) parseR;
912  if (order_by_node.query instanceof SqlSelect) {
913  select_node = (SqlSelect) order_by_node.query;
914  SqlOrderBy new_order_by_node = desugar(select_node, order_by_node, typeFactory);
915  if (new_order_by_node != null) {
916  return new_order_by_node;
917  }
918  } else if (order_by_node.query instanceof SqlWith) {
919  SqlWith old_with_node = (SqlWith) order_by_node.query;
920  if (old_with_node.body instanceof SqlSelect) {
921  select_node = (SqlSelect) old_with_node.body;
922  desugar(select_node, typeFactory);
923  }
924  }
925  } else if (parseR instanceof SqlWith) {
926  SqlWith old_with_node = (SqlWith) parseR;
927  if (old_with_node.body instanceof SqlSelect) {
928  select_node = (SqlSelect) old_with_node.body;
929  desugar(select_node, typeFactory);
930  }
931  }
932  return parseR;
933  }
934 
935  private void desugar(SqlSelect select_node, RelDataTypeFactory typeFactory) {
936  desugar(select_node, null, typeFactory);
937  }
938 
939  private SqlNode expandCase(SqlCase old_case_node, RelDataTypeFactory typeFactory) {
940  SqlNodeList newWhenList =
941  new SqlNodeList(old_case_node.getWhenOperands().getParserPosition());
942  SqlNodeList newThenList =
943  new SqlNodeList(old_case_node.getThenOperands().getParserPosition());
944  java.util.Map<String, SqlNode> id_to_expr = new java.util.HashMap<String, SqlNode>();
945  for (SqlNode node : old_case_node.getWhenOperands()) {
946  SqlNode newCall = expand(node, id_to_expr, typeFactory);
947  if (null != newCall) {
948  newWhenList.add(newCall);
949  } else {
950  newWhenList.add(node);
951  }
952  }
953  for (SqlNode node : old_case_node.getThenOperands()) {
954  SqlNode newCall = expand(node, id_to_expr, typeFactory);
955  if (null != newCall) {
956  newThenList.add(newCall);
957  } else {
958  newThenList.add(node);
959  }
960  }
961  SqlNode new_else_operand = old_case_node.getElseOperand();
962  if (null != new_else_operand) {
963  SqlNode candidate_else_operand =
964  expand(old_case_node.getElseOperand(), id_to_expr, typeFactory);
965  if (null != candidate_else_operand) {
966  new_else_operand = candidate_else_operand;
967  }
968  }
969  SqlNode new_value_operand = old_case_node.getValueOperand();
970  if (null != new_value_operand) {
971  SqlNode candidate_value_operand =
972  expand(old_case_node.getValueOperand(), id_to_expr, typeFactory);
973  if (null != candidate_value_operand) {
974  new_value_operand = candidate_value_operand;
975  }
976  }
977  SqlNode newCaseNode = SqlCase.createSwitched(old_case_node.getParserPosition(),
978  new_value_operand,
979  newWhenList,
980  newThenList,
981  new_else_operand);
982  return newCaseNode;
983  }
984 
985  private SqlOrderBy desugar(SqlSelect select_node,
986  SqlOrderBy order_by_node,
987  RelDataTypeFactory typeFactory) {
988  MAPDLOGGER.debug("desugar: before: " + select_node.toString());
989  desugarExpression(select_node.getFrom(), typeFactory);
990  desugarExpression(select_node.getWhere(), typeFactory);
991  SqlNodeList select_list = select_node.getSelectList();
992  SqlNodeList new_select_list = new SqlNodeList(select_list.getParserPosition());
993  java.util.Map<String, SqlNode> id_to_expr = new java.util.HashMap<String, SqlNode>();
994  for (SqlNode proj : select_list) {
995  if (!(proj instanceof SqlBasicCall)) {
996  if (proj instanceof SqlCase) {
997  new_select_list.add(expandCase((SqlCase) proj, typeFactory));
998  } else {
999  new_select_list.add(proj);
1000  }
1001  } else {
1002  assert proj instanceof SqlBasicCall;
1003  SqlBasicCall proj_call = (SqlBasicCall) proj;
1004  if (proj_call.operands.length > 0) {
1005  for (int i = 0; i < proj_call.operands.length; i++) {
1006  if (proj_call.operand(i) instanceof SqlCase) {
1007  SqlNode new_op = expandCase(proj_call.operand(i), typeFactory);
1008  proj_call.setOperand(i, new_op);
1009  }
1010  }
1011  }
1012  new_select_list.add(expand(proj_call, id_to_expr, typeFactory));
1013  }
1014  }
1015  select_node.setSelectList(new_select_list);
1016  SqlNodeList group_by_list = select_node.getGroup();
1017  if (group_by_list != null) {
1018  select_node.setGroupBy(expand(group_by_list, id_to_expr, typeFactory));
1019  }
1020  SqlNode having = select_node.getHaving();
1021  if (having != null) {
1022  expand(having, id_to_expr, typeFactory);
1023  }
1024  SqlOrderBy new_order_by_node = null;
1025  if (order_by_node != null && order_by_node.orderList != null
1026  && order_by_node.orderList.size() > 0) {
1027  SqlNodeList new_order_by_list =
1028  expand(order_by_node.orderList, id_to_expr, typeFactory);
1029  new_order_by_node = new SqlOrderBy(order_by_node.getParserPosition(),
1030  select_node,
1031  new_order_by_list,
1032  order_by_node.offset,
1033  order_by_node.fetch);
1034  }
1035 
1036  MAPDLOGGER.debug("desugar: after: " + select_node.toString());
1037  return new_order_by_node;
1038  }
1039 
1040  private void desugarExpression(SqlNode node, RelDataTypeFactory typeFactory) {
1041  if (node instanceof SqlSelect) {
1042  desugar((SqlSelect) node, typeFactory);
1043  return;
1044  }
1045  if (!(node instanceof SqlBasicCall)) {
1046  return;
1047  }
1048  SqlBasicCall basic_call = (SqlBasicCall) node;
1049  for (SqlNode operator : basic_call.getOperands()) {
1050  if (operator instanceof SqlOrderBy) {
1051  desugarExpression(((SqlOrderBy) operator).query, typeFactory);
1052  } else {
1053  desugarExpression(operator, typeFactory);
1054  }
1055  }
1056  }
1057 
1058  private SqlNode expand(final SqlNode node,
1059  final java.util.Map<String, SqlNode> id_to_expr,
1060  RelDataTypeFactory typeFactory) {
1061  MAPDLOGGER.debug("expand: " + node.toString());
1062  if (node instanceof SqlBasicCall) {
1063  SqlBasicCall node_call = (SqlBasicCall) node;
1064  SqlNode[] operands = node_call.getOperands();
1065  for (int i = 0; i < operands.length; ++i) {
1066  node_call.setOperand(i, expand(operands[i], id_to_expr, typeFactory));
1067  }
1068  SqlNode expanded_variance = expandVariance(node_call, typeFactory);
1069  if (expanded_variance != null) {
1070  return expanded_variance;
1071  }
1072  SqlNode expanded_covariance = expandCovariance(node_call, typeFactory);
1073  if (expanded_covariance != null) {
1074  return expanded_covariance;
1075  }
1076  SqlNode expanded_correlation = expandCorrelation(node_call, typeFactory);
1077  if (expanded_correlation != null) {
1078  return expanded_correlation;
1079  }
1080  }
1081  if (node instanceof SqlSelect) {
1082  SqlSelect select_node = (SqlSelect) node;
1083  desugar(select_node, typeFactory);
1084  }
1085  return node;
1086  }
1087 
1088  private SqlNodeList expand(final SqlNodeList group_by_list,
1089  final java.util.Map<String, SqlNode> id_to_expr,
1090  RelDataTypeFactory typeFactory) {
1091  SqlNodeList new_group_by_list = new SqlNodeList(new SqlParserPos(-1, -1));
1092  for (SqlNode group_by : group_by_list) {
1093  if (!(group_by instanceof SqlIdentifier)) {
1094  new_group_by_list.add(expand(group_by, id_to_expr, typeFactory));
1095  continue;
1096  }
1097  SqlIdentifier group_by_id = ((SqlIdentifier) group_by);
1098  if (id_to_expr.containsKey(group_by_id.toString())) {
1099  new_group_by_list.add(id_to_expr.get(group_by_id.toString()));
1100  } else {
1101  new_group_by_list.add(group_by);
1102  }
1103  }
1104  return new_group_by_list;
1105  }
1106 
1107  private SqlNode expandVariance(
1108  final SqlBasicCall proj_call, RelDataTypeFactory typeFactory) {
1109  // Expand variance aggregates that are not supported natively
1110  if (proj_call.operandCount() != 1) {
1111  return null;
1112  }
1113  boolean biased;
1114  boolean sqrt;
1115  boolean flt;
1116  if (proj_call.getOperator().isName("STDDEV_POP", false)) {
1117  biased = true;
1118  sqrt = true;
1119  flt = false;
1120  } else if (proj_call.getOperator().getName().equalsIgnoreCase("STDDEV_POP_FLOAT")) {
1121  biased = true;
1122  sqrt = true;
1123  flt = true;
1124  } else if (proj_call.getOperator().isName("STDDEV_SAMP", false)
1125  || proj_call.getOperator().getName().equalsIgnoreCase("STDDEV")) {
1126  biased = false;
1127  sqrt = true;
1128  flt = false;
1129  } else if (proj_call.getOperator().getName().equalsIgnoreCase("STDDEV_SAMP_FLOAT")
1130  || proj_call.getOperator().getName().equalsIgnoreCase("STDDEV_FLOAT")) {
1131  biased = false;
1132  sqrt = true;
1133  flt = true;
1134  } else if (proj_call.getOperator().isName("VAR_POP", false)) {
1135  biased = true;
1136  sqrt = false;
1137  flt = false;
1138  } else if (proj_call.getOperator().getName().equalsIgnoreCase("VAR_POP_FLOAT")) {
1139  biased = true;
1140  sqrt = false;
1141  flt = true;
1142  } else if (proj_call.getOperator().isName("VAR_SAMP", false)
1143  || proj_call.getOperator().getName().equalsIgnoreCase("VARIANCE")) {
1144  biased = false;
1145  sqrt = false;
1146  flt = false;
1147  } else if (proj_call.getOperator().getName().equalsIgnoreCase("VAR_SAMP_FLOAT")
1148  || proj_call.getOperator().getName().equalsIgnoreCase("VARIANCE_FLOAT")) {
1149  biased = false;
1150  sqrt = false;
1151  flt = true;
1152  } else {
1153  return null;
1154  }
1155  final SqlNode operand = proj_call.operand(0);
1156  final SqlParserPos pos = proj_call.getParserPosition();
1157  SqlNode expanded_proj_call =
1158  expandVariance(pos, operand, biased, sqrt, flt, typeFactory);
1159  MAPDLOGGER.debug("Expanded select_list SqlCall: " + proj_call.toString());
1160  MAPDLOGGER.debug("to : " + expanded_proj_call.toString());
1161  return expanded_proj_call;
1162  }
1163 
1164  private SqlNode expandVariance(final SqlParserPos pos,
1165  final SqlNode operand,
1166  boolean biased,
1167  boolean sqrt,
1168  boolean flt,
1169  RelDataTypeFactory typeFactory) {
1170  // stddev_pop(x) ==>
1171  // power(
1172  // (sum(x * x) - sum(x) * sum(x) / (case count(x) when 0 then NULL else count(x)
1173  // end)) / (case count(x) when 0 then NULL else count(x) end), .5)
1174  //
1175  // stddev_samp(x) ==>
1176  // power(
1177  // (sum(x * x) - sum(x) * sum(x) / (case count(x) when 0 then NULL else count(x)
1178  // )) / ((case count(x) when 1 then NULL else count(x) - 1 end)), .5)
1179  //
1180  // var_pop(x) ==>
1181  // (sum(x * x) - sum(x) * sum(x) / ((case count(x) when 0 then NULL else
1182  // count(x)
1183  // end))) / ((case count(x) when 0 then NULL else count(x) end))
1184  //
1185  // var_samp(x) ==>
1186  // (sum(x * x) - sum(x) * sum(x) / ((case count(x) when 0 then NULL else
1187  // count(x)
1188  // end))) / ((case count(x) when 1 then NULL else count(x) - 1 end))
1189  //
1190  final SqlNode arg = SqlStdOperatorTable.CAST.createCall(pos,
1191  operand,
1192  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1193  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1194  final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
1195  final SqlNode sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared);
1196  final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg);
1197  final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, sum, sum);
1198  final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg);
1199  final SqlLiteral nul = SqlLiteral.createNull(pos);
1200  final SqlNumericLiteral zero = SqlLiteral.createExactNumeric("0", pos);
1201  final SqlNode countEqZero = SqlStdOperatorTable.EQUALS.createCall(pos, count, zero);
1202  SqlNodeList whenList = new SqlNodeList(pos);
1203  SqlNodeList thenList = new SqlNodeList(pos);
1204  whenList.add(countEqZero);
1205  thenList.add(nul);
1206  final SqlNode int_denominator = SqlStdOperatorTable.CASE.createCall(
1207  null, pos, null, whenList, thenList, count);
1208  final SqlNode denominator = SqlStdOperatorTable.CAST.createCall(pos,
1209  int_denominator,
1210  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1211  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1212  final SqlNode avgSumSquared =
1213  SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, denominator);
1214  final SqlNode diff =
1215  SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquared, avgSumSquared);
1216  final SqlNode denominator1;
1217  if (biased) {
1218  denominator1 = denominator;
1219  } else {
1220  final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
1221  final SqlNode countEqOne = SqlStdOperatorTable.EQUALS.createCall(pos, count, one);
1222  final SqlNode countMinusOne = SqlStdOperatorTable.MINUS.createCall(pos, count, one);
1223  SqlNodeList whenList1 = new SqlNodeList(pos);
1224  SqlNodeList thenList1 = new SqlNodeList(pos);
1225  whenList1.add(countEqOne);
1226  thenList1.add(nul);
1227  final SqlNode int_denominator1 = SqlStdOperatorTable.CASE.createCall(
1228  null, pos, null, whenList1, thenList1, countMinusOne);
1229  denominator1 = SqlStdOperatorTable.CAST.createCall(pos,
1230  int_denominator1,
1231  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1232  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1233  }
1234  final SqlNode div = SqlStdOperatorTable.DIVIDE.createCall(pos, diff, denominator1);
1235  SqlNode result = div;
1236  if (sqrt) {
1237  final SqlNumericLiteral half = SqlLiteral.createExactNumeric("0.5", pos);
1238  result = SqlStdOperatorTable.POWER.createCall(pos, div, half);
1239  }
1240  return SqlStdOperatorTable.CAST.createCall(pos,
1241  result,
1242  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1243  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1244  }
1245 
1246  private SqlNode expandCovariance(
1247  final SqlBasicCall proj_call, RelDataTypeFactory typeFactory) {
1248  // Expand covariance aggregates
1249  if (proj_call.operandCount() != 2) {
1250  return null;
1251  }
1252  boolean pop;
1253  boolean flt;
1254  if (proj_call.getOperator().isName("COVAR_POP", false)) {
1255  pop = true;
1256  flt = false;
1257  } else if (proj_call.getOperator().isName("COVAR_SAMP", false)) {
1258  pop = false;
1259  flt = false;
1260  } else if (proj_call.getOperator().getName().equalsIgnoreCase("COVAR_POP_FLOAT")) {
1261  pop = true;
1262  flt = true;
1263  } else if (proj_call.getOperator().getName().equalsIgnoreCase("COVAR_SAMP_FLOAT")) {
1264  pop = false;
1265  flt = true;
1266  } else {
1267  return null;
1268  }
1269  final SqlNode operand0 = proj_call.operand(0);
1270  final SqlNode operand1 = proj_call.operand(1);
1271  final SqlParserPos pos = proj_call.getParserPosition();
1272  SqlNode expanded_proj_call =
1273  expandCovariance(pos, operand0, operand1, pop, flt, typeFactory);
1274  MAPDLOGGER.debug("Expanded select_list SqlCall: " + proj_call.toString());
1275  MAPDLOGGER.debug("to : " + expanded_proj_call.toString());
1276  return expanded_proj_call;
1277  }
1278 
1279  private SqlNode expandCovariance(SqlParserPos pos,
1280  final SqlNode operand0,
1281  final SqlNode operand1,
1282  boolean pop,
1283  boolean flt,
1284  RelDataTypeFactory typeFactory) {
1285  // covar_pop(x, y) ==> avg(x * y) - avg(x) * avg(y)
1286  // covar_samp(x, y) ==> (sum(x * y) - sum(x) * avg(y))
1287  // ((case count(x) when 1 then NULL else count(x) - 1 end))
1288  final SqlNode arg0 = SqlStdOperatorTable.CAST.createCall(operand0.getParserPosition(),
1289  operand0,
1290  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1291  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1292  final SqlNode arg1 = SqlStdOperatorTable.CAST.createCall(operand1.getParserPosition(),
1293  operand1,
1294  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1295  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1296  final SqlNode mulArg = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg0, arg1);
1297  final SqlNode avgArg1 = SqlStdOperatorTable.AVG.createCall(pos, arg1);
1298  if (pop) {
1299  final SqlNode avgMulArg = SqlStdOperatorTable.AVG.createCall(pos, mulArg);
1300  final SqlNode avgArg0 = SqlStdOperatorTable.AVG.createCall(pos, arg0);
1301  final SqlNode mulAvgAvg =
1302  SqlStdOperatorTable.MULTIPLY.createCall(pos, avgArg0, avgArg1);
1303  final SqlNode covarPop =
1304  SqlStdOperatorTable.MINUS.createCall(pos, avgMulArg, mulAvgAvg);
1305  return SqlStdOperatorTable.CAST.createCall(pos,
1306  covarPop,
1307  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1308  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1309  }
1310  final SqlNode sumMulArg = SqlStdOperatorTable.SUM.createCall(pos, mulArg);
1311  final SqlNode sumArg0 = SqlStdOperatorTable.SUM.createCall(pos, arg0);
1312  final SqlNode mulSumAvg =
1313  SqlStdOperatorTable.MULTIPLY.createCall(pos, sumArg0, avgArg1);
1314  final SqlNode sub = SqlStdOperatorTable.MINUS.createCall(pos, sumMulArg, mulSumAvg);
1315  final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, operand0);
1316  final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
1317  final SqlNode countEqOne = SqlStdOperatorTable.EQUALS.createCall(pos, count, one);
1318  final SqlNode countMinusOne = SqlStdOperatorTable.MINUS.createCall(pos, count, one);
1319  final SqlLiteral nul = SqlLiteral.createNull(pos);
1320  SqlNodeList whenList1 = new SqlNodeList(pos);
1321  SqlNodeList thenList1 = new SqlNodeList(pos);
1322  whenList1.add(countEqOne);
1323  thenList1.add(nul);
1324  final SqlNode int_denominator = SqlStdOperatorTable.CASE.createCall(
1325  null, pos, null, whenList1, thenList1, countMinusOne);
1326  final SqlNode denominator = SqlStdOperatorTable.CAST.createCall(pos,
1327  int_denominator,
1328  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1329  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1330  final SqlNode covarSamp =
1331  SqlStdOperatorTable.DIVIDE.createCall(pos, sub, denominator);
1332  return SqlStdOperatorTable.CAST.createCall(pos,
1333  covarSamp,
1334  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1335  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1336  }
1337 
1338  private SqlNode expandCorrelation(
1339  final SqlBasicCall proj_call, RelDataTypeFactory typeFactory) {
1340  // Expand correlation coefficient
1341  if (proj_call.operandCount() != 2) {
1342  return null;
1343  }
1344  boolean flt;
1345  if (proj_call.getOperator().isName("CORR", false)
1346  || proj_call.getOperator().getName().equalsIgnoreCase("CORRELATION")) {
1347  // expand correlation coefficient
1348  flt = false;
1349  } else if (proj_call.getOperator().getName().equalsIgnoreCase("CORR_FLOAT")
1350  || proj_call.getOperator().getName().equalsIgnoreCase("CORRELATION_FLOAT")) {
1351  // expand correlation coefficient
1352  flt = true;
1353  } else {
1354  return null;
1355  }
1356  // corr(x, y) ==> (avg(x * y) - avg(x) * avg(y)) / (stddev_pop(x) *
1357  // stddev_pop(y))
1358  // ==> covar_pop(x, y) / (stddev_pop(x) * stddev_pop(y))
1359  final SqlNode operand0 = proj_call.operand(0);
1360  final SqlNode operand1 = proj_call.operand(1);
1361  final SqlParserPos pos = proj_call.getParserPosition();
1362  SqlNode covariance =
1363  expandCovariance(pos, operand0, operand1, true, flt, typeFactory);
1364  SqlNode stddev0 = expandVariance(pos, operand0, true, true, flt, typeFactory);
1365  SqlNode stddev1 = expandVariance(pos, operand1, true, true, flt, typeFactory);
1366  final SqlNode mulStddev =
1367  SqlStdOperatorTable.MULTIPLY.createCall(pos, stddev0, stddev1);
1368  final SqlNumericLiteral zero = SqlLiteral.createExactNumeric("0.0", pos);
1369  final SqlNode mulStddevEqZero =
1370  SqlStdOperatorTable.EQUALS.createCall(pos, mulStddev, zero);
1371  final SqlLiteral nul = SqlLiteral.createNull(pos);
1372  SqlNodeList whenList1 = new SqlNodeList(pos);
1373  SqlNodeList thenList1 = new SqlNodeList(pos);
1374  whenList1.add(mulStddevEqZero);
1375  thenList1.add(nul);
1376  final SqlNode denominator = SqlStdOperatorTable.CASE.createCall(
1377  null, pos, null, whenList1, thenList1, mulStddev);
1378  final SqlNode expanded_proj_call =
1379  SqlStdOperatorTable.DIVIDE.createCall(pos, covariance, denominator);
1380  MAPDLOGGER.debug("Expanded select_list SqlCall: " + proj_call.toString());
1381  MAPDLOGGER.debug("to : " + expanded_proj_call.toString());
1382  return expanded_proj_call;
1383  }
1384 
1385  public SqlIdentifierCapturer captureIdentifiers(String sql, boolean legacy_syntax)
1386  throws SqlParseException {
1387  try {
1388  Planner planner = getPlanner();
1389  SqlNode node = parseSql(sql, legacy_syntax, planner);
1390  return captureIdentifiers(node);
1391  } catch (Exception | Error e) {
1392  MAPDLOGGER.error("Error parsing sql: " + sql, e);
1393  return new SqlIdentifierCapturer();
1394  }
1395  }
1396 
1397  public SqlIdentifierCapturer captureIdentifiers(SqlNode node) throws SqlParseException {
1398  try {
1400  capturer.scan(node);
1401  return capturer;
1402  } catch (Exception | Error e) {
1403  MAPDLOGGER.error("Error parsing sql: " + node, e);
1404  return new SqlIdentifierCapturer();
1405  }
1406  }
1407 
1408  public int getCallCount() {
1409  return callCount;
1410  }
1411 
1412  public void updateMetaData(String schema, String table) {
1413  MAPDLOGGER.debug("schema :" + schema + " table :" + table);
1414  MapDSchema mapd =
1415  new MapDSchema(dataDir, this, mapdPort, null, sock_transport_properties);
1416  mapd.updateMetaData(schema, table);
1417  }
1418 
1419  protected RelDataTypeSystem createTypeSystem() {
1420  final MapDTypeSystem typeSystem = new MapDTypeSystem();
1421  return typeSystem;
1422  }
1423 }
SqlNode expand(final SqlNode node, final java.util.Map< String, SqlNode > id_to_expr, RelDataTypeFactory typeFactory)
static final EnumSet< SqlKind > IN
JoinType
Definition: sqldefs.h:107
static final EnumSet< SqlKind > SCALAR
RelRoot replaceIsTrue(final RelDataTypeFactory typeFactory, RelRoot root)
void desugarExpression(SqlNode node, RelDataTypeFactory typeFactory)
SqlNode expandCovariance(final SqlBasicCall proj_call, RelDataTypeFactory typeFactory)
SqlNode expandCorrelation(final SqlBasicCall proj_call, RelDataTypeFactory typeFactory)
RelRoot queryToRelNode(final String sql, final MapDParserOptions parserOptions)
Pair< String, SqlIdentifierCapturer > process(String sql, final MapDParserOptions parserOptions)
std::string join(T const &container, std::string const &delim)
SqlNode expandCase(SqlCase old_case_node, RelDataTypeFactory typeFactory)
static final EnumSet< SqlKind > ARRAY_VALUE
final Supplier< MapDSqlOperatorTable > mapDSqlOperatorTable
SqlNode expandVariance(final SqlParserPos pos, final SqlNode operand, boolean biased, boolean sqrt, boolean flt, RelDataTypeFactory typeFactory)
SqlIdentifierCapturer getAccessedObjects()
Definition: MapDView.java:66
void updateMetaData(String schema, String table)
MapDParser(String dataDir, final Supplier< MapDSqlOperatorTable > mapDSqlOperatorTable, int mapdPort, SockTransportProperties skT)
String processSql(String sql, final MapDParserOptions parserOptions)
SqlIdentifierCapturer captureIdentifiers(SqlNode node)
SqlSelect rewriteSimpleUpdateAsSelect(final SqlUpdate update)
SockTransportProperties sock_transport_properties
static final SqlArrayValueConstructorAllowingEmpty ARRAY_VALUE_CONSTRUCTOR
static final EnumSet< SqlKind > EXISTS
SqlNode parseSql(String sql, final boolean legacy_syntax, Planner planner)
static final EnumSet< SqlKind > UPDATE
SqlIdentifierCapturer captureIdentifiers(String sql, boolean legacy_syntax)
static Map< String, Boolean > SubqueryCorrMemo
Set< String > resolveSelectIdentifiers(SqlIdentifierCapturer capturer)
RelRoot convertSqlToRelNode(final SqlNode sqlNode, final MapDPlanner mapDPlanner, final MapDParserOptions parserOptions)
String getTableName(SqlNode node)
boolean isCorrelated(SqlNode expression)
Table getTable(String string)
Definition: MapDSchema.java:51
int64_t const int32_t sz assert(dest)
String processSql(final SqlNode sqlNode, final MapDParserOptions parserOptions)
void setUser(MapDUser mapdUser)
SqlNodeList expand(final SqlNodeList group_by_list, final java.util.Map< String, SqlNode > id_to_expr, RelDataTypeFactory typeFactory)
SqlNode expandCovariance(SqlParserPos pos, final SqlNode operand0, final SqlNode operand1, boolean pop, boolean flt, RelDataTypeFactory typeFactory)
void desugar(SqlSelect select_node, RelDataTypeFactory typeFactory)
SqlOrderBy desugar(SqlSelect select_node, SqlOrderBy order_by_node, RelDataTypeFactory typeFactory)
static final ThreadLocal< MapDParser > CURRENT_PARSER
SqlNode expandVariance(final SqlBasicCall proj_call, RelDataTypeFactory typeFactory)
MapDPlanner.CompletionResult getCompletionHints(String sql, int cursor, List< String > visible_tables)
static final EnumSet< SqlKind > DELETE
MapDPlanner getPlanner(final boolean allowSubQueryExpansion, final boolean allowPushdownJoinCondition)
static final Context MAPD_CONNECTION_CONTEXT
RelDataTypeSystem createTypeSystem()
LogicalTableModify getDummyUpdate(SqlUpdate update)
RelRoot rewriteUpdateAsSelect(SqlUpdate update, MapDParserOptions parserOptions)