OmniSciDB  1dac507f6e
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
QueryPhysicalInputsCollector.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
20 #include "RelAlgVisitor.h"
21 #include "RexVisitor.h"
22 
23 namespace {
24 
25 using PhysicalInputSet = std::unordered_set<PhysicalInput>;
26 
27 class RelAlgPhysicalInputsVisitor : public RelAlgVisitor<PhysicalInputSet> {
28  public:
29  PhysicalInputSet visitCompound(const RelCompound* compound) const override;
30  PhysicalInputSet visitFilter(const RelFilter* filter) const override;
31  PhysicalInputSet visitJoin(const RelJoin* join) const override;
32  PhysicalInputSet visitLeftDeepInnerJoin(const RelLeftDeepInnerJoin*) const override;
33  PhysicalInputSet visitProject(const RelProject* project) const override;
34  PhysicalInputSet visitSort(const RelSort* sort) const override;
35 
36  protected:
37  PhysicalInputSet aggregateResult(const PhysicalInputSet& aggregate,
38  const PhysicalInputSet& next_result) const override;
39 };
40 
41 class RexPhysicalInputsVisitor : public RexVisitor<PhysicalInputSet> {
42  public:
43  PhysicalInputSet visitInput(const RexInput* input) const override {
44  const auto source_ra = input->getSourceNode();
45  const auto scan_ra = dynamic_cast<const RelScan*>(source_ra);
46  if (!scan_ra) {
47  const auto join_ra = dynamic_cast<const RelJoin*>(source_ra);
48  if (join_ra) {
49  const auto node_inputs = get_node_output(join_ra);
50  CHECK_LT(input->getIndex(), node_inputs.size());
51  return visitInput(&node_inputs[input->getIndex()]);
52  }
53  return PhysicalInputSet{};
54  }
55  const auto scan_td = scan_ra->getTableDescriptor();
56  CHECK(scan_td);
57  const int col_id = input->getIndex() + 1;
58  const int table_id = scan_td->tableId;
59  CHECK_GT(table_id, 0);
60  return {{col_id, table_id}};
61  }
62 
63  PhysicalInputSet visitSubQuery(const RexSubQuery* subquery) const override {
64  const auto ra = subquery->getRelAlg();
65  CHECK(ra);
67  return visitor.visit(ra);
68  }
69 
70  PhysicalInputSet visitOperator(const RexOperator* oper) const override {
72  if (auto window_oper = dynamic_cast<const RexWindowFunctionOperator*>(oper)) {
73  for (const auto& partition_key : window_oper->getPartitionKeys()) {
74  if (auto input = dynamic_cast<const RexInput*>(partition_key.get())) {
75  const auto source_node = input->getSourceNode();
76  if (auto filter_node = dynamic_cast<const RelFilter*>(source_node)) {
77  // Partitions utilize string dictionary translation in the hash join framework
78  // if the partition key is a dictionary encoded string. Ensure we reach the
79  // source for all partition keys, so we can access string dictionaries for the
80  // partition keys while we build the partition (hash) table
81  CHECK_EQ(filter_node->inputCount(), size_t(1));
82  const auto parent_node = filter_node->getInput(0);
83  const auto node_inputs = get_node_output(parent_node);
84  CHECK_LT(input->getIndex(), node_inputs.size());
85  result = aggregateResult(result, visitInput(&node_inputs[input->getIndex()]));
86  }
87  result = aggregateResult(result, visit(input));
88  }
89  }
90  return result;
91  }
92  for (size_t i = 0; i < oper->size(); i++) {
93  result = aggregateResult(result, visit(oper->getOperand(i)));
94  }
95  return result;
96  }
97 
98  protected:
100  const PhysicalInputSet& next_result) const override {
101  auto result = aggregate;
102  result.insert(next_result.begin(), next_result.end());
103  return result;
104  }
105 };
106 
107 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitCompound(
108  const RelCompound* compound) const {
110  for (size_t i = 0; i < compound->getScalarSourcesSize(); ++i) {
111  const auto rex = compound->getScalarSource(i);
112  CHECK(rex);
113  RexPhysicalInputsVisitor visitor;
114  const auto rex_phys_inputs = visitor.visit(rex);
115  result.insert(rex_phys_inputs.begin(), rex_phys_inputs.end());
116  }
117  const auto filter = compound->getFilterExpr();
118  if (filter) {
119  RexPhysicalInputsVisitor visitor;
120  const auto filter_phys_inputs = visitor.visit(filter);
121  result.insert(filter_phys_inputs.begin(), filter_phys_inputs.end());
122  }
123  return result;
124 }
125 
126 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitFilter(const RelFilter* filter) const {
127  const auto condition = filter->getCondition();
128  CHECK(condition);
129  RexPhysicalInputsVisitor visitor;
130  return visitor.visit(condition);
131 }
132 
133 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitJoin(const RelJoin* join) const {
134  const auto condition = join->getCondition();
135  if (!condition) {
136  return PhysicalInputSet{};
137  }
138  RexPhysicalInputsVisitor visitor;
139  return visitor.visit(condition);
140 }
141 
142 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitLeftDeepInnerJoin(
143  const RelLeftDeepInnerJoin* left_deep_inner_join) const {
145  const auto condition = left_deep_inner_join->getInnerCondition();
146  RexPhysicalInputsVisitor visitor;
147  if (condition) {
148  result = visitor.visit(condition);
149  }
150  CHECK_GE(left_deep_inner_join->inputCount(), size_t(2));
151  for (size_t nesting_level = 1; nesting_level <= left_deep_inner_join->inputCount() - 1;
152  ++nesting_level) {
153  const auto outer_condition = left_deep_inner_join->getOuterCondition(nesting_level);
154  if (outer_condition) {
155  const auto outer_result = visitor.visit(outer_condition);
156  result.insert(outer_result.begin(), outer_result.end());
157  }
158  }
159  return result;
160 }
161 
162 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitProject(
163  const RelProject* project) const {
165  for (size_t i = 0; i < project->size(); ++i) {
166  const auto rex = project->getProjectAt(i);
167  CHECK(rex);
168  RexPhysicalInputsVisitor visitor;
169  const auto rex_phys_inputs = visitor.visit(rex);
170  result.insert(rex_phys_inputs.begin(), rex_phys_inputs.end());
171  }
172  return result;
173 }
174 
175 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitSort(const RelSort* sort) const {
176  CHECK_EQ(sort->inputCount(), size_t(1));
177  return visit(sort->getInput(0));
178 }
179 
180 PhysicalInputSet RelAlgPhysicalInputsVisitor::aggregateResult(
181  const PhysicalInputSet& aggregate,
182  const PhysicalInputSet& next_result) const {
183  auto result = aggregate;
184  result.insert(next_result.begin(), next_result.end());
185  return result;
186 }
187 
188 class RelAlgPhysicalTableInputsVisitor : public RelAlgVisitor<std::unordered_set<int>> {
189  public:
190  std::unordered_set<int> visitScan(const RelScan* scan) const override {
191  return {scan->getTableDescriptor()->tableId};
192  }
193 
194  protected:
195  std::unordered_set<int> aggregateResult(
196  const std::unordered_set<int>& aggregate,
197  const std::unordered_set<int>& next_result) const override {
198  auto result = aggregate;
199  result.insert(next_result.begin(), next_result.end());
200  return result;
201  }
202 };
203 
204 } // namespace
205 
206 std::unordered_set<PhysicalInput> get_physical_inputs(const RelAlgNode* ra) {
207  RelAlgPhysicalInputsVisitor phys_inputs_visitor;
208  return phys_inputs_visitor.visit(ra);
209 }
210 
211 std::unordered_set<int> get_physical_table_inputs(const RelAlgNode* ra) {
212  RelAlgPhysicalTableInputsVisitor phys_table_inputs_visitor;
213  return phys_table_inputs_visitor.visit(ra);
214 }
#define CHECK_EQ(x, y)
Definition: Logger.h:198
size_t size() const override
const RexScalar * getFilterExpr() const
const RexScalar * getOuterCondition(const size_t nesting_level) const
size_t size() const
const RexScalar * getOperand(const size_t idx) const
const RexScalar * getCondition() const
std::string join(T const &container, std::string const &delim)
#define CHECK_GE(x, y)
Definition: Logger.h:203
const RexScalar * getCondition() const
#define CHECK_GT(x, y)
Definition: Logger.h:202
PhysicalInputSet visitSubQuery(const RexSubQuery *subquery) const override
virtual T visit(const RexScalar *rex_scalar) const
Definition: RexVisitor.h:27
const size_t getScalarSourcesSize() const
unsigned getIndex() const
CHECK(cgen_state)
PhysicalInputSet aggregateResult(const PhysicalInputSet &aggregate, const PhysicalInputSet &next_result) const override
T visit(const RelAlgNode *rel_alg) const
Definition: RelAlgVisitor.h:25
const RelAlgNode * getInput(const size_t idx) const
const RexScalar * getProjectAt(const size_t idx) const
#define CHECK_LT(x, y)
Definition: Logger.h:200
const RelAlgNode * getSourceNode() const
RANodeOutput get_node_output(const RelAlgNode *ra_node)
const RexScalar * getInnerCondition() const
std::unordered_set< PhysicalInput > get_physical_inputs(const RelAlgNode *ra)
const size_t inputCount() const
const TableDescriptor * getTableDescriptor() const
std::unordered_set< int > get_physical_table_inputs(const RelAlgNode *ra)
std::unordered_set< int > aggregateResult(const std::unordered_set< int > &aggregate, const std::unordered_set< int > &next_result) const override
const RexScalar * getScalarSource(const size_t i) const