OmniSciDB  c1a53651b2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
OuterJoinOptViaNullRejectionRule.java
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package org.apache.calcite.rel.rules;
18 
19 import org.apache.calcite.plan.RelOptRuleCall;
20 import org.apache.calcite.plan.hep.HepRelVertex;
21 import org.apache.calcite.rel.RelNode;
22 import org.apache.calcite.rel.core.Join;
23 import org.apache.calcite.rel.core.JoinRelType;
24 import org.apache.calcite.rel.logical.LogicalFilter;
25 import org.apache.calcite.rel.logical.LogicalJoin;
26 import org.apache.calcite.rel.logical.LogicalProject;
27 import org.apache.calcite.rel.logical.LogicalTableScan;
28 import org.apache.calcite.rex.RexCall;
29 import org.apache.calcite.rex.RexInputRef;
30 import org.apache.calcite.rex.RexLiteral;
31 import org.apache.calcite.rex.RexNode;
32 import org.apache.calcite.sql.SqlKind;
33 import org.apache.calcite.tools.RelBuilder;
34 import org.apache.calcite.tools.RelBuilderFactory;
35 import org.apache.calcite.util.mapping.Mappings;
36 import org.slf4j.Logger;
37 import org.slf4j.LoggerFactory;
38 
39 import java.util.ArrayList;
40 import java.util.HashMap;
41 import java.util.HashSet;
42 import java.util.List;
43 import java.util.Map;
44 import java.util.Set;
45 
47  // goal: relax full outer join to either left or inner joins
48  // consider two tables 'foo(a int, b int)' and 'bar(c int, d int)'
49  // foo = {(1,3), (2,4), (NULL, 5)} // bar = {(1,2), (4, 3), (NULL, 5)}
50 
51  // 1. full outer join -> left
52  // : select * from foo full outer join bar on a = c where a is not null;
53  // = select * from foo left outer join bar on a = c where a is not null;
54 
55  // 2. full outer join -> inner
56  // : select * from foo full outer join bar on a = c where a is not null and c is
57  // not null; = select * from foo join bar on a = c; (or select * from foo, bar
58  // where a = c;)
59 
60  // 3. left outer join --> inner
61  // : select * from foo left outer join bar on a = c where c is not null;
62  // = select * from foo join bar on a = c; (or select * from foo, bar where a = c;)
63 
64  // null rejection: "col IS NOT NULL" or "col > NULL_INDICATOR" in WHERE clause
65  // i.e., col > 1 must reject any tuples having null value in a col column
66 
67  // todo(yoonmin): runtime query optimization via statistic
68  // in fact, we can optimize more broad range of the query having outer joins
69  // by using filter predicates on join tables (but not on join cols)
70  // because such filter conditions could affect join tables and
71  // they can make join cols to be null rejected
72 
73  public static Set<String> visitedJoinMemo = new HashSet<>();
74  final static Logger HEAVYDBLOGGER =
75  LoggerFactory.getLogger(OuterJoinOptViaNullRejectionRule.class);
76 
77  public OuterJoinOptViaNullRejectionRule(RelBuilderFactory relBuilderFactory) {
78  super(operand(RelNode.class, operand(Join.class, null, any())),
79  relBuilderFactory,
80  "OuterJoinOptViaNullRejectionRule");
81  clearMemo();
82  }
83 
84  void clearMemo() {
85  visitedJoinMemo.clear();
86  }
87 
88  @Override
89  public void onMatch(RelOptRuleCall call) {
90  RelNode parentNode = call.rel(0);
91  LogicalJoin join = (LogicalJoin) call.rel(1);
92  String condString = join.getCondition().toString();
93  if (visitedJoinMemo.contains(condString)) {
94  return;
95  } else {
96  visitedJoinMemo.add(condString);
97  }
98  if (!(join.getCondition() instanceof RexCall)) {
99  return; // an inner join
100  }
101  if (join.getJoinType() == JoinRelType.INNER || join.getJoinType() == JoinRelType.SEMI
102  || join.getJoinType() == JoinRelType.ANTI) {
103  return; // non target
104  }
105  RelNode joinLeftChild = ((HepRelVertex) join.getLeft()).getCurrentRel();
106  RelNode joinRightChild = ((HepRelVertex) join.getRight()).getCurrentRel();
107  if (joinLeftChild instanceof LogicalProject) {
108  return; // disable this opt when LHS has subquery (i.e., filter push-down)
109  }
110  if (!(joinRightChild instanceof LogicalTableScan)) {
111  return; // disable this opt when RHS has subquery (i.e., filter push-down)
112  }
113  // an outer join contains its join cond in itself,
114  // not in a filter as typical inner join op. does
115  RexCall joinCond = (RexCall) join.getCondition();
116  Set<Integer> leftJoinCols = new HashSet<>();
117  Set<Integer> rightJoinCols = new HashSet<>();
118  Map<Integer, String> leftJoinColToColNameMap = new HashMap<>();
119  Map<Integer, String> rightJoinColToColNameMap = new HashMap<>();
120  Set<Integer> originalLeftJoinCols = new HashSet<>();
121  Set<Integer> originalRightJoinCols = new HashSet<>();
122  Map<Integer, String> originalLeftJoinColToColNameMap = new HashMap<>();
123  Map<Integer, String> originalRightJoinColToColNameMap = new HashMap<>();
124  List<RexCall> capturedFilterPredFromJoin = new ArrayList<>();
125  if (joinCond.getKind() == SqlKind.EQUALS) {
126  addJoinCols(joinCond,
127  join,
128  leftJoinCols,
129  rightJoinCols,
130  leftJoinColToColNameMap,
131  rightJoinColToColNameMap,
132  originalLeftJoinCols,
133  originalRightJoinCols,
134  originalLeftJoinColToColNameMap,
135  originalRightJoinColToColNameMap);
136  // we only consider ANDED exprs
137  } else if (joinCond.getKind() == SqlKind.AND) {
138  for (RexNode n : joinCond.getOperands()) {
139  if (n instanceof RexCall) {
140  RexCall op = (RexCall) n;
141  if (op.getOperands().size() > 2
142  && op.getOperands().get(1) instanceof RexLiteral) {
143  // try to capture literal comparison of join column located in the cur join
144  // node
145  capturedFilterPredFromJoin.add(op);
146  continue;
147  }
148  addJoinCols(op,
149  join,
150  leftJoinCols,
151  rightJoinCols,
152  leftJoinColToColNameMap,
153  rightJoinColToColNameMap,
154  originalLeftJoinCols,
155  originalRightJoinCols,
156  originalLeftJoinColToColNameMap,
157  originalRightJoinColToColNameMap);
158  }
159  }
160  }
161 
162  if (leftJoinCols.isEmpty() || rightJoinCols.isEmpty()) {
163  return;
164  }
165 
166  // find filter node(s)
167  RelNode root = call.getPlanner().getRoot();
168  List<LogicalFilter> collectedFilterNodes = new ArrayList<>();
169  RelNode curNode = root;
170  final RelBuilder relBuilder = call.builder();
171  // collect filter nodes
172  collectFilterCondition(curNode, collectedFilterNodes);
173  if (collectedFilterNodes.isEmpty()) {
174  // we have a last chance to take a look at this join condition itself
175  // i.e., the filter preds lay with the join conditions in the same join node
176  // but for now we disable the optimization to avoid unexpected plan issue
177  return;
178  }
179 
180  // check whether join column has filter predicate(s)
181  // and collect join column info used in target join nodes to be translated
182  Set<Integer> nullRejectedLeftJoinCols = new HashSet<>();
183  Set<Integer> nullRejectedRightJoinCols = new HashSet<>();
184  boolean hasExprsConnectedViaOR = false;
185  for (LogicalFilter filter : collectedFilterNodes) {
186  RexNode node = filter.getCondition();
187  if (node instanceof RexCall) {
188  RexCall curExpr = (RexCall) node;
189  // we only consider ANDED exprs
190  if (curExpr.getKind() == SqlKind.OR) {
191  hasExprsConnectedViaOR = true;
192  break;
193  }
194  if (curExpr.getKind() == SqlKind.AND) {
195  for (RexNode n : curExpr.getOperands()) {
196  if (n instanceof RexCall) {
197  RexCall c = (RexCall) n;
198  if (isCandidateFilterPred(c)) {
199  RexInputRef col = (RexInputRef) c.getOperands().get(0);
200  int colId = col.getIndex();
201  boolean leftFilter = leftJoinCols.contains(colId);
202  boolean rightFilter = rightJoinCols.contains(colId);
203  if (leftFilter && rightFilter) {
204  // here we currently do not have a concrete column tracing logic
205  // so it may become a source of plan issue, so we disable this opt
206  return;
207  }
209  filter,
210  nullRejectedLeftJoinCols,
211  nullRejectedRightJoinCols,
212  leftJoinColToColNameMap,
213  rightJoinColToColNameMap);
214  }
215  }
216  }
217  } else {
218  if (curExpr instanceof RexCall) {
219  if (isCandidateFilterPred(curExpr)) {
220  RexInputRef col = (RexInputRef) curExpr.getOperands().get(0);
221  int colId = col.getIndex();
222  boolean leftFilter = leftJoinCols.contains(colId);
223  boolean rightFilter = rightJoinCols.contains(colId);
224  if (leftFilter && rightFilter) {
225  // here we currently do not have a concrete column tracing logic
226  // so it may become a source of plan issue, so we disable this opt
227  return;
228  }
229  addNullRejectedJoinCols(curExpr,
230  filter,
231  nullRejectedLeftJoinCols,
232  nullRejectedRightJoinCols,
233  leftJoinColToColNameMap,
234  rightJoinColToColNameMap);
235  }
236  }
237  }
238  }
239  }
240 
241  // we skip to optimize this query since analyzing complex filter exprs
242  // connected via OR condition is complex and risky
243  if (hasExprsConnectedViaOR) {
244  return;
245  }
246 
247  if (!capturedFilterPredFromJoin.isEmpty()) {
248  for (RexCall c : capturedFilterPredFromJoin) {
249  RexInputRef col = (RexInputRef) c.getOperands().get(0);
250  int colId = col.getIndex();
251  String colName = join.getRowType().getFieldNames().get(colId);
252  Boolean l = false;
253  Boolean r = false;
254  if (originalLeftJoinColToColNameMap.containsKey(colId)
255  && originalLeftJoinColToColNameMap.get(colId).equals(colName)) {
256  l = true;
257  }
258  if (originalRightJoinColToColNameMap.containsKey(colId)
259  && originalRightJoinColToColNameMap.get(colId).equals(colName)) {
260  r = true;
261  }
262  if (l && !r) {
263  nullRejectedLeftJoinCols.add(colId);
264  } else if (r && !l) {
265  nullRejectedRightJoinCols.add(colId);
266  } else if (r && l) {
267  return;
268  }
269  }
270  }
271 
272  Boolean leftNullRejected = false;
273  Boolean rightNullRejected = false;
274  if (!nullRejectedLeftJoinCols.isEmpty()
275  && leftJoinCols.equals(nullRejectedLeftJoinCols)) {
276  leftNullRejected = true;
277  }
278  if (!nullRejectedRightJoinCols.isEmpty()
279  && rightJoinCols.equals(nullRejectedRightJoinCols)) {
280  rightNullRejected = true;
281  }
282 
283  // relax outer join condition depending on null rejected cols
284  RelNode newJoinNode = null;
285  Boolean needTransform = false;
286  if (join.getJoinType() == JoinRelType.FULL) {
287  // 1) full -> left
288  if (leftNullRejected && !rightNullRejected) {
289  newJoinNode = join.copy(join.getTraitSet(),
290  join.getCondition(),
291  join.getLeft(),
292  join.getRight(),
293  JoinRelType.LEFT,
294  join.isSemiJoinDone());
295  needTransform = true;
296  }
297 
298  // 2) full -> inner
299  if (leftNullRejected && rightNullRejected) {
300  newJoinNode = join.copy(join.getTraitSet(),
301  join.getCondition(),
302  join.getLeft(),
303  join.getRight(),
304  JoinRelType.INNER,
305  join.isSemiJoinDone());
306  needTransform = true;
307  }
308  } else if (join.getJoinType() == JoinRelType.LEFT) {
309  // 3) left -> inner
310  if (rightNullRejected) {
311  newJoinNode = join.copy(join.getTraitSet(),
312  join.getCondition(),
313  join.getLeft(),
314  join.getRight(),
315  JoinRelType.INNER,
316  join.isSemiJoinDone());
317  needTransform = true;
318  }
319  }
320  if (needTransform) {
321  relBuilder.push(newJoinNode);
322  parentNode.replaceInput(0, newJoinNode);
323  call.transformTo(parentNode);
324  }
325  return;
326  }
327 
328  void addJoinCols(RexCall joinCond,
329  LogicalJoin joinOp,
330  Set<Integer> leftJoinCols,
331  Set<Integer> rightJoinCols,
332  Map<Integer, String> leftJoinColToColNameMap,
333  Map<Integer, String> rightJoinColToColNameMap,
334  Set<Integer> originalLeftJoinCols,
335  Set<Integer> originalRightJoinCols,
336  Map<Integer, String> originalLeftJoinColToColNameMap,
337  Map<Integer, String> originalRightJoinColToColNameMap) {
338  if (joinCond.getOperands().size() != 2
339  || !(joinCond.getOperands().get(0) instanceof RexInputRef)
340  || !(joinCond.getOperands().get(1) instanceof RexInputRef)) {
341  return;
342  }
343  RexInputRef leftJoinCol = (RexInputRef) joinCond.getOperands().get(0);
344  RexInputRef rightJoinCol = (RexInputRef) joinCond.getOperands().get(1);
345  originalLeftJoinCols.add(leftJoinCol.getIndex());
346  originalRightJoinCols.add(rightJoinCol.getIndex());
347  originalLeftJoinColToColNameMap.put(leftJoinCol.getIndex(),
348  joinOp.getRowType().getFieldNames().get(leftJoinCol.getIndex()));
349  originalRightJoinColToColNameMap.put(rightJoinCol.getIndex(),
350  joinOp.getRowType().getFieldNames().get(rightJoinCol.getIndex()));
351  if (leftJoinCol.getIndex() > rightJoinCol.getIndex()) {
352  leftJoinCol = (RexInputRef) joinCond.getOperands().get(1);
353  rightJoinCol = (RexInputRef) joinCond.getOperands().get(0);
354  }
355  int originalLeftColOffset = traceColOffset(joinOp.getLeft(), leftJoinCol, 0);
356  int originalRightColOffset = traceColOffset(joinOp.getRight(),
357  rightJoinCol,
358  joinOp.getLeft().getRowType().getFieldCount());
359  if (originalLeftColOffset != -1) {
360  return;
361  }
362  int leftColOffset =
363  originalLeftColOffset == -1 ? leftJoinCol.getIndex() : originalLeftColOffset;
364  int rightColOffset = originalRightColOffset == -1 ? rightJoinCol.getIndex()
365  : originalRightColOffset;
366  String leftJoinColName = joinOp.getRowType().getFieldNames().get(leftColOffset);
367  String rightJoinColName =
368  joinOp.getRowType().getFieldNames().get(rightJoinCol.getIndex());
369  leftJoinCols.add(leftColOffset);
370  rightJoinCols.add(rightColOffset);
371  leftJoinColToColNameMap.put(leftColOffset, leftJoinColName);
372  rightJoinColToColNameMap.put(rightColOffset, rightJoinColName);
373  return;
374  }
375 
376  void addNullRejectedJoinCols(RexCall call,
377  LogicalFilter targetFilter,
378  Set<Integer> nullRejectedLeftJoinCols,
379  Set<Integer> nullRejectedRightJoinCols,
380  Map<Integer, String> leftJoinColToColNameMap,
381  Map<Integer, String> rightJoinColToColNameMap) {
382  if (isCandidateFilterPred(call) && call.getOperands().get(0) instanceof RexInputRef) {
383  RexInputRef col = (RexInputRef) call.getOperands().get(0);
384  int colId = col.getIndex();
385  String colName = targetFilter.getRowType().getFieldNames().get(colId);
386  Boolean l = false;
387  Boolean r = false;
388  if (leftJoinColToColNameMap.containsKey(colId)
389  && leftJoinColToColNameMap.get(colId).equals(colName)) {
390  l = true;
391  }
392  if (rightJoinColToColNameMap.containsKey(colId)
393  && rightJoinColToColNameMap.get(colId).equals(colName)) {
394  r = true;
395  }
396  if (l && !r) {
397  nullRejectedLeftJoinCols.add(colId);
398  return;
399  }
400  if (r && !l) {
401  nullRejectedRightJoinCols.add(colId);
402  return;
403  }
404  }
405  }
406 
407  void collectFilterCondition(RelNode curNode, List<LogicalFilter> collectedFilterNodes) {
408  if (curNode instanceof HepRelVertex) {
409  curNode = ((HepRelVertex) curNode).getCurrentRel();
410  }
411  if (curNode instanceof LogicalFilter) {
412  collectedFilterNodes.add((LogicalFilter) curNode);
413  }
414  if (curNode.getInputs().size() == 0) {
415  // end of the query plan, move out
416  return;
417  }
418  for (int i = 0; i < curNode.getInputs().size(); i++) {
419  collectFilterCondition(curNode.getInput(i), collectedFilterNodes);
420  }
421  }
422 
423  void collectProjectNode(RelNode curNode, List<LogicalProject> collectedProject) {
424  if (curNode instanceof HepRelVertex) {
425  curNode = ((HepRelVertex) curNode).getCurrentRel();
426  }
427  if (curNode instanceof LogicalProject) {
428  collectedProject.add((LogicalProject) curNode);
429  }
430  if (curNode.getInputs().size() == 0) {
431  // end of the query plan, move out
432  return;
433  }
434  for (int i = 0; i < curNode.getInputs().size(); i++) {
435  collectProjectNode(curNode.getInput(i), collectedProject);
436  }
437  }
438 
439  int traceColOffset(RelNode curNode, RexInputRef colRef, int startOffset) {
440  int colOffset = -1;
441  ArrayList<LogicalProject> collectedProjectNodes = new ArrayList<>();
442  collectProjectNode(curNode, collectedProjectNodes);
443  // the nearest project node that may permute the column offset
444  if (!collectedProjectNodes.isEmpty()) {
445  // get the closest project node from the cur join node's target child
446  LogicalProject projectNode = collectedProjectNodes.get(0);
447  Mappings.TargetMapping targetMapping = projectNode.getMapping();
448  if (null != colRef && null != targetMapping) {
449  // try to track the original col offset
450  int base_offset = colRef.getIndex() - startOffset;
451 
452  if (base_offset >= 0 && base_offset < targetMapping.getSourceCount()) {
453  colOffset = targetMapping.getSourceOpt(base_offset);
454  }
455  }
456  }
457  return colOffset;
458  }
459 
460  boolean isComparisonOp(RexCall c) {
461  SqlKind opKind = c.getKind();
462  return (SqlKind.BINARY_COMPARISON.contains(opKind)
463  || SqlKind.BINARY_EQUALITY.contains(opKind));
464  }
465 
466  boolean isNotNullFilter(RexCall c) {
467  return (c.op.kind == SqlKind.IS_NOT_NULL && c.operands.size() == 1);
468  }
469 
470  boolean isCandidateFilterPred(RexCall c) {
471  return (isNotNullFilter(c)
472  || (c.operands.size() == 2 && isComparisonOp(c)
473  && c.operands.get(0) instanceof RexInputRef
474  && c.operands.get(1) instanceof RexLiteral));
475  }
476 }
void collectFilterCondition(RelNode curNode, List< LogicalFilter > collectedFilterNodes)
void addNullRejectedJoinCols(RexCall call, LogicalFilter targetFilter, Set< Integer > nullRejectedLeftJoinCols, Set< Integer > nullRejectedRightJoinCols, Map< Integer, String > leftJoinColToColNameMap, Map< Integer, String > rightJoinColToColNameMap)
std::string join(T const &container, std::string const &delim)
tuple root
Definition: setup.in.py:14
void addJoinCols(RexCall joinCond, LogicalJoin joinOp, Set< Integer > leftJoinCols, Set< Integer > rightJoinCols, Map< Integer, String > leftJoinColToColNameMap, Map< Integer, String > rightJoinColToColNameMap, Set< Integer > originalLeftJoinCols, Set< Integer > originalRightJoinCols, Map< Integer, String > originalLeftJoinColToColNameMap, Map< Integer, String > originalRightJoinColToColNameMap)
std::string toString(const ExecutorDeviceType &device_type)
int traceColOffset(RelNode curNode, RexInputRef colRef, int startOffset)
void collectProjectNode(RelNode curNode, List< LogicalProject > collectedProject)
constexpr double n
Definition: Utm.h:38