17 package com.mapd.calcite.rel.rules;
19 import org.apache.calcite.plan.RelOptCluster;
20 import org.apache.calcite.plan.RelOptRuleCall;
21 import org.apache.calcite.plan.RelOptUtil;
22 import org.apache.calcite.plan.RelRule;
23 import org.apache.calcite.plan.hep.HepRelVertex;
24 import org.apache.calcite.rel.RelNode;
25 import org.apache.calcite.rel.logical.LogicalFilter;
26 import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
27 import org.apache.calcite.rel.metadata.RelColumnMapping;
28 import org.apache.calcite.rel.rules.TransformationRule;
29 import org.apache.calcite.rel.type.RelDataTypeField;
30 import org.apache.calcite.rex.RexBuilder;
31 import org.apache.calcite.rex.RexNode;
32 import org.apache.calcite.rex.RexUtil;
33 import org.apache.calcite.tools.RelBuilder;
34 import org.apache.calcite.tools.RelBuilderFactory;
35 import org.apache.calcite.util.ImmutableBitSet;
37 import java.util.ArrayList;
38 import java.util.Arrays;
39 import java.util.HashMap;
40 import java.util.HashSet;
41 import java.util.List;
53 extends RelRule<FilterTableFunctionMultiInputTransposeRule.Config>
54 implements TransformationRule {
62 this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(
Config.class));
68 public void onMatch(RelOptRuleCall call) {
69 final Boolean debugMode =
false;
70 LogicalFilter filter = call.rel(0);
71 LogicalTableFunctionScan funcRel = call.rel(1);
72 Set<RelColumnMapping> columnMappings = funcRel.getColumnMappings();
73 if (columnMappings == null || columnMappings.isEmpty()) {
87 List<RelNode> funcInputs = funcRel.getInputs();
88 final Integer numFuncInputs = funcInputs.size();
89 if (numFuncInputs < 1) {
90 debugPrint(
"RETURN: funcInputs.size()=" + funcInputs.size(), debugMode);
94 List<HashMap<Integer, Integer>> columnMaps =
95 new ArrayList<HashMap<Integer, Integer>>(numFuncInputs);
96 for (
Integer i = 0; i < numFuncInputs; i++) {
97 columnMaps.add(i,
new HashMap<Integer, Integer>());
100 for (RelColumnMapping mapping : columnMappings) {
101 debugPrint(
"iInputRel.iInputColumn: mapping.iOutputColumn=" + mapping.iInputRel
102 +
"." + mapping.iInputColumn +
": " + mapping.iOutputColumn,
104 if (mapping.derived) {
107 columnMaps.get(mapping.iInputRel).put(mapping.iOutputColumn, mapping.iInputColumn);
110 final List<RelNode> newFuncInputs =
new ArrayList<>();
111 final RelOptCluster cluster = funcRel.getCluster();
112 final RexNode condition = filter.getCondition();
113 debugPrint(
"condition=" + condition, debugMode);
117 List<RelDataTypeField> outputFields = funcRel.getRowType().getFieldList();
118 final Integer numOutputs = outputFields.size();
122 List<RexNode> outputConjunctivePredicates = RelOptUtil.conjunctions(condition);
123 final Integer numConjunctivePredicates = outputConjunctivePredicates.size();
124 int[] outputColPushdownCount =
new int[numOutputs];
125 int[] successfulFilterPushDowns =
new int[numConjunctivePredicates];
126 int[] failedFilterPushDowns =
new int[numConjunctivePredicates];
129 Boolean didPushDown =
false;
132 for (RelNode funcInput : funcInputs) {
133 final List<RelDataTypeField> inputFields = funcInput.getRowType().getFieldList();
134 debugPrint(
"inputFields=" + inputFields, debugMode);
135 List<RelDataTypeField> validInputFields =
new ArrayList<RelDataTypeField>();
136 List<RelDataTypeField> validOutputFields =
new ArrayList<RelDataTypeField>();
137 int[] adjustments =
new int[numOutputs];
138 List<RexNode> filtersToBePushedDown =
new ArrayList<>();
139 Set<Integer> uniquePushedDownOutputIdxs =
new HashSet<Integer>();
140 Set<Integer> seenOutputIdxs =
new HashSet<Integer>();
143 columnMaps.get(inputRelIdx).entrySet()) {
144 final Integer inputColIdx = outputInputColMapping.getValue();
145 final Integer outputColIdx = outputInputColMapping.getKey();
146 validInputFields.add(inputFields.get(inputColIdx));
147 validOutputFields.add(outputFields.get(outputColIdx));
148 adjustments[outputColIdx] = inputColIdx - outputColIdx;
150 debugPrint(
"validInputFields: " + validInputFields, debugMode);
151 debugPrint(
"validOutputFields: " + validOutputFields, debugMode);
152 debugPrint(
"adjustments=" + Arrays.toString(adjustments), debugMode);
153 Boolean anyFilterRefsPartiallyMapToInputs =
false;
154 List<Boolean> subFiltersDidMapToAnyInputs =
new ArrayList<Boolean>();
156 for (RexNode conjunctiveFilter : outputConjunctivePredicates) {
157 ImmutableBitSet filterRefs = RelOptUtil.InputFinder.bits(conjunctiveFilter);
158 final List<Integer> filterRefColIdxList = filterRefs.toList();
159 Boolean anyFilterColsPresentInInput =
false;
160 Boolean allFilterColsPresentInInput =
true;
161 for (
Integer filterRefColIdx : filterRefColIdxList) {
162 debugPrint(
"filterColIdx: " + filterRefColIdx, debugMode);
163 if (!(columnMaps.get(inputRelIdx).containsKey(filterRefColIdx))) {
164 allFilterColsPresentInInput =
false;
166 anyFilterColsPresentInInput =
true;
167 uniquePushedDownOutputIdxs.add(filterRefColIdx);
168 seenOutputIdxs.add(filterRefColIdx);
171 subFiltersDidMapToAnyInputs.add(anyFilterColsPresentInInput);
172 if (anyFilterColsPresentInInput) {
173 if (allFilterColsPresentInInput) {
174 filtersToBePushedDown.add(conjunctiveFilter);
180 anyFilterRefsPartiallyMapToInputs =
true;
186 +
" Any filter refs partially map to inputs: "
187 + anyFilterRefsPartiallyMapToInputs,
190 "# Filters to be pushed down: " + filtersToBePushedDown.size(), debugMode);
195 if (anyFilterRefsPartiallyMapToInputs) {
196 for (
Integer filterIdx = 0; filterIdx < numConjunctivePredicates; filterIdx++) {
197 if (subFiltersDidMapToAnyInputs.get(filterIdx)) {
198 failedFilterPushDowns[filterIdx]++;
201 debugPrint(
"Failed to push down input: " + inputRelIdx, debugMode);
202 newFuncInputs.add(funcInput);
204 if (filtersToBePushedDown.isEmpty()) {
205 debugPrint(
"No filters to push down: " + inputRelIdx, debugMode);
206 newFuncInputs.add(funcInput);
208 debugPrint(
"Func input at pushdown: " + funcInput, debugMode);
209 if (funcInput instanceof HepRelVertex
210 && ((HepRelVertex) funcInput).getCurrentRel()
211 instanceof LogicalFilter) {
212 debugPrint(
"Filter existed on input node", debugMode);
213 final HepRelVertex inputHepRelVertex = (HepRelVertex) funcInput;
214 final LogicalFilter inputFilter =
215 (LogicalFilter) (inputHepRelVertex.getCurrentRel());
216 final RexNode inputCondition = inputFilter.getCondition();
217 final List<RexNode> inputConjunctivePredicates =
218 RelOptUtil.conjunctions(inputCondition);
219 if (inputConjunctivePredicates.size() > 0) {
220 RexBuilder rexBuilder = filter.getCluster().getRexBuilder();
221 RexNode pushdownCondition = RexUtil.composeConjunction(
222 rexBuilder, filtersToBePushedDown,
false);
223 final RexNode newPushdownCondition = pushdownCondition.accept(
224 new RelOptUtil.RexInputConverter(rexBuilder,
228 final List<RexNode> newPushdownConjunctivePredicates =
229 RelOptUtil.conjunctions(newPushdownCondition);
230 final Integer numOriginalPushdownConjunctivePredicates =
231 newPushdownConjunctivePredicates.size();
232 debugPrint(
"Output predicates: " + newPushdownConjunctivePredicates,
234 debugPrint(
"Input predicates: " + inputConjunctivePredicates, debugMode);
235 newPushdownConjunctivePredicates.removeAll(inputConjunctivePredicates);
237 if (newPushdownConjunctivePredicates.isEmpty()) {
238 debugPrint(
"All filters existed on input node", debugMode);
239 newFuncInputs.add(funcInput);
242 if (newPushdownConjunctivePredicates.size()
243 < numOriginalPushdownConjunctivePredicates) {
244 debugPrint(
"Some predicates eliminated.", debugMode);
247 debugPrint(
"# Filters to be pushed down after prune: "
248 + filtersToBePushedDown.size(),
251 debugPrint(
"No filter detected on input node", debugMode);
254 RexBuilder rexBuilder = filter.getCluster().getRexBuilder();
255 RexNode pushdownCondition =
256 RexUtil.composeConjunction(rexBuilder, filtersToBePushedDown,
false);
258 debugPrint(
"Trying to push down filter", debugMode);
259 final RexNode newCondition =
260 pushdownCondition.accept(
new RelOptUtil.RexInputConverter(rexBuilder,
265 newFuncInputs.add(LogicalFilter.create(funcInput, newCondition));
266 for (
Integer pushedDownOutputIdx : uniquePushedDownOutputIdxs) {
267 outputColPushdownCount[pushedDownOutputIdx]++;
269 for (
Integer filterIdx = 0; filterIdx < numConjunctivePredicates;
271 if (subFiltersDidMapToAnyInputs.get(filterIdx)) {
272 successfulFilterPushDowns[filterIdx]++;
275 }
catch (java.lang.ArrayIndexOutOfBoundsException e) {
283 debugPrint(
"Did not push down - returning", debugMode);
287 List<RexNode> remainingFilters =
new ArrayList<>();
288 for (
Integer filterIdx = 0; filterIdx < numConjunctivePredicates; filterIdx++) {
289 if (successfulFilterPushDowns[filterIdx] == 0
290 || failedFilterPushDowns[filterIdx] > 0) {
291 remainingFilters.add(outputConjunctivePredicates.get(filterIdx));
295 debugPrint(
"Remaining filters: " + remainingFilters, debugMode);
297 LogicalTableFunctionScan newTableFuncRel = LogicalTableFunctionScan.create(cluster,
300 funcRel.getElementType(),
301 funcRel.getRowType(),
304 final RelBuilder relBuilder = call.builder();
305 relBuilder.push(newTableFuncRel);
306 if (!remainingFilters.isEmpty()) {
307 relBuilder.filter(remainingFilters);
309 final RelNode outputNode = relBuilder.build();
310 debugPrint(RelOptUtil.toString(outputNode), debugMode);
311 call.transformTo(outputNode);
316 System.out.println(msg);
323 EMPTY.withOperandSupplier(b0
324 -> b0.operand(LogicalFilter.class)
326 -> b1.operand(LogicalTableFunctionScan.class)