OmniSciDB  a5dc49c757
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
WindowExpressionRewrite.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 namespace {
20 
21 // Returns true iff the case expression has an else null branch.
22 bool matches_else_null(const Analyzer::CaseExpr* case_expr) {
23  const auto else_null =
24  dynamic_cast<const Analyzer::Constant*>(case_expr->get_else_expr());
25  return else_null && else_null->get_is_null();
26 }
27 
28 // Returns true iff the expression is a big integer greater than 0.
29 bool matches_gt_bigint_zero(const Analyzer::BinOper* window_gt_zero) {
30  if (window_gt_zero->get_optype() != kGT) {
31  return false;
32  }
33  const auto zero =
34  dynamic_cast<const Analyzer::Constant*>(window_gt_zero->get_right_operand());
35  return zero && zero->get_type_info().get_type() == kBIGINT &&
36  zero->get_constval().bigintval == 0;
37 }
38 
39 // Returns true iff the sum and the count match in type and arguments. Used to replace
40 // combination can be replaced with an explicit average.
42  const Analyzer::WindowFunction* count_window_expr) {
43  CHECK_EQ(count_window_expr->get_type_info().get_type(), kBIGINT);
44  return expr_list_match(sum_window_expr->getArgs(), count_window_expr->getArgs());
45 }
46 
48  return kind == SqlWindowFunctionKind::SUM_INTERNAL ||
50 }
51 
52 } // namespace
53 
54 std::shared_ptr<Analyzer::WindowFunction> rewrite_sum_window(const Analyzer::Expr* expr) {
55  const auto case_expr = dynamic_cast<const Analyzer::CaseExpr*>(expr);
56  if (!case_expr || !matches_else_null(case_expr)) {
57  return nullptr;
58  }
59  const auto& expr_pair_list = case_expr->get_expr_pair_list();
60  if (expr_pair_list.size() != 1) {
61  return nullptr;
62  }
63  const auto& expr_pair = expr_pair_list.front();
64  const auto window_gt_zero =
65  dynamic_cast<const Analyzer::BinOper*>(expr_pair.first.get());
66  if (!window_gt_zero || !matches_gt_bigint_zero(window_gt_zero)) {
67  return nullptr;
68  }
69  const auto sum_window_expr =
70  std::dynamic_pointer_cast<Analyzer::WindowFunction>(remove_cast(expr_pair.second));
71  if (!sum_window_expr || !is_sum_kind(sum_window_expr->getKind())) {
72  return nullptr;
73  }
74  const auto count_window_expr =
75  std::dynamic_pointer_cast<const Analyzer::WindowFunction>(
76  remove_cast(window_gt_zero->get_own_left_operand()));
77  if (!count_window_expr ||
78  count_window_expr->getKind() != SqlWindowFunctionKind::COUNT) {
79  return nullptr;
80  }
81  if (!window_sum_and_count_match(sum_window_expr.get(), count_window_expr.get())) {
82  return nullptr;
83  }
84  CHECK(sum_window_expr);
85  auto sum_ti = sum_window_expr->get_type_info();
86  if (sum_ti.is_integer()) {
87  sum_ti = SQLTypeInfo(kBIGINT, sum_ti.get_notnull());
88  }
89  return makeExpr<Analyzer::WindowFunction>(
90  sum_ti,
92  sum_window_expr->getArgs(),
93  sum_window_expr->getPartitionKeys(),
94  sum_window_expr->getOrderKeys(),
95  sum_window_expr->getFrameBoundType(),
96  sum_window_expr->getFrameStartBound()->deep_copy(),
97  sum_window_expr->getFrameEndBound()->deep_copy(),
98  sum_window_expr->getCollation());
99 }
100 
101 std::shared_ptr<Analyzer::WindowFunction> rewrite_avg_window(const Analyzer::Expr* expr) {
102  const auto cast_expr = dynamic_cast<const Analyzer::UOper*>(expr);
103  const auto div_expr = dynamic_cast<const Analyzer::BinOper*>(
104  cast_expr && cast_expr->get_optype() == kCAST ? cast_expr->get_operand() : expr);
105  if (!div_expr || div_expr->get_optype() != kDIVIDE) {
106  return nullptr;
107  }
108  const auto sum_window_expr = rewrite_sum_window(div_expr->get_left_operand());
109  if (!sum_window_expr) {
110  return nullptr;
111  }
112  const auto cast_count_window =
113  dynamic_cast<const Analyzer::UOper*>(div_expr->get_right_operand());
114  if (cast_count_window && cast_count_window->get_optype() != kCAST) {
115  return nullptr;
116  }
117  const auto count_window = dynamic_cast<const Analyzer::WindowFunction*>(
118  cast_count_window ? cast_count_window->get_operand()
119  : div_expr->get_right_operand());
120  if (!count_window || count_window->getKind() != SqlWindowFunctionKind::COUNT) {
121  return nullptr;
122  }
123  CHECK_EQ(count_window->get_type_info().get_type(), kBIGINT);
124  if (cast_count_window && cast_count_window->get_type_info().get_type() !=
125  sum_window_expr->get_type_info().get_type()) {
126  return nullptr;
127  }
128  if (!expr_list_match(sum_window_expr.get()->getArgs(), count_window->getArgs())) {
129  return nullptr;
130  }
131  return makeExpr<Analyzer::WindowFunction>(
134  sum_window_expr->getArgs(),
135  sum_window_expr->getPartitionKeys(),
136  sum_window_expr->getOrderKeys(),
137  sum_window_expr->getFrameBoundType(),
138  sum_window_expr->getFrameStartBound()->deep_copy(),
139  sum_window_expr->getFrameEndBound()->deep_copy(),
140  sum_window_expr->getCollation());
141 }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
const Expr * get_else_expr() const
Definition: Analyzer.h:1387
bool matches_gt_bigint_zero(const Analyzer::BinOper *window_gt_zero)
std::shared_ptr< Analyzer::Expr > remove_cast(const std::shared_ptr< Analyzer::Expr > &expr)
Definition: Analyzer.cpp:4611
std::shared_ptr< Analyzer::WindowFunction > rewrite_avg_window(const Analyzer::Expr *expr)
bool window_sum_and_count_match(const Analyzer::WindowFunction *sum_window_expr, const Analyzer::WindowFunction *count_window_expr)
std::shared_ptr< Analyzer::WindowFunction > rewrite_sum_window(const Analyzer::Expr *expr)
const Expr * get_right_operand() const
Definition: Analyzer.h:456
bool get_is_null() const
Definition: Analyzer.h:347
Definition: sqldefs.h:51
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:391
SQLOps get_optype() const
Definition: Analyzer.h:452
bool matches_else_null(const Analyzer::CaseExpr *case_expr)
const std::vector< std::shared_ptr< Analyzer::Expr > > & getArgs() const
Definition: Analyzer.h:2927
bool expr_list_match(const std::vector< std::shared_ptr< Analyzer::Expr >> &lhs, const std::vector< std::shared_ptr< Analyzer::Expr >> &rhs)
Definition: Analyzer.cpp:4598
bool is_sum_kind(const SqlWindowFunctionKind kind)
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:79
Definition: sqldefs.h:36
SqlWindowFunctionKind
Definition: sqldefs.h:129
#define CHECK(condition)
Definition: Logger.h:291
const std::list< std::pair< std::shared_ptr< Analyzer::Expr >, std::shared_ptr< Analyzer::Expr > > > & get_expr_pair_list() const
Definition: Analyzer.h:1384