OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
HeavyDBParser.java
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.mapd.calcite.parser;
18 
19 import static org.apache.calcite.sql.parser.SqlParserPos.ZERO;
20 
21 import com.google.common.collect.ImmutableList;
25 import com.mapd.parser.extension.ddl.ExtendedSqlParser;
28 
29 import org.apache.calcite.avatica.util.Casing;
30 import org.apache.calcite.config.CalciteConnectionConfig;
31 import org.apache.calcite.config.CalciteConnectionConfigImpl;
32 import org.apache.calcite.config.CalciteConnectionProperty;
33 import org.apache.calcite.plan.Context;
34 import org.apache.calcite.plan.RelOptTable;
35 import org.apache.calcite.plan.RelOptUtil;
36 import org.apache.calcite.plan.hep.HepPlanner;
37 import org.apache.calcite.plan.hep.HepProgramBuilder;
40 import org.apache.calcite.rel.RelNode;
41 import org.apache.calcite.rel.RelRoot;
42 import org.apache.calcite.rel.RelShuttleImpl;
43 import org.apache.calcite.rel.RelWriter;
44 import org.apache.calcite.rel.core.TableModify;
45 import org.apache.calcite.rel.core.TableModify.Operation;
47 import org.apache.calcite.rel.externalize.RelWriterImpl;
48 import org.apache.calcite.rel.logical.LogicalProject;
49 import org.apache.calcite.rel.logical.LogicalTableModify;
50 import org.apache.calcite.rel.rules.CoreRules;
52 import org.apache.calcite.rel.type.RelDataType;
53 import org.apache.calcite.rel.type.RelDataTypeFactory;
54 import org.apache.calcite.rel.type.RelDataTypeSystem;
55 import org.apache.calcite.rex.*;
56 import org.apache.calcite.runtime.CalciteException;
57 import org.apache.calcite.schema.SchemaPlus;
58 import org.apache.calcite.schema.Statistic;
59 import org.apache.calcite.schema.Table;
60 import org.apache.calcite.sql.*;
61 import org.apache.calcite.sql.advise.SqlAdvisorValidator;
62 import org.apache.calcite.sql.dialect.CalciteSqlDialect;
63 import org.apache.calcite.sql.fun.SqlCase;
64 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
65 import org.apache.calcite.sql.parser.SqlParseException;
66 import org.apache.calcite.sql.parser.SqlParser;
67 import org.apache.calcite.sql.parser.SqlParserPos;
68 import org.apache.calcite.sql.type.OperandTypes;
69 import org.apache.calcite.sql.type.ReturnTypes;
70 import org.apache.calcite.sql.type.SqlTypeName;
71 import org.apache.calcite.sql.type.SqlTypeUtil;
72 import org.apache.calcite.sql.util.SqlBasicVisitor;
73 import org.apache.calcite.sql.util.SqlShuttle;
74 import org.apache.calcite.sql.util.SqlVisitor;
75 import org.apache.calcite.sql.validate.SqlConformanceEnum;
76 import org.apache.calcite.sql.validate.SqlValidator;
79 import org.apache.calcite.tools.*;
80 import org.apache.calcite.util.Pair;
81 import org.apache.calcite.util.Util;
82 import org.slf4j.Logger;
83 import org.slf4j.LoggerFactory;
84 
85 import java.io.IOException;
86 import java.io.PrintWriter;
87 import java.io.StringWriter;
88 import java.lang.reflect.Field;
89 import java.util.*;
90 import java.util.concurrent.ConcurrentHashMap;
91 import java.util.function.BiPredicate;
92 import java.util.function.Supplier;
93 import java.util.stream.Stream;
94 
95 import ai.heavy.thrift.server.TColumnType;
96 import ai.heavy.thrift.server.TDatumType;
97 import ai.heavy.thrift.server.TEncodingType;
98 import ai.heavy.thrift.server.TTableDetails;
99 
100 public final class HeavyDBParser {
101  public static final ThreadLocal<HeavyDBParser> CURRENT_PARSER = new ThreadLocal<>();
102  private static final EnumSet<SqlKind> SCALAR =
103  EnumSet.of(SqlKind.SCALAR_QUERY, SqlKind.SELECT);
104  private static final EnumSet<SqlKind> EXISTS = EnumSet.of(SqlKind.EXISTS);
105  private static final EnumSet<SqlKind> DELETE = EnumSet.of(SqlKind.DELETE);
106  private static final EnumSet<SqlKind> UPDATE = EnumSet.of(SqlKind.UPDATE);
107  private static final EnumSet<SqlKind> IN = EnumSet.of(SqlKind.IN);
108  private static final EnumSet<SqlKind> ARRAY_VALUE =
109  EnumSet.of(SqlKind.ARRAY_VALUE_CONSTRUCTOR);
110  private static final EnumSet<SqlKind> OTHER_FUNCTION =
111  EnumSet.of(SqlKind.OTHER_FUNCTION);
112 
113  final static Logger HEAVYDBLOGGER = LoggerFactory.getLogger(HeavyDBParser.class);
114 
115  private final Supplier<HeavyDBSqlOperatorTable> dbSqlOperatorTable;
116  private final String dataDir;
117 
118  private int callCount = 0;
119  private final int dbPort;
122 
123  private static Map<String, Boolean> SubqueryCorrMemo = new ConcurrentHashMap<>();
124 
125  public HeavyDBParser(String dataDir,
126  final Supplier<HeavyDBSqlOperatorTable> dbSqlOperatorTable,
127  int dbPort,
129  this.dataDir = dataDir;
130  this.dbSqlOperatorTable = dbSqlOperatorTable;
131  this.dbPort = dbPort;
132  this.sock_transport_properties = skT;
133  }
134 
135  public void clearMemo() {
136  SubqueryCorrMemo.clear();
137  }
138 
139  private static final Context DB_CONNECTION_CONTEXT = new Context() {
140  HeavyDBTypeSystem myTypeSystem = new HeavyDBTypeSystem();
141  CalciteConnectionConfig config = new CalciteConnectionConfigImpl(new Properties()) {
142  {
143  properties.put(CalciteConnectionProperty.CASE_SENSITIVE.camelName(),
144  String.valueOf(false));
145  properties.put(CalciteConnectionProperty.CONFORMANCE.camelName(),
146  String.valueOf(SqlConformanceEnum.LENIENT));
147  }
148 
149  @SuppressWarnings("unchecked")
150  public <T extends Object> T typeSystem(
151  java.lang.Class<T> typeSystemClass, T defaultTypeSystem) {
152  return (T) myTypeSystem;
153  };
154 
155  public boolean caseSensitive() {
156  return false;
157  };
158 
159  public org.apache.calcite.sql.validate.SqlConformance conformance() {
160  return SqlConformanceEnum.LENIENT;
161  };
162  };
163 
164  @Override
165  public <C> C unwrap(Class<C> aClass) {
166  if (aClass.isInstance(config)) {
167  return aClass.cast(config);
168  }
169  return null;
170  }
171  };
172 
174  return getPlanner(true, false, false);
175  }
176 
177  private boolean isCorrelated(SqlNode expression) {
178  String queryString = expression.toSqlString(CalciteSqlDialect.DEFAULT).getSql();
179  Boolean isCorrelatedSubquery = SubqueryCorrMemo.get(queryString);
180  if (null != isCorrelatedSubquery) {
181  return isCorrelatedSubquery;
182  }
183 
184  try {
188  parser.setUser(dbUser);
189  parser.processSql(expression, options);
190  } catch (Exception e) {
191  // if we are not able to parse, then assume correlated
192  SubqueryCorrMemo.put(queryString, true);
193  return true;
194  }
195  SubqueryCorrMemo.put(queryString, false);
196  return false;
197  }
198 
199  private boolean isHashJoinableType(TColumnType type) {
200  switch (type.getCol_type().type) {
201  case TINYINT:
202  case SMALLINT:
203  case INT:
204  case BIGINT: {
205  return true;
206  }
207  case STR: {
208  return type.col_type.encoding == TEncodingType.DICT;
209  }
210  default: {
211  return false;
212  }
213  }
214  }
215 
216  private boolean isColumnHashJoinable(
217  List<String> joinColumnIdentifier, MetaConnect mc) {
218  try {
219  TTableDetails tableDetails = mc.get_table_details(joinColumnIdentifier.get(0));
220  return null
221  != tableDetails.row_desc.stream()
222  .filter(c
223  -> c.col_name.toLowerCase(Locale.ROOT)
224  .equals(joinColumnIdentifier.get(1)
225  .toLowerCase(
226  Locale.ROOT))
227  && isHashJoinableType(c))
228  .findFirst()
229  .orElse(null);
230  } catch (Exception e) {
231  return false;
232  }
233  }
234 
235  private HeavyDBPlanner getPlanner(final boolean allowSubQueryExpansion,
236  final boolean isWatchdogEnabled,
237  final boolean isDistributedMode) {
238  HeavyDBUser user = new HeavyDBUser(dbUser.getUser(),
239  dbUser.getSession(),
240  dbUser.getDB(),
241  -1,
242  ImmutableList.of());
243  final MetaConnect mc =
245  BiPredicate<SqlNode, SqlNode> expandPredicate = new BiPredicate<SqlNode, SqlNode>() {
246  @Override
247  public boolean test(SqlNode root, SqlNode expression) {
248  if (!allowSubQueryExpansion) {
249  return false;
250  }
251 
252  if (expression.isA(EXISTS) || expression.isA(IN)) {
253  // try to expand subquery by EXISTS and IN clauses by default
254  // note that current Calcite decorrelator fails to flat
255  // NOT-IN clause in some cases, so we do not decorrelate it for now
256 
257  if (expression.isA(IN)) {
258  // If we enable watchdog, we suffer from large projection exception in many
259  // cases since decorrelation needs de-duplication step which adds project -
260  // aggregate logic. And the added project is the source of the exception when
261  // its underlying table is large. Thus, we enable IN-clause decorrelation
262  // under watchdog iff we explicitly have correlated join in IN-clause
263  if (expression instanceof SqlCall) {
264  SqlCall outerSelectCall = (SqlCall) expression;
265  if (outerSelectCall.getOperandList().size() == 2) {
266  // if IN clause is correlated, its second operand of corresponding
267  // expression is SELECT clause which indicates a correlated subquery.
268  // Here, an expression "f.val IN (SELECT ...)" has two operands.
269  // Since we have interest in its subquery, so try to check whether
270  // the second operand, i.e., call.getOperandList().get(1)
271  // is a type of SqlSelect and also is correlated.
272  if (outerSelectCall.getOperandList().get(1) instanceof SqlSelect) {
273  // the below checking logic is to allow IN-clause decorrelation
274  // if it has hash joinable IN expression without correlated join
275  // i.e., SELECT ... WHERE a.intVal IN (SELECT b.intVal FROM b) ...;
276  SqlSelect innerSelectCall =
277  (SqlSelect) outerSelectCall.getOperandList().get(1);
278  if (innerSelectCall.hasWhere()) {
279  // IN-clause may have correlated join within subquery's WHERE clause
280  // i.e., f.val IN (SELECT r.val FROM R r WHERE f.val2 = r.val2)
281  // then we have to deccorrelate the IN-clause
282  JoinOperatorChecker joinOperatorChecker = new JoinOperatorChecker();
283  if (joinOperatorChecker.containsExpression(
284  innerSelectCall.getWhere())) {
285  return true;
286  }
287  }
288  if (isDistributedMode) {
289  // we temporarily disable IN-clause decorrelation in dist mode
290  // todo (yoonmin) : relax this in dist mode when available
291  return false;
292  }
293  boolean hasHashJoinableExpression = false;
294  if (isWatchdogEnabled) {
295  // when watchdog is enabled, we try to selectively allow decorrelation
296  // iff IN-expression is between two columns that both are hash
297  // joinable
298  Map<String, String> tableAliasMap = new HashMap<>();
299  if (root instanceof SqlSelect) {
300  tableAliasFinder(((SqlSelect) root).getFrom(), tableAliasMap);
301  }
302  tableAliasFinder(innerSelectCall.getFrom(), tableAliasMap);
303  if (outerSelectCall.getOperandList().get(0) instanceof SqlIdentifier
304  && innerSelectCall.getSelectList().get(0)
305  instanceof SqlIdentifier) {
306  SqlIdentifier outerColIdentifier =
307  (SqlIdentifier) outerSelectCall.getOperandList().get(0);
308  SqlIdentifier innerColIdentifier =
309  (SqlIdentifier) innerSelectCall.getSelectList().get(0);
310  if (tableAliasMap.containsKey(outerColIdentifier.names.get(0))
311  && tableAliasMap.containsKey(
312  innerColIdentifier.names.get(0))) {
313  String outerTableName =
314  tableAliasMap.get(outerColIdentifier.names.get(0));
315  String innerTableName =
316  tableAliasMap.get(innerColIdentifier.names.get(0));
317  if (isColumnHashJoinable(ImmutableList.of(outerTableName,
318  outerColIdentifier.names.get(1)),
319  mc)
321  ImmutableList.of(innerTableName,
322  innerColIdentifier.names.get(1)),
323  mc)) {
324  hasHashJoinableExpression = true;
325  }
326  }
327  }
328  if (!hasHashJoinableExpression) {
329  return false;
330  }
331  }
332  }
333  }
334  }
335  if (root instanceof SqlSelect) {
336  SqlSelect selectCall = (SqlSelect) root;
337  if (new ExpressionListedInSelectClauseChecker().containsExpression(
338  selectCall, expression)) {
339  // occasionally, Calcite cannot properly decorrelate IN-clause listed in
340  // SELECT clause e.g., SELECT x, CASE WHEN x in (SELECT x FROM R) ... FROM
341  // ... in that case we disable input query's decorrelation
342  return false;
343  }
344  if (null != selectCall.getWhere()) {
345  if (new ExpressionListedAsChildOROperatorChecker().containsExpression(
346  selectCall.getWhere(), expression)) {
347  // Decorrelation logic of the current Calcite cannot cover IN-clause
348  // well if it is listed as a child operand of OR-op
349  return false;
350  }
351  }
352  if (null != selectCall.getHaving()) {
353  if (new ExpressionListedAsChildOROperatorChecker().containsExpression(
354  selectCall.getHaving(), expression)) {
355  // Decorrelation logic of the current Calcite cannot cover IN-clause
356  // well if it is listed as a child operand of OR-op
357  return false;
358  }
359  }
360  }
361  }
362 
363  // otherwise, let's decorrelate the expression
364  return true;
365  }
366 
367  // special handling of sub-queries
368  if (expression.isA(SCALAR) && isCorrelated(expression)) {
369  // only expand if it is correlated.
370  SqlSelect select = null;
371  if (expression instanceof SqlCall) {
372  SqlCall call = (SqlCall) expression;
373  if (call.getOperator().equals(SqlStdOperatorTable.SCALAR_QUERY)) {
374  expression = call.getOperandList().get(0);
375  }
376  }
377 
378  if (expression instanceof SqlSelect) {
379  select = (SqlSelect) expression;
380  }
381 
382  if (null != select) {
383  if (null != select.getFetch() || null != select.getOffset()
384  || (null != select.getOrderList()
385  && select.getOrderList().size() != 0)) {
386  throw new CalciteException(
387  "Correlated sub-queries with ordering not supported.", null);
388  }
389  }
390  return true;
391  }
392 
393  // per default we do not want to expand
394  return false;
395  }
396  };
397 
398  final HeavyDBSchema defaultSchema = new HeavyDBSchema(
400  final SchemaPlus rootSchema = Frameworks.createRootSchema(true);
401  final SchemaPlus defaultSchemaPlus = rootSchema.add(dbUser.getDB(), defaultSchema);
402  for (String db : mc.getDatabases()) {
403  if (!db.equalsIgnoreCase(dbUser.getDB())) {
404  rootSchema.add(db,
405  new HeavyDBSchema(
407  }
408  }
409 
410  final FrameworkConfig config =
411  Frameworks.newConfigBuilder()
412  .defaultSchema(defaultSchemaPlus)
413  .operatorTable(dbSqlOperatorTable.get())
414  .parserConfig(SqlParser.configBuilder()
415  .setConformance(SqlConformanceEnum.LENIENT)
416  .setUnquotedCasing(Casing.UNCHANGED)
417  .setCaseSensitive(false)
418  // allow identifiers of up to 512 chars
419  .setIdentifierMaxLength(512)
420  .setParserFactory(ExtendedSqlParser.FACTORY)
421  .build())
422  .sqlToRelConverterConfig(
423  SqlToRelConverter
424  .configBuilder()
425  // enable sub-query expansion (de-correlation)
426  .withExpandPredicate(expandPredicate)
427  // allow as many as possible IN operator values
428  .withInSubQueryThreshold(Integer.MAX_VALUE)
429  .withHintStrategyTable(
431  .build())
432 
433  .typeSystem(createTypeSystem())
434  .context(DB_CONNECTION_CONTEXT)
435  .build();
436  HeavyDBPlanner planner = new HeavyDBPlanner(config);
437  planner.setRestrictions(dbUser.getRestrictions());
438  return planner;
439  }
440 
441  public void setUser(HeavyDBUser dbUser) {
442  this.dbUser = dbUser;
443  }
444 
445  public Pair<String, SqlIdentifierCapturer> process(
446  String sql, final HeavyDBParserOptions parserOptions)
447  throws SqlParseException, ValidationException, RelConversionException {
448  final HeavyDBPlanner planner = getPlanner(
449  true, parserOptions.isWatchdogEnabled(), parserOptions.isDistributedMode());
450  final SqlNode sqlNode = parseSql(sql, parserOptions.isLegacySyntax(), planner);
451  String res = processSql(sqlNode, parserOptions);
452  SqlIdentifierCapturer capture = captureIdentifiers(sqlNode);
453  return new Pair<String, SqlIdentifierCapturer>(res, capture);
454  }
455 
457  String query, final HeavyDBParserOptions parserOptions) throws IOException {
458  HeavyDBSchema schema = new HeavyDBSchema(
460  HeavyDBPlanner planner = getPlanner(
461  true, parserOptions.isWatchdogEnabled(), parserOptions.isDistributedMode());
462 
463  planner.setFilterPushDownInfo(parserOptions.getFilterPushDownInfo());
464  RelRoot optRel = planner.buildRATreeAndPerformQueryOptimization(query, schema);
465  optRel = replaceIsTrue(planner.getTypeFactory(), optRel);
466  return HeavyDBSerializer.toString(optRel.project());
467  }
468 
469  public String processSql(String sql, final HeavyDBParserOptions parserOptions)
470  throws SqlParseException, ValidationException, RelConversionException {
471  callCount++;
472 
473  final HeavyDBPlanner planner = getPlanner(
474  true, parserOptions.isWatchdogEnabled(), parserOptions.isDistributedMode());
475  final SqlNode sqlNode = parseSql(sql, parserOptions.isLegacySyntax(), planner);
476 
477  return processSql(sqlNode, parserOptions);
478  }
479 
480  public String processSql(
481  final SqlNode sqlNode, final HeavyDBParserOptions parserOptions)
482  throws SqlParseException, ValidationException, RelConversionException {
483  callCount++;
484 
485  if (sqlNode instanceof JsonSerializableDdl) {
486  return ((JsonSerializableDdl) sqlNode).toJsonString();
487  }
488 
489  if (sqlNode instanceof SqlDdl) {
490  return sqlNode.toString();
491  }
492 
493  final HeavyDBPlanner planner = getPlanner(
494  true, parserOptions.isWatchdogEnabled(), parserOptions.isDistributedMode());
495  planner.advanceToValidate();
496 
497  final RelRoot sqlRel = convertSqlToRelNode(sqlNode, planner, parserOptions);
498  RelNode project = sqlRel.project();
499  if (project == null) {
500  throw new RuntimeException("Cannot convert the sql to AST");
501  }
502  if (parserOptions.isExplainDetail()) {
503  StringWriter sw = new StringWriter();
504  RelWriter planWriter = new HeavyDBRelWriterImpl(
505  new PrintWriter(sw), SqlExplainLevel.EXPPLAN_ATTRIBUTES, false);
506  project.explain(planWriter);
507  return sw.toString();
508  } else if (parserOptions.isExplain()) {
509  return RelOptUtil.toString(sqlRel.project());
510  }
511  return HeavyDBSerializer.toString(project);
512  }
513 
514  public HeavyDBPlanner.CompletionResult getCompletionHints(
515  String sql, int cursor, List<String> visible_tables) {
516  return getPlanner().getCompletionHints(sql, cursor, visible_tables);
517  }
518 
519  public HashSet<ImmutableList<String>> resolveSelectIdentifiers(
520  SqlIdentifierCapturer capturer) {
521  HashSet<ImmutableList<String>> resolved = new HashSet<ImmutableList<String>>();
522 
523  for (ImmutableList<String> names : capturer.selects) {
524  HeavyDBSchema schema = new HeavyDBSchema(
525  dataDir, this, dbPort, dbUser, sock_transport_properties, names.get(1));
526  HeavyDBTable table = (HeavyDBTable) schema.getTable(names.get(0));
527  if (null == table) {
528  throw new RuntimeException("table/view not found: " + names.get(0));
529  }
530 
531  if (table instanceof HeavyDBView) {
532  HeavyDBView view = (HeavyDBView) table;
533  resolved.addAll(resolveSelectIdentifiers(view.getAccessedObjects()));
534  } else {
535  resolved.add(names);
536  }
537  }
538 
539  return resolved;
540  }
541 
542  private String getTableName(SqlNode node) {
543  if (node.isA(EnumSet.of(SqlKind.AS))) {
544  node = ((SqlCall) node).getOperandList().get(1);
545  }
546  if (node instanceof SqlIdentifier) {
547  SqlIdentifier id = (SqlIdentifier) node;
548  return id.names.get(id.names.size() - 1);
549  }
550  return null;
551  }
552 
553  private SqlSelect rewriteSimpleUpdateAsSelect(final SqlUpdate update) {
554  SqlNode where = update.getCondition();
555 
556  if (update.getSourceExpressionList().size() != 1) {
557  return null;
558  }
559 
560  if (!(update.getSourceExpressionList().get(0) instanceof SqlSelect)) {
561  return null;
562  }
563 
564  final SqlSelect inner = (SqlSelect) update.getSourceExpressionList().get(0);
565 
566  if (null != inner.getGroup() || null != inner.getFetch() || null != inner.getOffset()
567  || (null != inner.getOrderList() && inner.getOrderList().size() != 0)
568  || (null != inner.getGroup() && inner.getGroup().size() != 0)
569  || null == getTableName(inner.getFrom())) {
570  return null;
571  }
572 
573  if (!isCorrelated(inner)) {
574  return null;
575  }
576 
577  final String updateTableName = getTableName(update.getTargetTable());
578 
579  if (null != where) {
580  where = where.accept(new SqlShuttle() {
581  @Override
582  public SqlNode visit(SqlIdentifier id) {
583  if (id.isSimple()) {
584  id = new SqlIdentifier(Arrays.asList(updateTableName, id.getSimple()),
585  id.getParserPosition());
586  }
587 
588  return id;
589  }
590  });
591  }
592 
593  SqlJoin join = new SqlJoin(ZERO,
594  update.getTargetTable(),
595  SqlLiteral.createBoolean(false, ZERO),
596  SqlLiteral.createSymbol(JoinType.LEFT, ZERO),
597  inner.getFrom(),
598  SqlLiteral.createSymbol(JoinConditionType.ON, ZERO),
599  inner.getWhere());
600 
601  SqlNode select0 = inner.getSelectList().get(0);
602 
603  boolean wrapInSingleValue = true;
604  if (select0 instanceof SqlCall) {
605  SqlCall selectExprCall = (SqlCall) select0;
606  if (Util.isSingleValue(selectExprCall)) {
607  wrapInSingleValue = false;
608  }
609  }
610 
611  if (wrapInSingleValue) {
612  if (select0.isA(EnumSet.of(SqlKind.AS))) {
613  select0 = ((SqlCall) select0).getOperandList().get(0);
614  }
615  select0 = new SqlBasicCall(
616  SqlStdOperatorTable.SINGLE_VALUE, new SqlNode[] {select0}, ZERO);
617  }
618 
619  SqlNodeList selectList = new SqlNodeList(ZERO);
620  selectList.add(select0);
621  selectList.add(new SqlBasicCall(SqlStdOperatorTable.AS,
622  new SqlNode[] {new SqlBasicCall(
623  new SqlUnresolvedFunction(
624  new SqlIdentifier("OFFSET_IN_FRAGMENT", ZERO),
625  null,
626  null,
627  null,
628  null,
629  SqlFunctionCategory.USER_DEFINED_FUNCTION),
630  new SqlNode[0],
631  SqlParserPos.ZERO),
632  new SqlIdentifier("EXPR$DELETE_OFFSET_IN_FRAGMENT", ZERO)},
633  ZERO));
634 
635  SqlNodeList groupBy = new SqlNodeList(ZERO);
636  groupBy.add(new SqlIdentifier("EXPR$DELETE_OFFSET_IN_FRAGMENT", ZERO));
637 
638  SqlSelect select = new SqlSelect(ZERO,
639  null,
640  selectList,
641  join,
642  where,
643  groupBy,
644  null,
645  null,
646  null,
647  null,
648  null,
649  null);
650  return select;
651  }
652 
653  private LogicalTableModify getDummyUpdate(SqlUpdate update)
654  throws SqlParseException, ValidationException, RelConversionException {
655  SqlIdentifier targetTable = (SqlIdentifier) update.getTargetTable();
656  String targetTableName = targetTable.toString();
657  HeavyDBPlanner planner = getPlanner();
658  String dummySql = "DELETE FROM " + targetTableName;
659  SqlNode dummyNode = planner.parse(dummySql);
660  dummyNode = planner.validate(dummyNode);
661  RelRoot dummyRoot = planner.rel(dummyNode);
662  LogicalTableModify dummyModify = (LogicalTableModify) dummyRoot.rel;
663  return dummyModify;
664  }
665 
666  private RelRoot rewriteUpdateAsSelect(
667  SqlUpdate update, HeavyDBParserOptions parserOptions)
668  throws SqlParseException, ValidationException, RelConversionException {
669  int correlatedQueriesCount[] = new int[1];
670  SqlBasicVisitor<Void> correlatedQueriesCounter = new SqlBasicVisitor<Void>() {
671  @Override
672  public Void visit(SqlCall call) {
673  if (call.isA(SCALAR)
674  && ((call instanceof SqlBasicCall && call.operandCount() == 1
675  && !call.operand(0).isA(SCALAR))
676  || !(call instanceof SqlBasicCall))) {
677  if (isCorrelated(call)) {
678  correlatedQueriesCount[0]++;
679  }
680  }
681  return super.visit(call);
682  }
683  };
684 
685  update.accept(correlatedQueriesCounter);
686  if (correlatedQueriesCount[0] > 1) {
687  throw new CalciteException(
688  "table modifications with multiple correlated sub-queries not supported.",
689  null);
690  }
691 
692  boolean allowSubqueryDecorrelation = true;
693  SqlNode updateCondition = update.getCondition();
694  if (null != updateCondition) {
695  boolean hasInClause =
696  new FindSqlOperator().containsSqlOperator(updateCondition, SqlKind.IN);
697  if (hasInClause) {
698  SqlNode updateTargetTable = update.getTargetTable();
699  if (null != updateTargetTable && updateTargetTable instanceof SqlIdentifier) {
700  SqlIdentifier targetTable = (SqlIdentifier) updateTargetTable;
701  if (targetTable.names.size() == 2) {
702  final MetaConnect mc = new MetaConnect(dbPort,
703  dataDir,
704  dbUser,
705  this,
706  sock_transport_properties,
707  targetTable.names.get(0));
708  TTableDetails updateTargetTableDetails =
709  mc.get_table_details(targetTable.names.get(1));
710  if (null != updateTargetTableDetails
711  && updateTargetTableDetails.is_temporary) {
712  allowSubqueryDecorrelation = false;
713  }
714  }
715  }
716  }
717  }
718 
719  SqlNodeList sourceExpression = new SqlNodeList(SqlParserPos.ZERO);
720  LogicalTableModify dummyModify = getDummyUpdate(update);
721  RelOptTable targetTable = dummyModify.getTable();
722  RelDataType targetTableType = targetTable.getRowType();
723 
724  SqlSelect select = rewriteSimpleUpdateAsSelect(update);
725  boolean applyRexCast = null == select;
726 
727  if (null == select) {
728  for (int i = 0; i < update.getSourceExpressionList().size(); i++) {
729  SqlNode targetColumn = update.getTargetColumnList().get(i);
730  SqlNode expression = update.getSourceExpressionList().get(i);
731 
732  if (!(targetColumn instanceof SqlIdentifier)) {
733  throw new RuntimeException("Unknown identifier type!");
734  }
735  SqlIdentifier id = (SqlIdentifier) targetColumn;
736  RelDataType fieldType =
737  targetTableType.getField(id.names.get(id.names.size() - 1), false, false)
738  .getType();
739 
740  if (expression.isA(ARRAY_VALUE) && null != fieldType.getComponentType()) {
741  // apply a cast to all array value elements
742 
743  SqlDataTypeSpec elementType = new SqlDataTypeSpec(
744  new SqlBasicTypeNameSpec(fieldType.getComponentType().getSqlTypeName(),
745  fieldType.getPrecision(),
746  fieldType.getScale(),
747  null == fieldType.getCharset() ? null
748  : fieldType.getCharset().name(),
749  SqlParserPos.ZERO),
750  SqlParserPos.ZERO);
751  SqlCall array_expression = (SqlCall) expression;
752  ArrayList<SqlNode> values = new ArrayList<>();
753 
754  for (SqlNode value : array_expression.getOperandList()) {
755  if (value.isA(EnumSet.of(SqlKind.LITERAL))) {
756  SqlNode casted_value = new SqlBasicCall(SqlStdOperatorTable.CAST,
757  new SqlNode[] {value, elementType},
758  value.getParserPosition());
759  values.add(casted_value);
760  } else {
761  values.add(value);
762  }
763  }
764 
765  expression = new SqlBasicCall(HeavyDBSqlOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
766  values.toArray(new SqlNode[0]),
767  expression.getParserPosition());
768  }
769  sourceExpression.add(expression);
770  }
771 
772  sourceExpression.add(new SqlBasicCall(SqlStdOperatorTable.AS,
773  new SqlNode[] {
774  new SqlBasicCall(new SqlUnresolvedFunction(
775  new SqlIdentifier("OFFSET_IN_FRAGMENT",
776  SqlParserPos.ZERO),
777  null,
778  null,
779  null,
780  null,
781  SqlFunctionCategory.USER_DEFINED_FUNCTION),
782  new SqlNode[0],
783  SqlParserPos.ZERO),
784  new SqlIdentifier("EXPR$DELETE_OFFSET_IN_FRAGMENT", ZERO)},
785  ZERO));
786 
787  select = new SqlSelect(SqlParserPos.ZERO,
788  null,
789  sourceExpression,
790  update.getTargetTable(),
791  update.getCondition(),
792  null,
793  null,
794  null,
795  null,
796  null,
797  null,
798  null);
799  }
800 
801  HeavyDBPlanner planner = getPlanner(allowSubqueryDecorrelation,
802  parserOptions.isWatchdogEnabled(),
803  parserOptions.isDistributedMode());
804  SqlNode node = null;
805  try {
806  node = planner.parse(select.toSqlString(CalciteSqlDialect.DEFAULT).getSql());
807  node = planner.validate(node);
808  } catch (Exception e) {
809  HEAVYDBLOGGER.error("Error processing UPDATE rewrite, rewritten stmt was: "
810  + select.toSqlString(CalciteSqlDialect.DEFAULT).getSql());
811  throw e;
812  }
813 
814  RelRoot root = planner.rel(node);
815  LogicalProject project = (LogicalProject) root.project();
816 
817  ArrayList<String> fields = new ArrayList<String>();
818  ArrayList<RexNode> nodes = new ArrayList<RexNode>();
819  final RexBuilder builder = new RexBuilder(planner.getTypeFactory());
820 
821  for (SqlNode n : update.getTargetColumnList()) {
822  if (n instanceof SqlIdentifier) {
823  SqlIdentifier id = (SqlIdentifier) n;
824  fields.add(id.names.get(id.names.size() - 1));
825  } else {
826  throw new RuntimeException("Unknown identifier type!");
827  }
828  }
829 
830  // The magical number here when processing the projection
831  // is skipping the OFFSET_IN_FRAGMENT() expression used by
832  // update and delete
833  int idx = 0;
834  for (RexNode exp : project.getProjects()) {
835  if (applyRexCast && idx + 1 < project.getProjects().size()) {
836  RelDataType expectedFieldType =
837  targetTableType.getField(fields.get(idx), false, false).getType();
838  boolean is_array_kind = exp.isA(ARRAY_VALUE);
839  boolean is_func_kind = exp.isA(OTHER_FUNCTION);
840  // runtime functions have expression kind == OTHER_FUNCTION, even if they
841  // return an array
842  if (!exp.getType().equals(expectedFieldType)
843  && !(is_array_kind || is_func_kind)) {
844  exp = builder.makeCast(expectedFieldType, exp);
845  }
846  }
847 
848  nodes.add(exp);
849  idx++;
850  }
851 
852  ArrayList<RexNode> inputs = new ArrayList<RexNode>();
853  int n = 0;
854  for (int i = 0; i < fields.size(); i++) {
855  inputs.add(
856  new RexInputRef(n, project.getRowType().getFieldList().get(n).getType()));
857  n++;
858  }
859 
860  fields.add("EXPR$DELETE_OFFSET_IN_FRAGMENT");
861  inputs.add(new RexInputRef(n, project.getRowType().getFieldList().get(n).getType()));
862 
863  project = project.copy(
864  project.getTraitSet(), project.getInput(), nodes, project.getRowType());
865 
866  LogicalTableModify modify = LogicalTableModify.create(targetTable,
867  dummyModify.getCatalogReader(),
868  project,
869  Operation.UPDATE,
870  fields,
871  inputs,
872  true);
873  return RelRoot.of(modify, SqlKind.UPDATE);
874  }
875 
876  RelRoot queryToRelNode(final String sql, final HeavyDBParserOptions parserOptions)
877  throws SqlParseException, ValidationException, RelConversionException {
878  final HeavyDBPlanner planner = getPlanner(
879  true, parserOptions.isWatchdogEnabled(), parserOptions.isDistributedMode());
880  final SqlNode sqlNode = parseSql(sql, parserOptions.isLegacySyntax(), planner);
881  return convertSqlToRelNode(sqlNode, planner, parserOptions);
882  }
883 
884  RelRoot convertSqlToRelNode(final SqlNode sqlNode,
886  final HeavyDBParserOptions parserOptions)
887  throws SqlParseException, ValidationException, RelConversionException {
888  SqlNode node = sqlNode;
889  HeavyDBPlanner planner = HeavyDBPlanner;
890  boolean allowCorrelatedSubQueryExpansion = true;
891  boolean patchUpdateToDelete = false;
892  if (node.isA(DELETE)) {
893  SqlDelete sqlDelete = (SqlDelete) node;
894  node = new SqlUpdate(node.getParserPosition(),
895  sqlDelete.getTargetTable(),
896  SqlNodeList.EMPTY,
897  SqlNodeList.EMPTY,
898  sqlDelete.getCondition(),
899  sqlDelete.getSourceSelect(),
900  sqlDelete.getAlias());
901 
902  patchUpdateToDelete = true;
903  }
904  if (node.isA(UPDATE)) {
905  SqlUpdate update = (SqlUpdate) node;
906  update = (SqlUpdate) planner.validate(update);
907  RelRoot root = rewriteUpdateAsSelect(update, parserOptions);
908 
909  if (patchUpdateToDelete) {
910  LogicalTableModify modify = (LogicalTableModify) root.rel;
911 
912  try {
913  Field f = TableModify.class.getDeclaredField("operation");
914  f.setAccessible(true);
915  f.set(modify, Operation.DELETE);
916  } catch (Throwable e) {
917  throw new RuntimeException(e);
918  }
919 
920  root = RelRoot.of(modify, SqlKind.DELETE);
921  }
922 
923  return root;
924  }
925  if (parserOptions.isLegacySyntax()) {
926  // close original planner
927  planner.close();
928  // create a new one
929  planner = getPlanner(allowCorrelatedSubQueryExpansion,
930  parserOptions.isWatchdogEnabled(),
931  parserOptions.isDistributedMode());
932  node = parseSql(
933  node.toSqlString(CalciteSqlDialect.DEFAULT).toString(), false, planner);
934  }
935 
936  SqlNode validateR = planner.validate(node);
937  planner.setFilterPushDownInfo(parserOptions.getFilterPushDownInfo());
938  // check to see if a view is involved in the query
939  boolean foundView = false;
940  SqlIdentifierCapturer capturer = captureIdentifiers(sqlNode);
941  for (ImmutableList<String> names : capturer.selects) {
942  HeavyDBSchema schema = new HeavyDBSchema(
943  dataDir, this, dbPort, dbUser, sock_transport_properties, names.get(1));
944  HeavyDBTable table = (HeavyDBTable) schema.getTable(names.get(0));
945  if (null == table) {
946  throw new RuntimeException("table/view not found: " + names.get(0));
947  }
948  if (table instanceof HeavyDBView) {
949  foundView = true;
950  }
951  }
952  RelRoot relRootNode = planner.getRelRoot(validateR);
953  relRootNode = replaceIsTrue(planner.getTypeFactory(), relRootNode);
954  RelNode rootNode = planner.optimizeRATree(
955  relRootNode.project(), parserOptions.isViewOptimizeEnabled(), foundView);
956  planner.close();
957  return new RelRoot(rootNode,
958  relRootNode.validatedRowType,
959  relRootNode.kind,
960  relRootNode.fields,
961  relRootNode.collation,
962  Collections.emptyList());
963  }
964 
965  private RelRoot replaceIsTrue(final RelDataTypeFactory typeFactory, RelRoot root) {
966  final RexShuttle callShuttle = new RexShuttle() {
967  RexBuilder builder = new RexBuilder(typeFactory);
968 
969  public RexNode visitCall(RexCall call) {
970  call = (RexCall) super.visitCall(call);
971  if (call.getKind() == SqlKind.IS_TRUE) {
972  return builder.makeCall(SqlStdOperatorTable.AND,
973  builder.makeCall(
974  SqlStdOperatorTable.IS_NOT_NULL, call.getOperands().get(0)),
975  call.getOperands().get(0));
976  } else if (call.getKind() == SqlKind.IS_NOT_TRUE) {
977  return builder.makeCall(SqlStdOperatorTable.OR,
978  builder.makeCall(
979  SqlStdOperatorTable.IS_NULL, call.getOperands().get(0)),
980  builder.makeCall(SqlStdOperatorTable.NOT, call.getOperands().get(0)));
981  } else if (call.getKind() == SqlKind.IS_FALSE) {
982  return builder.makeCall(SqlStdOperatorTable.AND,
983  builder.makeCall(
984  SqlStdOperatorTable.IS_NOT_NULL, call.getOperands().get(0)),
985  builder.makeCall(SqlStdOperatorTable.NOT, call.getOperands().get(0)));
986  } else if (call.getKind() == SqlKind.IS_NOT_FALSE) {
987  return builder.makeCall(SqlStdOperatorTable.OR,
988  builder.makeCall(
989  SqlStdOperatorTable.IS_NULL, call.getOperands().get(0)),
990  call.getOperands().get(0));
991  }
992 
993  return call;
994  }
995  };
996 
997  RelNode node = root.rel.accept(new RelShuttleImpl() {
998  @Override
999  protected RelNode visitChild(RelNode parent, int i, RelNode child) {
1000  RelNode node = super.visitChild(parent, i, child);
1001  return node.accept(callShuttle);
1002  }
1003  });
1004 
1005  return new RelRoot(node,
1006  root.validatedRowType,
1007  root.kind,
1008  root.fields,
1009  root.collation,
1010  Collections.emptyList());
1011  }
1012 
1013  private SqlNode parseSql(String sql, final boolean legacy_syntax, Planner planner)
1014  throws SqlParseException {
1015  SqlNode parseR = null;
1016  try {
1017  parseR = planner.parse(sql);
1018  HEAVYDBLOGGER.debug(" node is \n" + parseR.toString());
1019  } catch (SqlParseException ex) {
1020  HEAVYDBLOGGER.error("failed to parse SQL '" + sql + "' \n" + ex.toString());
1021  throw ex;
1022  }
1023 
1024  if (!legacy_syntax) {
1025  return parseR;
1026  }
1027 
1028  RelDataTypeFactory typeFactory = planner.getTypeFactory();
1029  SqlSelect select_node = null;
1030  if (parseR instanceof SqlSelect) {
1031  select_node = (SqlSelect) parseR;
1032  desugar(select_node, typeFactory);
1033  } else if (parseR instanceof SqlOrderBy) {
1034  SqlOrderBy order_by_node = (SqlOrderBy) parseR;
1035  if (order_by_node.query instanceof SqlSelect) {
1036  select_node = (SqlSelect) order_by_node.query;
1037  SqlOrderBy new_order_by_node = desugar(select_node, order_by_node, typeFactory);
1038  if (new_order_by_node != null) {
1039  return new_order_by_node;
1040  }
1041  } else if (order_by_node.query instanceof SqlWith) {
1042  SqlWith old_with_node = (SqlWith) order_by_node.query;
1043  if (old_with_node.body instanceof SqlSelect) {
1044  select_node = (SqlSelect) old_with_node.body;
1045  desugar(select_node, typeFactory);
1046  }
1047  }
1048  } else if (parseR instanceof SqlWith) {
1049  SqlWith old_with_node = (SqlWith) parseR;
1050  if (old_with_node.body instanceof SqlSelect) {
1051  select_node = (SqlSelect) old_with_node.body;
1052  desugar(select_node, typeFactory);
1053  }
1054  }
1055  return parseR;
1056  }
1057 
1058  private void desugar(SqlSelect select_node, RelDataTypeFactory typeFactory) {
1059  desugar(select_node, null, typeFactory);
1060  }
1061 
1062  private SqlNode expandCase(SqlCase old_case_node, RelDataTypeFactory typeFactory) {
1063  SqlNodeList newWhenList =
1064  new SqlNodeList(old_case_node.getWhenOperands().getParserPosition());
1065  SqlNodeList newThenList =
1066  new SqlNodeList(old_case_node.getThenOperands().getParserPosition());
1067  java.util.Map<String, SqlNode> id_to_expr = new java.util.HashMap<String, SqlNode>();
1068  for (SqlNode node : old_case_node.getWhenOperands()) {
1069  SqlNode newCall = expand(node, id_to_expr, typeFactory);
1070  if (null != newCall) {
1071  newWhenList.add(newCall);
1072  } else {
1073  newWhenList.add(node);
1074  }
1075  }
1076  for (SqlNode node : old_case_node.getThenOperands()) {
1077  SqlNode newCall = expand(node, id_to_expr, typeFactory);
1078  if (null != newCall) {
1079  newThenList.add(newCall);
1080  } else {
1081  newThenList.add(node);
1082  }
1083  }
1084  SqlNode new_else_operand = old_case_node.getElseOperand();
1085  if (null != new_else_operand) {
1086  SqlNode candidate_else_operand =
1087  expand(old_case_node.getElseOperand(), id_to_expr, typeFactory);
1088  if (null != candidate_else_operand) {
1089  new_else_operand = candidate_else_operand;
1090  }
1091  }
1092  SqlNode new_value_operand = old_case_node.getValueOperand();
1093  if (null != new_value_operand) {
1094  SqlNode candidate_value_operand =
1095  expand(old_case_node.getValueOperand(), id_to_expr, typeFactory);
1096  if (null != candidate_value_operand) {
1097  new_value_operand = candidate_value_operand;
1098  }
1099  }
1100  SqlNode newCaseNode = SqlCase.createSwitched(old_case_node.getParserPosition(),
1101  new_value_operand,
1102  newWhenList,
1103  newThenList,
1104  new_else_operand);
1105  return newCaseNode;
1106  }
1107 
1108  private SqlOrderBy desugar(SqlSelect select_node,
1109  SqlOrderBy order_by_node,
1110  RelDataTypeFactory typeFactory) {
1111  HEAVYDBLOGGER.debug("desugar: before: " + select_node.toString());
1112  desugarExpression(select_node.getFrom(), typeFactory);
1113  desugarExpression(select_node.getWhere(), typeFactory);
1114  SqlNodeList select_list = select_node.getSelectList();
1115  SqlNodeList new_select_list = new SqlNodeList(select_list.getParserPosition());
1116  java.util.Map<String, SqlNode> id_to_expr = new java.util.HashMap<String, SqlNode>();
1117  for (SqlNode proj : select_list) {
1118  if (!(proj instanceof SqlBasicCall)) {
1119  if (proj instanceof SqlCase) {
1120  new_select_list.add(expandCase((SqlCase) proj, typeFactory));
1121  } else {
1122  new_select_list.add(proj);
1123  }
1124  } else {
1125  assert proj instanceof SqlBasicCall;
1126  SqlBasicCall proj_call = (SqlBasicCall) proj;
1127  if (proj_call.operands.length > 0) {
1128  for (int i = 0; i < proj_call.operands.length; i++) {
1129  if (proj_call.operand(i) instanceof SqlCase) {
1130  SqlNode new_op = expandCase(proj_call.operand(i), typeFactory);
1131  proj_call.setOperand(i, new_op);
1132  }
1133  }
1134  }
1135  new_select_list.add(expand(proj_call, id_to_expr, typeFactory));
1136  }
1137  }
1138  select_node.setSelectList(new_select_list);
1139  SqlNodeList group_by_list = select_node.getGroup();
1140  if (group_by_list != null) {
1141  select_node.setGroupBy(expand(group_by_list, id_to_expr, typeFactory));
1142  }
1143  SqlNode having = select_node.getHaving();
1144  if (having != null) {
1145  expand(having, id_to_expr, typeFactory);
1146  }
1147  SqlOrderBy new_order_by_node = null;
1148  if (order_by_node != null && order_by_node.orderList != null
1149  && order_by_node.orderList.size() > 0) {
1150  SqlNodeList new_order_by_list =
1151  expand(order_by_node.orderList, id_to_expr, typeFactory);
1152  new_order_by_node = new SqlOrderBy(order_by_node.getParserPosition(),
1153  select_node,
1154  new_order_by_list,
1155  order_by_node.offset,
1156  order_by_node.fetch);
1157  }
1158 
1159  HEAVYDBLOGGER.debug("desugar: after: " + select_node.toString());
1160  return new_order_by_node;
1161  }
1162 
1163  private void desugarExpression(SqlNode node, RelDataTypeFactory typeFactory) {
1164  if (node instanceof SqlSelect) {
1165  desugar((SqlSelect) node, typeFactory);
1166  return;
1167  }
1168  if (!(node instanceof SqlBasicCall)) {
1169  return;
1170  }
1171  SqlBasicCall basic_call = (SqlBasicCall) node;
1172  for (SqlNode operator : basic_call.getOperands()) {
1173  if (operator instanceof SqlOrderBy) {
1174  desugarExpression(((SqlOrderBy) operator).query, typeFactory);
1175  } else {
1176  desugarExpression(operator, typeFactory);
1177  }
1178  }
1179  }
1180 
1181  private SqlNode expand(final SqlNode node,
1182  final java.util.Map<String, SqlNode> id_to_expr,
1183  RelDataTypeFactory typeFactory) {
1184  HEAVYDBLOGGER.debug("expand: " + node.toString());
1185  if (node instanceof SqlBasicCall) {
1186  SqlBasicCall node_call = (SqlBasicCall) node;
1187  SqlNode[] operands = node_call.getOperands();
1188  for (int i = 0; i < operands.length; ++i) {
1189  node_call.setOperand(i, expand(operands[i], id_to_expr, typeFactory));
1190  }
1191  SqlNode expanded_string_function = expandStringFunctions(node_call, typeFactory);
1192  if (expanded_string_function != null) {
1193  return expanded_string_function;
1194  }
1195  SqlNode expanded_variance = expandVariance(node_call, typeFactory);
1196  if (expanded_variance != null) {
1197  return expanded_variance;
1198  }
1199  SqlNode expanded_covariance = expandCovariance(node_call, typeFactory);
1200  if (expanded_covariance != null) {
1201  return expanded_covariance;
1202  }
1203  SqlNode expanded_correlation = expandCorrelation(node_call, typeFactory);
1204  if (expanded_correlation != null) {
1205  return expanded_correlation;
1206  }
1207  }
1208  if (node instanceof SqlSelect) {
1209  SqlSelect select_node = (SqlSelect) node;
1210  desugar(select_node, typeFactory);
1211  }
1212  return node;
1213  }
1214 
1215  private SqlNodeList expand(final SqlNodeList group_by_list,
1216  final java.util.Map<String, SqlNode> id_to_expr,
1217  RelDataTypeFactory typeFactory) {
1218  SqlNodeList new_group_by_list = new SqlNodeList(new SqlParserPos(-1, -1));
1219  for (SqlNode group_by : group_by_list) {
1220  if (!(group_by instanceof SqlIdentifier)) {
1221  new_group_by_list.add(expand(group_by, id_to_expr, typeFactory));
1222  continue;
1223  }
1224  SqlIdentifier group_by_id = ((SqlIdentifier) group_by);
1225  if (id_to_expr.containsKey(group_by_id.toString())) {
1226  new_group_by_list.add(id_to_expr.get(group_by_id.toString()));
1227  } else {
1228  new_group_by_list.add(group_by);
1229  }
1230  }
1231  return new_group_by_list;
1232  }
1233 
1234  private SqlNode expandStringFunctions(
1235  final SqlBasicCall proj_call, RelDataTypeFactory typeFactory) {
1236  //
1237  // Expand string functions
1238  //
1239 
1240  final int operandCount = proj_call.operandCount();
1241 
1242  if (proj_call.getOperator().isName("MID", false)
1243  || proj_call.getOperator().isName("SUBSTR", false)) {
1244  // Replace MID/SUBSTR with SUBSTRING
1245  //
1246  // Note: SUBSTRING doesn't offer much flexibility for the numeric arg's type
1247  // "only constant, column, or other string operator arguments are allowed"
1248  final SqlParserPos pos = proj_call.getParserPosition();
1249  if (operandCount == 2) {
1250  final SqlNode primary_operand = proj_call.operand(0);
1251  final SqlNode from_operand = proj_call.operand(1);
1252  return SqlStdOperatorTable.SUBSTRING.createCall(
1253  pos, primary_operand, from_operand);
1254 
1255  } else if (operandCount == 3) {
1256  final SqlNode primary_operand = proj_call.operand(0);
1257  final SqlNode from_operand = proj_call.operand(1);
1258  final SqlNode for_operand = proj_call.operand(2);
1259  return SqlStdOperatorTable.SUBSTRING.createCall(
1260  pos, primary_operand, from_operand, for_operand);
1261  }
1262  return null;
1263 
1264  } else if (proj_call.getOperator().isName("CONTAINS", false)) {
1265  // Replace CONTAINS with LIKE
1266  // as noted by TABLEAU's own published documention
1267  final SqlParserPos pos = proj_call.getParserPosition();
1268  if (operandCount == 2) {
1269  final SqlNode primary = proj_call.operand(0);
1270  final SqlNode pattern = proj_call.operand(1);
1271 
1272  if (pattern instanceof SqlLiteral) {
1273  // LIKE only supports Literal patterns ... at the moment
1274  SqlLiteral literalPattern = (SqlLiteral) pattern;
1275  String sPattern = literalPattern.getValueAs(String.class);
1276  SqlLiteral withWildcards =
1277  SqlLiteral.createCharString("%" + sPattern + "%", pos);
1278  return SqlStdOperatorTable.LIKE.createCall(pos, primary, withWildcards);
1279  }
1280  }
1281  return null;
1282 
1283  } else if (proj_call.getOperator().isName("ENDSWITH", false)) {
1284  // Replace ENDSWITH with LIKE
1285  final SqlParserPos pos = proj_call.getParserPosition();
1286  if (operandCount == 2) {
1287  final SqlNode primary = proj_call.operand(0);
1288  final SqlNode pattern = proj_call.operand(1);
1289 
1290  if (pattern instanceof SqlLiteral) {
1291  // LIKE only supports Literal patterns ... at the moment
1292  SqlLiteral literalPattern = (SqlLiteral) pattern;
1293  String sPattern = literalPattern.getValueAs(String.class);
1294  SqlLiteral withWildcards = SqlLiteral.createCharString("%" + sPattern, pos);
1295  return SqlStdOperatorTable.LIKE.createCall(pos, primary, withWildcards);
1296  }
1297  }
1298  return null;
1299  } else if (proj_call.getOperator().isName("LCASE", false)) {
1300  // Expand LCASE with LOWER
1301  final SqlParserPos pos = proj_call.getParserPosition();
1302  if (operandCount == 1) {
1303  final SqlNode primary = proj_call.operand(0);
1304  return SqlStdOperatorTable.LOWER.createCall(pos, primary);
1305  }
1306  return null;
1307 
1308  } else if (proj_call.getOperator().isName("LEFT", false)) {
1309  // Replace LEFT with SUBSTRING
1310  final SqlParserPos pos = proj_call.getParserPosition();
1311 
1312  if (operandCount == 2) {
1313  final SqlNode primary = proj_call.operand(0);
1314  SqlNode start = SqlLiteral.createExactNumeric("0", SqlParserPos.ZERO);
1315  final SqlNode count = proj_call.operand(1);
1316  return SqlStdOperatorTable.SUBSTRING.createCall(pos, primary, start, count);
1317  }
1318  return null;
1319 
1320  } else if (proj_call.getOperator().isName("LEN", false)) {
1321  // Replace LEN with CHARACTER_LENGTH
1322  final SqlParserPos pos = proj_call.getParserPosition();
1323  if (operandCount == 1) {
1324  final SqlNode primary = proj_call.operand(0);
1325  return SqlStdOperatorTable.CHARACTER_LENGTH.createCall(pos, primary);
1326  }
1327  return null;
1328 
1329  } else if (proj_call.getOperator().isName("MAX", false)
1330  || proj_call.getOperator().isName("MIN", false)) {
1331  // Replace MAX(a,b), MIN(a,b) with CASE
1332  final SqlParserPos pos = proj_call.getParserPosition();
1333 
1334  if (operandCount == 2) {
1335  final SqlNode arg1 = proj_call.operand(0);
1336  final SqlNode arg2 = proj_call.operand(1);
1337 
1338  SqlNodeList whenList = new SqlNodeList(pos);
1339  SqlNodeList thenList = new SqlNodeList(pos);
1340  SqlNodeList elseClause = new SqlNodeList(pos);
1341 
1342  if (proj_call.getOperator().isName("MAX", false)) {
1343  whenList.add(
1344  SqlStdOperatorTable.GREATER_THAN_OR_EQUAL.createCall(pos, arg1, arg2));
1345  } else {
1346  whenList.add(
1347  SqlStdOperatorTable.LESS_THAN_OR_EQUAL.createCall(pos, arg1, arg2));
1348  }
1349  thenList.add(arg1);
1350  elseClause.add(arg2);
1351 
1352  SqlNode caseIdentifier = null;
1353  return SqlCase.createSwitched(
1354  pos, caseIdentifier, whenList, thenList, elseClause);
1355  }
1356  return null;
1357 
1358  } else if (proj_call.getOperator().isName("RIGHT", false)) {
1359  // Replace RIGHT with SUBSTRING
1360  final SqlParserPos pos = proj_call.getParserPosition();
1361 
1362  if (operandCount == 2) {
1363  final SqlNode primary = proj_call.operand(0);
1364  final SqlNode count = proj_call.operand(1);
1365  if (count instanceof SqlNumericLiteral) {
1366  SqlNumericLiteral numericCount = (SqlNumericLiteral) count;
1367  if (numericCount.intValue(true) > 0) {
1368  // common case
1369  final SqlNode negativeCount =
1370  SqlNumericLiteral.createNegative(numericCount, pos);
1371  return SqlStdOperatorTable.SUBSTRING.createCall(pos, primary, negativeCount);
1372  }
1373  // allow zero (or negative) to return an empty string
1374  // matches behavior of LEFT
1375  SqlNode zero = SqlLiteral.createExactNumeric("0", SqlParserPos.ZERO);
1376  return SqlStdOperatorTable.SUBSTRING.createCall(pos, primary, zero, zero);
1377  }
1378  // if not a simple literal ... attempt to evaluate
1379  // expected to fail ... with a useful error message
1380  return SqlStdOperatorTable.SUBSTRING.createCall(pos, primary, count);
1381  }
1382  return null;
1383 
1384  } else if (proj_call.getOperator().isName("SPACE", false)) {
1385  // Replace SPACE with REPEAT
1386  final SqlParserPos pos = proj_call.getParserPosition();
1387  if (operandCount == 1) {
1388  final SqlNode count = proj_call.operand(0);
1389  SqlFunction fn_repeat = new SqlFunction("REPEAT",
1390  SqlKind.OTHER_FUNCTION,
1391  ReturnTypes.ARG0_NULLABLE,
1392  null,
1393  OperandTypes.CHARACTER,
1394  SqlFunctionCategory.STRING);
1395  SqlLiteral space = SqlLiteral.createCharString(" ", pos);
1396  return fn_repeat.createCall(pos, space, count);
1397  }
1398  return null;
1399 
1400  } else if (proj_call.getOperator().isName("SPLIT", false)) {
1401  // Replace SPLIT with SPLIT_PART
1402  final SqlParserPos pos = proj_call.getParserPosition();
1403  if (operandCount == 3) {
1404  final SqlNode primary = proj_call.operand(0);
1405  final SqlNode delimeter = proj_call.operand(1);
1406  final SqlNode count = proj_call.operand(2);
1407  SqlFunction fn_split = new SqlFunction("SPLIT_PART",
1408  SqlKind.OTHER_FUNCTION,
1409  ReturnTypes.ARG0_NULLABLE,
1410  null,
1411  OperandTypes.CHARACTER,
1412  SqlFunctionCategory.STRING);
1413 
1414  return fn_split.createCall(pos, primary, delimeter, count);
1415  }
1416  return null;
1417 
1418  } else if (proj_call.getOperator().isName("STARTSWITH", false)) {
1419  // Replace STARTSWITH with LIKE
1420  final SqlParserPos pos = proj_call.getParserPosition();
1421  if (operandCount == 2) {
1422  final SqlNode primary = proj_call.operand(0);
1423  final SqlNode pattern = proj_call.operand(1);
1424 
1425  if (pattern instanceof SqlLiteral) {
1426  // LIKE only supports Literal patterns ... at the moment
1427  SqlLiteral literalPattern = (SqlLiteral) pattern;
1428  String sPattern = literalPattern.getValueAs(String.class);
1429  SqlLiteral withWildcards = SqlLiteral.createCharString(sPattern + "%", pos);
1430  return SqlStdOperatorTable.LIKE.createCall(pos, primary, withWildcards);
1431  }
1432  }
1433  return null;
1434 
1435  } else if (proj_call.getOperator().isName("UCASE", false)) {
1436  // Replace UCASE with UPPER
1437  final SqlParserPos pos = proj_call.getParserPosition();
1438  if (operandCount == 1) {
1439  final SqlNode primary = proj_call.operand(0);
1440  return SqlStdOperatorTable.UPPER.createCall(pos, primary);
1441  }
1442  return null;
1443  }
1444 
1445  return null;
1446  }
1447 
1448  private SqlNode expandVariance(
1449  final SqlBasicCall proj_call, RelDataTypeFactory typeFactory) {
1450  // Expand variance aggregates that are not supported natively
1451  if (proj_call.operandCount() != 1) {
1452  return null;
1453  }
1454  boolean biased;
1455  boolean sqrt;
1456  boolean flt;
1457  if (proj_call.getOperator().isName("STDDEV_POP", false)) {
1458  biased = true;
1459  sqrt = true;
1460  flt = false;
1461  } else if (proj_call.getOperator().getName().equalsIgnoreCase("STDDEV_POP_FLOAT")) {
1462  biased = true;
1463  sqrt = true;
1464  flt = true;
1465  } else if (proj_call.getOperator().isName("STDDEV_SAMP", false)
1466  || proj_call.getOperator().getName().equalsIgnoreCase("STDDEV")) {
1467  biased = false;
1468  sqrt = true;
1469  flt = false;
1470  } else if (proj_call.getOperator().getName().equalsIgnoreCase("STDDEV_SAMP_FLOAT")
1471  || proj_call.getOperator().getName().equalsIgnoreCase("STDDEV_FLOAT")) {
1472  biased = false;
1473  sqrt = true;
1474  flt = true;
1475  } else if (proj_call.getOperator().isName("VAR_POP", false)) {
1476  biased = true;
1477  sqrt = false;
1478  flt = false;
1479  } else if (proj_call.getOperator().getName().equalsIgnoreCase("VAR_POP_FLOAT")) {
1480  biased = true;
1481  sqrt = false;
1482  flt = true;
1483  } else if (proj_call.getOperator().isName("VAR_SAMP", false)
1484  || proj_call.getOperator().getName().equalsIgnoreCase("VARIANCE")) {
1485  biased = false;
1486  sqrt = false;
1487  flt = false;
1488  } else if (proj_call.getOperator().getName().equalsIgnoreCase("VAR_SAMP_FLOAT")
1489  || proj_call.getOperator().getName().equalsIgnoreCase("VARIANCE_FLOAT")) {
1490  biased = false;
1491  sqrt = false;
1492  flt = true;
1493  } else {
1494  return null;
1495  }
1496  final SqlNode operand = proj_call.operand(0);
1497  final SqlParserPos pos = proj_call.getParserPosition();
1498  SqlNode expanded_proj_call =
1499  expandVariance(pos, operand, biased, sqrt, flt, typeFactory);
1500  HEAVYDBLOGGER.debug("Expanded select_list SqlCall: " + proj_call.toString());
1501  HEAVYDBLOGGER.debug("to : " + expanded_proj_call.toString());
1502  return expanded_proj_call;
1503  }
1504 
1505  private SqlNode expandVariance(final SqlParserPos pos,
1506  final SqlNode operand,
1507  boolean biased,
1508  boolean sqrt,
1509  boolean flt,
1510  RelDataTypeFactory typeFactory) {
1511  // stddev_pop(x) ==>
1512  // power(
1513  // (sum(x * x) - sum(x) * sum(x) / (case count(x) when 0 then NULL else count(x)
1514  // end)) / (case count(x) when 0 then NULL else count(x) end), .5)
1515  //
1516  // stddev_samp(x) ==>
1517  // power(
1518  // (sum(x * x) - sum(x) * sum(x) / (case count(x) when 0 then NULL else count(x)
1519  // )) / ((case count(x) when 1 then NULL else count(x) - 1 end)), .5)
1520  //
1521  // var_pop(x) ==>
1522  // (sum(x * x) - sum(x) * sum(x) / ((case count(x) when 0 then NULL else
1523  // count(x)
1524  // end))) / ((case count(x) when 0 then NULL else count(x) end))
1525  //
1526  // var_samp(x) ==>
1527  // (sum(x * x) - sum(x) * sum(x) / ((case count(x) when 0 then NULL else
1528  // count(x)
1529  // end))) / ((case count(x) when 1 then NULL else count(x) - 1 end))
1530  //
1531  final SqlNode arg = SqlStdOperatorTable.CAST.createCall(pos,
1532  operand,
1533  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1534  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1535  final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
1536  final SqlNode sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared);
1537  final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg);
1538  final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, sum, sum);
1539  final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg);
1540  final SqlLiteral nul = SqlLiteral.createNull(pos);
1541  final SqlNumericLiteral zero = SqlLiteral.createExactNumeric("0", pos);
1542  final SqlNode countEqZero = SqlStdOperatorTable.EQUALS.createCall(pos, count, zero);
1543  SqlNodeList whenList = new SqlNodeList(pos);
1544  SqlNodeList thenList = new SqlNodeList(pos);
1545  whenList.add(countEqZero);
1546  thenList.add(nul);
1547  final SqlNode int_denominator = SqlStdOperatorTable.CASE.createCall(
1548  null, pos, null, whenList, thenList, count);
1549  final SqlNode denominator = SqlStdOperatorTable.CAST.createCall(pos,
1550  int_denominator,
1551  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1552  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1553  final SqlNode avgSumSquared =
1554  SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, denominator);
1555  final SqlNode diff =
1556  SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquared, avgSumSquared);
1557  final SqlNode denominator1;
1558  if (biased) {
1559  denominator1 = denominator;
1560  } else {
1561  final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
1562  final SqlNode countEqOne = SqlStdOperatorTable.EQUALS.createCall(pos, count, one);
1563  final SqlNode countMinusOne = SqlStdOperatorTable.MINUS.createCall(pos, count, one);
1564  SqlNodeList whenList1 = new SqlNodeList(pos);
1565  SqlNodeList thenList1 = new SqlNodeList(pos);
1566  whenList1.add(countEqOne);
1567  thenList1.add(nul);
1568  final SqlNode int_denominator1 = SqlStdOperatorTable.CASE.createCall(
1569  null, pos, null, whenList1, thenList1, countMinusOne);
1570  denominator1 = SqlStdOperatorTable.CAST.createCall(pos,
1571  int_denominator1,
1572  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1573  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1574  }
1575  final SqlNode div = SqlStdOperatorTable.DIVIDE.createCall(pos, diff, denominator1);
1576  SqlNode result = div;
1577  if (sqrt) {
1578  final SqlNumericLiteral half = SqlLiteral.createExactNumeric("0.5", pos);
1579  result = SqlStdOperatorTable.POWER.createCall(pos, div, half);
1580  }
1581  return SqlStdOperatorTable.CAST.createCall(pos,
1582  result,
1583  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1584  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1585  }
1586 
1587  private SqlNode expandCovariance(
1588  final SqlBasicCall proj_call, RelDataTypeFactory typeFactory) {
1589  // Expand covariance aggregates
1590  if (proj_call.operandCount() != 2) {
1591  return null;
1592  }
1593  boolean pop;
1594  boolean flt;
1595  if (proj_call.getOperator().isName("COVAR_POP", false)) {
1596  pop = true;
1597  flt = false;
1598  } else if (proj_call.getOperator().isName("COVAR_SAMP", false)) {
1599  pop = false;
1600  flt = false;
1601  } else if (proj_call.getOperator().getName().equalsIgnoreCase("COVAR_POP_FLOAT")) {
1602  pop = true;
1603  flt = true;
1604  } else if (proj_call.getOperator().getName().equalsIgnoreCase("COVAR_SAMP_FLOAT")) {
1605  pop = false;
1606  flt = true;
1607  } else {
1608  return null;
1609  }
1610  final SqlNode operand0 = proj_call.operand(0);
1611  final SqlNode operand1 = proj_call.operand(1);
1612  final SqlParserPos pos = proj_call.getParserPosition();
1613  SqlNode expanded_proj_call =
1614  expandCovariance(pos, operand0, operand1, pop, flt, typeFactory);
1615  HEAVYDBLOGGER.debug("Expanded select_list SqlCall: " + proj_call.toString());
1616  HEAVYDBLOGGER.debug("to : " + expanded_proj_call.toString());
1617  return expanded_proj_call;
1618  }
1619 
1620  private SqlNode expandCovariance(SqlParserPos pos,
1621  final SqlNode operand0,
1622  final SqlNode operand1,
1623  boolean pop,
1624  boolean flt,
1625  RelDataTypeFactory typeFactory) {
1626  // covar_pop(x, y) ==> avg(x * y) - avg(x) * avg(y)
1627  // covar_samp(x, y) ==> (sum(x * y) - sum(x) * avg(y))
1628  // ((case count(x) when 1 then NULL else count(x) - 1 end))
1629  final SqlNode arg0 = SqlStdOperatorTable.CAST.createCall(operand0.getParserPosition(),
1630  operand0,
1631  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1632  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1633  final SqlNode arg1 = SqlStdOperatorTable.CAST.createCall(operand1.getParserPosition(),
1634  operand1,
1635  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1636  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1637  final SqlNode mulArg = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg0, arg1);
1638  final SqlNode avgArg1 = SqlStdOperatorTable.AVG.createCall(pos, arg1);
1639  if (pop) {
1640  final SqlNode avgMulArg = SqlStdOperatorTable.AVG.createCall(pos, mulArg);
1641  final SqlNode avgArg0 = SqlStdOperatorTable.AVG.createCall(pos, arg0);
1642  final SqlNode mulAvgAvg =
1643  SqlStdOperatorTable.MULTIPLY.createCall(pos, avgArg0, avgArg1);
1644  final SqlNode covarPop =
1645  SqlStdOperatorTable.MINUS.createCall(pos, avgMulArg, mulAvgAvg);
1646  return SqlStdOperatorTable.CAST.createCall(pos,
1647  covarPop,
1648  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1649  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1650  }
1651  final SqlNode sumMulArg = SqlStdOperatorTable.SUM.createCall(pos, mulArg);
1652  final SqlNode sumArg0 = SqlStdOperatorTable.SUM.createCall(pos, arg0);
1653  final SqlNode mulSumAvg =
1654  SqlStdOperatorTable.MULTIPLY.createCall(pos, sumArg0, avgArg1);
1655  final SqlNode sub = SqlStdOperatorTable.MINUS.createCall(pos, sumMulArg, mulSumAvg);
1656  final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, operand0);
1657  final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
1658  final SqlNode countEqOne = SqlStdOperatorTable.EQUALS.createCall(pos, count, one);
1659  final SqlNode countMinusOne = SqlStdOperatorTable.MINUS.createCall(pos, count, one);
1660  final SqlLiteral nul = SqlLiteral.createNull(pos);
1661  SqlNodeList whenList1 = new SqlNodeList(pos);
1662  SqlNodeList thenList1 = new SqlNodeList(pos);
1663  whenList1.add(countEqOne);
1664  thenList1.add(nul);
1665  final SqlNode int_denominator = SqlStdOperatorTable.CASE.createCall(
1666  null, pos, null, whenList1, thenList1, countMinusOne);
1667  final SqlNode denominator = SqlStdOperatorTable.CAST.createCall(pos,
1668  int_denominator,
1669  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1670  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1671  final SqlNode covarSamp =
1672  SqlStdOperatorTable.DIVIDE.createCall(pos, sub, denominator);
1673  return SqlStdOperatorTable.CAST.createCall(pos,
1674  covarSamp,
1675  SqlTypeUtil.convertTypeToSpec(typeFactory.createSqlType(
1676  flt ? SqlTypeName.FLOAT : SqlTypeName.DOUBLE)));
1677  }
1678 
1679  private SqlNode expandCorrelation(
1680  final SqlBasicCall proj_call, RelDataTypeFactory typeFactory) {
1681  // Expand correlation coefficient
1682  if (proj_call.operandCount() != 2) {
1683  return null;
1684  }
1685  boolean flt;
1686  if (proj_call.getOperator().isName("CORR", false)
1687  || proj_call.getOperator().getName().equalsIgnoreCase("CORRELATION")) {
1688  // expand correlation coefficient
1689  flt = false;
1690  } else if (proj_call.getOperator().getName().equalsIgnoreCase("CORR_FLOAT")
1691  || proj_call.getOperator().getName().equalsIgnoreCase("CORRELATION_FLOAT")) {
1692  // expand correlation coefficient
1693  flt = true;
1694  } else {
1695  return null;
1696  }
1697  // corr(x, y) ==> (avg(x * y) - avg(x) * avg(y)) / (stddev_pop(x) *
1698  // stddev_pop(y))
1699  // ==> covar_pop(x, y) / (stddev_pop(x) * stddev_pop(y))
1700  final SqlNode operand0 = proj_call.operand(0);
1701  final SqlNode operand1 = proj_call.operand(1);
1702  final SqlParserPos pos = proj_call.getParserPosition();
1703  SqlNode covariance =
1704  expandCovariance(pos, operand0, operand1, true, flt, typeFactory);
1705  SqlNode stddev0 = expandVariance(pos, operand0, true, true, flt, typeFactory);
1706  SqlNode stddev1 = expandVariance(pos, operand1, true, true, flt, typeFactory);
1707  final SqlNode mulStddev =
1708  SqlStdOperatorTable.MULTIPLY.createCall(pos, stddev0, stddev1);
1709  final SqlNumericLiteral zero = SqlLiteral.createExactNumeric("0.0", pos);
1710  final SqlNode mulStddevEqZero =
1711  SqlStdOperatorTable.EQUALS.createCall(pos, mulStddev, zero);
1712  final SqlLiteral nul = SqlLiteral.createNull(pos);
1713  SqlNodeList whenList1 = new SqlNodeList(pos);
1714  SqlNodeList thenList1 = new SqlNodeList(pos);
1715  whenList1.add(mulStddevEqZero);
1716  thenList1.add(nul);
1717  final SqlNode denominator = SqlStdOperatorTable.CASE.createCall(
1718  null, pos, null, whenList1, thenList1, mulStddev);
1719  final SqlNode expanded_proj_call =
1720  SqlStdOperatorTable.DIVIDE.createCall(pos, covariance, denominator);
1721  HEAVYDBLOGGER.debug("Expanded select_list SqlCall: " + proj_call.toString());
1722  HEAVYDBLOGGER.debug("to : " + expanded_proj_call.toString());
1723  return expanded_proj_call;
1724  }
1725 
1726  public SqlIdentifierCapturer captureIdentifiers(String sql, boolean legacy_syntax)
1727  throws SqlParseException {
1728  try {
1729  Planner planner = getPlanner();
1730  SqlNode node = parseSql(sql, legacy_syntax, planner);
1731  return captureIdentifiers(node);
1732  } catch (Exception | Error e) {
1733  HEAVYDBLOGGER.error("Error parsing sql: " + sql, e);
1734  return new SqlIdentifierCapturer();
1735  }
1736  }
1737 
1738  public SqlIdentifierCapturer captureIdentifiers(SqlNode node) throws SqlParseException {
1739  try {
1741  capturer.scan(node);
1742  capturer.selects = addDbContextIfMissing(capturer.selects);
1743  capturer.updates = addDbContextIfMissing(capturer.updates);
1744  capturer.deletes = addDbContextIfMissing(capturer.deletes);
1745  capturer.inserts = addDbContextIfMissing(capturer.inserts);
1746  return capturer;
1747  } catch (Exception | Error e) {
1748  HEAVYDBLOGGER.error("Error parsing sql: " + node, e);
1749  return new SqlIdentifierCapturer();
1750  }
1751  }
1752 
1753  private Set<ImmutableList<String>> addDbContextIfMissing(
1754  Set<ImmutableList<String>> names) {
1755  Set<ImmutableList<String>> result = new HashSet<>();
1756  for (ImmutableList<String> name : names) {
1757  if (name.size() == 1) {
1758  result.add(new ImmutableList.Builder<String>()
1759  .addAll(name)
1760  .add(dbUser.getDB())
1761  .build());
1762  } else {
1763  result.add(name);
1764  }
1765  }
1766  return result;
1767  }
1768 
1769  public int getCallCount() {
1770  return callCount;
1771  }
1772 
1773  public void updateMetaData(String schema, String table) {
1774  HEAVYDBLOGGER.debug("schema :" + schema + " table :" + table);
1775  HeavyDBSchema db = new HeavyDBSchema(
1776  dataDir, this, dbPort, null, sock_transport_properties, schema);
1777  db.updateMetaData(schema, table);
1778  }
1779 
1780  protected RelDataTypeSystem createTypeSystem() {
1781  final HeavyDBTypeSystem typeSystem = new HeavyDBTypeSystem();
1782  return typeSystem;
1783  }
1784 
1786  extends SqlBasicVisitor<Void> {
1787  @Override
1788  public Void visit(SqlCall call) {
1789  if (call instanceof SqlSelect) {
1790  SqlSelect selectNode = (SqlSelect) call;
1791  String targetString = targetExpression.toString();
1792  for (SqlNode listedNode : selectNode.getSelectList()) {
1793  if (listedNode.toString().contains(targetString)) {
1794  throw Util.FoundOne.NULL;
1795  }
1796  }
1797  }
1798  return super.visit(call);
1799  }
1800 
1801  boolean containsExpression(SqlNode node, SqlNode targetExpression) {
1802  try {
1803  this.targetExpression = targetExpression;
1804  node.accept(this);
1805  return false;
1806  } catch (Util.FoundOne e) {
1807  return true;
1808  }
1809  }
1810 
1812  }
1813 
1815  extends SqlBasicVisitor<Void> {
1816  @Override
1817  public Void visit(SqlCall call) {
1818  if (call instanceof SqlBasicCall) {
1819  SqlBasicCall basicCall = (SqlBasicCall) call;
1820  if (basicCall.getKind() == SqlKind.OR) {
1821  String targetString = targetExpression.toString();
1822  for (SqlNode listedOperand : basicCall.operands) {
1823  if (listedOperand.toString().contains(targetString)) {
1824  throw Util.FoundOne.NULL;
1825  }
1826  }
1827  }
1828  }
1829  return super.visit(call);
1830  }
1831 
1832  boolean containsExpression(SqlNode node, SqlNode targetExpression) {
1833  try {
1834  this.targetExpression = targetExpression;
1835  node.accept(this);
1836  return false;
1837  } catch (Util.FoundOne e) {
1838  return true;
1839  }
1840  }
1841 
1843  }
1844 
1845  private static class JoinOperatorChecker extends SqlBasicVisitor<Void> {
1846  Set<SqlBasicCall> targetCalls = new HashSet<>();
1847 
1848  public boolean isEqualityJoinOperator(SqlBasicCall basicCall) {
1849  if (null != basicCall) {
1850  if (basicCall.operands.length == 2
1851  && (basicCall.getKind() == SqlKind.EQUALS
1852  || basicCall.getKind() == SqlKind.NOT_EQUALS)
1853  && basicCall.operand(0) instanceof SqlIdentifier
1854  && basicCall.operand(1) instanceof SqlIdentifier) {
1855  return true;
1856  }
1857  }
1858  return false;
1859  }
1860 
1861  @Override
1862  public Void visit(SqlCall call) {
1863  if (call instanceof SqlBasicCall) {
1864  targetCalls.add((SqlBasicCall) call);
1865  }
1866  for (SqlNode node : call.getOperandList()) {
1867  if (null != node && !targetCalls.contains(node)) {
1868  node.accept(this);
1869  }
1870  }
1871  return super.visit(call);
1872  }
1873 
1874  boolean containsExpression(SqlNode node) {
1875  try {
1876  if (null != node) {
1877  node.accept(this);
1878  for (SqlBasicCall basicCall : targetCalls) {
1879  if (isEqualityJoinOperator(basicCall)) {
1880  throw Util.FoundOne.NULL;
1881  }
1882  }
1883  }
1884  return false;
1885  } catch (Util.FoundOne e) {
1886  return true;
1887  }
1888  }
1889  }
1890 
1891  // this visitor checks whether a parse tree contains at least one
1892  // specific SQL operator we have an interest in
1893  // (do not count the accurate # operators we found)
1894  private static class FindSqlOperator extends SqlBasicVisitor<Void> {
1895  @Override
1896  public Void visit(SqlCall call) {
1897  if (call instanceof SqlBasicCall) {
1898  SqlBasicCall basicCall = (SqlBasicCall) call;
1899  if (basicCall.getKind().equals(targetKind)) {
1900  throw Util.FoundOne.NULL;
1901  }
1902  }
1903  return super.visit(call);
1904  }
1905 
1906  boolean containsSqlOperator(SqlNode node, SqlKind operatorKind) {
1907  try {
1908  targetKind = operatorKind;
1909  node.accept(this);
1910  return false;
1911  } catch (Util.FoundOne e) {
1912  return true;
1913  }
1914  }
1915 
1916  private SqlKind targetKind;
1917  }
1918 
1919  public void tableAliasFinder(SqlNode sqlNode, Map<String, String> tableAliasMap) {
1920  final SqlVisitor<Void> aliasCollector = new SqlBasicVisitor<Void>() {
1921  @Override
1922  public Void visit(SqlCall call) {
1923  if (call instanceof SqlBasicCall) {
1924  SqlBasicCall basicCall = (SqlBasicCall) call;
1925  if (basicCall.getKind() == SqlKind.AS) {
1926  if (basicCall.operand(0) instanceof SqlIdentifier) {
1927  // we need to check whether basicCall's the first operand is SqlIdentifier
1928  // since sometimes it represents non column identifier like SqlSelect
1929  SqlIdentifier colNameIdentifier = (SqlIdentifier) basicCall.operand(0);
1930  String tblName = colNameIdentifier.names.size() == 1
1931  ? colNameIdentifier.names.get(0)
1932  : colNameIdentifier.names.get(1);
1933  tableAliasMap.put(basicCall.operand(1).toString(), tblName);
1934  }
1935  }
1936  }
1937  return super.visit(call);
1938  }
1939  };
1940  sqlNode.accept(aliasCollector);
1941  }
1942 }
void desugarExpression(SqlNode node, RelDataTypeFactory typeFactory)
boolean containsExpression(SqlNode node, SqlNode targetExpression)
SqlIdentifierCapturer captureIdentifiers(SqlNode node)
SqlNode expandCovariance(final SqlBasicCall proj_call, RelDataTypeFactory typeFactory)
SqlNode expandVariance(final SqlBasicCall proj_call, RelDataTypeFactory typeFactory)
JoinType
Definition: sqldefs.h:174
SqlSelect rewriteSimpleUpdateAsSelect(final SqlUpdate update)
SqlOrderBy desugar(SqlSelect select_node, SqlOrderBy order_by_node, RelDataTypeFactory typeFactory)
RelRoot replaceIsTrue(final RelDataTypeFactory typeFactory, RelRoot root)
std::string join(T const &container, std::string const &delim)
tuple root
Definition: setup.in.py:14
SqlNode expandCase(SqlCase old_case_node, RelDataTypeFactory typeFactory)
std::string toString(const QueryDescriptionType &type)
Definition: Types.h:64
SqlNode parseSql(String sql, final boolean legacy_syntax, Planner planner)
static final EnumSet< SqlKind > ARRAY_VALUE
boolean isColumnHashJoinable(List< String > joinColumnIdentifier, MetaConnect mc)
HashSet< ImmutableList< String > > resolveSelectIdentifiers(SqlIdentifierCapturer capturer)
SqlNode expandStringFunctions(final SqlBasicCall proj_call, RelDataTypeFactory typeFactory)
RelRoot convertSqlToRelNode(final SqlNode sqlNode, final HeavyDBPlanner HeavyDBPlanner, final HeavyDBParserOptions parserOptions)
static final ThreadLocal< HeavyDBParser > CURRENT_PARSER
static final EnumSet< SqlKind > UPDATE
LogicalTableModify getDummyUpdate(SqlUpdate update)
String processSql(final SqlNode sqlNode, final HeavyDBParserOptions parserOptions)
static final EnumSet< SqlKind > IN
static final EnumSet< SqlKind > OTHER_FUNCTION
SqlIdentifierCapturer getAccessedObjects()
static Map< String, Boolean > SubqueryCorrMemo
boolean isHashJoinableType(TColumnType type)
SockTransportProperties sock_transport_properties
RelRoot rewriteUpdateAsSelect(SqlUpdate update, HeavyDBParserOptions parserOptions)
String processSql(String sql, final HeavyDBParserOptions parserOptions)
void updateMetaData(String schema, String table)
static final Context DB_CONNECTION_CONTEXT
static final EnumSet< SqlKind > DELETE
String buildRATreeAndPerformQueryOptimization(String query, final HeavyDBParserOptions parserOptions)
SqlIdentifierCapturer captureIdentifiers(String sql, boolean legacy_syntax)
torch::Tensor f(torch::Tensor x, torch::Tensor W_target, torch::Tensor b_target)
SqlNode expandVariance(final SqlParserPos pos, final SqlNode operand, boolean biased, boolean sqrt, boolean flt, RelDataTypeFactory typeFactory)
void desugar(SqlSelect select_node, RelDataTypeFactory typeFactory)
boolean containsSqlOperator(SqlNode node, SqlKind operatorKind)
static final EnumSet< SqlKind > SCALAR
HeavyDBPlanner.CompletionResult getCompletionHints(String sql, int cursor, List< String > visible_tables)
RelRoot queryToRelNode(final String sql, final HeavyDBParserOptions parserOptions)
SqlNode expandCovariance(SqlParserPos pos, final SqlNode operand0, final SqlNode operand1, boolean pop, boolean flt, RelDataTypeFactory typeFactory)
string name
Definition: setup.in.py:72
constexpr double n
Definition: Utm.h:38
SqlNodeList expand(final SqlNodeList group_by_list, final java.util.Map< String, SqlNode > id_to_expr, RelDataTypeFactory typeFactory)
Set< ImmutableList< String > > addDbContextIfMissing(Set< ImmutableList< String >> names)
SqlNode expandCorrelation(final SqlBasicCall proj_call, RelDataTypeFactory typeFactory)
Pair< String, SqlIdentifierCapturer > process(String sql, final HeavyDBParserOptions parserOptions)
static final SqlArrayValueConstructorAllowingEmpty ARRAY_VALUE_CONSTRUCTOR
static final EnumSet< SqlKind > EXISTS
void tableAliasFinder(SqlNode sqlNode, Map< String, String > tableAliasMap)
boolean isCorrelated(SqlNode expression)
HeavyDBParser(String dataDir, final Supplier< HeavyDBSqlOperatorTable > dbSqlOperatorTable, int dbPort, SockTransportProperties skT)
SqlNode expand(final SqlNode node, final java.util.Map< String, SqlNode > id_to_expr, RelDataTypeFactory typeFactory)
HeavyDBPlanner getPlanner(final boolean allowSubQueryExpansion, final boolean isWatchdogEnabled, final boolean isDistributedMode)
final Supplier< HeavyDBSqlOperatorTable > dbSqlOperatorTable