OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
QueryPhysicalInputsCollector.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include "RelAlgDag.h"
20 #include "RelAlgVisitor.h"
21 #include "RexVisitor.h"
22 #include "Shared/misc.h"
24 
25 namespace {
26 
27 using PhysicalInputSet = std::unordered_set<PhysicalInput>;
28 
29 class RelAlgPhysicalInputsVisitor : public RelAlgVisitor<PhysicalInputSet> {
30  public:
31  PhysicalInputSet visitCompound(const RelCompound* compound) const override;
32  PhysicalInputSet visitFilter(const RelFilter* filter) const override;
33  PhysicalInputSet visitJoin(const RelJoin* join) const override;
34  PhysicalInputSet visitLeftDeepInnerJoin(const RelLeftDeepInnerJoin*) const override;
35  PhysicalInputSet visitProject(const RelProject* project) const override;
36  PhysicalInputSet visitSort(const RelSort* sort) const override;
37 
38  protected:
39  PhysicalInputSet aggregateResult(const PhysicalInputSet& aggregate,
40  const PhysicalInputSet& next_result) const override;
41 };
42 
43 class RexPhysicalInputsVisitor : public RexVisitor<PhysicalInputSet> {
44  public:
45  PhysicalInputSet visitInput(const RexInput* input) const override {
46  const auto source_ra = input->getSourceNode();
47  const auto scan_ra = dynamic_cast<const RelScan*>(source_ra);
48  if (!scan_ra) {
49  const auto join_ra = dynamic_cast<const RelJoin*>(source_ra);
50  if (join_ra) {
51  const auto node_inputs = get_node_output(join_ra);
52  CHECK_LT(input->getIndex(), node_inputs.size());
53  return visitInput(&node_inputs[input->getIndex()]);
54  }
55  return PhysicalInputSet{};
56  }
57  const auto scan_td = scan_ra->getTableDescriptor();
58  CHECK(scan_td);
59  const int col_id = input->getIndex() + 1;
60  const int table_id = scan_td->tableId;
61  CHECK_GT(table_id, 0);
62  auto db_id = scan_ra->getCatalog().getDatabaseId();
63  return {{col_id, table_id, db_id}};
64  }
65 
66  PhysicalInputSet visitSubQuery(const RexSubQuery* subquery) const override {
67  const auto ra = subquery->getRelAlg();
68  CHECK(ra);
70  return visitor.visit(ra);
71  }
72 
73  PhysicalInputSet visitOperator(const RexOperator* oper) const override {
75  if (auto window_oper = dynamic_cast<const RexWindowFunctionOperator*>(oper)) {
76  for (const auto& partition_key : window_oper->getPartitionKeys()) {
77  if (auto input = dynamic_cast<const RexInput*>(partition_key.get())) {
78  const auto source_node = input->getSourceNode();
79  if (auto filter_node = dynamic_cast<const RelFilter*>(source_node)) {
80  // Partitions utilize string dictionary translation in the hash join framework
81  // if the partition key is a dictionary encoded string. Ensure we reach the
82  // source for all partition keys, so we can access string dictionaries for the
83  // partition keys while we build the partition (hash) table
84  CHECK_EQ(filter_node->inputCount(), size_t(1));
85  const auto parent_node = filter_node->getInput(0);
86  const auto node_inputs = get_node_output(parent_node);
87  CHECK_LT(input->getIndex(), node_inputs.size());
88  result = aggregateResult(result, visitInput(&node_inputs[input->getIndex()]));
89  }
90  result = aggregateResult(result, visit(input));
91  }
92  }
93  }
94  for (size_t i = 0; i < oper->size(); i++) {
95  result = aggregateResult(result, visit(oper->getOperand(i)));
96  }
97  return result;
98  }
99 
100  protected:
102  const PhysicalInputSet& next_result) const override {
103  auto result = aggregate;
104  result.insert(next_result.begin(), next_result.end());
105  return result;
106  }
107 };
108 
109 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitCompound(
110  const RelCompound* compound) const {
112  for (size_t i = 0; i < compound->getScalarSourcesSize(); ++i) {
113  const auto rex = compound->getScalarSource(i);
114  CHECK(rex);
115  RexPhysicalInputsVisitor visitor;
116  const auto rex_phys_inputs = visitor.visit(rex);
117  result.insert(rex_phys_inputs.begin(), rex_phys_inputs.end());
118  }
119  const auto filter = compound->getFilterExpr();
120  if (filter) {
121  RexPhysicalInputsVisitor visitor;
122  const auto filter_phys_inputs = visitor.visit(filter);
123  result.insert(filter_phys_inputs.begin(), filter_phys_inputs.end());
124  }
125  return result;
126 }
127 
128 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitFilter(const RelFilter* filter) const {
129  const auto condition = filter->getCondition();
130  CHECK(condition);
131  RexPhysicalInputsVisitor visitor;
132  return visitor.visit(condition);
133 }
134 
135 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitJoin(const RelJoin* join) const {
136  const auto condition = join->getCondition();
137  if (!condition) {
138  return PhysicalInputSet{};
139  }
140  RexPhysicalInputsVisitor visitor;
141  return visitor.visit(condition);
142 }
143 
144 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitLeftDeepInnerJoin(
145  const RelLeftDeepInnerJoin* left_deep_inner_join) const {
147  const auto condition = left_deep_inner_join->getInnerCondition();
148  RexPhysicalInputsVisitor visitor;
149  if (condition) {
150  result = visitor.visit(condition);
151  }
152  CHECK_GE(left_deep_inner_join->inputCount(), size_t(2));
153  for (size_t nesting_level = 1; nesting_level <= left_deep_inner_join->inputCount() - 1;
154  ++nesting_level) {
155  const auto outer_condition = left_deep_inner_join->getOuterCondition(nesting_level);
156  if (outer_condition) {
157  const auto outer_result = visitor.visit(outer_condition);
158  result.insert(outer_result.begin(), outer_result.end());
159  }
160  }
161  return result;
162 }
163 
164 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitProject(
165  const RelProject* project) const {
167  for (size_t i = 0; i < project->size(); ++i) {
168  const auto rex = project->getProjectAt(i);
169  CHECK(rex);
170  RexPhysicalInputsVisitor visitor;
171  const auto rex_phys_inputs = visitor.visit(rex);
172  result.insert(rex_phys_inputs.begin(), rex_phys_inputs.end());
173  }
174  return result;
175 }
176 
177 PhysicalInputSet RelAlgPhysicalInputsVisitor::visitSort(const RelSort* sort) const {
178  CHECK_EQ(sort->inputCount(), size_t(1));
179  return visit(sort->getInput(0));
180 }
181 
182 PhysicalInputSet RelAlgPhysicalInputsVisitor::aggregateResult(
183  const PhysicalInputSet& aggregate,
184  const PhysicalInputSet& next_result) const {
185  auto result = aggregate;
186  result.insert(next_result.begin(), next_result.end());
187  return result;
188 }
189 
191  public:
193  using TableKeys = std::unordered_set<shared::TableKey>;
194 
195  static TableKeys getTableIds(RelAlgNode const* node) {
197  visitor.visit(node);
198  return std::move(visitor.table_ids_);
199  }
200 
201  private:
203 
204  void visit(RelScan const* scan) override {
205  table_ids_.insert(
206  {scan->getCatalog().getDatabaseId(), scan->getTableDescriptor()->tableId});
207  }
208 
209  // Only RelScan nodes have physical table ids, so we may prune any nodes that can never
210  // have a RelScan node as a descendant.
211  void visit(RelLogicalValues const*) override {}
212 };
213 
214 } // namespace
215 
216 std::unordered_set<PhysicalInput> get_physical_inputs(const RelAlgNode* ra) {
217  RelAlgPhysicalInputsVisitor phys_inputs_visitor;
218  return phys_inputs_visitor.visit(ra);
219 }
220 
221 std::unordered_set<shared::TableKey> get_physical_table_inputs(const RelAlgNode* ra) {
222  return RelAlgPhysicalTableInputsVisitor::getTableIds(ra);
223 }
224 
225 std::ostream& operator<<(std::ostream& os, PhysicalInput const& physical_input) {
226  return os << '(' << physical_input.col_id << ',' << physical_input.table_id << ','
227  << physical_input.db_id << ')';
228 }
229 
230 size_t PhysicalInput::hash() const {
232  return hash_value;
233 }
234 
235 bool PhysicalInput::operator==(const PhysicalInput& that) const {
236  return col_id == that.col_id && table_id == that.table_id && db_id == that.db_id;
237 }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
virtual void visit(RelAlgNode const *)
size_t size() const override
Definition: RelAlgDag.h:1320
const RexScalar * getFilterExpr() const
Definition: RelAlgDag.h:2096
const Catalog_Namespace::Catalog & getCatalog() const
Definition: RelAlgDag.h:1119
bool operator==(const PhysicalInput &that) const
const RexScalar * getOuterCondition(const size_t nesting_level) const
size_t size() const
Definition: RelAlgDag.h:364
std::ostream & operator<<(std::ostream &os, const SessionInfo &session_info)
Definition: SessionInfo.cpp:57
const RexScalar * getOperand(const size_t idx) const
Definition: RelAlgDag.h:366
const RexScalar * getCondition() const
Definition: RelAlgDag.h:1898
std::string join(T const &container, std::string const &delim)
DEVICE void sort(ARGS &&...args)
Definition: gpu_enabled.h:105
#define CHECK_GE(x, y)
Definition: Logger.h:306
const RexScalar * getCondition() const
Definition: RelAlgDag.h:1652
#define CHECK_GT(x, y)
Definition: Logger.h:305
PhysicalInputSet visitSubQuery(const RexSubQuery *subquery) const override
const size_t getScalarSourcesSize() const
Definition: RelAlgDag.h:2110
unsigned getIndex() const
Definition: RelAlgDag.h:174
PhysicalInputSet aggregateResult(const PhysicalInputSet &aggregate, const PhysicalInputSet &next_result) const override
T visit(const RelAlgNode *rel_alg) const
Definition: RelAlgVisitor.h:25
const RelAlgNode * getInput(const size_t idx) const
Definition: RelAlgDag.h:877
int getDatabaseId() const
Definition: Catalog.h:326
const RexScalar * getProjectAt(const size_t idx) const
Definition: RelAlgDag.h:1352
#define CHECK_LT(x, y)
Definition: Logger.h:303
std::unordered_set< shared::TableKey > get_physical_table_inputs(const RelAlgNode *ra)
const RelAlgNode * getSourceNode() const
Definition: RelAlgDag.h:1056
std::size_t hash_value(RexAbstractInput const &rex_ab_input)
Definition: RelAlgDag.cpp:3525
const RexScalar * getInnerCondition() const
#define CHECK(condition)
Definition: Logger.h:291
Find out all the physical inputs (columns) a query is using.
std::unordered_set< PhysicalInput > get_physical_inputs(const RelAlgNode *ra)
size_t compute_hash(int32_t item_1, int32_t item_2)
Definition: misc.cpp:141
const size_t inputCount() const
Definition: RelAlgDag.h:875
const TableDescriptor * getTableDescriptor() const
Definition: RelAlgDag.h:1117
RANodeOutput get_node_output(const RelAlgNode *ra_node)
Definition: RelAlgDag.cpp:371
const RexScalar * getScalarSource(const size_t i) const
Definition: RelAlgDag.h:2112