OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
RelRexDagVisitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /*
18  * Q: Why are std::arrays used, instead of std::unordered_maps to match type_index to
19  * their handlers?
20  *
21  * A: Since they are static variables, they should be trivially destructible. See
22  * https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
23  */
24 
25 #include "RelRexDagVisitor.h"
26 #include "Logger/Logger.h"
27 
28 #include <algorithm>
29 #include <typeindex>
30 
31 template <typename T, typename... Ts>
33  RelRexDagVisitor::Handlers<T, sizeof...(Ts)> handlers{
34  {{std::type_index(typeid(Ts)), &RelRexDagVisitor::cast<T, Ts>}...}};
35  std::sort(handlers.begin(), handlers.end());
36  return handlers;
37 }
38 
39 // RelAlgNode types
40 void RelRexDagVisitor::castAndVisit(RelAlgNode const* rel_alg_node) {
41  // Array that pairs std::type_index(typeid(*rel_alg_node)) -> method pointer.
42  static auto const handlers = make_handlers<RelAlgNode,
45  RelFilter,
46  RelJoin,
50  RelModify,
51  RelProject,
52  RelScan,
53  RelSort,
56  static_assert(std::is_trivially_destructible_v<decltype(handlers)>);
57  // Will throw std::bad_typeid if rel_alg_node == nullptr.
58  auto const& type_index = std::type_index(typeid(*rel_alg_node));
59  auto const itr = std::lower_bound(handlers.cbegin(), handlers.cend(), type_index);
60  if (itr != handlers.cend() && itr->type_index == type_index) {
61  (this->*itr->handler)(rel_alg_node);
62  } else {
63  LOG(FATAL) << "Unhandled RelAlgNode type: "
64  << rel_alg_node->toString(RelRexToStringConfig::defaults());
65  }
66 }
67 
68 void RelRexDagVisitor::visit(RelAlgNode const* rel_alg_node) {
69  if (cache_.emplace(static_cast<Cache::value_type>(rel_alg_node)).second) {
70  castAndVisit(rel_alg_node);
71  for (size_t i = 0; i < rel_alg_node->inputCount(); ++i) {
72  visit(rel_alg_node->getInput(i));
73  }
74  }
75 }
76 
77 void RelRexDagVisitor::visit(RelCompound const* rel_compound) {
78  if (rel_compound->getFilterExpr()) {
79  visit(rel_compound->getFilterExpr());
80  }
81  for (size_t i = 0; i < rel_compound->getScalarSourcesSize(); ++i) {
82  visit(rel_compound->getScalarSource(i));
83  }
84 }
85 
86 void RelRexDagVisitor::visit(RelFilter const* rel_filter) {
87  visit(rel_filter->getCondition());
88 }
89 
90 void RelRexDagVisitor::visit(RelJoin const* rel_join) {
91  visit(rel_join->getCondition());
92 }
93 
94 void RelRexDagVisitor::visit(RelLeftDeepInnerJoin const* rel_left_deep_inner_join) {
95  visit(rel_left_deep_inner_join->getInnerCondition());
96  for (size_t level = 1; level < rel_left_deep_inner_join->inputCount(); ++level) {
97  if (auto* outer_condition = rel_left_deep_inner_join->getOuterCondition(level)) {
98  visit(outer_condition);
99  }
100  }
101 }
102 
103 void RelRexDagVisitor::visit(RelLogicalValues const* rel_logical_values) {
104  for (size_t row_idx = 0; row_idx < rel_logical_values->getNumRows(); ++row_idx) {
105  for (size_t col_idx = 0; col_idx < rel_logical_values->getRowsSize(); ++col_idx) {
106  visit(rel_logical_values->getValueAt(row_idx, col_idx));
107  }
108  }
109 }
110 
111 void RelRexDagVisitor::visit(RelProject const* rel_projection) {
112  for (size_t i = 0; i < rel_projection->size(); ++i) {
113  visit(rel_projection->getProjectAt(i));
114  }
115 }
116 
117 void RelRexDagVisitor::visit(RelTableFunction const* rel_table_function) {
118  for (size_t i = 0; i < rel_table_function->getTableFuncInputsSize(); ++i) {
119  visit(rel_table_function->getTableFuncInputAt(i));
120  }
121 }
122 
123 void RelRexDagVisitor::visit(RelTranslatedJoin const* rel_translated_join) {
124  visit(rel_translated_join->getLHS());
125  visit(rel_translated_join->getRHS());
126  if (auto* outer_join_condition = rel_translated_join->getOuterJoinCond()) {
127  visit(outer_join_condition);
128  }
129 }
130 
131 // RexScalar types
132 void RelRexDagVisitor::visit(RexScalar const* rex_scalar) {
133  // Array that pairs std::type_index(typeid(*rex_scalar)) -> method pointer.
134  static auto const handlers = make_handlers<RexScalar,
136  RexCase,
138  RexInput,
139  RexLiteral,
140  RexOperator,
141  RexRef,
142  RexSubQuery,
144  static_assert(std::is_trivially_destructible_v<decltype(handlers)>);
145  if (cache_.emplace(static_cast<Cache::value_type>(rex_scalar)).second) {
146  // Will throw std::bad_typeid if rex_scalar == nullptr.
147  auto const& type_index = std::type_index(typeid(*rex_scalar));
148  auto const itr = std::lower_bound(handlers.cbegin(), handlers.cend(), type_index);
149  if (itr != handlers.cend() && itr->type_index == type_index) {
150  (this->*itr->handler)(rex_scalar);
151  } else {
152  LOG(FATAL) << "Unhandled RexScalar type: "
153  << rex_scalar->toString(RelRexToStringConfig::defaults());
154  }
155  }
156 }
157 
159  RexWindowFunctionOperator const* rex_window_function_operator) {
160  for (const auto& partition_key : rex_window_function_operator->getPartitionKeys()) {
161  visit(partition_key.get());
162  }
163  for (const auto& order_key : rex_window_function_operator->getOrderKeys()) {
164  visit(order_key.get());
165  }
166 }
167 
168 void RelRexDagVisitor::visit(RexCase const* rex_case) {
169  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
170  visit(rex_case->getWhen(i));
171  visit(rex_case->getThen(i));
172  }
173  if (rex_case->getElse()) {
174  visit(rex_case->getElse());
175  }
176 }
177 
178 void RelRexDagVisitor::visit(RexFunctionOperator const* rex_function_operator) {
179  for (size_t i = 0; i < rex_function_operator->size(); ++i) {
180  visit(rex_function_operator->getOperand(i));
181  }
182 }
183 
184 void RelRexDagVisitor::visit(RexOperator const* rex_operator) {
185  for (size_t i = 0; i < rex_operator->size(); ++i) {
186  visit(rex_operator->getOperand(i));
187  }
188 }
189 
190 void RelRexDagVisitor::visit(RexSubQuery const* rex_sub_query) {
191  visit(rex_sub_query->getRelAlg());
192 }
193 
194 void RelRexDagVisitor::visit(RexInput const* rex_input) {
195  visit(rex_input->getSourceNode());
196 }
const RexScalar * getThen(const size_t idx) const
Definition: RelAlgDag.h:440
virtual void visit(RelAlgNode const *)
size_t size() const override
Definition: RelAlgDag.h:1320
const RexScalar * getFilterExpr() const
Definition: RelAlgDag.h:2096
const RexScalar * getElse() const
Definition: RelAlgDag.h:445
const RexScalar * getOuterCondition(const size_t nesting_level) const
#define LOG(tag)
Definition: Logger.h:285
size_t size() const
Definition: RelAlgDag.h:364
const RexScalar * getOperand(const size_t idx) const
Definition: RelAlgDag.h:366
size_t getNumRows() const
Definition: RelAlgDag.h:2720
const RexScalar * getCondition() const
Definition: RelAlgDag.h:1898
DEVICE void sort(ARGS &&...args)
Definition: gpu_enabled.h:105
const RexScalar * getWhen(const size_t idx) const
Definition: RelAlgDag.h:435
const RexScalar * getCondition() const
Definition: RelAlgDag.h:1652
std::array< TypeHandler< RelRexDagVisitor, T >, N > Handlers
size_t getRowsSize() const
Definition: RelAlgDag.h:2712
const size_t getScalarSourcesSize() const
Definition: RelAlgDag.h:2110
const RelAlgNode * getRHS() const
Definition: RelAlgDag.h:1808
static Handlers< T, sizeof...(Ts)> make_handlers()
size_t branchCount() const
Definition: RelAlgDag.h:433
const RelAlgNode * getInput(const size_t idx) const
Definition: RelAlgDag.h:877
size_t getTableFuncInputsSize() const
Definition: RelAlgDag.h:2560
virtual std::string toString(RelRexToStringConfig config=RelRexToStringConfig::defaults()) const =0
const RexScalar * getProjectAt(const size_t idx) const
Definition: RelAlgDag.h:1352
const ConstRexScalarPtrVector & getPartitionKeys() const
Definition: RelAlgDag.h:643
virtual std::string toString(RelRexToStringConfig config=RelRexToStringConfig::defaults()) const =0
const RelAlgNode * getLHS() const
Definition: RelAlgDag.h:1807
DEVICE auto lower_bound(ARGS &&...args)
Definition: gpu_enabled.h:78
const RexScalar * getOuterJoinCond() const
Definition: RelAlgDag.h:1813
static RelRexToStringConfig defaults()
Definition: RelAlgDag.h:78
const RelAlgNode * getSourceNode() const
Definition: RelAlgDag.h:1056
void castAndVisit(RelAlgNode const *)
const RexScalar * getTableFuncInputAt(const size_t idx) const
Definition: RelAlgDag.h:2566
const RexScalar * getValueAt(const size_t row_idx, const size_t col_idx) const
Definition: RelAlgDag.h:2705
const RexScalar * getInnerCondition() const
const ConstRexScalarPtrVector & getOrderKeys() const
Definition: RelAlgDag.h:653
const size_t inputCount() const
Definition: RelAlgDag.h:875
const RexScalar * getScalarSource(const size_t i) const
Definition: RelAlgDag.h:2112