OmniSciDB  cde582ebc3
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
RexVisitor.h
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef QUERYENGINE_REXVISITOR_H
18 #define QUERYENGINE_REXVISITOR_H
19 
20 #include "RelAlgDag.h"
21 
22 #include <memory>
23 
24 template <class T>
26  public:
27  virtual T visit(const RexScalar* rex_scalar) const {
28  CHECK(rex_scalar);
29  const auto rex_input = dynamic_cast<const RexInput*>(rex_scalar);
30  if (rex_input) {
31  return visitInput(rex_input);
32  }
33  const auto rex_literal = dynamic_cast<const RexLiteral*>(rex_scalar);
34  if (rex_literal) {
35  return visitLiteral(rex_literal);
36  }
37  const auto rex_subquery = dynamic_cast<const RexSubQuery*>(rex_scalar);
38  if (rex_subquery) {
39  return visitSubQuery(rex_subquery);
40  }
41  const auto rex_operator = dynamic_cast<const RexOperator*>(rex_scalar);
42  if (rex_operator) {
43  return visitOperator(rex_operator);
44  }
45  const auto rex_case = dynamic_cast<const RexCase*>(rex_scalar);
46  if (rex_case) {
47  return visitCase(rex_case);
48  }
49  const auto rex_ref = dynamic_cast<const RexRef*>(rex_scalar);
50  if (rex_ref) {
51  return visitRef(rex_ref);
52  }
53  LOG(FATAL) << "No visit path for "
54  << rex_scalar->toString(RelRexToStringConfig::defaults());
55  return defaultResult();
56  }
57 
58  virtual T visitInput(const RexInput*) const = 0;
59 
60  virtual T visitLiteral(const RexLiteral*) const = 0;
61 
62  virtual T visitSubQuery(const RexSubQuery*) const = 0;
63 
64  virtual T visitRef(const RexRef*) const = 0;
65 
66  virtual T visitOperator(const RexOperator* rex_operator) const = 0;
67 
68  virtual T visitCase(const RexCase* rex_case) const = 0;
69 
70  protected:
71  virtual T defaultResult() const = 0;
72 };
73 
74 template <class T>
75 class RexVisitor : public RexVisitorBase<T> {
76  public:
77  T visitInput(const RexInput*) const override { return defaultResult(); }
78 
79  T visitLiteral(const RexLiteral*) const override { return defaultResult(); }
80 
81  T visitSubQuery(const RexSubQuery*) const override { return defaultResult(); }
82 
83  T visitRef(const RexRef*) const override { return defaultResult(); }
84 
85  T visitOperator(const RexOperator* rex_operator) const override {
86  const size_t operand_count = rex_operator->size();
88  for (size_t i = 0; i < operand_count; ++i) {
89  const auto operand = rex_operator->getOperand(i);
90  T operandResult = RexVisitorBase<T>::visit(operand);
91  result = aggregateResult(result, operandResult);
92  }
93  const auto rex_window_func_operator =
94  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
95  if (rex_window_func_operator) {
96  return visitWindowFunctionOperator(rex_window_func_operator, result);
97  }
98  return result;
99  }
100 
101  T visitCase(const RexCase* rex_case) const override {
102  T result = defaultResult();
103  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
104  const auto when = rex_case->getWhen(i);
105  result = aggregateResult(result, RexVisitorBase<T>::visit(when));
106  const auto then = rex_case->getThen(i);
107  result = aggregateResult(result, RexVisitorBase<T>::visit(then));
108  }
109  if (rex_case->getElse()) {
110  result = aggregateResult(result, RexVisitorBase<T>::visit(rex_case->getElse()));
111  }
112  return result;
113  }
114 
115  protected:
116  virtual T aggregateResult(const T& aggregate, const T& next_result) const {
117  return next_result;
118  }
119 
120  T defaultResult() const override { return T{}; }
121 
122  private:
123  T visitWindowFunctionOperator(const RexWindowFunctionOperator* rex_window_func_operator,
124  const T operands_visit_result) const {
125  T result = operands_visit_result;
126  for (const auto& key : rex_window_func_operator->getPartitionKeys()) {
127  T partial_result = RexVisitorBase<T>::visit(key.get());
128  result = aggregateResult(result, partial_result);
129  }
130  for (const auto& key : rex_window_func_operator->getOrderKeys()) {
131  T partial_result = RexVisitorBase<T>::visit(key.get());
132  result = aggregateResult(result, partial_result);
133  }
134  return result;
135  }
136 };
137 
138 class RexDeepCopyVisitor : public RexVisitorBase<std::unique_ptr<const RexScalar>> {
139  protected:
140  using RetType = std::unique_ptr<const RexScalar>;
141 
142  RetType visitInput(const RexInput* input) const override { return input->deepCopy(); }
143 
144  RetType visitLiteral(const RexLiteral* literal) const override {
145  return literal->deepCopy();
146  }
147 
148  RetType visitSubQuery(const RexSubQuery* subquery) const override {
149  return subquery->deepCopy();
150  }
151 
152  RetType visitRef(const RexRef* ref) const override { return ref->deepCopy(); }
153 
154  RetType visitOperator(const RexOperator* rex_operator) const override {
155  const auto rex_window_function_operator =
156  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
157  if (rex_window_function_operator) {
158  return visitWindowFunctionOperator(rex_window_function_operator);
159  }
160 
161  const size_t operand_count = rex_operator->size();
162  std::vector<RetType> new_opnds;
163  for (size_t i = 0; i < operand_count; ++i) {
164  new_opnds.push_back(visit(rex_operator->getOperand(i)));
165  }
166  return rex_operator->getDisambiguated(new_opnds);
167  }
168 
170  const RexWindowFunctionOperator* rex_window_function_operator) const {
171  const size_t operand_count = rex_window_function_operator->size();
172  std::vector<RetType> new_opnds;
173  for (size_t i = 0; i < operand_count; ++i) {
174  new_opnds.push_back(visit(rex_window_function_operator->getOperand(i)));
175  }
176 
177  const auto& partition_keys = rex_window_function_operator->getPartitionKeys();
178  std::vector<std::unique_ptr<const RexScalar>> disambiguated_partition_keys;
179  for (const auto& partition_key : partition_keys) {
180  disambiguated_partition_keys.emplace_back(visit(partition_key.get()));
181  }
182  std::vector<std::unique_ptr<const RexScalar>> disambiguated_order_keys;
183  const auto& order_keys = rex_window_function_operator->getOrderKeys();
184  for (const auto& order_key : order_keys) {
185  disambiguated_order_keys.emplace_back(visit(order_key.get()));
186  }
187  return rex_window_function_operator->disambiguatedOperands(
188  new_opnds,
189  disambiguated_partition_keys,
190  disambiguated_order_keys,
191  rex_window_function_operator->getCollation());
192  }
193 
194  RetType visitCase(const RexCase* rex_case) const override {
195  std::vector<std::pair<RetType, RetType>> new_pair_list;
196  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
197  new_pair_list.emplace_back(visit(rex_case->getWhen(i)),
198  visit(rex_case->getThen(i)));
199  }
200  auto new_else = visit(rex_case->getElse());
201  return std::make_unique<RexCase>(new_pair_list, new_else);
202  }
203 
204  private:
205  RetType defaultResult() const override { return nullptr; }
206 
207  public:
208  using RowValues = std::vector<std::unique_ptr<const RexScalar>>;
209 
210  static std::vector<RowValues> copy(std::vector<RowValues> const& rhs) {
211  RexDeepCopyVisitor copier;
212  std::vector<RowValues> retval;
213  retval.reserve(rhs.size());
214  for (auto const& row : rhs) {
215  retval.push_back({});
216  retval.back().reserve(row.size());
217  for (auto const& value : row) {
218  retval.back().push_back(copier.visit(value.get()));
219  }
220  }
221  return retval;
222  }
223 };
224 
225 template <bool bAllowMissing>
227  public:
228  RexInputRenumber(const std::unordered_map<size_t, size_t>& new_numbering)
229  : old_to_new_idx_(new_numbering) {}
230  RetType visitInput(const RexInput* input) const override {
231  auto renum_it = old_to_new_idx_.find(input->getIndex());
232  if (bAllowMissing) {
233  if (renum_it != old_to_new_idx_.end()) {
234  return std::make_unique<RexInput>(input->getSourceNode(), renum_it->second);
235  } else {
236  return input->deepCopy();
237  }
238  } else {
239  CHECK(renum_it != old_to_new_idx_.end());
240  return std::make_unique<RexInput>(input->getSourceNode(), renum_it->second);
241  }
242  }
243 
244  private:
245  const std::unordered_map<size_t, size_t>& old_to_new_idx_;
246 };
247 
248 #endif // QUERYENGINE_REXVISITOR_H
virtual std::string toString(RelRexToStringConfig config) const =0
const RexScalar * getThen(const size_t idx) const
Definition: RelAlgDag.h:400
T defaultResult() const override
Definition: RexVisitor.h:120
std::vector< std::unique_ptr< const RexScalar >> RowValues
Definition: RexVisitor.h:208
T visitCase(const RexCase *rex_case) const override
Definition: RexVisitor.h:101
const RexScalar * getElse() const
Definition: RelAlgDag.h:405
std::unique_ptr< RexRef > deepCopy() const
Definition: RelAlgDag.h:708
#define LOG(tag)
Definition: Logger.h:216
size_t size() const
Definition: RelAlgDag.h:245
const RexScalar * getOperand(const size_t idx) const
Definition: RelAlgDag.h:247
RetType visitInput(const RexInput *input) const override
Definition: RexVisitor.h:230
const std::vector< SortField > & getCollation() const
Definition: RelAlgDag.h:600
virtual T visitRef(const RexRef *) const =0
RetType visitOperator(const RexOperator *rex_operator) const override
Definition: RexVisitor.h:154
virtual T visitLiteral(const RexLiteral *) const =0
T visitOperator(const RexOperator *rex_operator) const override
Definition: RexVisitor.h:85
const RexScalar * getWhen(const size_t idx) const
Definition: RelAlgDag.h:395
RexInputRenumber(const std::unordered_map< size_t, size_t > &new_numbering)
Definition: RexVisitor.h:228
virtual T visitCase(const RexCase *rex_case) const =0
RetType visitWindowFunctionOperator(const RexWindowFunctionOperator *rex_window_function_operator) const
Definition: RexVisitor.h:169
RetType visitLiteral(const RexLiteral *literal) const override
Definition: RexVisitor.h:144
virtual std::unique_ptr< const RexOperator > getDisambiguated(std::vector< std::unique_ptr< const RexScalar >> &operands) const
Definition: RelAlgDag.h:240
virtual T aggregateResult(const T &aggregate, const T &next_result) const
Definition: RexVisitor.h:116
virtual T visit(const RexScalar *rex_scalar) const
Definition: RexVisitor.h:27
const std::unordered_map< size_t, size_t > & old_to_new_idx_
Definition: RexVisitor.h:245
unsigned getIndex() const
Definition: RelAlgDag.h:72
T visitRef(const RexRef *) const override
Definition: RexVisitor.h:83
virtual T defaultResult() const =0
RetType visitSubQuery(const RexSubQuery *subquery) const override
Definition: RexVisitor.h:148
std::unique_ptr< const RexOperator > disambiguatedOperands(ConstRexScalarPtrVector &operands, ConstRexScalarPtrVector &partition_keys, ConstRexScalarPtrVector &order_keys, const std::vector< SortField > &collation) const
Definition: RelAlgDag.h:608
virtual T visitOperator(const RexOperator *rex_operator) const =0
virtual T visitInput(const RexInput *) const =0
virtual T visitSubQuery(const RexSubQuery *) const =0
size_t branchCount() const
Definition: RelAlgDag.h:393
RetType visitInput(const RexInput *input) const override
Definition: RexVisitor.h:142
T visitWindowFunctionOperator(const RexWindowFunctionOperator *rex_window_func_operator, const T operands_visit_result) const
Definition: RexVisitor.h:123
std::unique_ptr< RexInput > deepCopy() const
Definition: RelAlgDag.h:367
const ConstRexScalarPtrVector & getPartitionKeys() const
Definition: RelAlgDag.h:573
T visitInput(const RexInput *) const override
Definition: RexVisitor.h:77
static RelRexToStringConfig defaults()
Definition: RelAlgDag.h:49
const RelAlgNode * getSourceNode() const
Definition: RelAlgDag.h:351
T visitSubQuery(const RexSubQuery *) const override
Definition: RexVisitor.h:81
RetType visitRef(const RexRef *ref) const override
Definition: RexVisitor.h:152
#define CHECK(condition)
Definition: Logger.h:222
const ConstRexScalarPtrVector & getOrderKeys() const
Definition: RelAlgDag.h:583
static std::vector< RowValues > copy(std::vector< RowValues > const &rhs)
Definition: RexVisitor.h:210
RetType visitCase(const RexCase *rex_case) const override
Definition: RexVisitor.h:194
RetType defaultResult() const override
Definition: RexVisitor.h:205
std::unique_ptr< const RexScalar > RetType
Definition: RexVisitor.h:140
T visitLiteral(const RexLiteral *) const override
Definition: RexVisitor.h:79