OmniSciDB  a987f07e93
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
QueryPhysicalInputsCollector.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include "RelAlgDag.h"
20 #include "RelAlgVisitor.h"
21 #include "RexVisitor.h"
23 
24 namespace {
25 
26 using PhysicalInputSet = std::unordered_set<PhysicalInput>;
27 
28 class RelAlgPhysicalInputsVisitor : public RelAlgVisitor<PhysicalInputSet> {
29  public:
30  PhysicalInputSet visitCompound(const RelCompound* compound) const override;
31  PhysicalInputSet visitFilter(const RelFilter* filter) const override;
32  PhysicalInputSet visitJoin(const RelJoin* join) const override;
33  PhysicalInputSet visitLeftDeepInnerJoin(const RelLeftDeepInnerJoin*) const override;
34  PhysicalInputSet visitProject(const RelProject* project) const override;
35  PhysicalInputSet visitSort(const RelSort* sort) const override;
36 
37  protected:
38  PhysicalInputSet aggregateResult(const PhysicalInputSet& aggregate,
39  const PhysicalInputSet& next_result) const override;
40 };
41 
42 class RexPhysicalInputsVisitor : public RexVisitor<PhysicalInputSet> {
43  public:
44  PhysicalInputSet visitInput(const RexInput* input) const override {
45  const auto source_ra = input->getSourceNode();
46  const auto scan_ra = dynamic_cast<const RelScan*>(source_ra);
47  if (!scan_ra) {
48  const auto join_ra = dynamic_cast<const RelJoin*>(source_ra);
49  if (join_ra) {
50  const auto node_inputs = get_node_output(join_ra);
51  CHECK_LT(input->getIndex(), node_inputs.size());
52  return visitInput(&node_inputs[input->getIndex()]);
53  }
54  return PhysicalInputSet{};
55  }
56  const auto scan_td = scan_ra->getTableDescriptor();
57  CHECK(scan_td);
58  const int col_id = input->getIndex() + 1;
59  const int table_id = scan_td->tableId;
60  CHECK_GT(table_id, 0);
61  return {{col_id, table_id}};
62  }
63 
64  PhysicalInputSet visitSubQuery(const RexSubQuery* subquery) const override {
65  const auto ra = subquery->getRelAlg();
66  CHECK(ra);
68  return visitor.visit(ra);
69  }
70 
71  PhysicalInputSet visitOperator(const RexOperator* oper) const override {
73  if (auto window_oper = dynamic_cast<const RexWindowFunctionOperator*>(oper)) {
74  for (const auto& partition_key : window_oper->getPartitionKeys()) {
75  if (auto input = dynamic_cast<const RexInput*>(partition_key.get())) {
76  const auto source_node = input->getSourceNode();
77  if (auto filter_node = dynamic_cast<const RelFilter*>(source_node)) {
78  // Partitions utilize string dictionary translation in the hash join framework
79  // if the partition key is a dictionary encoded string. Ensure we reach the
80  // source for all partition keys, so we can access string dictionaries for the
81  // partition keys while we build the partition (hash) table
82  CHECK_EQ(filter_node->inputCount(), size_t(1));
83  const auto parent_node = filter_node->getInput(0);
84  const auto node_inputs = get_node_output(parent_node);
85  CHECK_LT(input->getIndex(), node_inputs.size());
86  result = aggregateResult(result, visitInput(&node_inputs[input->getIndex()]));
87  }
88  result = aggregateResult(result, visit(input));
89  }
90  }
91  }
92  for (size_t i = 0; i < oper->size(); i++) {
93  result = aggregateResult(result, visit(oper->getOperand(i)));
94  }
95  return result;
96  }
97 
98  protected:
100  const PhysicalInputSet& next_result) const override {
101  auto result = aggregate;
102  result.insert(next_result.begin(), next_result.end());
103  return result;
104  }
105 };
106 
107 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitCompound(
108  const RelCompound* compound) const {
110  for (size_t i = 0; i < compound->getScalarSourcesSize(); ++i) {
111  const auto rex = compound->getScalarSource(i);
112  CHECK(rex);
113  RexPhysicalInputsVisitor visitor;
114  const auto rex_phys_inputs = visitor.visit(rex);
115  result.insert(rex_phys_inputs.begin(), rex_phys_inputs.end());
116  }
117  const auto filter = compound->getFilterExpr();
118  if (filter) {
119  RexPhysicalInputsVisitor visitor;
120  const auto filter_phys_inputs = visitor.visit(filter);
121  result.insert(filter_phys_inputs.begin(), filter_phys_inputs.end());
122  }
123  return result;
124 }
125 
126 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitFilter(const RelFilter* filter) const {
127  const auto condition = filter->getCondition();
128  CHECK(condition);
129  RexPhysicalInputsVisitor visitor;
130  return visitor.visit(condition);
131 }
132 
133 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitJoin(const RelJoin* join) const {
134  const auto condition = join->getCondition();
135  if (!condition) {
136  return PhysicalInputSet{};
137  }
138  RexPhysicalInputsVisitor visitor;
139  return visitor.visit(condition);
140 }
141 
142 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitLeftDeepInnerJoin(
143  const RelLeftDeepInnerJoin* left_deep_inner_join) const {
145  const auto condition = left_deep_inner_join->getInnerCondition();
146  RexPhysicalInputsVisitor visitor;
147  if (condition) {
148  result = visitor.visit(condition);
149  }
150  CHECK_GE(left_deep_inner_join->inputCount(), size_t(2));
151  for (size_t nesting_level = 1; nesting_level <= left_deep_inner_join->inputCount() - 1;
152  ++nesting_level) {
153  const auto outer_condition = left_deep_inner_join->getOuterCondition(nesting_level);
154  if (outer_condition) {
155  const auto outer_result = visitor.visit(outer_condition);
156  result.insert(outer_result.begin(), outer_result.end());
157  }
158  }
159  return result;
160 }
161 
162 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitProject(
163  const RelProject* project) const {
165  for (size_t i = 0; i < project->size(); ++i) {
166  const auto rex = project->getProjectAt(i);
167  CHECK(rex);
168  RexPhysicalInputsVisitor visitor;
169  const auto rex_phys_inputs = visitor.visit(rex);
170  result.insert(rex_phys_inputs.begin(), rex_phys_inputs.end());
171  }
172  return result;
173 }
174 
175 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitSort(const RelSort* sort) const {
176  CHECK_EQ(sort->inputCount(), size_t(1));
177  return visit(sort->getInput(0));
178 }
179 
180 PhysicalInputSet RelAlgPhysicalInputsVisitor::aggregateResult(
181  const PhysicalInputSet& aggregate,
182  const PhysicalInputSet& next_result) const {
183  auto result = aggregate;
184  result.insert(next_result.begin(), next_result.end());
185  return result;
186 }
187 
189  public:
191  using TableIds = std::unordered_set<int>;
192 
193  static TableIds getTableIds(RelAlgNode const* node) {
195  visitor.visit(node);
196  return std::move(visitor.table_ids_);
197  }
198 
199  private:
201 
202  void visit(RelScan const* scan) override {
203  table_ids_.insert(scan->getTableDescriptor()->tableId);
204  }
205 
206  // Only RelScan nodes have physical table ids, so we may prune any nodes that can never
207  // have a RelScan node as a descendant.
208  void visit(RelLogicalValues const*) override {}
209 };
210 
211 } // namespace
212 
213 std::unordered_set<PhysicalInput> get_physical_inputs(const RelAlgNode* ra) {
214  RelAlgPhysicalInputsVisitor phys_inputs_visitor;
215  return phys_inputs_visitor.visit(ra);
216 }
217 
218 std::unordered_set<int> get_physical_table_inputs(const RelAlgNode* ra) {
219  return RelAlgPhysicalTableInputsVisitor::getTableIds(ra);
220 }
221 
222 std::ostream& operator<<(std::ostream& os, PhysicalInput const& physical_input) {
223  return os << '(' << physical_input.col_id << ',' << physical_input.table_id << ')';
224 }
#define CHECK_EQ(x, y)
Definition: Logger.h:297
virtual void visit(RelAlgNode const *)
size_t size() const override
Definition: RelAlgDag.h:1156
const RexScalar * getFilterExpr() const
Definition: RelAlgDag.h:1826
const RexScalar * getOuterCondition(const size_t nesting_level) const
size_t size() const
Definition: RelAlgDag.h:270
std::ostream & operator<<(std::ostream &os, const SessionInfo &session_info)
Definition: SessionInfo.cpp:57
const RexScalar * getOperand(const size_t idx) const
Definition: RelAlgDag.h:272
const RexScalar * getCondition() const
Definition: RelAlgDag.h:1678
std::string join(T const &container, std::string const &delim)
DEVICE void sort(ARGS &&...args)
Definition: gpu_enabled.h:105
#define CHECK_GE(x, y)
Definition: Logger.h:302
const RexScalar * getCondition() const
Definition: RelAlgDag.h:1454
#define CHECK_GT(x, y)
Definition: Logger.h:301
PhysicalInputSet visitSubQuery(const RexSubQuery *subquery) const override
const size_t getScalarSourcesSize() const
Definition: RelAlgDag.h:1840
unsigned getIndex() const
Definition: RelAlgDag.h:77
PhysicalInputSet aggregateResult(const PhysicalInputSet &aggregate, const PhysicalInputSet &next_result) const override
T visit(const RelAlgNode *rel_alg) const
Definition: RelAlgVisitor.h:25
const RelAlgNode * getInput(const size_t idx) const
Definition: RelAlgDag.h:892
const RexScalar * getProjectAt(const size_t idx) const
Definition: RelAlgDag.h:1186
#define CHECK_LT(x, y)
Definition: Logger.h:299
const RelAlgNode * getSourceNode() const
Definition: RelAlgDag.h:389
const RexScalar * getInnerCondition() const
#define CHECK(condition)
Definition: Logger.h:289
Find out all the physical inputs (columns) a query is using.
std::unordered_set< PhysicalInput > get_physical_inputs(const RelAlgNode *ra)
const size_t inputCount() const
Definition: RelAlgDag.h:890
const TableDescriptor * getTableDescriptor() const
Definition: RelAlgDag.h:982
std::unordered_set< int > get_physical_table_inputs(const RelAlgNode *ra)
RANodeOutput get_node_output(const RelAlgNode *ra_node)
Definition: RelAlgDag.cpp:370
const RexScalar * getScalarSource(const size_t i) const
Definition: RelAlgDag.h:1842