OmniSciDB  b24e664e58
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
RexVisitor.h
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef QUERYENGINE_REXVISITOR_H
18 #define QUERYENGINE_REXVISITOR_H
19 
21 
22 #include <memory>
23 
24 template <class T>
26  public:
27  virtual T visit(const RexScalar* rex_scalar) const {
28  CHECK(rex_scalar);
29  const auto rex_input = dynamic_cast<const RexInput*>(rex_scalar);
30  if (rex_input) {
31  return visitInput(rex_input);
32  }
33  const auto rex_literal = dynamic_cast<const RexLiteral*>(rex_scalar);
34  if (rex_literal) {
35  return visitLiteral(rex_literal);
36  }
37  const auto rex_subquery = dynamic_cast<const RexSubQuery*>(rex_scalar);
38  if (rex_subquery) {
39  return visitSubQuery(rex_subquery);
40  }
41  const auto rex_operator = dynamic_cast<const RexOperator*>(rex_scalar);
42  if (rex_operator) {
43  return visitOperator(rex_operator);
44  }
45  const auto rex_case = dynamic_cast<const RexCase*>(rex_scalar);
46  if (rex_case) {
47  return visitCase(rex_case);
48  }
49  const auto rex_ref = dynamic_cast<const RexRef*>(rex_scalar);
50  if (rex_ref) {
51  return visitRef(rex_ref);
52  }
53  LOG(FATAL) << "No visit path for " << rex_scalar->toString();
54  return defaultResult();
55  }
56 
57  virtual T visitInput(const RexInput*) const = 0;
58 
59  virtual T visitLiteral(const RexLiteral*) const = 0;
60 
61  virtual T visitSubQuery(const RexSubQuery*) const = 0;
62 
63  virtual T visitRef(const RexRef*) const = 0;
64 
65  virtual T visitOperator(const RexOperator* rex_operator) const = 0;
66 
67  virtual T visitCase(const RexCase* rex_case) const = 0;
68 
69  protected:
70  virtual T defaultResult() const = 0;
71 };
72 
73 template <class T>
74 class RexVisitor : public RexVisitorBase<T> {
75  public:
76  T visitInput(const RexInput*) const override { return defaultResult(); }
77 
78  T visitLiteral(const RexLiteral*) const override { return defaultResult(); }
79 
80  T visitSubQuery(const RexSubQuery*) const override { return defaultResult(); }
81 
82  T visitRef(const RexRef*) const override { return defaultResult(); }
83 
84  T visitOperator(const RexOperator* rex_operator) const override {
85  const size_t operand_count = rex_operator->size();
86  T result = defaultResult();
87  for (size_t i = 0; i < operand_count; ++i) {
88  const auto operand = rex_operator->getOperand(i);
89  T operandResult = RexVisitorBase<T>::visit(operand);
90  result = aggregateResult(result, operandResult);
91  }
92  const auto rex_window_func_operator =
93  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
94  if (rex_window_func_operator) {
95  return visitWindowFunctionOperator(rex_window_func_operator, result);
96  }
97  return result;
98  }
99 
100  T visitCase(const RexCase* rex_case) const override {
101  T result = defaultResult();
102  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
103  const auto when = rex_case->getWhen(i);
104  result = aggregateResult(result, RexVisitorBase<T>::visit(when));
105  const auto then = rex_case->getThen(i);
106  result = aggregateResult(result, RexVisitorBase<T>::visit(then));
107  }
108  if (rex_case->getElse()) {
109  result = aggregateResult(result, RexVisitorBase<T>::visit(rex_case->getElse()));
110  }
111  return result;
112  }
113 
114  protected:
115  virtual T aggregateResult(const T& aggregate, const T& next_result) const {
116  return next_result;
117  }
118 
119  T defaultResult() const override { return T{}; }
120 
121  private:
122  T visitWindowFunctionOperator(const RexWindowFunctionOperator* rex_window_func_operator,
123  const T operands_visit_result) const {
124  T result = operands_visit_result;
125  for (const auto& key : rex_window_func_operator->getPartitionKeys()) {
126  T partial_result = RexVisitorBase<T>::visit(key.get());
127  result = aggregateResult(result, partial_result);
128  }
129  for (const auto& key : rex_window_func_operator->getOrderKeys()) {
130  T partial_result = RexVisitorBase<T>::visit(key.get());
131  result = aggregateResult(result, partial_result);
132  }
133  return result;
134  }
135 };
136 
137 class RexDeepCopyVisitor : public RexVisitorBase<std::unique_ptr<const RexScalar>> {
138  protected:
139  using RetType = std::unique_ptr<const RexScalar>;
140 
141  RetType visitInput(const RexInput* input) const override { return input->deepCopy(); }
142 
143  RetType visitLiteral(const RexLiteral* literal) const override {
144  return literal->deepCopy();
145  }
146 
147  RetType visitSubQuery(const RexSubQuery* subquery) const override {
148  return subquery->deepCopy();
149  }
150 
151  RetType visitRef(const RexRef* ref) const override { return ref->deepCopy(); }
152 
153  RetType visitOperator(const RexOperator* rex_operator) const override {
154  const auto rex_window_function_operator =
155  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
156  if (rex_window_function_operator) {
157  return visitWindowFunctionOperator(rex_window_function_operator);
158  }
159 
160  const size_t operand_count = rex_operator->size();
161  std::vector<RetType> new_opnds;
162  for (size_t i = 0; i < operand_count; ++i) {
163  new_opnds.push_back(visit(rex_operator->getOperand(i)));
164  }
165  return rex_operator->getDisambiguated(new_opnds);
166  }
167 
169  const RexWindowFunctionOperator* rex_window_function_operator) const {
170  const size_t operand_count = rex_window_function_operator->size();
171  std::vector<RetType> new_opnds;
172  for (size_t i = 0; i < operand_count; ++i) {
173  new_opnds.push_back(visit(rex_window_function_operator->getOperand(i)));
174  }
175 
176  const auto& partition_keys = rex_window_function_operator->getPartitionKeys();
177  std::vector<std::unique_ptr<const RexScalar>> disambiguated_partition_keys;
178  for (const auto& partition_key : partition_keys) {
179  disambiguated_partition_keys.emplace_back(visit(partition_key.get()));
180  }
181  std::vector<std::unique_ptr<const RexScalar>> disambiguated_order_keys;
182  const auto& order_keys = rex_window_function_operator->getOrderKeys();
183  for (const auto& order_key : order_keys) {
184  disambiguated_order_keys.emplace_back(visit(order_key.get()));
185  }
186  return rex_window_function_operator->disambiguatedOperands(
187  new_opnds,
188  disambiguated_partition_keys,
189  disambiguated_order_keys,
190  rex_window_function_operator->getCollation());
191  }
192 
193  RetType visitCase(const RexCase* rex_case) const override {
194  std::vector<std::pair<RetType, RetType>> new_pair_list;
195  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
196  new_pair_list.emplace_back(visit(rex_case->getWhen(i)),
197  visit(rex_case->getThen(i)));
198  }
199  auto new_else = visit(rex_case->getElse());
200  return std::make_unique<RexCase>(new_pair_list, new_else);
201  }
202 
203  private:
204  RetType defaultResult() const override { return nullptr; }
205 };
206 
207 template <bool bAllowMissing>
209  public:
210  RexInputRenumber(const std::unordered_map<size_t, size_t>& new_numbering)
211  : old_to_new_idx_(new_numbering) {}
212  RetType visitInput(const RexInput* input) const override {
213  auto renum_it = old_to_new_idx_.find(input->getIndex());
214  if (bAllowMissing) {
215  if (renum_it != old_to_new_idx_.end()) {
216  return std::make_unique<RexInput>(input->getSourceNode(), renum_it->second);
217  } else {
218  return input->deepCopy();
219  }
220  } else {
221  CHECK(renum_it != old_to_new_idx_.end());
222  return std::make_unique<RexInput>(input->getSourceNode(), renum_it->second);
223  }
224  }
225 
226  private:
227  const std::unordered_map<size_t, size_t>& old_to_new_idx_;
228 };
229 
230 #endif // QUERYENGINE_REXVISITOR_H
const RexScalar * getThen(const size_t idx) const
T defaultResult() const override
Definition: RexVisitor.h:119
T visitCase(const RexCase *rex_case) const override
Definition: RexVisitor.h:100
const RexScalar * getElse() const
std::unique_ptr< RexRef > deepCopy() const
#define LOG(tag)
Definition: Logger.h:185
size_t size() const
const RexScalar * getOperand(const size_t idx) const
RetType visitInput(const RexInput *input) const override
Definition: RexVisitor.h:212
const std::vector< SortField > & getCollation() const
virtual T visitRef(const RexRef *) const =0
RetType visitOperator(const RexOperator *rex_operator) const override
Definition: RexVisitor.h:153
virtual T visitLiteral(const RexLiteral *) const =0
T visitOperator(const RexOperator *rex_operator) const override
Definition: RexVisitor.h:84
const RexScalar * getWhen(const size_t idx) const
RexInputRenumber(const std::unordered_map< size_t, size_t > &new_numbering)
Definition: RexVisitor.h:210
virtual T visitCase(const RexCase *rex_case) const =0
RetType visitWindowFunctionOperator(const RexWindowFunctionOperator *rex_window_function_operator) const
Definition: RexVisitor.h:168
RetType visitLiteral(const RexLiteral *literal) const override
Definition: RexVisitor.h:143
virtual std::unique_ptr< const RexOperator > getDisambiguated(std::vector< std::unique_ptr< const RexScalar >> &operands) const
virtual T aggregateResult(const T &aggregate, const T &next_result) const
Definition: RexVisitor.h:115
virtual T visit(const RexScalar *rex_scalar) const
Definition: RexVisitor.h:27
const std::unordered_map< size_t, size_t > & old_to_new_idx_
Definition: RexVisitor.h:227
unsigned getIndex() const
T visitRef(const RexRef *) const override
Definition: RexVisitor.h:82
virtual T defaultResult() const =0
CHECK(cgen_state)
RetType visitSubQuery(const RexSubQuery *subquery) const override
Definition: RexVisitor.h:147
std::unique_ptr< const RexOperator > disambiguatedOperands(ConstRexScalarPtrVector &operands, ConstRexScalarPtrVector &partition_keys, ConstRexScalarPtrVector &order_keys, const std::vector< SortField > &collation) const
virtual T visitOperator(const RexOperator *rex_operator) const =0
virtual T visitInput(const RexInput *) const =0
virtual T visitSubQuery(const RexSubQuery *) const =0
size_t branchCount() const
RetType visitInput(const RexInput *input) const override
Definition: RexVisitor.h:141
T visitWindowFunctionOperator(const RexWindowFunctionOperator *rex_window_func_operator, const T operands_visit_result) const
Definition: RexVisitor.h:122
std::unique_ptr< RexInput > deepCopy() const
const ConstRexScalarPtrVector & getPartitionKeys() const
T visitInput(const RexInput *) const override
Definition: RexVisitor.h:76
const RelAlgNode * getSourceNode() const
T visitSubQuery(const RexSubQuery *) const override
Definition: RexVisitor.h:80
RetType visitRef(const RexRef *ref) const override
Definition: RexVisitor.h:151
virtual std::string toString() const =0
const ConstRexScalarPtrVector & getOrderKeys() const
RetType visitCase(const RexCase *rex_case) const override
Definition: RexVisitor.h:193
RetType defaultResult() const override
Definition: RexVisitor.h:204
std::unique_ptr< const RexScalar > RetType
Definition: RexVisitor.h:139
T visitLiteral(const RexLiteral *) const override
Definition: RexVisitor.h:78