OmniSciDB  c07336695a
WindowExpressionRewrite.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2019 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 namespace {
20 
21 // Returns true iff the case expression has an else null branch.
22 bool matches_else_null(const Analyzer::CaseExpr* case_expr) {
23  const auto else_null =
24  dynamic_cast<const Analyzer::Constant*>(case_expr->get_else_expr());
25  return else_null && else_null->get_is_null();
26 }
27 
28 // Returns true iff the expression is a big integer greater than 0.
29 bool matches_gt_bigint_zero(const Analyzer::BinOper* window_gt_zero) {
30  if (window_gt_zero->get_optype() != kGT) {
31  return false;
32  }
33  const auto zero =
34  dynamic_cast<const Analyzer::Constant*>(window_gt_zero->get_right_operand());
35  return zero && zero->get_type_info().get_type() == kBIGINT &&
36  zero->get_constval().bigintval == 0;
37 }
38 
39 // Returns true iff the sum and the count match in type and arguments. Used to replace
40 // combination can be replaced with an explicit average.
42  const Analyzer::WindowFunction* count_window_expr) {
43  CHECK_EQ(count_window_expr->get_type_info().get_type(), kBIGINT);
44  return expr_list_match(sum_window_expr->getArgs(), count_window_expr->getArgs());
45 }
46 
48  return kind == SqlWindowFunctionKind::SUM_INTERNAL ||
50 }
51 
52 } // namespace
53 
54 std::shared_ptr<Analyzer::WindowFunction> rewrite_sum_window(const Analyzer::Expr* expr) {
55  const auto case_expr = dynamic_cast<const Analyzer::CaseExpr*>(expr);
56  if (!case_expr || !matches_else_null(case_expr)) {
57  return nullptr;
58  }
59  const auto& expr_pair_list = case_expr->get_expr_pair_list();
60  if (expr_pair_list.size() != 1) {
61  return nullptr;
62  }
63  const auto& expr_pair = expr_pair_list.front();
64  const auto window_gt_zero =
65  dynamic_cast<const Analyzer::BinOper*>(expr_pair.first.get());
66  if (!window_gt_zero || !matches_gt_bigint_zero(window_gt_zero)) {
67  return nullptr;
68  }
69  const auto sum_window_expr =
70  std::dynamic_pointer_cast<Analyzer::WindowFunction>(remove_cast(expr_pair.second));
71  if (!sum_window_expr || !is_sum_kind(sum_window_expr->getKind())) {
72  return nullptr;
73  }
74  const auto count_window_expr =
75  std::dynamic_pointer_cast<const Analyzer::WindowFunction>(
76  remove_cast(window_gt_zero->get_own_left_operand()));
77  if (!count_window_expr ||
78  count_window_expr->getKind() != SqlWindowFunctionKind::COUNT) {
79  return nullptr;
80  }
81  if (!window_sum_and_count_match(sum_window_expr.get(), count_window_expr.get())) {
82  return nullptr;
83  }
84  CHECK(sum_window_expr);
85  auto sum_ti = sum_window_expr->get_type_info();
86  if (sum_ti.is_integer()) {
87  sum_ti = SQLTypeInfo(kBIGINT, sum_ti.get_notnull());
88  }
89  return makeExpr<Analyzer::WindowFunction>(sum_ti,
91  sum_window_expr->getArgs(),
92  sum_window_expr->getPartitionKeys(),
93  sum_window_expr->getOrderKeys(),
94  sum_window_expr->getCollation());
95 }
96 
97 std::shared_ptr<Analyzer::WindowFunction> rewrite_avg_window(const Analyzer::Expr* expr) {
98  const auto cast_expr = dynamic_cast<const Analyzer::UOper*>(expr);
99  const auto div_expr = dynamic_cast<const Analyzer::BinOper*>(
100  cast_expr && cast_expr->get_optype() == kCAST ? cast_expr->get_operand() : expr);
101  if (!div_expr || div_expr->get_optype() != kDIVIDE) {
102  return nullptr;
103  }
104  const auto sum_window_expr = rewrite_sum_window(div_expr->get_left_operand());
105  if (!sum_window_expr) {
106  return nullptr;
107  }
108  const auto cast_count_window =
109  dynamic_cast<const Analyzer::UOper*>(div_expr->get_right_operand());
110  if (cast_count_window && cast_count_window->get_optype() != kCAST) {
111  return nullptr;
112  }
113  const auto count_window = dynamic_cast<const Analyzer::WindowFunction*>(
114  cast_count_window ? cast_count_window->get_operand()
115  : div_expr->get_right_operand());
116  if (!count_window || count_window->getKind() != SqlWindowFunctionKind::COUNT) {
117  return nullptr;
118  }
119  CHECK_EQ(count_window->get_type_info().get_type(), kBIGINT);
120  if (cast_count_window && cast_count_window->get_type_info().get_type() !=
121  sum_window_expr->get_type_info().get_type()) {
122  return nullptr;
123  }
124  if (!expr_list_match(sum_window_expr.get()->getArgs(), count_window->getArgs())) {
125  return nullptr;
126  }
127  return makeExpr<Analyzer::WindowFunction>(SQLTypeInfo(kDOUBLE),
129  sum_window_expr->getArgs(),
130  sum_window_expr->getPartitionKeys(),
131  sum_window_expr->getOrderKeys(),
132  sum_window_expr->getCollation());
133 }
#define CHECK_EQ(x, y)
Definition: Logger.h:195
bool matches_gt_bigint_zero(const Analyzer::BinOper *window_gt_zero)
std::shared_ptr< Analyzer::Expr > remove_cast(const std::shared_ptr< Analyzer::Expr > &expr)
Definition: Analyzer.cpp:2993
std::shared_ptr< Analyzer::WindowFunction > rewrite_avg_window(const Analyzer::Expr *expr)
bool window_sum_and_count_match(const Analyzer::WindowFunction *sum_window_expr, const Analyzer::WindowFunction *count_window_expr)
std::shared_ptr< Analyzer::WindowFunction > rewrite_sum_window(const Analyzer::Expr *expr)
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:319
Definition: sqldefs.h:49
SQLOps get_optype() const
Definition: Analyzer.h:432
bool matches_else_null(const Analyzer::CaseExpr *case_expr)
const Expr * get_else_expr() const
Definition: Analyzer.h:1044
bool expr_list_match(const std::vector< std::shared_ptr< Analyzer::Expr >> &lhs, const std::vector< std::shared_ptr< Analyzer::Expr >> &rhs)
Definition: Analyzer.cpp:2980
const std::list< std::pair< std::shared_ptr< Analyzer::Expr >, std::shared_ptr< Analyzer::Expr > > > & get_expr_pair_list() const
Definition: Analyzer.h:1041
bool is_sum_kind(const SqlWindowFunctionKind kind)
SQLTypeInfoCore< ArrayContextTypeSizer, ExecutorTypePackaging, DateTimeFacilities > SQLTypeInfo
Definition: sqltypes.h:819
Definition: sqldefs.h:34
bool get_is_null() const
Definition: Analyzer.h:327
SqlWindowFunctionKind
Definition: sqldefs.h:73
const std::vector< std::shared_ptr< Analyzer::Expr > > & getArgs() const
Definition: Analyzer.h:1341
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:77
#define CHECK(condition)
Definition: Logger.h:187
const Expr * get_right_operand() const
Definition: Analyzer.h:436