OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ScalarExprVisitor.h
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef QUERYENGINE_SCALAREXPRVISITOR_H
18 #define QUERYENGINE_SCALAREXPRVISITOR_H
19 
20 #include "../Analyzer/Analyzer.h"
21 
22 template <class T>
24  public:
25  T visit(const Analyzer::Expr* expr) const {
26  CHECK(expr);
27  visitBegin();
28  const auto var = dynamic_cast<const Analyzer::Var*>(expr);
29  if (var) {
30  return visitVar(var);
31  }
32  const auto column_var = dynamic_cast<const Analyzer::ColumnVar*>(expr);
33  if (column_var) {
34  return visitColumnVar(column_var);
35  }
36  const auto column_var_tuple = dynamic_cast<const Analyzer::ExpressionTuple*>(expr);
37  if (column_var_tuple) {
38  return visitColumnVarTuple(column_var_tuple);
39  }
40  const auto constant = dynamic_cast<const Analyzer::Constant*>(expr);
41  if (constant) {
42  return visitConstant(constant);
43  }
44  const auto uoper = dynamic_cast<const Analyzer::UOper*>(expr);
45  if (uoper) {
46  return visitUOper(uoper);
47  }
48  const auto bin_oper = dynamic_cast<const Analyzer::BinOper*>(expr);
49  if (bin_oper) {
50  return visitBinOper(bin_oper);
51  }
52  const auto geo_expr = dynamic_cast<const Analyzer::GeoExpr*>(expr);
53  if (geo_expr) {
54  return visitGeoExpr(geo_expr);
55  }
56  const auto in_values = dynamic_cast<const Analyzer::InValues*>(expr);
57  if (in_values) {
58  return visitInValues(in_values);
59  }
60  const auto in_integer_set = dynamic_cast<const Analyzer::InIntegerSet*>(expr);
61  if (in_integer_set) {
62  return visitInIntegerSet(in_integer_set);
63  }
64  const auto char_length = dynamic_cast<const Analyzer::CharLengthExpr*>(expr);
65  if (char_length) {
67  }
68  const auto key_for_string = dynamic_cast<const Analyzer::KeyForStringExpr*>(expr);
69  if (key_for_string) {
70  return visitKeyForString(key_for_string);
71  }
72  const auto sample_ratio = dynamic_cast<const Analyzer::SampleRatioExpr*>(expr);
73  if (sample_ratio) {
75  }
76  const auto width_bucket = dynamic_cast<const Analyzer::WidthBucketExpr*>(expr);
77  if (width_bucket) {
79  }
80  const auto ml_predict = dynamic_cast<const Analyzer::MLPredictExpr*>(expr);
81  if (ml_predict) {
82  return visitMLPredict(ml_predict);
83  }
84  const auto pca_project = dynamic_cast<const Analyzer::PCAProjectExpr*>(expr);
85  if (pca_project) {
86  return visitPCAProject(pca_project);
87  }
88  const auto string_oper = dynamic_cast<const Analyzer::StringOper*>(expr);
89  if (string_oper) {
90  return visitStringOper(string_oper);
91  }
92  const auto cardinality = dynamic_cast<const Analyzer::CardinalityExpr*>(expr);
93  if (cardinality) {
94  return visitCardinality(cardinality);
95  }
96  const auto like_expr = dynamic_cast<const Analyzer::LikeExpr*>(expr);
97  if (like_expr) {
98  return visitLikeExpr(like_expr);
99  }
100  const auto regexp_expr = dynamic_cast<const Analyzer::RegexpExpr*>(expr);
101  if (regexp_expr) {
102  return visitRegexpExpr(regexp_expr);
103  }
104  const auto case_ = dynamic_cast<const Analyzer::CaseExpr*>(expr);
105  if (case_) {
106  return visitCaseExpr(case_);
107  }
108  const auto datetrunc = dynamic_cast<const Analyzer::DatetruncExpr*>(expr);
109  if (datetrunc) {
110  return visitDatetruncExpr(datetrunc);
111  }
112  const auto extract = dynamic_cast<const Analyzer::ExtractExpr*>(expr);
113  if (extract) {
114  return visitExtractExpr(extract);
115  }
116  const auto window_func = dynamic_cast<const Analyzer::WindowFunction*>(expr);
117  if (window_func) {
118  return visitWindowFunction(window_func);
119  }
120  const auto func_with_custom_type_handling =
121  dynamic_cast<const Analyzer::FunctionOperWithCustomTypeHandling*>(expr);
122  if (func_with_custom_type_handling) {
123  return visitFunctionOperWithCustomTypeHandling(func_with_custom_type_handling);
124  }
125  const auto func = dynamic_cast<const Analyzer::FunctionOper*>(expr);
126  if (func) {
127  return visitFunctionOper(func);
128  }
129  const auto array = dynamic_cast<const Analyzer::ArrayExpr*>(expr);
130  if (array) {
131  return visitArrayOper(array);
132  }
133  const auto geo_uop = dynamic_cast<const Analyzer::GeoUOper*>(expr);
134  if (geo_uop) {
135  return visitGeoUOper(geo_uop);
136  }
137  const auto geo_binop = dynamic_cast<const Analyzer::GeoBinOper*>(expr);
138  if (geo_binop) {
139  return visitGeoBinOper(geo_binop);
140  }
141  const auto datediff = dynamic_cast<const Analyzer::DatediffExpr*>(expr);
142  if (datediff) {
143  return visitDatediffExpr(datediff);
144  }
145  const auto dateadd = dynamic_cast<const Analyzer::DateaddExpr*>(expr);
146  if (dateadd) {
147  return visitDateaddExpr(dateadd);
148  }
149  const auto likelihood = dynamic_cast<const Analyzer::LikelihoodExpr*>(expr);
150  if (likelihood) {
151  return visitLikelihood(likelihood);
152  }
153  const auto offset_in_fragment = dynamic_cast<const Analyzer::OffsetInFragment*>(expr);
154  if (offset_in_fragment) {
155  return visitOffsetInFragment(offset_in_fragment);
156  }
157  const auto agg = dynamic_cast<const Analyzer::AggExpr*>(expr);
158  if (agg) {
159  return visitAggExpr(agg);
160  }
161  const auto range_join_oper = dynamic_cast<const Analyzer::RangeOper*>(expr);
162  if (range_join_oper) {
163  return visitRangeJoinOper(range_join_oper);
164  }
165  return defaultResult();
166  }
167 
168  protected:
169  virtual T visitVar(const Analyzer::Var*) const { return defaultResult(); }
170 
171  virtual T visitColumnVar(const Analyzer::ColumnVar*) const { return defaultResult(); }
172 
174  return defaultResult();
175  }
176 
177  virtual T visitConstant(const Analyzer::Constant*) const { return defaultResult(); }
178 
179  virtual T visitUOper(const Analyzer::UOper* uoper) const {
180  T result = defaultResult();
181  result = aggregateResult(result, visit(uoper->get_operand()));
182  return result;
183  }
184 
185  virtual T visitBinOper(const Analyzer::BinOper* bin_oper) const {
186  T result = defaultResult();
187  result = aggregateResult(result, visit(bin_oper->get_left_operand()));
188  result = aggregateResult(result, visit(bin_oper->get_right_operand()));
189  return result;
190  }
191 
192  virtual T visitGeoExpr(const Analyzer::GeoExpr* geo_expr) const {
193  T result = defaultResult();
194  const auto geo_expr_children = geo_expr->getChildExprs();
195  for (const auto expr : geo_expr_children) {
196  result = aggregateResult(result, visit(expr));
197  }
198  return result;
199  }
200 
201  virtual T visitInValues(const Analyzer::InValues* in_values) const {
202  T result = visit(in_values->get_arg());
203  const auto& value_list = in_values->get_value_list();
204  for (const auto& in_value : value_list) {
205  result = aggregateResult(result, visit(in_value.get()));
206  }
207  return result;
208  }
209 
210  virtual T visitInIntegerSet(const Analyzer::InIntegerSet* in_integer_set) const {
211  return visit(in_integer_set->get_arg());
212  }
213 
215  T result = defaultResult();
216  result = aggregateResult(result, visit(char_length->get_arg()));
217  return result;
218  }
219 
220  virtual T visitKeyForString(const Analyzer::KeyForStringExpr* key_for_string) const {
221  T result = defaultResult();
222  result = aggregateResult(result, visit(key_for_string->get_arg()));
223  return result;
224  }
225 
227  T result = defaultResult();
228  result = aggregateResult(result, visit(sample_ratio->get_arg()));
229  return result;
230  }
231 
232  virtual T visitStringOper(const Analyzer::StringOper* string_oper) const {
233  T result = defaultResult();
234  for (const auto& arg : string_oper->getOwnArgs()) {
235  result = aggregateResult(result, visit(arg.get()));
236  }
237  return result;
238  }
239 
240  virtual T visitCardinality(const Analyzer::CardinalityExpr* cardinality) const {
241  T result = defaultResult();
242  result = aggregateResult(result, visit(cardinality->get_arg()));
243  return result;
244  }
245 
246  virtual T visitLikeExpr(const Analyzer::LikeExpr* like) const {
247  T result = defaultResult();
248  result = aggregateResult(result, visit(like->get_arg()));
249  result = aggregateResult(result, visit(like->get_like_expr()));
250  if (like->get_escape_expr()) {
251  result = aggregateResult(result, visit(like->get_escape_expr()));
252  }
253  return result;
254  }
255 
256  virtual T visitRegexpExpr(const Analyzer::RegexpExpr* regexp) const {
257  T result = defaultResult();
258  result = aggregateResult(result, visit(regexp->get_arg()));
259  result = aggregateResult(result, visit(regexp->get_pattern_expr()));
260  if (regexp->get_escape_expr()) {
261  result = aggregateResult(result, visit(regexp->get_escape_expr()));
262  }
263  return result;
264  }
265 
267  T result = defaultResult();
268  result = aggregateResult(result, visit(width_bucket_expr->get_target_value()));
269  result = aggregateResult(result, visit(width_bucket_expr->get_lower_bound()));
270  result = aggregateResult(result, visit(width_bucket_expr->get_upper_bound()));
271  result = aggregateResult(result, visit(width_bucket_expr->get_partition_count()));
272  return result;
273  }
274 
275  virtual T visitMLPredict(const Analyzer::MLPredictExpr* ml_predict_expr) const {
276  T result = defaultResult();
277  result = aggregateResult(result, visit(ml_predict_expr->get_model_value()));
278  const auto& regressor_values = ml_predict_expr->get_regressor_values();
279  for (const auto& regressor_value : regressor_values) {
280  result = aggregateResult(result, visit(regressor_value.get()));
281  }
282  return result;
283  }
284 
285  virtual T visitPCAProject(const Analyzer::PCAProjectExpr* pca_project_expr) const {
286  T result = defaultResult();
287  result = aggregateResult(result, visit(pca_project_expr->get_model_value()));
288  const auto& feature_values = pca_project_expr->get_feature_values();
289  for (const auto& feature_value : feature_values) {
290  result = aggregateResult(result, visit(feature_value.get()));
291  }
292  result = aggregateResult(result, visit(pca_project_expr->get_pc_dimension_value()));
293  return result;
294  }
295 
296  virtual T visitCaseExpr(const Analyzer::CaseExpr* case_) const {
297  T result = defaultResult();
298  const auto& expr_pair_list = case_->get_expr_pair_list();
299  for (const auto& expr_pair : expr_pair_list) {
300  result = aggregateResult(result, visit(expr_pair.first.get()));
301  result = aggregateResult(result, visit(expr_pair.second.get()));
302  }
303  result = aggregateResult(result, visit(case_->get_else_expr()));
304  return result;
305  }
306 
307  virtual T visitDatetruncExpr(const Analyzer::DatetruncExpr* datetrunc) const {
308  T result = defaultResult();
309  result = aggregateResult(result, visit(datetrunc->get_from_expr()));
310  return result;
311  }
312 
313  virtual T visitExtractExpr(const Analyzer::ExtractExpr* extract) const {
314  T result = defaultResult();
315  result = aggregateResult(result, visit(extract->get_from_expr()));
316  return result;
317  }
318 
320  const Analyzer::FunctionOperWithCustomTypeHandling* func_oper) const {
321  return visitFunctionOper(func_oper);
322  }
323 
324  virtual T visitArrayOper(Analyzer::ArrayExpr const* array_expr) const {
325  T result = defaultResult();
326  for (size_t i = 0; i < array_expr->getElementCount(); ++i) {
327  result = aggregateResult(result, visit(array_expr->getElement(i)));
328  }
329  return result;
330  }
331 
332  virtual T visitGeoUOper(const Analyzer::GeoUOper* geo_expr) const {
333  T result = defaultResult();
334  for (const auto& arg : geo_expr->getArgs0()) {
335  result = aggregateResult(result, visit(arg.get()));
336  }
337  return result;
338  }
339 
340  virtual T visitGeoBinOper(const Analyzer::GeoBinOper* geo_expr) const {
341  T result = defaultResult();
342  for (const auto& arg : geo_expr->getArgs0()) {
343  result = aggregateResult(result, visit(arg.get()));
344  }
345  for (const auto& arg : geo_expr->getArgs1()) {
346  result = aggregateResult(result, visit(arg.get()));
347  }
348  return result;
349  }
350 
351  virtual T visitFunctionOper(const Analyzer::FunctionOper* func_oper) const {
352  T result = defaultResult();
353  for (size_t i = 0; i < func_oper->getArity(); ++i) {
354  result = aggregateResult(result, visit(func_oper->getArg(i)));
355  }
356  return result;
357  }
358 
359  virtual T visitWindowFunction(const Analyzer::WindowFunction* window_func) const {
360  T result = defaultResult();
361  for (const auto& arg : window_func->getArgs()) {
362  result = aggregateResult(result, visit(arg.get()));
363  }
364  for (const auto& partition_key : window_func->getPartitionKeys()) {
365  result = aggregateResult(result, visit(partition_key.get()));
366  }
367  for (const auto& order_key : window_func->getOrderKeys()) {
368  result = aggregateResult(result, visit(order_key.get()));
369  }
370  return result;
371  }
372 
373  virtual T visitDatediffExpr(const Analyzer::DatediffExpr* datediff) const {
374  T result = defaultResult();
375  result = aggregateResult(result, visit(datediff->get_start_expr()));
376  result = aggregateResult(result, visit(datediff->get_end_expr()));
377  return result;
378  }
379 
380  virtual T visitDateaddExpr(const Analyzer::DateaddExpr* dateadd) const {
381  T result = defaultResult();
382  result = aggregateResult(result, visit(dateadd->get_number_expr()));
383  result = aggregateResult(result, visit(dateadd->get_datetime_expr()));
384  return result;
385  }
386 
387  virtual T visitLikelihood(const Analyzer::LikelihoodExpr* likelihood) const {
388  return visit(likelihood->get_arg());
389  }
390 
392  return defaultResult();
393  }
394 
395  virtual T visitAggExpr(const Analyzer::AggExpr* agg) const {
396  T result = defaultResult();
397  if (agg->get_arg()) {
398  return aggregateResult(result, visit(agg->get_arg()));
399  }
400  return defaultResult();
401  }
402 
403  virtual T visitRangeJoinOper(const Analyzer::RangeOper* range_oper) const {
404  T result = defaultResult();
405  result = aggregateResult(result, visit(range_oper->get_left_operand()));
406  result = aggregateResult(result, visit(range_oper->get_right_operand()));
407  return result;
408  }
409 
410  protected:
411  virtual T aggregateResult(const T& aggregate, const T& next_result) const {
412  return next_result;
413  }
414 
415  virtual void visitBegin() const {}
416 
417  virtual T defaultResult() const { return T{}; }
418 };
419 
420 #endif // QUERYENGINE_SCALAREXPRVISITOR_H
virtual T visitAggExpr(const Analyzer::AggExpr *agg) const
const Expr * get_from_expr() const
Definition: Analyzer.h:1432
const Expr * get_partition_count() const
Definition: Analyzer.h:1201
virtual T aggregateResult(const T &aggregate, const T &next_result) const
virtual T visitOffsetInFragment(const Analyzer::OffsetInFragment *) const
const Expr * get_else_expr() const
Definition: Analyzer.h:1387
const std::vector< std::shared_ptr< Analyzer::Expr > > & getArgs0() const
Definition: Analyzer.h:2962
const Expr * get_escape_expr() const
Definition: Analyzer.h:1064
size_t getArity() const
Definition: Analyzer.h:2615
const Expr * get_escape_expr() const
Definition: Analyzer.h:1136
const Expr * get_right_operand() const
Definition: Analyzer.h:456
virtual T visitGeoBinOper(const Analyzer::GeoBinOper *geo_expr) const
const Expr * get_arg() const
Definition: Analyzer.h:1133
Expr * get_arg() const
Definition: Analyzer.h:1330
const Expr * get_arg() const
Definition: Analyzer.h:1267
T visit(const Analyzer::Expr *expr) const
virtual T visitCardinality(const Analyzer::CardinalityExpr *cardinality) const
virtual T visitGeoUOper(const Analyzer::GeoUOper *geo_expr) const
virtual std::vector< Analyzer::Expr * > getChildExprs() const
Definition: Analyzer.h:3091
const std::vector< std::shared_ptr< Analyzer::Expr > > & getOrderKeys() const
Definition: Analyzer.h:2802
virtual T visitExtractExpr(const Analyzer::ExtractExpr *extract) const
const Expr * get_left_operand() const
Definition: Analyzer.h:552
const Expr * get_arg() const
Definition: Analyzer.h:1061
virtual T visitLikelihood(const Analyzer::LikelihoodExpr *likelihood) const
virtual T visitCharLength(const Analyzer::CharLengthExpr *char_length) const
virtual T visitVar(const Analyzer::Var *) const
RUNTIME_EXPORT ALWAYS_INLINE DEVICE int32_t width_bucket(const double target_value, const double lower_bound, const double upper_bound, const double scale_factor, const int32_t partition_count)
virtual T visitGeoExpr(const Analyzer::GeoExpr *geo_expr) const
virtual T visitPCAProject(const Analyzer::PCAProjectExpr *pca_project_expr) const
virtual T visitLikeExpr(const Analyzer::LikeExpr *like) const
const Expr * get_pc_dimension_value() const
Definition: Analyzer.h:792
const std::vector< std::shared_ptr< Analyzer::Expr > > & getArgs() const
Definition: Analyzer.h:2796
virtual void visitBegin() const
const std::vector< std::shared_ptr< Analyzer::Expr > > & getArgs0() const
Definition: Analyzer.h:2932
virtual T visitColumnVar(const Analyzer::ColumnVar *) const
virtual T visitWindowFunction(const Analyzer::WindowFunction *window_func) const
const Expr * get_start_expr() const
Definition: Analyzer.h:1525
const Expr * get_right_operand() const
Definition: Analyzer.h:553
size_t getElementCount() const
Definition: Analyzer.h:2892
RUNTIME_EXPORT ALWAYS_INLINE DEVICE int32_t char_length(const char *str, const int32_t str_len)
virtual T visitDatediffExpr(const Analyzer::DatediffExpr *datediff) const
const Expr * get_pattern_expr() const
Definition: Analyzer.h:1135
virtual T visitDateaddExpr(const Analyzer::DateaddExpr *dateadd) const
Expression class for string functions The &quot;arg&quot; constructor parameter must be an expression that reso...
Definition: Analyzer.h:1601
const Expr * get_from_expr() const
Definition: Analyzer.h:1567
const Expr * get_datetime_expr() const
Definition: Analyzer.h:1478
virtual T visitMLPredict(const Analyzer::MLPredictExpr *ml_predict_expr) const
const Expr * get_like_expr() const
Definition: Analyzer.h:1063
const Analyzer::Expr * getArg(const size_t i) const
Definition: Analyzer.h:2617
virtual T visitRangeJoinOper(const Analyzer::RangeOper *range_oper) const
const Expr * get_operand() const
Definition: Analyzer.h:384
const Expr * get_arg() const
Definition: Analyzer.h:962
const Expr * get_model_value() const
Definition: Analyzer.h:788
const Expr * get_arg() const
Definition: Analyzer.h:868
const std::list< std::shared_ptr< Analyzer::Expr > > & get_value_list() const
Definition: Analyzer.h:646
virtual T visitConstant(const Analyzer::Constant *) const
virtual T visitKeyForString(const Analyzer::KeyForStringExpr *key_for_string) const
virtual T visitCaseExpr(const Analyzer::CaseExpr *case_) const
virtual T visitBinOper(const Analyzer::BinOper *bin_oper) const
const Expr * get_target_value() const
Definition: Analyzer.h:1198
const std::vector< std::shared_ptr< Analyzer::Expr > > & get_feature_values() const
Definition: Analyzer.h:789
const Expr * get_arg() const
Definition: Analyzer.h:917
virtual T visitInValues(const Analyzer::InValues *in_values) const
virtual T visitInIntegerSet(const Analyzer::InIntegerSet *in_integer_set) const
const Expr * get_end_expr() const
Definition: Analyzer.h:1526
#define CHECK(condition)
Definition: Logger.h:291
virtual T visitFunctionOperWithCustomTypeHandling(const Analyzer::FunctionOperWithCustomTypeHandling *func_oper) const
virtual T visitFunctionOper(const Analyzer::FunctionOper *func_oper) const
const Expr * get_model_value() const
Definition: Analyzer.h:713
virtual T visitArrayOper(Analyzer::ArrayExpr const *array_expr) const
virtual T defaultResult() const
const Expr * get_left_operand() const
Definition: Analyzer.h:455
virtual T visitRegexpExpr(const Analyzer::RegexpExpr *regexp) const
virtual T visitSampleRatio(const Analyzer::SampleRatioExpr *sample_ratio) const
virtual T visitUOper(const Analyzer::UOper *uoper) const
const Expr * get_lower_bound() const
Definition: Analyzer.h:1199
virtual T visitColumnVarTuple(const Analyzer::ExpressionTuple *) const
const Expr * get_arg() const
Definition: Analyzer.h:693
virtual T visitStringOper(const Analyzer::StringOper *string_oper) const
const std::vector< std::shared_ptr< Analyzer::Expr > > & get_regressor_values() const
Definition: Analyzer.h:714
virtual T visitDatetruncExpr(const Analyzer::DatetruncExpr *datetrunc) const
const Expr * get_upper_bound() const
Definition: Analyzer.h:1200
const std::vector< std::shared_ptr< Analyzer::Expr > > & getPartitionKeys() const
Definition: Analyzer.h:2798
virtual T visitWidthBucket(const Analyzer::WidthBucketExpr *width_bucket_expr) const
const Expr * get_arg() const
Definition: Analyzer.h:1007
RUNTIME_EXPORT ALWAYS_INLINE DEVICE bool sample_ratio(const double proportion, const int64_t row_offset)
const Expr * get_arg() const
Definition: Analyzer.h:644
std::vector< std::shared_ptr< Analyzer::Expr > > getOwnArgs() const
Definition: Analyzer.h:1698
const std::vector< std::shared_ptr< Analyzer::Expr > > & getArgs1() const
Definition: Analyzer.h:2963
const std::list< std::pair< std::shared_ptr< Analyzer::Expr >, std::shared_ptr< Analyzer::Expr > > > & get_expr_pair_list() const
Definition: Analyzer.h:1384
const Expr * get_number_expr() const
Definition: Analyzer.h:1477
const Analyzer::Expr * getElement(const size_t i) const
Definition: Analyzer.h:2896
RUNTIME_EXPORT ALWAYS_INLINE DEVICE int32_t width_bucket_expr(const double target_value, const bool reversed, const double lower_bound, const double upper_bound, const int32_t partition_count)