OmniSciDB  c07336695a
RexVisitor.h
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef QUERYENGINE_REXVISITOR_H
18 #define QUERYENGINE_REXVISITOR_H
19 
21 
22 #include <memory>
23 
24 template <class T>
26  public:
27  virtual T visit(const RexScalar* rex_scalar) const {
28  const auto rex_input = dynamic_cast<const RexInput*>(rex_scalar);
29  if (rex_input) {
30  return visitInput(rex_input);
31  }
32  const auto rex_literal = dynamic_cast<const RexLiteral*>(rex_scalar);
33  if (rex_literal) {
34  return visitLiteral(rex_literal);
35  }
36  const auto rex_subquery = dynamic_cast<const RexSubQuery*>(rex_scalar);
37  if (rex_subquery) {
38  return visitSubQuery(rex_subquery);
39  }
40  const auto rex_operator = dynamic_cast<const RexOperator*>(rex_scalar);
41  if (rex_operator) {
42  return visitOperator(rex_operator);
43  }
44  const auto rex_case = dynamic_cast<const RexCase*>(rex_scalar);
45  if (rex_case) {
46  return visitCase(rex_case);
47  }
48  const auto rex_ref = dynamic_cast<const RexRef*>(rex_scalar);
49  if (rex_ref) {
50  return visitRef(rex_ref);
51  }
52  CHECK(false);
53  return defaultResult();
54  }
55 
56  virtual T visitInput(const RexInput*) const = 0;
57 
58  virtual T visitLiteral(const RexLiteral*) const = 0;
59 
60  virtual T visitSubQuery(const RexSubQuery*) const = 0;
61 
62  virtual T visitRef(const RexRef*) const = 0;
63 
64  virtual T visitOperator(const RexOperator* rex_operator) const = 0;
65 
66  virtual T visitCase(const RexCase* rex_case) const = 0;
67 
68  protected:
69  virtual T defaultResult() const = 0;
70 };
71 
72 template <class T>
73 class RexVisitor : public RexVisitorBase<T> {
74  public:
75  T visitInput(const RexInput*) const override { return defaultResult(); }
76 
77  T visitLiteral(const RexLiteral*) const override { return defaultResult(); }
78 
79  T visitSubQuery(const RexSubQuery*) const override { return defaultResult(); }
80 
81  T visitRef(const RexRef*) const override { return defaultResult(); }
82 
83  T visitOperator(const RexOperator* rex_operator) const override {
84  const size_t operand_count = rex_operator->size();
85  T result = defaultResult();
86  for (size_t i = 0; i < operand_count; ++i) {
87  const auto operand = rex_operator->getOperand(i);
88  T operandResult = RexVisitorBase<T>::visit(operand);
89  result = aggregateResult(result, operandResult);
90  }
91  const auto rex_window_func_operator =
92  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
93  if (rex_window_func_operator) {
94  return visitWindowFunctionOperator(rex_window_func_operator, result);
95  }
96  return result;
97  }
98 
99  T visitCase(const RexCase* rex_case) const override {
100  T result = defaultResult();
101  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
102  const auto when = rex_case->getWhen(i);
103  result = aggregateResult(result, RexVisitorBase<T>::visit(when));
104  const auto then = rex_case->getThen(i);
105  result = aggregateResult(result, RexVisitorBase<T>::visit(then));
106  }
107  if (rex_case->getElse()) {
108  result = aggregateResult(result, RexVisitorBase<T>::visit(rex_case->getElse()));
109  }
110  return result;
111  }
112 
113  protected:
114  virtual T aggregateResult(const T& aggregate, const T& next_result) const {
115  return next_result;
116  }
117 
118  T defaultResult() const override { return T{}; }
119 
120  private:
121  T visitWindowFunctionOperator(const RexWindowFunctionOperator* rex_window_func_operator,
122  const T operands_visit_result) const {
123  T result = operands_visit_result;
124  for (const auto& key : rex_window_func_operator->getPartitionKeys()) {
125  T partial_result = RexVisitorBase<T>::visit(key.get());
126  result = aggregateResult(result, partial_result);
127  }
128  for (const auto& key : rex_window_func_operator->getOrderKeys()) {
129  T partial_result = RexVisitorBase<T>::visit(key.get());
130  result = aggregateResult(result, partial_result);
131  }
132  return result;
133  }
134 };
135 
136 class RexDeepCopyVisitor : public RexVisitorBase<std::unique_ptr<const RexScalar>> {
137  protected:
138  using RetType = std::unique_ptr<const RexScalar>;
139 
140  RetType visitInput(const RexInput* input) const override { return input->deepCopy(); }
141 
142  RetType visitLiteral(const RexLiteral* literal) const override {
143  return literal->deepCopy();
144  }
145 
146  RetType visitSubQuery(const RexSubQuery* subquery) const override {
147  return subquery->deepCopy();
148  }
149 
150  RetType visitRef(const RexRef* ref) const override { return ref->deepCopy(); }
151 
152  RetType visitOperator(const RexOperator* rex_operator) const override {
153  const size_t operand_count = rex_operator->size();
154  std::vector<RetType> new_opnds;
155  for (size_t i = 0; i < operand_count; ++i) {
156  new_opnds.push_back(visit(rex_operator->getOperand(i)));
157  }
158  const auto rex_window_function_operator =
159  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
160  if (rex_window_function_operator) {
161  const auto& partition_keys = rex_window_function_operator->getPartitionKeys();
162  std::vector<std::unique_ptr<const RexScalar>> disambiguated_partition_keys;
163  for (const auto& partition_key : partition_keys) {
164  disambiguated_partition_keys.emplace_back(visit(partition_key.get()));
165  }
166  std::vector<std::unique_ptr<const RexScalar>> disambiguated_order_keys;
167  const auto& order_keys = rex_window_function_operator->getOrderKeys();
168  for (const auto& order_key : order_keys) {
169  disambiguated_order_keys.emplace_back(visit(order_key.get()));
170  }
171  return rex_window_function_operator->disambiguatedOperands(
172  new_opnds,
173  disambiguated_partition_keys,
174  disambiguated_order_keys,
175  rex_window_function_operator->getCollation());
176  }
177  return rex_operator->getDisambiguated(new_opnds);
178  }
179 
180  RetType visitCase(const RexCase* rex_case) const override {
181  std::vector<std::pair<RetType, RetType>> new_pair_list;
182  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
183  new_pair_list.emplace_back(visit(rex_case->getWhen(i)),
184  visit(rex_case->getThen(i)));
185  }
186  auto new_else = visit(rex_case->getElse());
187  return boost::make_unique<RexCase>(new_pair_list, new_else);
188  }
189 
190  private:
191  RetType defaultResult() const override { return nullptr; }
192 };
193 
194 template <bool bAllowMissing>
196  public:
197  RexInputRenumber(const std::unordered_map<size_t, size_t>& new_numbering)
198  : old_to_new_idx_(new_numbering) {}
199  RetType visitInput(const RexInput* input) const override {
200  auto renum_it = old_to_new_idx_.find(input->getIndex());
201  if (bAllowMissing) {
202  if (renum_it != old_to_new_idx_.end()) {
203  return boost::make_unique<RexInput>(input->getSourceNode(), renum_it->second);
204  } else {
205  return input->deepCopy();
206  }
207  } else {
208  CHECK(renum_it != old_to_new_idx_.end());
209  return boost::make_unique<RexInput>(input->getSourceNode(), renum_it->second);
210  }
211  }
212 
213  private:
214  const std::unordered_map<size_t, size_t>& old_to_new_idx_;
215 };
216 
217 #endif // QUERYENGINE_REXVISITOR_H
const ConstRexScalarPtrVector & getPartitionKeys() const
T defaultResult() const override
Definition: RexVisitor.h:118
T visitCase(const RexCase *rex_case) const override
Definition: RexVisitor.h:99
RetType visitInput(const RexInput *input) const override
Definition: RexVisitor.h:199
const ConstRexScalarPtrVector & getOrderKeys() const
virtual T visitRef(const RexRef *) const =0
size_t branchCount() const
RetType visitOperator(const RexOperator *rex_operator) const override
Definition: RexVisitor.h:152
virtual T visitLiteral(const RexLiteral *) const =0
T visitOperator(const RexOperator *rex_operator) const override
Definition: RexVisitor.h:83
RexInputRenumber(const std::unordered_map< size_t, size_t > &new_numbering)
Definition: RexVisitor.h:197
virtual T visitCase(const RexCase *rex_case) const =0
RetType visitLiteral(const RexLiteral *literal) const override
Definition: RexVisitor.h:142
std::unique_ptr< RexInput > deepCopy() const
const RexScalar * getWhen(const size_t idx) const
const std::unordered_map< size_t, size_t > & old_to_new_idx_
Definition: RexVisitor.h:214
T visitRef(const RexRef *) const override
Definition: RexVisitor.h:81
virtual T defaultResult() const =0
const RexScalar * getThen(const size_t idx) const
RetType visitSubQuery(const RexSubQuery *subquery) const override
Definition: RexVisitor.h:146
const RelAlgNode * getSourceNode() const
virtual T visitOperator(const RexOperator *rex_operator) const =0
virtual T visitInput(const RexInput *) const =0
virtual T visitSubQuery(const RexSubQuery *) const =0
T visitWindowFunctionOperator(const RexWindowFunctionOperator *rex_window_func_operator, const T operands_visit_result) const
Definition: RexVisitor.h:121
RetType visitInput(const RexInput *input) const override
Definition: RexVisitor.h:140
virtual std::unique_ptr< const RexOperator > getDisambiguated(std::vector< std::unique_ptr< const RexScalar >> &operands) const
const RexScalar * getOperand(const size_t idx) const
virtual T visit(const RexScalar *rex_scalar) const
Definition: RexVisitor.h:27
T visitInput(const RexInput *) const override
Definition: RexVisitor.h:75
virtual T aggregateResult(const T &aggregate, const T &next_result) const
Definition: RexVisitor.h:114
T visitSubQuery(const RexSubQuery *) const override
Definition: RexVisitor.h:79
RetType visitRef(const RexRef *ref) const override
Definition: RexVisitor.h:150
#define CHECK(condition)
Definition: Logger.h:187
RetType visitCase(const RexCase *rex_case) const override
Definition: RexVisitor.h:180
std::unique_ptr< RexRef > deepCopy() const
RetType defaultResult() const override
Definition: RexVisitor.h:191
std::unique_ptr< const RexScalar > RetType
Definition: RexVisitor.h:138
T visitLiteral(const RexLiteral *) const override
Definition: RexVisitor.h:77
const RexScalar * getElse() const