OmniSciDB  d2f719934e
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
QueryPhysicalInputsCollector.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include "RelAlgDagBuilder.h"
20 #include "RelAlgVisitor.h"
21 #include "RexVisitor.h"
22 
23 namespace {
24 
25 using PhysicalInputSet = std::unordered_set<PhysicalInput>;
26 
27 class RelAlgPhysicalInputsVisitor : public RelAlgVisitor<PhysicalInputSet> {
28  public:
29  PhysicalInputSet visitCompound(const RelCompound* compound) const override;
30  PhysicalInputSet visitFilter(const RelFilter* filter) const override;
31  PhysicalInputSet visitJoin(const RelJoin* join) const override;
32  PhysicalInputSet visitLeftDeepInnerJoin(const RelLeftDeepInnerJoin*) const override;
33  PhysicalInputSet visitProject(const RelProject* project) const override;
34  PhysicalInputSet visitSort(const RelSort* sort) const override;
35 
36  protected:
37  PhysicalInputSet aggregateResult(const PhysicalInputSet& aggregate,
38  const PhysicalInputSet& next_result) const override;
39 };
40 
41 class RexPhysicalInputsVisitor : public RexVisitor<PhysicalInputSet> {
42  public:
43  PhysicalInputSet visitInput(const RexInput* input) const override {
44  const auto source_ra = input->getSourceNode();
45  const auto scan_ra = dynamic_cast<const RelScan*>(source_ra);
46  if (!scan_ra) {
47  const auto join_ra = dynamic_cast<const RelJoin*>(source_ra);
48  if (join_ra) {
49  const auto node_inputs = get_node_output(join_ra);
50  CHECK_LT(input->getIndex(), node_inputs.size());
51  return visitInput(&node_inputs[input->getIndex()]);
52  }
53  return PhysicalInputSet{};
54  }
55  const auto scan_td = scan_ra->getTableDescriptor();
56  CHECK(scan_td);
57  const int col_id = input->getIndex() + 1;
58  const int table_id = scan_td->tableId;
59  CHECK_GT(table_id, 0);
60  return {{col_id, table_id}};
61  }
62 
63  PhysicalInputSet visitSubQuery(const RexSubQuery* subquery) const override {
64  const auto ra = subquery->getRelAlg();
65  CHECK(ra);
67  return visitor.visit(ra);
68  }
69 
70  PhysicalInputSet visitOperator(const RexOperator* oper) const override {
72  if (auto window_oper = dynamic_cast<const RexWindowFunctionOperator*>(oper)) {
73  for (const auto& partition_key : window_oper->getPartitionKeys()) {
74  if (auto input = dynamic_cast<const RexInput*>(partition_key.get())) {
75  const auto source_node = input->getSourceNode();
76  if (auto filter_node = dynamic_cast<const RelFilter*>(source_node)) {
77  // Partitions utilize string dictionary translation in the hash join framework
78  // if the partition key is a dictionary encoded string. Ensure we reach the
79  // source for all partition keys, so we can access string dictionaries for the
80  // partition keys while we build the partition (hash) table
81  CHECK_EQ(filter_node->inputCount(), size_t(1));
82  const auto parent_node = filter_node->getInput(0);
83  const auto node_inputs = get_node_output(parent_node);
84  CHECK_LT(input->getIndex(), node_inputs.size());
85  result = aggregateResult(result, visitInput(&node_inputs[input->getIndex()]));
86  }
87  result = aggregateResult(result, visit(input));
88  }
89  }
90  }
91  for (size_t i = 0; i < oper->size(); i++) {
92  result = aggregateResult(result, visit(oper->getOperand(i)));
93  }
94  return result;
95  }
96 
97  protected:
99  const PhysicalInputSet& next_result) const override {
100  auto result = aggregate;
101  result.insert(next_result.begin(), next_result.end());
102  return result;
103  }
104 };
105 
106 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitCompound(
107  const RelCompound* compound) const {
109  for (size_t i = 0; i < compound->getScalarSourcesSize(); ++i) {
110  const auto rex = compound->getScalarSource(i);
111  CHECK(rex);
112  RexPhysicalInputsVisitor visitor;
113  const auto rex_phys_inputs = visitor.visit(rex);
114  result.insert(rex_phys_inputs.begin(), rex_phys_inputs.end());
115  }
116  const auto filter = compound->getFilterExpr();
117  if (filter) {
118  RexPhysicalInputsVisitor visitor;
119  const auto filter_phys_inputs = visitor.visit(filter);
120  result.insert(filter_phys_inputs.begin(), filter_phys_inputs.end());
121  }
122  return result;
123 }
124 
125 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitFilter(const RelFilter* filter) const {
126  const auto condition = filter->getCondition();
127  CHECK(condition);
128  RexPhysicalInputsVisitor visitor;
129  return visitor.visit(condition);
130 }
131 
132 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitJoin(const RelJoin* join) const {
133  const auto condition = join->getCondition();
134  if (!condition) {
135  return PhysicalInputSet{};
136  }
137  RexPhysicalInputsVisitor visitor;
138  return visitor.visit(condition);
139 }
140 
141 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitLeftDeepInnerJoin(
142  const RelLeftDeepInnerJoin* left_deep_inner_join) const {
144  const auto condition = left_deep_inner_join->getInnerCondition();
145  RexPhysicalInputsVisitor visitor;
146  if (condition) {
147  result = visitor.visit(condition);
148  }
149  CHECK_GE(left_deep_inner_join->inputCount(), size_t(2));
150  for (size_t nesting_level = 1; nesting_level <= left_deep_inner_join->inputCount() - 1;
151  ++nesting_level) {
152  const auto outer_condition = left_deep_inner_join->getOuterCondition(nesting_level);
153  if (outer_condition) {
154  const auto outer_result = visitor.visit(outer_condition);
155  result.insert(outer_result.begin(), outer_result.end());
156  }
157  }
158  return result;
159 }
160 
161 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitProject(
162  const RelProject* project) const {
164  for (size_t i = 0; i < project->size(); ++i) {
165  const auto rex = project->getProjectAt(i);
166  CHECK(rex);
167  RexPhysicalInputsVisitor visitor;
168  const auto rex_phys_inputs = visitor.visit(rex);
169  result.insert(rex_phys_inputs.begin(), rex_phys_inputs.end());
170  }
171  return result;
172 }
173 
174 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitSort(const RelSort* sort) const {
175  CHECK_EQ(sort->inputCount(), size_t(1));
176  return visit(sort->getInput(0));
177 }
178 
179 PhysicalInputSet RelAlgPhysicalInputsVisitor::aggregateResult(
180  const PhysicalInputSet& aggregate,
181  const PhysicalInputSet& next_result) const {
182  auto result = aggregate;
183  result.insert(next_result.begin(), next_result.end());
184  return result;
185 }
186 
187 class RelAlgPhysicalTableInputsVisitor : public RelAlgVisitor<std::unordered_set<int>> {
188  public:
189  std::unordered_set<int> visitScan(const RelScan* scan) const override {
190  return {scan->getTableDescriptor()->tableId};
191  }
192 
193  protected:
194  std::unordered_set<int> aggregateResult(
195  const std::unordered_set<int>& aggregate,
196  const std::unordered_set<int>& next_result) const override {
197  auto result = aggregate;
198  result.insert(next_result.begin(), next_result.end());
199  return result;
200  }
201 };
202 
203 } // namespace
204 
205 std::unordered_set<PhysicalInput> get_physical_inputs(const RelAlgNode* ra) {
206  RelAlgPhysicalInputsVisitor phys_inputs_visitor;
207  return phys_inputs_visitor.visit(ra);
208 }
209 
210 std::unordered_set<int> get_physical_table_inputs(const RelAlgNode* ra) {
211  RelAlgPhysicalTableInputsVisitor phys_table_inputs_visitor;
212  return phys_table_inputs_visitor.visit(ra);
213 }
214 
215 std::ostream& operator<<(std::ostream& os, PhysicalInput const& physical_input) {
216  return os << '(' << physical_input.col_id << ',' << physical_input.table_id << ')';
217 }
#define CHECK_EQ(x, y)
Definition: Logger.h:219
size_t size() const override
const RexScalar * getFilterExpr() const
const RexScalar * getOuterCondition(const size_t nesting_level) const
size_t size() const
std::ostream & operator<<(std::ostream &os, const SessionInfo &session_info)
Definition: SessionInfo.cpp:57
const RexScalar * getOperand(const size_t idx) const
const RexScalar * getCondition() const
std::string join(T const &container, std::string const &delim)
DEVICE void sort(ARGS &&...args)
Definition: gpu_enabled.h:105
#define CHECK_GE(x, y)
Definition: Logger.h:224
const RexScalar * getCondition() const
#define CHECK_GT(x, y)
Definition: Logger.h:223
PhysicalInputSet visitSubQuery(const RexSubQuery *subquery) const override
virtual T visit(const RexScalar *rex_scalar) const
Definition: RexVisitor.h:27
const size_t getScalarSourcesSize() const
unsigned getIndex() const
PhysicalInputSet aggregateResult(const PhysicalInputSet &aggregate, const PhysicalInputSet &next_result) const override
T visit(const RelAlgNode *rel_alg) const
Definition: RelAlgVisitor.h:25
const RelAlgNode * getInput(const size_t idx) const
const RexScalar * getProjectAt(const size_t idx) const
#define CHECK_LT(x, y)
Definition: Logger.h:221
const RelAlgNode * getSourceNode() const
const RexScalar * getInnerCondition() const
#define CHECK(condition)
Definition: Logger.h:211
RANodeOutput get_node_output(const RelAlgNode *ra_node)
std::unordered_set< PhysicalInput > get_physical_inputs(const RelAlgNode *ra)
const size_t inputCount() const
const TableDescriptor * getTableDescriptor() const
std::unordered_set< int > get_physical_table_inputs(const RelAlgNode *ra)
std::unordered_set< int > aggregateResult(const std::unordered_set< int > &aggregate, const std::unordered_set< int > &next_result) const override
const RexScalar * getScalarSource(const size_t i) const