OmniSciDB  4201147b46
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
HeavyDBParser.java
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.mapd.calcite.parser;
18 
19 import static org.apache.calcite.sql.parser.SqlParserPos.ZERO;
20 
21 import com.google.common.collect.ImmutableList;
25 import com.mapd.parser.extension.ddl.ExtendedSqlParser;
28 
29 import org.apache.calcite.avatica.util.Casing;
30 import org.apache.calcite.config.CalciteConnectionConfig;
31 import org.apache.calcite.config.CalciteConnectionConfigImpl;
32 import org.apache.calcite.config.CalciteConnectionProperty;
33 import org.apache.calcite.plan.Context;
34 import org.apache.calcite.plan.RelOptTable;
35 import org.apache.calcite.plan.RelOptUtil;
36 import org.apache.calcite.plan.hep.HepPlanner;
37 import org.apache.calcite.plan.hep.HepProgramBuilder;
40 import org.apache.calcite.rel.RelNode;
41 import org.apache.calcite.rel.RelRoot;
42 import org.apache.calcite.rel.RelShuttleImpl;
43 import org.apache.calcite.rel.core.TableModify;
44 import org.apache.calcite.rel.core.TableModify.Operation;
45 import org.apache.calcite.rel.logical.LogicalProject;
46 import org.apache.calcite.rel.logical.LogicalTableModify;
47 import org.apache.calcite.rel.rules.CoreRules;
48 import org.apache.calcite.rel.type.RelDataType;
49 import org.apache.calcite.rel.type.RelDataTypeFactory;
50 import org.apache.calcite.rel.type.RelDataTypeSystem;
51 import org.apache.calcite.rex.*;
52 import org.apache.calcite.runtime.CalciteException;
53 import org.apache.calcite.schema.SchemaPlus;
54 import org.apache.calcite.sql.*;
55 import org.apache.calcite.sql.dialect.CalciteSqlDialect;
56 import org.apache.calcite.sql.fun.SqlCase;
57 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
58 import org.apache.calcite.sql.parser.SqlParseException;
59 import org.apache.calcite.sql.parser.SqlParser;
60 import org.apache.calcite.sql.parser.SqlParserPos;
61 import org.apache.calcite.sql.type.SqlTypeName;
62 import org.apache.calcite.sql.type.SqlTypeUtil;
63 import org.apache.calcite.sql.util.SqlBasicVisitor;
64 import org.apache.calcite.sql.util.SqlShuttle;
65 import org.apache.calcite.sql.validate.SqlConformanceEnum;
67 import org.apache.calcite.tools.*;
68 import org.apache.calcite.util.Pair;
69 import org.apache.calcite.util.Util;
70 import org.slf4j.Logger;
71 import org.slf4j.LoggerFactory;
72 
73 import java.io.IOException;
74 import java.lang.reflect.Field;
75 import java.util.*;
76 import java.util.concurrent.ConcurrentHashMap;
77 import java.util.function.BiPredicate;
78 import java.util.function.Supplier;
79 import java.util.stream.Stream;
80 
81 import ai.heavy.thrift.server.TTableDetails;
82 
83 public final class HeavyDBParser {
84  public static final ThreadLocal<HeavyDBParser> CURRENT_PARSER = new ThreadLocal<>();
85  private static final EnumSet<SqlKind> SCALAR =
86  EnumSet.of(SqlKind.SCALAR_QUERY, SqlKind.SELECT);
87  private static final EnumSet<SqlKind> EXISTS = EnumSet.of(SqlKind.EXISTS);
88  private static final EnumSet<SqlKind> DELETE = EnumSet.of(SqlKind.DELETE);
89  private static final EnumSet<SqlKind> UPDATE = EnumSet.of(SqlKind.UPDATE);
90  private static final EnumSet<SqlKind> IN = EnumSet.of(SqlKind.IN);
91  private static final EnumSet<SqlKind> ARRAY_VALUE =
92  EnumSet.of(SqlKind.ARRAY_VALUE_CONSTRUCTOR);
93 
94  final static Logger HEAVYDBLOGGER = LoggerFactory.getLogger(HeavyDBParser.class);
95 
96  private final Supplier<HeavyDBSqlOperatorTable> dbSqlOperatorTable;
97  private final String dataDir;
98 
99  private int callCount = 0;
100  private final int dbPort;
103 
104  private static Map<String, Boolean> SubqueryCorrMemo = new ConcurrentHashMap<>();
105 
106  public HeavyDBParser(String dataDir,
107  final Supplier<HeavyDBSqlOperatorTable> dbSqlOperatorTable,
108  int dbPort,
110  this.dataDir = dataDir;
111  this.dbSqlOperatorTable = dbSqlOperatorTable;
112  this.dbPort = dbPort;
113  this.sock_transport_properties = skT;
114  }
115 
116  public void clearMemo() {
117  SubqueryCorrMemo.clear();
118  }
119 
120  private static final Context DB_CONNECTION_CONTEXT = new Context() {
121  HeavyDBTypeSystem myTypeSystem = new HeavyDBTypeSystem();
122  CalciteConnectionConfig config = new CalciteConnectionConfigImpl(new Properties()) {
123  {
124  properties.put(CalciteConnectionProperty.CASE_SENSITIVE.camelName(),
125  String.valueOf(false));
126  properties.put(CalciteConnectionProperty.CONFORMANCE.camelName(),
127  String.valueOf(SqlConformanceEnum.LENIENT));
128  }
129 
130  @SuppressWarnings("unchecked")
131  public <T extends Object> T typeSystem(
132  java.lang.Class<T> typeSystemClass, T defaultTypeSystem) {
133  return (T) myTypeSystem;
134  };
135 
136  public boolean caseSensitive() {
137  return false;
138  };
139 
140  public org.apache.calcite.sql.validate.SqlConformance conformance() {
141  return SqlConformanceEnum.LENIENT;
142  };
143  };
144 
145  @Override
146  public <C> C unwrap(Class<C> aClass) {
147  if (aClass.isInstance(config)) {
148  return aClass.cast(config);
149  }
150  return null;
151  }
152  };
153 
155  return getPlanner(true, false);
156  }
157 
158  private boolean isCorrelated(SqlNode expression) {
159  String queryString = expression.toSqlString(CalciteSqlDialect.DEFAULT).getSql();
160  Boolean isCorrelatedSubquery = SubqueryCorrMemo.get(queryString);
161  if (null != isCorrelatedSubquery) {
162  return isCorrelatedSubquery;
163  }
164 
165  try {
169  parser.setUser(dbUser);
170  parser.processSql(expression, options);
171  } catch (Exception e) {
172  // if we are not able to parse, then assume correlated
173  SubqueryCorrMemo.put(queryString, true);
174  return true;
175  }
176  SubqueryCorrMemo.put(queryString, false);
177  return false;
178  }
179 
181  final boolean allowSubQueryExpansion, final boolean isWatchdogEnabled) {
182  BiPredicate<SqlNode, SqlNode> expandPredicate = new BiPredicate<SqlNode, SqlNode>() {
183  @Override
184  public boolean test(SqlNode root, SqlNode expression) {
185  if (!allowSubQueryExpansion) {
186  return false;
187  }
188 
189  if (expression.isA(EXISTS) || expression.isA(IN)) {
190  // try to expand subquery by EXISTS and IN clauses by default
191  // note that current Calcite decorrelator fails to flat
192  // NOT-IN clause in some cases, so we do not decorrelate it for now
193 
194  if (expression.isA(IN)) {
195  // If we enable watchdog, we suffer from large projection exception in many
196  // cases since decorrelation needs de-duplication step which adds project -
197  // aggregate logic. And the added project is the source of the exception when
198  // its underlying table is large. Thus, we enable IN-clause decorrelation
199  // under watchdog iff we explicitly have correlated join in IN-clause
200  if (isWatchdogEnabled) {
201  boolean found_expression = false;
202  if (expression instanceof SqlCall) {
203  SqlCall call = (SqlCall) expression;
204  if (call.getOperandList().size() == 2) {
205  // if IN clause is correlated, its second operand of corresponding
206  // expression is SELECT clause which indicates a correlated subquery.
207  // Here, an expression "f.val IN (SELECT ...)" has two operands.
208  // Since we have interest in its subquery, so try to check whether
209  // the second operand, i.e., call.getOperandList().get(1)
210  // is a type of SqlSelect and also is correlated.
211  if (call.getOperandList().get(1) instanceof SqlSelect) {
212  expression = call.getOperandList().get(1);
213  SqlSelect select_call = (SqlSelect) expression;
214  if (select_call.hasWhere()) {
215  // IN-clause may have correlated join within subquery's WHERE clause
216  // i.e., f.val IN (SELECT r.val FROM R r WHERE f.val2 = r.val2)
217  // then we have to deccorrelate the IN-clause
218  JoinOperatorChecker joinOperatorChecker = new JoinOperatorChecker();
219  if (joinOperatorChecker.containsExpression(
220  select_call.getWhere())) {
221  found_expression = true;
222  }
223  }
224  }
225  }
226  }
227  if (!found_expression) {
228  return false;
229  }
230  }
231 
232  if (root instanceof SqlSelect) {
233  SqlSelect selectCall = (SqlSelect) root;
234  if (new ExpressionListedInSelectClauseChecker().containsExpression(
235  selectCall, expression)) {
236  // occasionally, Calcite cannot properly decorrelate IN-clause listed in
237  // SELECT clause e.g., SELECT x, CASE WHEN x in (SELECT x FROM R) ... FROM
238  // ... in that case we disable input query's decorrelation
239  return false;
240  }
241  if (null != selectCall.getWhere()) {
242  if (new ExpressionListedAsChildOROperatorChecker().containsExpression(
243  selectCall.getWhere(), expression)) {
244  // Decorrelation logic of the current Calcite cannot cover IN-clause
245  // well if it is listed as a child operand of OR-op
246  return false;
247  }
248  }
249  if (null != selectCall.getHaving()) {
250  if (new ExpressionListedAsChildOROperatorChecker().containsExpression(
251  selectCall.getHaving(), expression)) {
252  // Decorrelation logic of the current Calcite cannot cover IN-clause
253  // well if it is listed as a child operand of OR-op
254  return false;
255  }
256  }
257  }
258  }
259 
260  // otherwise, let's decorrelate the expression
261  return true;
262  }
263 
264  // special handling of sub-queries
265  if (expression.isA(SCALAR) && isCorrelated(expression)) {
266  // only expand if it is correlated.
267  SqlSelect select = null;
268  if (expression instanceof SqlCall) {
269  SqlCall call = (SqlCall) expression;
270  if (call.getOperator().equals(SqlStdOperatorTable.SCALAR_QUERY)) {
271  expression = call.getOperandList().get(0);
272  }
273  }
274 
275  if (expression instanceof SqlSelect) {
276  select = (SqlSelect) expression;
277  }
278 
279  if (null != select) {
280  if (null != select.getFetch() || null != select.getOffset()
281  || (null != select.getOrderList()
282  && select.getOrderList().size() != 0)) {
283  throw new CalciteException(
284  "Correlated sub-queries with ordering not supported.", null);
285  }
286  }
287  return true;
288  }
289 
290  // per default we do not want to expand
291  return false;
292  }
293  };
294 
295  // create the default schema
296  final SchemaPlus rootSchema = Frameworks.createRootSchema(true);
297  final HeavyDBSchema defaultSchema =
299  final SchemaPlus defaultSchemaPlus = rootSchema.add(dbUser.getDB(), defaultSchema);
300 
301  // add the other potential schemas
302  // this is where the systyem schema would be added
303  final MetaConnect mc =
305 
306  // TODO MAT for this checkin we are not going to actually allow any additional
307  // schemas
308  // Eveything should work and perform as it ever did
309  if (false) {
310  for (String db : mc.getDatabases()) {
311  if (!db.toUpperCase().equals(dbUser.getDB().toUpperCase())) {
312  rootSchema.add(db,
313  new HeavyDBSchema(
315  }
316  }
317  }
318 
319  final FrameworkConfig config =
320  Frameworks.newConfigBuilder()
321  .defaultSchema(defaultSchemaPlus)
322  .operatorTable(dbSqlOperatorTable.get())
323  .parserConfig(SqlParser.configBuilder()
324  .setConformance(SqlConformanceEnum.LENIENT)
325  .setUnquotedCasing(Casing.UNCHANGED)
326  .setCaseSensitive(false)
327  // allow identifiers of up to 512 chars
328  .setIdentifierMaxLength(512)
329  .setParserFactory(ExtendedSqlParser.FACTORY)
330  .build())
331  .sqlToRelConverterConfig(
332  SqlToRelConverter
333  .configBuilder()
334  // enable sub-query expansion (de-correlation)
335  .withExpandPredicate(expandPredicate)
336  // allow as many as possible IN operator values
337  .withInSubQueryThreshold(Integer.MAX_VALUE)
338  .withHintStrategyTable(
340  .build())
341 
342  .typeSystem(createTypeSystem())
343  .context(DB_CONNECTION_CONTEXT)
344  .build();
345  HeavyDBPlanner planner = new HeavyDBPlanner(config);
346  planner.setRestrictions(dbUser.getRestrictions());
347  return planner;
348  }
349 
350  public void setUser(HeavyDBUser dbUser) {
351  this.dbUser = dbUser;
352  }
353 
354  public Pair<String, SqlIdentifierCapturer> process(
355  String sql, final HeavyDBParserOptions parserOptions)
356  throws SqlParseException, ValidationException, RelConversionException {
357  final HeavyDBPlanner planner = getPlanner(true, parserOptions.isWatchdogEnabled());
358  final SqlNode sqlNode = parseSql(sql, parserOptions.isLegacySyntax(), planner);
359  String res = processSql(sqlNode, parserOptions);
360  SqlIdentifierCapturer capture = captureIdentifiers(sqlNode);
361  return new Pair<String, SqlIdentifierCapturer>(res, capture);
362  }
363 
365  String query, final HeavyDBParserOptions parserOptions) throws IOException {
366  HeavyDBSchema schema =
368  HeavyDBPlanner planner = getPlanner(true, parserOptions.isWatchdogEnabled());
369 
370  planner.setFilterPushDownInfo(parserOptions.getFilterPushDownInfo());
371  RelRoot optRel = planner.buildRATreeAndPerformQueryOptimization(query, schema);
372  optRel = replaceIsTrue(planner.getTypeFactory(), optRel);
373  return HeavyDBSerializer.toString(optRel.project());
374  }
375 
376  public String processSql(String sql, final HeavyDBParserOptions parserOptions)
377  throws SqlParseException, ValidationException, RelConversionException {
378  callCount++;
379 
380  final HeavyDBPlanner planner = getPlanner(true, parserOptions.isWatchdogEnabled());
381  final SqlNode sqlNode = parseSql(sql, parserOptions.isLegacySyntax(), planner);
382 
383  return processSql(sqlNode, parserOptions);
384  }
385 
386  public String processSql(
387  final SqlNode sqlNode, final HeavyDBParserOptions parserOptions)
388  throws SqlParseException, ValidationException, RelConversionException {
389  callCount++;
390 
391  if (sqlNode instanceof JsonSerializableDdl) {
392  return ((JsonSerializableDdl) sqlNode).toJsonString();
393  }
394 
395  if (sqlNode instanceof SqlDdl) {
396  return sqlNode.toString();
397  }
398 
399  final HeavyDBPlanner planner = getPlanner(true, parserOptions.isWatchdogEnabled());
400  planner.advanceToValidate();
401 
402  final RelRoot sqlRel = convertSqlToRelNode(sqlNode, planner, parserOptions);
403  RelNode project = sqlRel.project();
404 
405  if (parserOptions.isExplain()) {
406  return RelOptUtil.toString(sqlRel.project());
407  }
408 
409  String res = HeavyDBSerializer.toString(project);
410 
411  return res;
412  }
413 
414  public HeavyDBPlanner.CompletionResult getCompletionHints(
415  String sql, int cursor, List<String> visible_tables) {
416  return getPlanner().getCompletionHints(sql, cursor, visible_tables);
417  }
418 
419  public HashSet<ImmutableList<String>> resolveSelectIdentifiers(
420  SqlIdentifierCapturer capturer) {
421  HeavyDBSchema schema =
423  HashSet<ImmutableList<String>> resolved = new HashSet<ImmutableList<String>>();
424 
425  for (ImmutableList<String> names : capturer.selects) {
426  HeavyDBTable table = (HeavyDBTable) schema.getTable(names.get(0));
427  if (null == table) {
428  throw new RuntimeException("table/view not found: " + names.get(0));
429  }
430 
431  if (table instanceof HeavyDBView) {
432  HeavyDBView view = (HeavyDBView) table;
433  resolved.addAll(resolveSelectIdentifiers(view.getAccessedObjects()));
434  } else {
435  resolved.add(names);
436  }
437  }
438 
439  return resolved;
440  }
441 
442  private String getTableName(SqlNode node) {
443  if (node.isA(EnumSet.of(SqlKind.AS))) {
444  node = ((SqlCall) node).getOperandList().get(1);
445  }
446  if (node instanceof SqlIdentifier) {
447  SqlIdentifier id = (SqlIdentifier) node;
448  return id.names.get(id.names.size() - 1);
449  }
450  return null;
451  }
452 
453  private SqlSelect rewriteSimpleUpdateAsSelect(final SqlUpdate update) {
454  SqlNode where = update.getCondition();
455 
456  if (update.getSourceExpressionList().size() != 1) {
457  return null;
458  }
459 
460  if (!(update.getSourceExpressionList().get(0) instanceof SqlSelect)) {
461  return null;
462  }
463 
464  final SqlSelect inner = (SqlSelect) update.getSourceExpressionList().get(0);
465 
466  if (null != inner.getGroup() || null != inner.getFetch() || null != inner.getOffset()
467  || (null != inner.getOrderList() && inner.getOrderList().size() != 0)
468  || (null != inner.getGroup() && inner.getGroup().size() != 0)
469  || null == getTableName(inner.getFrom())) {
470  return null;
471  }
472 
473  if (!isCorrelated(inner)) {
474  return null;
475  }
476 
477  final String updateTableName = getTableName(update.getTargetTable());
478 
479  if (null != where) {
480  where = where.accept(new SqlShuttle() {
481  @Override
482  public SqlNode visit(SqlIdentifier id) {
483  if (id.isSimple()) {
484  id = new SqlIdentifier(Arrays.asList(updateTableName, id.getSimple()),
485  id.getParserPosition());
486  }
487 
488  return id;
489  }
490  });
491  }
492 
493  SqlJoin join = new SqlJoin(ZERO,
494  update.getTargetTable(),
495  SqlLiteral.createBoolean(false, ZERO),
496  SqlLiteral.createSymbol(JoinType.LEFT, ZERO),
497  inner.getFrom(),
498  SqlLiteral.createSymbol(JoinConditionType.ON, ZERO),
499  inner.getWhere());
500 
501  SqlNode select0 = inner.getSelectList().get(0);
502 
503  boolean wrapInSingleValue = true;
504  if (select0 instanceof SqlCall) {
505  SqlCall selectExprCall = (SqlCall) select0;
506  if (Util.isSingleValue(selectExprCall)) {
507  wrapInSingleValue = false;
508  }
509  }
510 
511  if (wrapInSingleValue) {
512  if (select0.isA(EnumSet.of(SqlKind.AS))) {
513  select0 = ((SqlCall) select0).getOperandList().get(0);
514  }
515  select0 = new SqlBasicCall(
516  SqlStdOperatorTable.SINGLE_VALUE, new SqlNode[] {select0}, ZERO);
517  }
518 
519  SqlNodeList selectList = new SqlNodeList(ZERO);
520  selectList.add(select0);
521  selectList.add(new SqlBasicCall(SqlStdOperatorTable.AS,
522  new SqlNode[] {new SqlBasicCall(
523  new SqlUnresolvedFunction(
524  new SqlIdentifier("OFFSET_IN_FRAGMENT", ZERO),
525  null,
526  null,
527  null,
528  null,
529  SqlFunctionCategory.USER_DEFINED_FUNCTION),
530  new SqlNode[0],
531  SqlParserPos.ZERO),
532  new SqlIdentifier("EXPR$DELETE_OFFSET_IN_FRAGMENT", ZERO)},
533  ZERO));
534 
535  SqlNodeList groupBy = new SqlNodeList(ZERO);
536  groupBy.add(new SqlIdentifier("EXPR$DELETE_OFFSET_IN_FRAGMENT", ZERO));
537 
538  SqlSelect select = new SqlSelect(ZERO,
539  null,
540  selectList,
541  join,
542  where,
543  groupBy,
544  null,
545  null,
546  null,
547  null,
548  null,
549  null);
550  return select;
551  }
552 
553  private LogicalTableModify getDummyUpdate(SqlUpdate update)
554  throws SqlParseException, ValidationException, RelConversionException {
555  SqlIdentifier targetTable = (SqlIdentifier) update.getTargetTable();
556  String targetTableName = targetTable.names.get(targetTable.names.size() - 1);
557  HeavyDBPlanner planner = getPlanner();
558  String dummySql = "DELETE FROM " + targetTableName;
559  SqlNode dummyNode = planner.parse(dummySql);
560  dummyNode = planner.validate(dummyNode);
561  RelRoot dummyRoot = planner.rel(dummyNode);
562  LogicalTableModify dummyModify = (LogicalTableModify) dummyRoot.rel;
563  return dummyModify;
564  }
565 
566  private RelRoot rewriteUpdateAsSelect(
567  SqlUpdate update, HeavyDBParserOptions parserOptions)
568  throws SqlParseException, ValidationException, RelConversionException {
569  int correlatedQueriesCount[] = new int[1];
570  SqlBasicVisitor<Void> correlatedQueriesCounter = new SqlBasicVisitor<Void>() {
571  @Override
572  public Void visit(SqlCall call) {
573  if (call.isA(SCALAR)
574  && ((call instanceof SqlBasicCall && call.operandCount() == 1
575  && !call.operand(0).isA(SCALAR))
576  || !(call instanceof SqlBasicCall))) {
577  if (isCorrelated(call)) {
578  correlatedQueriesCount[0]++;
579  }
580  }
581  return super.visit(call);
582  }
583  };
584 
585  update.accept(correlatedQueriesCounter);
586  if (correlatedQueriesCount[0] > 1) {
587  throw new CalciteException(
588  "table modifications with multiple correlated sub-queries not supported.",
589  null);
590  }
591 
592  boolean allowSubqueryDecorrelation = true;
593  SqlNode updateCondition = update.getCondition();
594  if (null != updateCondition) {
595  boolean hasInClause =
596  new FindSqlOperator().containsSqlOperator(updateCondition, SqlKind.IN);
597  if (hasInClause) {
598  SqlNode updateTargetTable = update.getTargetTable();
599  if (null != updateTargetTable && updateTargetTable instanceof SqlIdentifier) {
600  SqlIdentifier targetTable = (SqlIdentifier) updateTargetTable;
601  if (targetTable.names.size() == 2) {
602  final MetaConnect mc = new MetaConnect(dbPort,
603  dataDir,
604  dbUser,
605  this,
606  sock_transport_properties,
607  targetTable.names.get(0));
608  TTableDetails updateTargetTableDetails =
609  mc.get_table_details(targetTable.names.get(1));
610  if (null != updateTargetTableDetails
611  && updateTargetTableDetails.is_temporary) {
612  allowSubqueryDecorrelation = false;
613  }
614  }
615  }
616  }
617  }
618 
619  SqlNodeList sourceExpression = new SqlNodeList(SqlParserPos.ZERO);
620  LogicalTableModify dummyModify = getDummyUpdate(update);
621  RelOptTable targetTable = dummyModify.getTable();
622  RelDataType targetTableType = targetTable.getRowType();
623 
624  SqlSelect select = rewriteSimpleUpdateAsSelect(update);
625  boolean applyRexCast = null == select;
626 
627  if (null == select) {
628  for (int i = 0; i < update.getSourceExpressionList().size(); i++) {
629  SqlNode targetColumn = update.getTargetColumnList().get(i);
630  SqlNode expression = update.getSourceExpressionList().get(i);
631 
632  if (!(targetColumn instanceof SqlIdentifier)) {
633  throw new RuntimeException("Unknown identifier type!");
634  }
635  SqlIdentifier id = (SqlIdentifier) targetColumn;
636  RelDataType fieldType =
637  targetTableType.getField(id.names.get(id.names.size() - 1), false, false)
638  .getType();
639 
640  if (expression.isA(ARRAY_VALUE) && null != fieldType.getComponentType()) {
641  // apply a cast to all array value elements
642 
643  SqlDataTypeSpec elementType = new SqlDataTypeSpec(
644  new SqlBasicTypeNameSpec(fieldType.getComponentType().getSqlTypeName(),
645  fieldType.getPrecision(),
646  fieldType.getScale(),
647  null == fieldType.getCharset() ? null
648  : fieldType.getCharset().name(),
649  SqlParserPos.ZERO),
650  SqlParserPos.ZERO);
651  SqlCall array_expression = (SqlCall) expression;
652  ArrayList<SqlNode> values = new ArrayList<>();
653 
654  for (SqlNode value : array_expression.getOperandList()) {
655  if (value.isA(EnumSet.of(SqlKind.LITERAL))) {
656  SqlNode casted_value = new SqlBasicCall(SqlStdOperatorTable.CAST,
657  new SqlNode[] {value, elementType},
658  value.getParserPosition());
659  values.add(casted_value);
660  } else {
661  values.add(value);
662  }
663  }
664 
665  expression = new SqlBasicCall(HeavyDBSqlOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
666  values.toArray(new SqlNode[0]),
667  expression.getParserPosition());
668  }
669  sourceExpression.add(expression);
670  }
671 
672  sourceExpression.add(new SqlBasicCall(SqlStdOperatorTable.AS,
673  new SqlNode[] {
674  new SqlBasicCall(new SqlUnresolvedFunction(
675  new SqlIdentifier("OFFSET_IN_FRAGMENT",
676  SqlParserPos.ZERO),
677  null,
678  null,
679  null,
680  null,
681  SqlFunctionCategory.USER_DEFINED_FUNCTION),
682  new SqlNode[0],
683  SqlParserPos.ZERO),
684  new SqlIdentifier("EXPR$DELETE_OFFSET_IN_FRAGMENT", ZERO)},
685  ZERO));
686 
687  select = new SqlSelect(SqlParserPos.ZERO,
688  null,
689  sourceExpression,
690  update.getTargetTable(),
691  update.getCondition(),
692  null,
693  null,
694  null,
695  null,
696  null,
697  null,
698  null);
699  }
700 
701  HeavyDBPlanner planner =
702  getPlanner(allowSubqueryDecorrelation, parserOptions.isWatchdogEnabled());
703  SqlNode node = null;
704  try {
705  node = planner.parse(select.toSqlString(CalciteSqlDialect.DEFAULT).getSql());
706  node = planner.validate(node);
707  } catch (Exception e) {
708  HEAVYDBLOGGER.error("Error processing UPDATE rewrite, rewritten stmt was: "
709  + select.toSqlString(CalciteSqlDialect.DEFAULT).getSql());
710  throw e;
711  }
712 
713  RelRoot root = planner.rel(node);
714  LogicalProject project = (LogicalProject) root.project();
715 
716  ArrayList<String> fields = new ArrayList<String>();
717  ArrayList<RexNode> nodes = new ArrayList<RexNode>();
718  final RexBuilder builder = new RexBuilder(planner.getTypeFactory());
719 
720  for (SqlNode n : update.getTargetColumnList()) {
721  if (n instanceof SqlIdentifier) {
722  SqlIdentifier id = (SqlIdentifier) n;
723  fields.add(id.names.get(id.names.size() - 1));
724  } else {
725  throw new RuntimeException("Unknown identifier type!");
726  }
727  }
728 
729  // The magical number here when processing the projection
730  // is skipping the OFFSET_IN_FRAGMENT() expression used by
731  // update and delete
732  int idx = 0;
733  for (RexNode exp : project.getProjects()) {
734  if (applyRexCast && idx + 1 < project.getProjects().size()) {
735  RelDataType expectedFieldType =
736  targetTableType.getField(fields.get(idx), false, false).getType();
737  if (!exp.getType().equals(expectedFieldType) && !exp.isA(ARRAY_VALUE)) {
738  exp = builder.makeCast(expectedFieldType, exp);
739  }
740  }
741 
742  nodes.add(exp);
743  idx++;
744  }
745 
746  ArrayList<RexNode> inputs = new ArrayList<RexNode>();
747  int n = 0;
748  for (int i = 0; i < fields.size(); i++) {
749  inputs.add(
750  new RexInputRef(n, project.getRowType().getFieldList().get(n).getType()));
751  n++;
752  }
753 
754  fields.add("EXPR$DELETE_OFFSET_IN_FRAGMENT");
755  inputs.add(new RexInputRef(n, project.getRowType().getFieldList().get(n).getType()));
756 
757  project = project.copy(
758  project.getTraitSet(), project.getInput(), nodes, project.getRowType());
759 
760  LogicalTableModify modify = LogicalTableModify.create(targetTable,
761  dummyModify.getCatalogReader(),
762  project,
763  Operation.UPDATE,
764  fields,
765  inputs,
766  true);
767  return RelRoot.of(modify, SqlKind.UPDATE);
768  }
769 
770  RelRoot queryToRelNode(final String sql, final HeavyDBParserOptions parserOptions)
771  throws SqlParseException, ValidationException, RelConversionException {
772  final HeavyDBPlanner planner = getPlanner(true, parserOptions.isWatchdogEnabled());
773  final SqlNode sqlNode = parseSql(sql, parserOptions.isLegacySyntax(), planner);
774  return convertSqlToRelNode(sqlNode, planner, parserOptions);
775  }
776 
777  RelRoot convertSqlToRelNode(final SqlNode sqlNode,
779  final HeavyDBParserOptions parserOptions)
780  throws SqlParseException, ValidationException, RelConversionException {
781  SqlNode node = sqlNode;
782  HeavyDBPlanner planner = HeavyDBPlanner;
783  boolean allowCorrelatedSubQueryExpansion = true;
784  boolean patchUpdateToDelete = false;
785  if (node.isA(DELETE)) {
786  SqlDelete sqlDelete = (SqlDelete) node;
787  node = new SqlUpdate(node.getParserPosition(),
788  sqlDelete.getTargetTable(),
789  SqlNodeList.EMPTY,
790  SqlNodeList.EMPTY,
791  sqlDelete.getCondition(),
792  sqlDelete.getSourceSelect(),
793  sqlDelete.getAlias());
794 
795  patchUpdateToDelete = true;
796  }
797  if (node.isA(UPDATE)) {
798  SqlUpdate update = (SqlUpdate) node;
799  update = (SqlUpdate) planner.validate(update);
800  RelRoot root = rewriteUpdateAsSelect(update, parserOptions);
801 
802  if (patchUpdateToDelete) {
803  LogicalTableModify modify = (LogicalTableModify) root.rel;
804 
805  try {
806  Field f = TableModify.class.getDeclaredField("operation");
807  f.setAccessible(true);
808  f.set(modify, Operation.DELETE);
809  } catch (Throwable e) {
810  throw new RuntimeException(e);
811  }
812 
813  root = RelRoot.of(modify, SqlKind.DELETE);
814  }
815 
816  return root;
817  }
818  if (parserOptions.isLegacySyntax()) {
819  // close original planner
820  planner.close();
821  // create a new one
822  planner = getPlanner(
823  allowCorrelatedSubQueryExpansion, parserOptions.isWatchdogEnabled());
824  node = parseSql(
825  node.toSqlString(CalciteSqlDialect.DEFAULT).toString(), false, planner);
826  }
827 
828  SqlNode validateR = planner.validate(node);
829  planner.setFilterPushDownInfo(parserOptions.getFilterPushDownInfo());
830  // check to see if a view is involved in the query
831  boolean foundView = false;
832  HeavyDBSchema schema =
833  new HeavyDBSchema(dataDir, this, dbPort, dbUser, sock_transport_properties);
834  SqlIdentifierCapturer capturer = captureIdentifiers(sqlNode);
835  for (ImmutableList<String> names : capturer.selects) {
836  HeavyDBTable table = (HeavyDBTable) schema.getTable(names.get(0));
837  if (null == table) {
838  throw new RuntimeException("table/view not found: " + names.get(0));
839  }
840  if (table instanceof HeavyDBView) {
841  foundView = true;
842  }
843  }
844  RelRoot relRootNode = planner.getRelRoot(validateR);
845  relRootNode = replaceIsTrue(planner.getTypeFactory(), relRootNode);
846  RelNode rootNode = planner.optimizeRATree(
847  relRootNode.project(), parserOptions.isViewOptimizeEnabled(), foundView);
848  planner.close();
849  return new RelRoot(rootNode,
850  relRootNode.validatedRowType,
851  relRootNode.kind,
852  relRootNode.fields,
853  relRootNode.collation,
854  Collections.emptyList());
855  }
856 
857  private RelRoot replaceIsTrue(final RelDataTypeFactory typeFactory, RelRoot root) {
858  final RexShuttle callShuttle = new RexShuttle() {
859  RexBuilder builder = new RexBuilder(typeFactory);
860 
861  public RexNode visitCall(RexCall call) {
862  call = (RexCall) super.visitCall(call);
863  if (call.getKind() == SqlKind.IS_TRUE) {
864  return builder.makeCall(SqlStdOperatorTable.AND,
865  builder.makeCall(
866  SqlStdOperatorTable.IS_NOT_NULL, call.getOperands().get(0)),
867  call.getOperands().get(0));
868  } else if (call.getKind() == SqlKind.IS_NOT_TRUE) {
869  return builder.makeCall(SqlStdOperatorTable.OR,
870  builder.makeCall(
871  SqlStdOperatorTable.IS_NULL, call.getOperands().get(0)),
872  builder.makeCall(SqlStdOperatorTable.NOT, call.getOperands().get(0)));
873  } else if (call.getKind() == SqlKind.IS_FALSE) {
874  return builder.makeCall(SqlStdOperatorTable.AND,
875  builder.makeCall(
876  SqlStdOperatorTable.IS_NOT_NULL, call.getOperands().get(0)),
877  builder.makeCall(SqlStdOperatorTable.NOT, call.getOperands().get(0)));
878  } else if (call.getKind() == SqlKind.IS_NOT_FALSE) {
879  return builder.makeCall(SqlStdOperatorTable.OR,
880  builder.makeCall(
881  SqlStdOperatorTable.IS_NULL, call.getOperands().get(0)),
882  call.getOperands().get(0));
883  }
884 
885  return call;
886  }
887  };
888 
889  RelNode node = root.rel.accept(new RelShuttleImpl() {
890  @Override
891  protected RelNode visitChild(RelNode parent, int i, RelNode child) {
892  RelNode node = super.visitChild(parent, i, child);
893  return node.accept(callShuttle);
894  }
895  });
896 
897  return new RelRoot(node,
898  root.validatedRowType,
899  root.kind,
900  root.fields,
901  root.collation,
902  Collections.emptyList());
903  }
904 
905  private SqlNode parseSql(String sql, final boolean legacy_syntax, Planner planner)
906  throws SqlParseException {
907  SqlNode parseR = null;
908  try {
909  parseR = planner.parse(sql);
910  HEAVYDBLOGGER.debug(" node is \n" + parseR.toString());
911  } catch (SqlParseException ex) {
912  HEAVYDBLOGGER.error("failed to parse SQL '" + sql + "' \n" + ex.toString());
913  throw ex;
914  }
915 
916  if (!legacy_syntax) {
917  return parseR;
918  }
919 
920  RelDataTypeFactory typeFactory = planner.getTypeFactory();
921  SqlSelect select_node = null;
922  if (parseR instanceof SqlSelect) {
923  select_node = (SqlSelect) parseR;
924  desugar(select_node, typeFactory);
925  } else if (parseR instanceof SqlOrderBy) {
926  SqlOrderBy order_by_node = (SqlOrderBy) parseR;
927  if (order_by_node.query instanceof SqlSelect) {
928  select_node = (SqlSelect) order_by_node.query;
929  SqlOrderBy new_order_by_node = desugar(select_node, order_by_node, typeFactory);
930  if (new_order_by_node != null) {
931  return new_order_by_node;
932  }
933  } else if (order_by_node.query instanceof SqlWith) {
934  SqlWith old_with_node = (SqlWith) order_by_node.query;
935  if (old_with_node.body instanceof SqlSelect) {
936  select_node = (SqlSelect) old_with_node.body;
937  desugar(select_node, typeFactory);
938  }
939  }
940  } else if (parseR instanceof SqlWith) {
941  SqlWith old_with_node = (SqlWith) parseR;
942  if (old_with_node.body instanceof SqlSelect) {
943  select_node = (SqlSelect) old_with_node.body;
944  desugar(select_node, typeFactory);
945  }
946  }
947  return parseR;
948  }
949 
950  private void desugar(SqlSelect select_node, RelDataTypeFactory typeFactory) {
951  desugar(select_node, null, typeFactory);
952  }
953 
954  private SqlNode expandCase(SqlCase old_case_node, RelDataTypeFactory typeFactory) {
955  SqlNodeList newWhenList =
956  new SqlNodeList(old_case_node.getWhenOperands().getParserPosition());
957  SqlNodeList newThenList =
958  new SqlNodeList(old_case_node.getThenOperands().getParserPosition());
959  java.util.Map<String, SqlNode> id_to_expr = new java.util.HashMap<String, SqlNode>();
960  for (SqlNode node : old_case_node.getWhenOperands()) {
961  SqlNode newCall = expand(node, id_to_expr, typeFactory);
962  if (null != newCall) {
963  newWhenList.add(newCall);
964  } else {
965  newWhenList.add(node);
966  }
967  }
968  for (SqlNode node : old_case_node.getThenOperands()) {
969  SqlNode newCall = expand(node, id_to_expr, typeFactory);
970  if (null != newCall) {
971  newThenList.add(newCall);
972  } else {
973  newThenList.add(node);
974  }
975  }
976  SqlNode new_else_operand = old_case_node.getElseOperand();
977  if (null != new_else_operand) {
978  SqlNode candidate_else_operand =
979  expand(old_case_node.getElseOperand(), id_to_expr, typeFactory);
980  if (null != candidate_else_operand) {
981  new_else_operand = candidate_else_operand;
982  }
983  }
984  SqlNode new_value_operand = old_case_node.getValueOperand();
985  if (null != new_value_operand) {
986  SqlNode candidate_value_operand =
987  expand(old_case_node.getValueOperand(), id_to_expr, typeFactory);
988  if (null != candidate_value_operand) {
989  new_value_operand = candidate_value_operand;
990  }
991  }
992  SqlNode newCaseNode = SqlCase.createSwitched(old_case_node.getParserPosition(),
993  new_value_operand,
994  newWhenList,
995  newThenList,
996  new_else_operand);
997  return newCaseNode;
998  }
999 
1000  private SqlOrderBy desugar(SqlSelect select_node,
1001  SqlOrderBy order_by_node,
1002  RelDataTypeFactory typeFactory) {
1003  HEAVYDBLOGGER.debug("desugar: before: " + select_node.toString());
1004  desugarExpression(select_node.getFrom(), typeFactory);
1005  desugarExpression(select_node.getWhere(), typeFactory);
1006  SqlNodeList select_list = select_node.getSelectList();
1007  SqlNodeList new_select_list = new SqlNodeList(select_list.getParserPosition());
1008  java.util.Map<String, SqlNode> id_to_expr = new java.util.HashMap<String, SqlNode>();
1009  for (SqlNode proj : select_list) {
1010  if (!(proj instanceof SqlBasicCall)) {
1011  if (proj instanceof SqlCase) {
1012  new_select_list.add(expandCase((SqlCase) proj, typeFactory));
1013  } else {
1014  new_select_list.add(proj);
1015  }
1016  } else {
1017  assert proj instanceof SqlBasicCall;
1018  SqlBasicCall proj_call = (SqlBasicCall) proj;
1019  if (proj_call.operands.length > 0) {
1020  for (int i = 0; i < proj_call.operands.length; i++) {
1021  if (proj_call.operand(i) instanceof SqlCase) {
1022  SqlNode new_op = expandCase(proj_call.operand(i), typeFactory);
1023  proj_call.setOperand(i, new_op);
1024  }
1025  }
1026  }
1027  new_select_list.add(expand(proj_call, id_to_expr, typeFactory));
1028  }
1029  }
1030  select_node.setSelectList(new_select_list);
1031  SqlNodeList group_by_list = select_node.getGroup();
1032  if (group_by_list != null) {
1033  select_node.setGroupBy(expand(group_by_list, id_to_expr, typeFactory));
1034  }
1035  SqlNode having = select_node.getHaving();
1036  if (having != null) {
1037  expand(having, id_to_expr, typeFactory);
1038  }
1039  SqlOrderBy new_order_by_node = null;
1040  if (order_by_node != null && order_by_node.orderList != null
1041  && order_by_node.orderList.size() > 0) {
1042  SqlNodeList new_order_by_list =
1043  expand(order_by_node.orderList, id_to_expr, typeFactory);
1044  new_order_by_node = new SqlOrderBy(order_by_node.getParserPosition(),
1045  select_node,
1046  new_order_by_list,
1047  order_by_node.offset,
1048  order_by_node.fetch);
1049  }
1050 
1051  HEAVYDBLOGGER.debug("desugar: after: " + select_node.toString());
1052  return new_order_by_node;
1053  }
1054 
1055  private void desugarExpression(SqlNode node, RelDataTypeFactory typeFactory) {
1056  if (node instanceof SqlSelect) {
1057  desugar((SqlSelect) node, typeFactory);
1058  return;
1059  }
1060  if (!(node instanceof SqlBasicCall)) {
1061  return;
1062  }
1063  SqlBasicCall basic_call = (SqlBasicCall) node;
1064  for (SqlNode operator : basic_call.getOperands()) {
1065  if (operator instanceof SqlOrderBy) {
1066  desugarExpression(((SqlOrderBy) operator).query, typeFactory);
1067  } else {
1068  desugarExpression(operator, typeFactory);
1069  }
1070  }
1071  }
1072 
1073  private SqlNode expand(final SqlNode node,
1074  final java.util.Map<String, SqlNode> id_to_expr,
1075  RelDataTypeFactory typeFactory) {
1076  HEAVYDBLOGGER.debug("expand: " + node.toString());
1077  if (node instanceof SqlBasicCall) {
1078  SqlBasicCall node_call = (SqlBasicCall) node;
1079  SqlNode[] operands = node_call.getOperands();
1080  for (int i = 0; i < operands.length; ++i) {
1081  node_call.setOperand(i, expand(operands[i], id_to_expr, typeFactory));
1082  }
1083  SqlNode expanded_substr = expandSubstr(node_call, typeFactory);
1084  if (expanded_substr != null) {
1085  return expanded_substr;
1086  }
1087  SqlNode expanded_variance = expandVariance(node_call, typeFactory);
1088  if (expanded_variance != null) {
1089  return expanded_variance;
1090  }
1091  SqlNode expanded_covariance = expandCovariance(node_call, typeFactory);
1092  if (expanded_covariance != null) {
1093  return expanded_covariance;
1094  }
1095  SqlNode expanded_correlation = expandCorrelation(node_call, typeFactory);
1096  if (expanded_correlation != null) {
1097  return expanded_correlation;
1098  }
1099  }
1100  if (node instanceof SqlSelect) {
1101  SqlSelect select_node = (SqlSelect) node;
1102  desugar(select_node, typeFactory);
1103  }
1104  return node;
1105  }
1106 
1107  private SqlNodeList expand(final SqlNodeList group_by_list,
1108  final java.util.Map<String, SqlNode> id_to_expr,
1109  RelDataTypeFactory typeFactory) {
1110  SqlNodeList new_group_by_list = new SqlNodeList(new SqlParserPos(-1, -1));
1111  for (SqlNode group_by : group_by_list) {
1112  if (!(group_by instanceof SqlIdentifier)) {
1113  new_group_by_list.add(expand(group_by, id_to_expr, typeFactory));
1114  continue;
1115  }
1116  SqlIdentifier group_by_id = ((SqlIdentifier) group_by);
1117  if (id_to_expr.containsKey(group_by_id.toString())) {
1118  new_group_by_list.add(id_to_expr.get(group_by_id.toString()));
1119  } else {
1120  new_group_by_list.add(group_by);
1121  }
1122  }
1123  return new_group_by_list;
1124  }
1125 
1126  private SqlNode expandSubstr(
1127  final SqlBasicCall proj_call, RelDataTypeFactory typeFactory) {
1128  // Expand SUBSTR to Calcite-native SUBSTRING
1129  if (!proj_call.getOperator().isName("SUBSTR", false)) {
1130  return null;
1131  }
1132  if (proj_call.operandCount() < 2 || proj_call.operandCount() > 3) {
1133  return null;
1134  }
1135  final SqlParserPos pos = proj_call.getParserPosition();
1136  final SqlNode primary_operand = proj_call.operand(0);
1137  final SqlNode from_operand = proj_call.operand(1);
1138  if (proj_call.operandCount() == 2) {
1139  return SqlStdOperatorTable.SUBSTRING.createCall(pos, primary_operand, from_operand);
1140  }
1141  final SqlNode for_operand = proj_call.operand(2);
1142  return SqlStdOperatorTable.SUBSTRING.createCall(
1143  pos, primary_operand, from_operand, for_operand);
1144  }
1145 
1146  private SqlNode expandVariance(
1147  final SqlBasicCall proj_call, RelDataTypeFactory typeFactory) {
1148  // Expand variance aggregates that are not supported natively
1149  if (proj_call.operandCount() != 1) {
1150  return null;
1151  }
1152  boolean biased;
1153  boolean sqrt;
1154  boolean flt;
1155  if (proj_call.getOperator().isName("STDDEV_POP", false)) {
1156  biased = true;
1157  sqrt = true;
1158  flt = false;
1159  } else if (proj_call.getOperator().getName().equalsIgnoreCase("STDDEV_POP_FLOAT")) {
1160  biased = true;
1161  sqrt = true;
1162  flt = true;
1163  } else if (proj_call.getOperator().isName("STDDEV_SAMP", false)
1164  || proj_call.getOperator().getName().equalsIgnoreCase("STDDEV")) {
1165  biased = false;
1166  sqrt = true;
1167  flt = false;
1168  } else if (proj_call.getOperator().getName().equalsIgnoreCase("STDDEV_SAMP_FLOAT")
1169  || proj_call.getOperator().getName().equalsIgnoreCase("STDDEV_FLOAT")) {
1170  biased = false;
1171  sqrt = true;
1172  flt = true;
1173  } else if (proj_call.getOperator().isName("VAR_POP", false)) {
1174  biased = true;
1175  sqrt = false;
1176  flt = false;
1177  } else if (proj_call.getOperator().getName().equalsIgnoreCase("VAR_POP_FLOAT")) {
1178  biased = true;
1179  sqrt = false;
1180  flt = true;
1181  } else if (proj_call.getOperator().isName("VAR_SAMP", false)
1182  || proj_call.getOperator().getName().equalsIgnoreCase("VARIANCE")) {
1183  biased = false;
1184  sqrt = false;
1185  flt = false;
1186  } else if (proj_call.getOperator().getName().equalsIgnoreCase("VAR_SAMP_FLOAT")
1187  || proj_call.getOperator().getName().equalsIgnoreCase("VARIANCE_FLOAT")) {
1188  biased = false;
1189  sqrt = false;
1190  flt = true;
1191  } else {
1192  return null;
1193  }
1194  final SqlNode operand = proj_call.operand(0);
1195  final SqlParserPos pos = proj_call.getParserPosition();
1196  SqlNode expanded_proj_call =
1197  expandVariance(pos, operand, biased, sqrt, flt, typeFactory);
1198  HEAVYDBLOGGER.debug("Expanded select_list SqlCall: " + proj_call.toString());
1199  HEAVYDBLOGGER.debug("to : " + expanded_proj_call.toString());
1200  return expanded_proj_call;
1201  }
1202 
1203  private SqlNode expandVariance(final SqlParserPos pos,
1204  final SqlNode operand,
1205  boolean biased,
1206  boolean sqrt,
1207  boolean flt,
1208  RelDataTypeFactory typeFactory) {
1209  // stddev_pop(x) ==>
1210  // power(
1211  // (sum(x * x) - sum(x) * sum(x) / (case count(x) when 0 then NULL else count(x)
1212  // end)) / (case count(x) when 0 then NULL else count(x) end), .5)
1213  //
1214  // stddev_samp(x) ==>
1215  // power(
1216  // (sum(x * x) - sum(x) * sum(x) / (case count(x) when 0 then NULL else count(x)
1217  // )) / ((case count(x) when 1 then NULL else count(x) - 1 end)), .5)
1218  //
1219  // var_pop(x) ==>
1220  // (sum(x * x) - sum(x) * sum(x) / ((case count(x) when 0 then NULL else
1221  // count(x)
1222  // end))) / ((case count(x) when 0 then NULL else count(x) end))
1223  //
1224  // var_samp(x) ==>
1225  // (sum(x * x) - sum(x) * sum(x) / ((case count(x) when 0 then NULL else
1226  // count(x)
1227  // end))) / ((case count(x) when 1 then NULL else count(x) - 1 end))
1228  //
1229  final SqlNode arg = SqlStdOperatorTable.CAST.createCall(pos,
1230  operand,
1231  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1232  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1233  final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
1234  final SqlNode sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared);
1235  final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg);
1236  final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, sum, sum);
1237  final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg);
1238  final SqlLiteral nul = SqlLiteral.createNull(pos);
1239  final SqlNumericLiteral zero = SqlLiteral.createExactNumeric("0", pos);
1240  final SqlNode countEqZero = SqlStdOperatorTable.EQUALS.createCall(pos, count, zero);
1241  SqlNodeList whenList = new SqlNodeList(pos);
1242  SqlNodeList thenList = new SqlNodeList(pos);
1243  whenList.add(countEqZero);
1244  thenList.add(nul);
1245  final SqlNode int_denominator = SqlStdOperatorTable.CASE.createCall(
1246  null, pos, null, whenList, thenList, count);
1247  final SqlNode denominator = SqlStdOperatorTable.CAST.createCall(pos,
1248  int_denominator,
1249  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1250  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1251  final SqlNode avgSumSquared =
1252  SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, denominator);
1253  final SqlNode diff =
1254  SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquared, avgSumSquared);
1255  final SqlNode denominator1;
1256  if (biased) {
1257  denominator1 = denominator;
1258  } else {
1259  final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
1260  final SqlNode countEqOne = SqlStdOperatorTable.EQUALS.createCall(pos, count, one);
1261  final SqlNode countMinusOne = SqlStdOperatorTable.MINUS.createCall(pos, count, one);
1262  SqlNodeList whenList1 = new SqlNodeList(pos);
1263  SqlNodeList thenList1 = new SqlNodeList(pos);
1264  whenList1.add(countEqOne);
1265  thenList1.add(nul);
1266  final SqlNode int_denominator1 = SqlStdOperatorTable.CASE.createCall(
1267  null, pos, null, whenList1, thenList1, countMinusOne);
1268  denominator1 = SqlStdOperatorTable.CAST.createCall(pos,
1269  int_denominator1,
1270  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1271  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1272  }
1273  final SqlNode div = SqlStdOperatorTable.DIVIDE.createCall(pos, diff, denominator1);
1274  SqlNode result = div;
1275  if (sqrt) {
1276  final SqlNumericLiteral half = SqlLiteral.createExactNumeric("0.5", pos);
1277  result = SqlStdOperatorTable.POWER.createCall(pos, div, half);
1278  }
1279  return SqlStdOperatorTable.CAST.createCall(pos,
1280  result,
1281  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1282  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1283  }
1284 
1285  private SqlNode expandCovariance(
1286  final SqlBasicCall proj_call, RelDataTypeFactory typeFactory) {
1287  // Expand covariance aggregates
1288  if (proj_call.operandCount() != 2) {
1289  return null;
1290  }
1291  boolean pop;
1292  boolean flt;
1293  if (proj_call.getOperator().isName("COVAR_POP", false)) {
1294  pop = true;
1295  flt = false;
1296  } else if (proj_call.getOperator().isName("COVAR_SAMP", false)) {
1297  pop = false;
1298  flt = false;
1299  } else if (proj_call.getOperator().getName().equalsIgnoreCase("COVAR_POP_FLOAT")) {
1300  pop = true;
1301  flt = true;
1302  } else if (proj_call.getOperator().getName().equalsIgnoreCase("COVAR_SAMP_FLOAT")) {
1303  pop = false;
1304  flt = true;
1305  } else {
1306  return null;
1307  }
1308  final SqlNode operand0 = proj_call.operand(0);
1309  final SqlNode operand1 = proj_call.operand(1);
1310  final SqlParserPos pos = proj_call.getParserPosition();
1311  SqlNode expanded_proj_call =
1312  expandCovariance(pos, operand0, operand1, pop, flt, typeFactory);
1313  HEAVYDBLOGGER.debug("Expanded select_list SqlCall: " + proj_call.toString());
1314  HEAVYDBLOGGER.debug("to : " + expanded_proj_call.toString());
1315  return expanded_proj_call;
1316  }
1317 
1318  private SqlNode expandCovariance(SqlParserPos pos,
1319  final SqlNode operand0,
1320  final SqlNode operand1,
1321  boolean pop,
1322  boolean flt,
1323  RelDataTypeFactory typeFactory) {
1324  // covar_pop(x, y) ==> avg(x * y) - avg(x) * avg(y)
1325  // covar_samp(x, y) ==> (sum(x * y) - sum(x) * avg(y))
1326  // ((case count(x) when 1 then NULL else count(x) - 1 end))
1327  final SqlNode arg0 = SqlStdOperatorTable.CAST.createCall(operand0.getParserPosition(),
1328  operand0,
1329  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1330  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1331  final SqlNode arg1 = SqlStdOperatorTable.CAST.createCall(operand1.getParserPosition(),
1332  operand1,
1333  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1334  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1335  final SqlNode mulArg = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg0, arg1);
1336  final SqlNode avgArg1 = SqlStdOperatorTable.AVG.createCall(pos, arg1);
1337  if (pop) {
1338  final SqlNode avgMulArg = SqlStdOperatorTable.AVG.createCall(pos, mulArg);
1339  final SqlNode avgArg0 = SqlStdOperatorTable.AVG.createCall(pos, arg0);
1340  final SqlNode mulAvgAvg =
1341  SqlStdOperatorTable.MULTIPLY.createCall(pos, avgArg0, avgArg1);
1342  final SqlNode covarPop =
1343  SqlStdOperatorTable.MINUS.createCall(pos, avgMulArg, mulAvgAvg);
1344  return SqlStdOperatorTable.CAST.createCall(pos,
1345  covarPop,
1346  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1347  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1348  }
1349  final SqlNode sumMulArg = SqlStdOperatorTable.SUM.createCall(pos, mulArg);
1350  final SqlNode sumArg0 = SqlStdOperatorTable.SUM.createCall(pos, arg0);
1351  final SqlNode mulSumAvg =
1352  SqlStdOperatorTable.MULTIPLY.createCall(pos, sumArg0, avgArg1);
1353  final SqlNode sub = SqlStdOperatorTable.MINUS.createCall(pos, sumMulArg, mulSumAvg);
1354  final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, operand0);
1355  final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
1356  final SqlNode countEqOne = SqlStdOperatorTable.EQUALS.createCall(pos, count, one);
1357  final SqlNode countMinusOne = SqlStdOperatorTable.MINUS.createCall(pos, count, one);
1358  final SqlLiteral nul = SqlLiteral.createNull(pos);
1359  SqlNodeList whenList1 = new SqlNodeList(pos);
1360  SqlNodeList thenList1 = new SqlNodeList(pos);
1361  whenList1.add(countEqOne);
1362  thenList1.add(nul);
1363  final SqlNode int_denominator = SqlStdOperatorTable.CASE.createCall(
1364  null, pos, null, whenList1, thenList1, countMinusOne);
1365  final SqlNode denominator = SqlStdOperatorTable.CAST.createCall(pos,
1366  int_denominator,
1367  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1368  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1369  final SqlNode covarSamp =
1370  SqlStdOperatorTable.DIVIDE.createCall(pos, sub, denominator);
1371  return SqlStdOperatorTable.CAST.createCall(pos,
1372  covarSamp,
1373  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1374  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1375  }
1376 
1377  private SqlNode expandCorrelation(
1378  final SqlBasicCall proj_call, RelDataTypeFactory typeFactory) {
1379  // Expand correlation coefficient
1380  if (proj_call.operandCount() != 2) {
1381  return null;
1382  }
1383  boolean flt;
1384  if (proj_call.getOperator().isName("CORR", false)
1385  || proj_call.getOperator().getName().equalsIgnoreCase("CORRELATION")) {
1386  // expand correlation coefficient
1387  flt = false;
1388  } else if (proj_call.getOperator().getName().equalsIgnoreCase("CORR_FLOAT")
1389  || proj_call.getOperator().getName().equalsIgnoreCase("CORRELATION_FLOAT")) {
1390  // expand correlation coefficient
1391  flt = true;
1392  } else {
1393  return null;
1394  }
1395  // corr(x, y) ==> (avg(x * y) - avg(x) * avg(y)) / (stddev_pop(x) *
1396  // stddev_pop(y))
1397  // ==> covar_pop(x, y) / (stddev_pop(x) * stddev_pop(y))
1398  final SqlNode operand0 = proj_call.operand(0);
1399  final SqlNode operand1 = proj_call.operand(1);
1400  final SqlParserPos pos = proj_call.getParserPosition();
1401  SqlNode covariance =
1402  expandCovariance(pos, operand0, operand1, true, flt, typeFactory);
1403  SqlNode stddev0 = expandVariance(pos, operand0, true, true, flt, typeFactory);
1404  SqlNode stddev1 = expandVariance(pos, operand1, true, true, flt, typeFactory);
1405  final SqlNode mulStddev =
1406  SqlStdOperatorTable.MULTIPLY.createCall(pos, stddev0, stddev1);
1407  final SqlNumericLiteral zero = SqlLiteral.createExactNumeric("0.0", pos);
1408  final SqlNode mulStddevEqZero =
1409  SqlStdOperatorTable.EQUALS.createCall(pos, mulStddev, zero);
1410  final SqlLiteral nul = SqlLiteral.createNull(pos);
1411  SqlNodeList whenList1 = new SqlNodeList(pos);
1412  SqlNodeList thenList1 = new SqlNodeList(pos);
1413  whenList1.add(mulStddevEqZero);
1414  thenList1.add(nul);
1415  final SqlNode denominator = SqlStdOperatorTable.CASE.createCall(
1416  null, pos, null, whenList1, thenList1, mulStddev);
1417  final SqlNode expanded_proj_call =
1418  SqlStdOperatorTable.DIVIDE.createCall(pos, covariance, denominator);
1419  HEAVYDBLOGGER.debug("Expanded select_list SqlCall: " + proj_call.toString());
1420  HEAVYDBLOGGER.debug("to : " + expanded_proj_call.toString());
1421  return expanded_proj_call;
1422  }
1423 
1424  public SqlIdentifierCapturer captureIdentifiers(String sql, boolean legacy_syntax)
1425  throws SqlParseException {
1426  try {
1427  Planner planner = getPlanner();
1428  SqlNode node = parseSql(sql, legacy_syntax, planner);
1429  return captureIdentifiers(node);
1430  } catch (Exception | Error e) {
1431  HEAVYDBLOGGER.error("Error parsing sql: " + sql, e);
1432  return new SqlIdentifierCapturer();
1433  }
1434  }
1435 
1436  public SqlIdentifierCapturer captureIdentifiers(SqlNode node) throws SqlParseException {
1437  try {
1439  capturer.scan(node);
1440  return capturer;
1441  } catch (Exception | Error e) {
1442  HEAVYDBLOGGER.error("Error parsing sql: " + node, e);
1443  return new SqlIdentifierCapturer();
1444  }
1445  }
1446 
1447  public int getCallCount() {
1448  return callCount;
1449  }
1450 
1451  public void updateMetaData(String schema, String table) {
1452  HEAVYDBLOGGER.debug("schema :" + schema + " table :" + table);
1453  HeavyDBSchema db =
1454  new HeavyDBSchema(dataDir, this, dbPort, null, sock_transport_properties);
1455  db.updateMetaData(schema, table);
1456  }
1457 
1458  protected RelDataTypeSystem createTypeSystem() {
1459  final HeavyDBTypeSystem typeSystem = new HeavyDBTypeSystem();
1460  return typeSystem;
1461  }
1462 
1464  extends SqlBasicVisitor<Void> {
1465  @Override
1466  public Void visit(SqlCall call) {
1467  if (call instanceof SqlSelect) {
1468  SqlSelect selectNode = (SqlSelect) call;
1469  String targetString = targetExpression.toString();
1470  for (SqlNode listedNode : selectNode.getSelectList()) {
1471  if (listedNode.toString().contains(targetString)) {
1472  throw Util.FoundOne.NULL;
1473  }
1474  }
1475  }
1476  return super.visit(call);
1477  }
1478 
1479  boolean containsExpression(SqlNode node, SqlNode targetExpression) {
1480  try {
1481  this.targetExpression = targetExpression;
1482  node.accept(this);
1483  return false;
1484  } catch (Util.FoundOne e) {
1485  return true;
1486  }
1487  }
1488 
1490  }
1491 
1493  extends SqlBasicVisitor<Void> {
1494  @Override
1495  public Void visit(SqlCall call) {
1496  if (call instanceof SqlBasicCall) {
1497  SqlBasicCall basicCall = (SqlBasicCall) call;
1498  if (basicCall.getKind() == SqlKind.OR) {
1499  String targetString = targetExpression.toString();
1500  for (SqlNode listedOperand : basicCall.operands) {
1501  if (listedOperand.toString().contains(targetString)) {
1502  throw Util.FoundOne.NULL;
1503  }
1504  }
1505  }
1506  }
1507  return super.visit(call);
1508  }
1509 
1510  boolean containsExpression(SqlNode node, SqlNode targetExpression) {
1511  try {
1512  this.targetExpression = targetExpression;
1513  node.accept(this);
1514  return false;
1515  } catch (Util.FoundOne e) {
1516  return true;
1517  }
1518  }
1519 
1521  }
1522 
1523  private static class JoinOperatorChecker extends SqlBasicVisitor<Void> {
1524  Set<SqlBasicCall> targetCalls = new HashSet<>();
1525 
1526  public boolean isEqualityJoinOperator(SqlBasicCall basicCall) {
1527  if (null != basicCall) {
1528  if (basicCall.operands.length == 2 && basicCall.getKind() == SqlKind.EQUALS
1529  && basicCall.operand(0) instanceof SqlIdentifier
1530  && basicCall.operand(1) instanceof SqlIdentifier) {
1531  return true;
1532  }
1533  }
1534  return false;
1535  }
1536 
1537  @Override
1538  public Void visit(SqlCall call) {
1539  if (call instanceof SqlBasicCall) {
1540  targetCalls.add((SqlBasicCall) call);
1541  }
1542  for (SqlNode node : call.getOperandList()) {
1543  if (null != node && !targetCalls.contains(node)) {
1544  node.accept(this);
1545  }
1546  }
1547  return super.visit(call);
1548  }
1549 
1550  boolean containsExpression(SqlNode node) {
1551  try {
1552  if (null != node) {
1553  node.accept(this);
1554  for (SqlBasicCall basicCall : targetCalls) {
1555  if (isEqualityJoinOperator(basicCall)) {
1556  throw Util.FoundOne.NULL;
1557  }
1558  }
1559  }
1560  return false;
1561  } catch (Util.FoundOne e) {
1562  return true;
1563  }
1564  }
1565  }
1566 
1567  // this visitor checks whether a parse tree contains at least one
1568  // specific SQL operator we have an interest in
1569  // (do not count the accurate # operators we found)
1570  private static class FindSqlOperator extends SqlBasicVisitor<Void> {
1571  @Override
1572  public Void visit(SqlCall call) {
1573  if (call instanceof SqlBasicCall) {
1574  SqlBasicCall basicCall = (SqlBasicCall) call;
1575  if (basicCall.getKind().equals(targetKind)) {
1576  throw Util.FoundOne.NULL;
1577  }
1578  }
1579  return super.visit(call);
1580  }
1581 
1582  boolean containsSqlOperator(SqlNode node, SqlKind operatorKind) {
1583  try {
1584  targetKind = operatorKind;
1585  node.accept(this);
1586  return false;
1587  } catch (Util.FoundOne e) {
1588  return true;
1589  }
1590  }
1591 
1592  private SqlKind targetKind;
1593  }
1594 }
void desugarExpression(SqlNode node, RelDataTypeFactory typeFactory)
boolean containsExpression(SqlNode node, SqlNode targetExpression)
SqlIdentifierCapturer captureIdentifiers(SqlNode node)
SqlNode expandCovariance(final SqlBasicCall proj_call, RelDataTypeFactory typeFactory)
SqlNode expandVariance(final SqlBasicCall proj_call, RelDataTypeFactory typeFactory)
JoinType
Definition: sqldefs.h:135
HeavyDBPlanner getPlanner(final boolean allowSubQueryExpansion, final boolean isWatchdogEnabled)
SqlSelect rewriteSimpleUpdateAsSelect(final SqlUpdate update)
SqlOrderBy desugar(SqlSelect select_node, SqlOrderBy order_by_node, RelDataTypeFactory typeFactory)
SqlNode expandSubstr(final SqlBasicCall proj_call, RelDataTypeFactory typeFactory)
RelRoot replaceIsTrue(final RelDataTypeFactory typeFactory, RelRoot root)
std::string join(T const &container, std::string const &delim)
tuple root
Definition: setup.in.py:14
SqlNode expandCase(SqlCase old_case_node, RelDataTypeFactory typeFactory)
SqlNode parseSql(String sql, final boolean legacy_syntax, Planner planner)
static final EnumSet< SqlKind > ARRAY_VALUE
constexpr double f
Definition: Utm.h:31
HashSet< ImmutableList< String > > resolveSelectIdentifiers(SqlIdentifierCapturer capturer)
RelRoot convertSqlToRelNode(final SqlNode sqlNode, final HeavyDBPlanner HeavyDBPlanner, final HeavyDBParserOptions parserOptions)
static final ThreadLocal< HeavyDBParser > CURRENT_PARSER
static final EnumSet< SqlKind > UPDATE
LogicalTableModify getDummyUpdate(SqlUpdate update)
String processSql(final SqlNode sqlNode, final HeavyDBParserOptions parserOptions)
static final EnumSet< SqlKind > IN
SqlIdentifierCapturer getAccessedObjects()
static Map< String, Boolean > SubqueryCorrMemo
SockTransportProperties sock_transport_properties
RelRoot rewriteUpdateAsSelect(SqlUpdate update, HeavyDBParserOptions parserOptions)
String processSql(String sql, final HeavyDBParserOptions parserOptions)
void updateMetaData(String schema, String table)
static final Context DB_CONNECTION_CONTEXT
static final EnumSet< SqlKind > DELETE
String buildRATreeAndPerformQueryOptimization(String query, final HeavyDBParserOptions parserOptions)
SqlIdentifierCapturer captureIdentifiers(String sql, boolean legacy_syntax)
SqlNode expandVariance(final SqlParserPos pos, final SqlNode operand, boolean biased, boolean sqrt, boolean flt, RelDataTypeFactory typeFactory)
void desugar(SqlSelect select_node, RelDataTypeFactory typeFactory)
boolean containsSqlOperator(SqlNode node, SqlKind operatorKind)
static final EnumSet< SqlKind > SCALAR
HeavyDBPlanner.CompletionResult getCompletionHints(String sql, int cursor, List< String > visible_tables)
RelRoot queryToRelNode(final String sql, final HeavyDBParserOptions parserOptions)
SqlNode expandCovariance(SqlParserPos pos, final SqlNode operand0, final SqlNode operand1, boolean pop, boolean flt, RelDataTypeFactory typeFactory)
string name
Definition: setup.in.py:72
constexpr double n
Definition: Utm.h:38
SqlNodeList expand(final SqlNodeList group_by_list, final java.util.Map< String, SqlNode > id_to_expr, RelDataTypeFactory typeFactory)
SqlNode expandCorrelation(final SqlBasicCall proj_call, RelDataTypeFactory typeFactory)
Pair< String, SqlIdentifierCapturer > process(String sql, final HeavyDBParserOptions parserOptions)
static final SqlArrayValueConstructorAllowingEmpty ARRAY_VALUE_CONSTRUCTOR
static final EnumSet< SqlKind > EXISTS
boolean isCorrelated(SqlNode expression)
HeavyDBParser(String dataDir, final Supplier< HeavyDBSqlOperatorTable > dbSqlOperatorTable, int dbPort, SockTransportProperties skT)
SqlNode expand(final SqlNode node, final java.util.Map< String, SqlNode > id_to_expr, RelDataTypeFactory typeFactory)
final Supplier< HeavyDBSqlOperatorTable > dbSqlOperatorTable