OmniSciDB  5ade3759e0
ExpressionRewrite.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ExpressionRewrite.h"
18 #include "../Analyzer/Analyzer.h"
19 #include "../Shared/sqldefs.h"
20 #include "DeepCopyVisitor.h"
21 #include "Execute.h"
22 #include "RelAlgTranslator.h"
23 #include "ScalarExprVisitor.h"
24 #include "Shared/Logger.h"
26 
27 #include <unordered_set>
28 
29 namespace {
30 
31 class OrToInVisitor : public ScalarExprVisitor<std::shared_ptr<Analyzer::InValues>> {
32  protected:
33  std::shared_ptr<Analyzer::InValues> visitBinOper(
34  const Analyzer::BinOper* bin_oper) const override {
35  switch (bin_oper->get_optype()) {
36  case kEQ: {
37  const auto rhs_owned = bin_oper->get_own_right_operand();
38  auto rhs_no_cast = extract_cast_arg(rhs_owned.get());
39  if (!dynamic_cast<const Analyzer::Constant*>(rhs_no_cast)) {
40  return nullptr;
41  }
42  const auto arg = bin_oper->get_own_left_operand();
43  const auto& arg_ti = arg->get_type_info();
44  auto rhs = rhs_no_cast->deep_copy()->add_cast(arg_ti);
45  return makeExpr<Analyzer::InValues>(
46  arg, std::list<std::shared_ptr<Analyzer::Expr>>{rhs});
47  }
48  case kOR: {
49  return aggregateResult(visit(bin_oper->get_left_operand()),
50  visit(bin_oper->get_right_operand()));
51  }
52  default:
53  break;
54  }
55  return nullptr;
56  }
57 
58  std::shared_ptr<Analyzer::InValues> visitUOper(
59  const Analyzer::UOper* uoper) const override {
60  return nullptr;
61  }
62 
63  std::shared_ptr<Analyzer::InValues> visitInValues(
64  const Analyzer::InValues*) const override {
65  return nullptr;
66  }
67 
68  std::shared_ptr<Analyzer::InValues> visitInIntegerSet(
69  const Analyzer::InIntegerSet*) const override {
70  return nullptr;
71  }
72 
73  std::shared_ptr<Analyzer::InValues> visitCharLength(
74  const Analyzer::CharLengthExpr*) const override {
75  return nullptr;
76  }
77 
78  std::shared_ptr<Analyzer::InValues> visitKeyForString(
79  const Analyzer::KeyForStringExpr*) const override {
80  return nullptr;
81  }
82 
83  std::shared_ptr<Analyzer::InValues> visitCardinality(
84  const Analyzer::CardinalityExpr*) const override {
85  return nullptr;
86  }
87 
88  std::shared_ptr<Analyzer::InValues> visitLikeExpr(
89  const Analyzer::LikeExpr*) const override {
90  return nullptr;
91  }
92 
93  std::shared_ptr<Analyzer::InValues> visitRegexpExpr(
94  const Analyzer::RegexpExpr*) const override {
95  return nullptr;
96  }
97 
98  std::shared_ptr<Analyzer::InValues> visitCaseExpr(
99  const Analyzer::CaseExpr*) const override {
100  return nullptr;
101  }
102 
103  std::shared_ptr<Analyzer::InValues> visitDatetruncExpr(
104  const Analyzer::DatetruncExpr*) const override {
105  return nullptr;
106  }
107 
108  std::shared_ptr<Analyzer::InValues> visitDatediffExpr(
109  const Analyzer::DatediffExpr*) const override {
110  return nullptr;
111  }
112 
113  std::shared_ptr<Analyzer::InValues> visitDateaddExpr(
114  const Analyzer::DateaddExpr*) const override {
115  return nullptr;
116  }
117 
118  std::shared_ptr<Analyzer::InValues> visitExtractExpr(
119  const Analyzer::ExtractExpr*) const override {
120  return nullptr;
121  }
122 
123  std::shared_ptr<Analyzer::InValues> visitLikelihood(
124  const Analyzer::LikelihoodExpr*) const override {
125  return nullptr;
126  }
127 
128  std::shared_ptr<Analyzer::InValues> visitAggExpr(
129  const Analyzer::AggExpr*) const override {
130  return nullptr;
131  }
132 
133  std::shared_ptr<Analyzer::InValues> aggregateResult(
134  const std::shared_ptr<Analyzer::InValues>& lhs,
135  const std::shared_ptr<Analyzer::InValues>& rhs) const override {
136  if (!lhs || !rhs) {
137  return nullptr;
138  }
139  if (!(*lhs->get_arg() == *rhs->get_arg())) {
140  return nullptr;
141  }
142  auto union_values = lhs->get_value_list();
143  const auto& rhs_values = rhs->get_value_list();
144  union_values.insert(union_values.end(), rhs_values.begin(), rhs_values.end());
145  return makeExpr<Analyzer::InValues>(lhs->get_own_arg(), union_values);
146  }
147 };
148 
150  protected:
151  std::shared_ptr<Analyzer::Expr> visitBinOper(
152  const Analyzer::BinOper* bin_oper) const override {
153  OrToInVisitor simple_visitor;
154  if (bin_oper->get_optype() == kOR) {
155  auto rewritten = simple_visitor.visit(bin_oper);
156  if (rewritten) {
157  return rewritten;
158  }
159  }
160  auto lhs = bin_oper->get_own_left_operand();
161  auto rhs = bin_oper->get_own_right_operand();
162  auto rewritten_lhs = visit(lhs.get());
163  auto rewritten_rhs = visit(rhs.get());
164  return makeExpr<Analyzer::BinOper>(bin_oper->get_type_info(),
165  bin_oper->get_contains_agg(),
166  bin_oper->get_optype(),
167  bin_oper->get_qualifier(),
168  rewritten_lhs ? rewritten_lhs : lhs,
169  rewritten_rhs ? rewritten_rhs : rhs);
170  }
171 };
172 
174  protected:
176 
177  RetType visitArrayOper(const Analyzer::ArrayExpr* array_expr) const override {
178  std::vector<std::shared_ptr<Analyzer::Expr>> args_copy;
179  for (size_t i = 0; i < array_expr->getElementCount(); ++i) {
180  auto const element_expr_ptr = visit(array_expr->getElement(i));
181  auto const& element_expr_type_info = element_expr_ptr->get_type_info();
182 
183  if (!element_expr_type_info.is_string() ||
184  element_expr_type_info.get_compression() != kENCODING_NONE) {
185  args_copy.push_back(element_expr_ptr);
186  } else {
187  auto transient_dict_type_info = element_expr_type_info;
188 
189  transient_dict_type_info.set_compression(kENCODING_DICT);
190  transient_dict_type_info.set_comp_param(TRANSIENT_DICT_ID);
191  transient_dict_type_info.set_fixed_size();
192  args_copy.push_back(element_expr_ptr->add_cast(transient_dict_type_info));
193  }
194  }
195 
196  const auto& type_info = array_expr->get_type_info();
197  return makeExpr<Analyzer::ArrayExpr>(
198  type_info, args_copy, array_expr->getExprIndex(), array_expr->isLocalAlloc());
199  }
200 };
201 
203  template <typename T>
204  bool foldComparison(SQLOps optype, T t1, T t2) const {
205  switch (optype) {
206  case kEQ:
207  return t1 == t2;
208  case kNE:
209  return t1 != t2;
210  case kLT:
211  return t1 < t2;
212  case kLE:
213  return t1 <= t2;
214  case kGT:
215  return t1 > t2;
216  case kGE:
217  return t1 >= t2;
218  default:
219  break;
220  }
221  throw std::runtime_error("Unable to fold");
222  return false;
223  }
224 
225  template <typename T>
226  bool foldLogic(SQLOps optype, T t1, T t2) const {
227  switch (optype) {
228  case kAND:
229  return t1 && t2;
230  case kOR:
231  return t1 || t2;
232  case kNOT:
233  return !t1;
234  default:
235  break;
236  }
237  throw std::runtime_error("Unable to fold");
238  return false;
239  }
240 
241  template <typename T>
242  T foldArithmetic(SQLOps optype, T t1, T t2) const {
243  bool t2_is_zero = (t2 == (t2 - t2));
244  bool t2_is_negative = (t2 < (t2 - t2));
245  switch (optype) {
246  case kPLUS:
247  // The MIN limit for float and double is the smallest representable value,
248  // not the lowest negative value! Switching to C++11 lowest.
249  if ((t2_is_negative && t1 < std::numeric_limits<T>::lowest() - t2) ||
250  (!t2_is_negative && t1 > std::numeric_limits<T>::max() - t2)) {
251  num_overflows_++;
252  throw std::runtime_error("Plus overflow");
253  }
254  return t1 + t2;
255  case kMINUS:
256  if ((t2_is_negative && t1 > std::numeric_limits<T>::max() + t2) ||
257  (!t2_is_negative && t1 < std::numeric_limits<T>::lowest() + t2)) {
258  num_overflows_++;
259  throw std::runtime_error("Minus overflow");
260  }
261  return t1 - t2;
262  case kMULTIPLY: {
263  if (t2_is_zero) {
264  return t2;
265  }
266  auto ct1 = t1;
267  auto ct2 = t2;
268  // Need to keep t2's sign on the left
269  if (t2_is_negative) {
270  if (t1 == std::numeric_limits<T>::lowest() ||
271  t2 == std::numeric_limits<T>::lowest()) {
272  // negation could overflow - bail
273  num_overflows_++;
274  throw std::runtime_error("Mul neg overflow");
275  }
276  ct1 = -t1; // ct1 gets t2's negativity
277  ct2 = -t2; // ct2 is now positive
278  }
279  // Don't check overlow if we are folding FP mul by a fraction
280  bool ct2_is_fraction = (ct2 < (ct2 / ct2));
281  if (!ct2_is_fraction) {
282  if (ct1 > std::numeric_limits<T>::max() / ct2 ||
283  ct1 < std::numeric_limits<T>::lowest() / ct2) {
284  num_overflows_++;
285  throw std::runtime_error("Mul overflow");
286  }
287  }
288  return t1 * t2;
289  }
290  case kDIVIDE:
291  if (t2_is_zero) {
292  throw std::runtime_error("Will not fold division by zero");
293  }
294  return t1 / t2;
295  default:
296  break;
297  }
298  throw std::runtime_error("Unable to fold");
299  }
300 
301  bool foldOper(SQLOps optype,
302  SQLTypes type,
303  Datum lhs,
304  Datum rhs,
305  Datum& result,
306  SQLTypes& result_type) const {
307  result_type = type;
308 
309  try {
310  switch (type) {
311  case kBOOLEAN:
312  if (IS_COMPARISON(optype)) {
313  result.boolval = foldComparison<bool>(optype, lhs.boolval, rhs.boolval);
314  result_type = kBOOLEAN;
315  return true;
316  }
317  if (IS_LOGIC(optype)) {
318  result.boolval = foldLogic<bool>(optype, lhs.boolval, rhs.boolval);
319  result_type = kBOOLEAN;
320  return true;
321  }
322  CHECK(!IS_ARITHMETIC(optype));
323  break;
324  case kTINYINT:
325  if (IS_COMPARISON(optype)) {
326  result.boolval =
327  foldComparison<int8_t>(optype, lhs.tinyintval, rhs.tinyintval);
328  result_type = kBOOLEAN;
329  return true;
330  }
331  if (IS_ARITHMETIC(optype)) {
332  result.tinyintval =
333  foldArithmetic<int8_t>(optype, lhs.tinyintval, rhs.tinyintval);
334  result_type = kTINYINT;
335  return true;
336  }
337  CHECK(!IS_LOGIC(optype));
338  break;
339  case kSMALLINT:
340  if (IS_COMPARISON(optype)) {
341  result.boolval =
342  foldComparison<int16_t>(optype, lhs.smallintval, rhs.smallintval);
343  result_type = kBOOLEAN;
344  return true;
345  }
346  if (IS_ARITHMETIC(optype)) {
347  result.smallintval =
348  foldArithmetic<int16_t>(optype, lhs.smallintval, rhs.smallintval);
349  result_type = kSMALLINT;
350  return true;
351  }
352  CHECK(!IS_LOGIC(optype));
353  break;
354  case kINT:
355  if (IS_COMPARISON(optype)) {
356  result.boolval = foldComparison<int32_t>(optype, lhs.intval, rhs.intval);
357  result_type = kBOOLEAN;
358  return true;
359  }
360  if (IS_ARITHMETIC(optype)) {
361  result.intval = foldArithmetic<int32_t>(optype, lhs.intval, rhs.intval);
362  result_type = kINT;
363  return true;
364  }
365  CHECK(!IS_LOGIC(optype));
366  break;
367  case kBIGINT:
368  if (IS_COMPARISON(optype)) {
369  result.boolval =
370  foldComparison<int64_t>(optype, lhs.bigintval, rhs.bigintval);
371  result_type = kBOOLEAN;
372  return true;
373  }
374  if (IS_ARITHMETIC(optype)) {
375  result.bigintval =
376  foldArithmetic<int64_t>(optype, lhs.bigintval, rhs.bigintval);
377  result_type = kBIGINT;
378  return true;
379  }
380  CHECK(!IS_LOGIC(optype));
381  break;
382  case kFLOAT:
383  if (IS_COMPARISON(optype)) {
384  result.boolval = foldComparison<float>(optype, lhs.floatval, rhs.floatval);
385  result_type = kBOOLEAN;
386  return true;
387  }
388  if (IS_ARITHMETIC(optype)) {
389  result.floatval = foldArithmetic<float>(optype, lhs.floatval, rhs.floatval);
390  result_type = kFLOAT;
391  return true;
392  }
393  CHECK(!IS_LOGIC(optype));
394  break;
395  case kDOUBLE:
396  if (IS_COMPARISON(optype)) {
397  result.boolval = foldComparison<double>(optype, lhs.doubleval, rhs.doubleval);
398  result_type = kBOOLEAN;
399  return true;
400  }
401  if (IS_ARITHMETIC(optype)) {
402  result.doubleval =
403  foldArithmetic<double>(optype, lhs.doubleval, rhs.doubleval);
404  result_type = kDOUBLE;
405  return true;
406  }
407  CHECK(!IS_LOGIC(optype));
408  break;
409  default:
410  break;
411  }
412  } catch (...) {
413  return false;
414  }
415  return false;
416  }
417 
418  std::shared_ptr<Analyzer::Expr> visitUOper(
419  const Analyzer::UOper* uoper) const override {
420  const auto unvisited_operand = uoper->get_operand();
421  const auto optype = uoper->get_optype();
422  const auto& ti = uoper->get_type_info();
423  if (optype == kCAST) {
424  // Cache the cast type so it could be used in operand rewriting/folding
425  casts_.insert({unvisited_operand, ti});
426  }
427  const auto operand = visit(unvisited_operand);
428  const auto& operand_ti = operand->get_type_info();
429  const auto operand_type =
430  operand_ti.is_decimal() ? decimal_to_int_type(operand_ti) : operand_ti.get_type();
431  const auto const_operand =
432  std::dynamic_pointer_cast<const Analyzer::Constant>(operand);
433 
434  if (const_operand) {
435  const auto operand_datum = const_operand->get_constval();
436  Datum zero_datum = {};
437  Datum result_datum = {};
438  SQLTypes result_type;
439  switch (optype) {
440  case kNOT: {
441  if (foldOper(kEQ,
442  operand_type,
443  zero_datum,
444  operand_datum,
445  result_datum,
446  result_type)) {
447  CHECK_EQ(result_type, kBOOLEAN);
448  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
449  }
450  break;
451  }
452  case kUMINUS: {
453  if (foldOper(kMINUS,
454  operand_type,
455  zero_datum,
456  operand_datum,
457  result_datum,
458  result_type)) {
459  if (!operand_ti.is_decimal()) {
460  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
461  }
462  return makeExpr<Analyzer::Constant>(ti, false, result_datum);
463  }
464  break;
465  }
466  case kCAST: {
467  // Trying to fold number to number casts only
468  if (!ti.is_number() || !operand_ti.is_number()) {
469  break;
470  }
471  // Disallowing folding of FP to DECIMAL casts for now:
472  // allowing them would make this test pass:
473  // update dectest set d=cast( 1234.0 as float );
474  // which is expected to throw in Update.ImplicitCastToNumericTypes
475  // due to cast codegen currently not supporting these casts
476  if (ti.is_decimal() && operand_ti.is_fp()) {
477  break;
478  }
479  auto operand_copy = const_operand->deep_copy();
480  auto cast_operand = operand_copy->add_cast(ti);
481  auto const_cast_operand =
482  std::dynamic_pointer_cast<const Analyzer::Constant>(cast_operand);
483  if (const_cast_operand) {
484  auto const_cast_datum = const_cast_operand->get_constval();
485  return makeExpr<Analyzer::Constant>(ti, false, const_cast_datum);
486  }
487  }
488  default:
489  break;
490  }
491  }
492 
493  return makeExpr<Analyzer::UOper>(
494  uoper->get_type_info(), uoper->get_contains_agg(), optype, operand);
495  }
496 
497  std::shared_ptr<Analyzer::Expr> visitBinOper(
498  const Analyzer::BinOper* bin_oper) const override {
499  const auto optype = bin_oper->get_optype();
500  auto ti = bin_oper->get_type_info();
501  auto left_operand = bin_oper->get_own_left_operand();
502  auto right_operand = bin_oper->get_own_right_operand();
503 
504  // Check if bin_oper result is cast to a larger int or fp type
505  if (casts_.find(bin_oper) != casts_.end()) {
506  const auto cast_ti = casts_[bin_oper];
507  const auto& lhs_ti = bin_oper->get_left_operand()->get_type_info();
508  // Propagate cast down to the operands for folding
509  if ((cast_ti.is_integer() || cast_ti.is_fp()) && lhs_ti.is_integer() &&
510  cast_ti.get_size() > lhs_ti.get_size() &&
511  (optype == kMINUS || optype == kPLUS || optype == kMULTIPLY)) {
512  // Before folding, cast the operands to the bigger type to avoid overflows.
513  // Currently upcasting smaller integer types to larger integers or double.
514  left_operand = left_operand->deep_copy()->add_cast(cast_ti);
515  right_operand = right_operand->deep_copy()->add_cast(cast_ti);
516  ti = cast_ti;
517  }
518  }
519 
520  const auto lhs = visit(left_operand.get());
521  const auto rhs = visit(right_operand.get());
522 
523  auto const_lhs = std::dynamic_pointer_cast<Analyzer::Constant>(lhs);
524  auto const_rhs = std::dynamic_pointer_cast<Analyzer::Constant>(rhs);
525  const auto& lhs_ti = lhs->get_type_info();
526  const auto& rhs_ti = rhs->get_type_info();
527  auto lhs_type = lhs_ti.is_decimal() ? decimal_to_int_type(lhs_ti) : lhs_ti.get_type();
528  auto rhs_type = rhs_ti.is_decimal() ? decimal_to_int_type(rhs_ti) : rhs_ti.get_type();
529 
530  if (const_lhs && const_rhs && lhs_type == rhs_type) {
531  auto lhs_datum = const_lhs->get_constval();
532  auto rhs_datum = const_rhs->get_constval();
533  Datum result_datum = {};
534  SQLTypes result_type;
535  if (foldOper(optype, lhs_type, lhs_datum, rhs_datum, result_datum, result_type)) {
536  // Fold all ops that don't take in decimal operands, and also decimal comparisons
537  if (!lhs_ti.is_decimal() || IS_COMPARISON(optype)) {
538  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
539  }
540  // Decimal arithmetic has been done as kBIGINT. Selectively fold some decimal ops,
541  // using result_datum and BinOper expr typeinfo which was adjusted for these ops.
542  if (optype == kMINUS || optype == kPLUS || optype == kMULTIPLY) {
543  return makeExpr<Analyzer::Constant>(ti, false, result_datum);
544  }
545  }
546  }
547 
548  if (optype == kAND && lhs_type == rhs_type && lhs_type == kBOOLEAN) {
549  if (const_rhs) {
550  auto rhs_datum = const_rhs->get_constval();
551  if (rhs_datum.boolval == false) {
552  Datum d;
553  d.boolval = false;
554  // lhs && false --> false
555  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
556  }
557  // lhs && true --> lhs
558  return lhs;
559  }
560  if (const_lhs) {
561  auto lhs_datum = const_lhs->get_constval();
562  if (lhs_datum.boolval == false) {
563  Datum d;
564  d.boolval = false;
565  // false && rhs --> false
566  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
567  }
568  // true && rhs --> rhs
569  return rhs;
570  }
571  }
572  if (optype == kOR && lhs_type == rhs_type && lhs_type == kBOOLEAN) {
573  if (const_rhs) {
574  auto rhs_datum = const_rhs->get_constval();
575  if (rhs_datum.boolval == true) {
576  Datum d;
577  d.boolval = true;
578  // lhs || true --> true
579  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
580  }
581  // lhs || false --> lhs
582  return lhs;
583  }
584  if (const_lhs) {
585  auto lhs_datum = const_lhs->get_constval();
586  if (lhs_datum.boolval == true) {
587  Datum d;
588  d.boolval = true;
589  // true || rhs --> true
590  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
591  }
592  // false || rhs --> rhs
593  return rhs;
594  }
595  }
596  if (*lhs == *rhs) {
597  // Tautologies: v=v; v<=v; v>=v
598  if (optype == kEQ || optype == kLE || optype == kGE) {
599  Datum d;
600  d.boolval = true;
601  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
602  }
603  // Contradictions: v!=v; v<v; v>v
604  if (optype == kNE || optype == kLT || optype == kGT) {
605  Datum d;
606  d.boolval = false;
607  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
608  }
609  // v-v
610  if (optype == kMINUS) {
611  Datum d = {};
612  return makeExpr<Analyzer::Constant>(lhs_type, false, d);
613  }
614  }
615  // Convert fp division by a constant to multiplication by 1/constant
616  if (optype == kDIVIDE && const_rhs && rhs_ti.is_fp()) {
617  auto rhs_datum = const_rhs->get_constval();
618  std::shared_ptr<Analyzer::Expr> recip_rhs = nullptr;
619  if (rhs_ti.get_type() == kFLOAT) {
620  if (rhs_datum.floatval == 1.0) {
621  return lhs;
622  }
623  auto f = std::fabs(rhs_datum.floatval);
624  if (f > 1.0 || (f != 0.0 && 1.0 < f * std::numeric_limits<float>::max())) {
625  rhs_datum.floatval = 1.0 / rhs_datum.floatval;
626  recip_rhs = makeExpr<Analyzer::Constant>(rhs_type, false, rhs_datum);
627  }
628  } else if (rhs_ti.get_type() == kDOUBLE) {
629  if (rhs_datum.doubleval == 1.0) {
630  return lhs;
631  }
632  auto d = std::fabs(rhs_datum.doubleval);
633  if (d > 1.0 || (d != 0.0 && 1.0 < d * std::numeric_limits<double>::max())) {
634  rhs_datum.doubleval = 1.0 / rhs_datum.doubleval;
635  recip_rhs = makeExpr<Analyzer::Constant>(rhs_type, false, rhs_datum);
636  }
637  }
638  if (recip_rhs) {
639  return makeExpr<Analyzer::BinOper>(ti,
640  bin_oper->get_contains_agg(),
641  kMULTIPLY,
642  bin_oper->get_qualifier(),
643  lhs,
644  recip_rhs);
645  }
646  }
647 
648  return makeExpr<Analyzer::BinOper>(ti,
649  bin_oper->get_contains_agg(),
650  bin_oper->get_optype(),
651  bin_oper->get_qualifier(),
652  lhs,
653  rhs);
654  }
655 
656  protected:
657  mutable std::unordered_map<const Analyzer::Expr*, const SQLTypeInfo> casts_;
658  mutable int32_t num_overflows_;
659 
660  public:
661  ConstantFoldingVisitor() : num_overflows_(0) {}
662  int32_t get_num_overflows() { return num_overflows_; }
663  void reset_num_overflows() { num_overflows_ = 0; }
664 };
665 
667  const auto with_likelihood = dynamic_cast<const Analyzer::LikelihoodExpr*>(expr);
668  if (!with_likelihood) {
669  return expr;
670  }
671  return with_likelihood->get_arg();
672 }
673 
674 } // namespace
675 
677  return ArrayElementStringLiteralEncodingVisitor().visit(expr);
678 }
679 
681  const auto sum_window = rewrite_sum_window(expr);
682  if (sum_window) {
683  return sum_window;
684  }
685  const auto avg_window = rewrite_avg_window(expr);
686  if (avg_window) {
687  return avg_window;
688  }
689  const auto expr_no_likelihood = strip_likelihood(expr);
690  // The following check is not strictly needed, but seems silly to transform a
691  // simple string comparison to an IN just to codegen the same thing anyway.
692 
693  RecursiveOrToInVisitor visitor;
694  auto rewritten_expr = visitor.visit(expr_no_likelihood);
695  const auto expr_with_likelihood =
696  std::dynamic_pointer_cast<const Analyzer::LikelihoodExpr>(rewritten_expr);
697  if (expr_with_likelihood) {
698  // Add back likelihood
699  return std::make_shared<Analyzer::LikelihoodExpr>(
700  rewritten_expr, expr_with_likelihood->get_likelihood());
701  }
702  return rewritten_expr;
703 }
704 
705 namespace {
706 
707 static const std::unordered_set<std::string> overlaps_supported_functions = {
708  "ST_Contains_MultiPolygon_Point",
709  "ST_Contains_Polygon_Point"};
710 
711 } // namespace
712 
713 boost::optional<OverlapsJoinConjunction> rewrite_overlaps_conjunction(
714  const std::shared_ptr<Analyzer::Expr> expr) {
715  auto func_oper = dynamic_cast<Analyzer::FunctionOper*>(expr.get());
716  if (func_oper) {
717  // TODO(adb): consider converting unordered set to an unordered map, potentially
718  // storing the rewrite function we want to apply in the map
719  if (overlaps_supported_functions.find(func_oper->getName()) !=
721  CHECK_GE(func_oper->getArity(), size_t(3));
722 
723  DeepCopyVisitor deep_copy_visitor;
724  auto lhs = func_oper->getOwnArg(2);
725  auto rewritten_lhs = deep_copy_visitor.visit(lhs.get());
726  CHECK(rewritten_lhs);
727  const auto& lhs_ti = rewritten_lhs->get_type_info();
728  if (!lhs_ti.is_geometry()) {
729  // TODO(adb): If ST_Contains is passed geospatial literals instead of columns, the
730  // function will be expanded during translation rather than during code
731  // generation. While this scenario does not make sense for the overlaps join, we
732  // need to detect and abort the overlaps rewrite. Adding a GeospatialConstant
733  // dervied class to the Analyzer may prove to be a better way to handle geo
734  // literals, but for now we ensure the LHS type is a geospatial type, which would
735  // mean the function has not been expanded to the physical types, yet.
736  LOG(WARNING) << "Failed to rewrite " << func_oper->getName()
737  << " to overlaps conjunction. LHS input type is not a geospatial "
738  "type. Are both inputs geospatial columns?";
739  return boost::none;
740  }
741 
742  // Read the bounds arg from the ST_Contains FuncOper (second argument)instead of the
743  // poly column (first argument)
744  auto rhs = func_oper->getOwnArg(1);
745  auto rewritten_rhs = deep_copy_visitor.visit(rhs.get());
746  CHECK(rewritten_rhs);
747 
748  auto overlaps_oper = makeExpr<Analyzer::BinOper>(
749  kBOOLEAN, kOVERLAPS, kONE, rewritten_lhs, rewritten_rhs);
750 
751  return OverlapsJoinConjunction{{expr}, {overlaps_oper}};
752  }
753  }
754  return boost::none;
755 }
756 
766  public:
768  for (const auto& join_condition : join_quals) {
769  for (const auto& qual : join_condition.quals) {
770  auto qual_bin_oper = dynamic_cast<Analyzer::BinOper*>(qual.get());
771  if (qual_bin_oper) {
772  join_qual_pairs.emplace_back(qual_bin_oper->get_left_operand(),
773  qual_bin_oper->get_right_operand());
774  }
775  }
776  }
777  }
778 
779  bool visitFunctionOper(const Analyzer::FunctionOper* func_oper) const override {
780  if (overlaps_supported_functions.find(func_oper->getName()) !=
782  const auto lhs = func_oper->getArg(2);
783  const auto rhs = func_oper->getArg(1);
784  for (const auto qual_pair : join_qual_pairs) {
785  if (*lhs == *qual_pair.first && *rhs == *qual_pair.second) {
786  return true;
787  }
788  }
789  }
790  return false;
791  }
792 
793  bool defaultResult() const override { return false; }
794 
795  private:
796  std::vector<std::pair<const Analyzer::Expr*, const Analyzer::Expr*>> join_qual_pairs;
797 };
798 
799 std::list<std::shared_ptr<Analyzer::Expr>> strip_join_covered_filter_quals(
800  const std::list<std::shared_ptr<Analyzer::Expr>>& quals,
801  const JoinQualsPerNestingLevel& join_quals) {
803  return quals;
804  }
805 
806  if (join_quals.empty()) {
807  return quals;
808  }
809 
810  std::list<std::shared_ptr<Analyzer::Expr>> quals_to_return;
811 
812  JoinCoveredQualVisitor visitor(join_quals);
813  for (const auto& qual : quals) {
814  if (!visitor.visit(qual.get())) {
815  // Not a covered qual, don't elide it from the filtered count
816  quals_to_return.push_back(qual);
817  }
818  }
819 
820  return quals_to_return;
821 }
822 
823 std::shared_ptr<Analyzer::Expr> fold_expr(const Analyzer::Expr* expr) {
824  if (!expr) {
825  return nullptr;
826  }
827  const auto expr_no_likelihood = strip_likelihood(expr);
828  ConstantFoldingVisitor visitor;
829  auto rewritten_expr = visitor.visit(expr_no_likelihood);
830  if (visitor.get_num_overflows() > 0 && rewritten_expr->get_type_info().is_integer() &&
831  rewritten_expr->get_type_info().get_type() != kBIGINT) {
832  auto rewritten_expr_const =
833  std::dynamic_pointer_cast<const Analyzer::Constant>(rewritten_expr);
834  if (!rewritten_expr_const) {
835  // Integer expression didn't fold completely the first time due to
836  // overflows in smaller type subexpressions, trying again with a cast
837  const auto& ti = SQLTypeInfo(kBIGINT, false);
838  auto bigint_expr_no_likelihood = expr_no_likelihood->deep_copy()->add_cast(ti);
839  auto rewritten_expr_take2 = visitor.visit(bigint_expr_no_likelihood.get());
840  auto rewritten_expr_take2_const =
841  std::dynamic_pointer_cast<Analyzer::Constant>(rewritten_expr_take2);
842  if (rewritten_expr_take2_const) {
843  // Managed to fold, switch to the new constant
844  rewritten_expr = rewritten_expr_take2_const;
845  }
846  }
847  }
848  const auto expr_with_likelihood = dynamic_cast<const Analyzer::LikelihoodExpr*>(expr);
849  if (expr_with_likelihood) {
850  // Add back likelihood
851  return std::make_shared<Analyzer::LikelihoodExpr>(
852  rewritten_expr, expr_with_likelihood->get_likelihood());
853  }
854  return rewritten_expr;
855 }
int8_t tinyintval
Definition: sqltypes.h:123
SQLQualifier get_qualifier() const
Definition: Analyzer.h:434
Analyzer::ExpressionPtr rewrite_array_elements(Analyzer::Expr const *expr)
#define CHECK_EQ(x, y)
Definition: Logger.h:195
#define IS_LOGIC(X)
Definition: sqldefs.h:60
std::shared_ptr< Analyzer::InValues > visitLikeExpr(const Analyzer::LikeExpr *) const override
std::shared_ptr< Analyzer::InValues > visitDateaddExpr(const Analyzer::DateaddExpr *) const override
void d(const SQLTypes expected_type, const std::string &str)
Definition: ImportTest.cpp:268
bool foldOper(SQLOps optype, SQLTypes type, Datum lhs, Datum rhs, Datum &result, SQLTypes &result_type) const
size_t getElementCount() const
Definition: Analyzer.h:1380
SQLTypes
Definition: sqltypes.h:40
std::shared_ptr< Analyzer::WindowFunction > rewrite_avg_window(const Analyzer::Expr *expr)
bool g_strip_join_covered_quals
Definition: Execute.cpp:88
std::string getName() const
Definition: Analyzer.h:1250
std::shared_ptr< Analyzer::InValues > visitKeyForString(const Analyzer::KeyForStringExpr *) const override
std::shared_ptr< Analyzer::WindowFunction > rewrite_sum_window(const Analyzer::Expr *expr)
const Analyzer::Expr * getArg(const size_t i) const
Definition: Analyzer.h:1254
std::shared_ptr< Analyzer::InValues > visitBinOper(const Analyzer::BinOper *bin_oper) const override
#define LOG(tag)
Definition: Logger.h:182
bool boolval
Definition: sqltypes.h:122
SQLOps
Definition: sqldefs.h:29
Definition: sqldefs.h:35
Definition: sqldefs.h:36
Definition: sqldefs.h:38
bool isLocalAlloc() const
Definition: Analyzer.h:1382
#define CHECK_GE(x, y)
Definition: Logger.h:200
Definition: sqldefs.h:49
Definition: sqldefs.h:30
std::shared_ptr< Analyzer::Expr > ExpressionPtr
Definition: Analyzer.h:179
const Analyzer::Expr * extract_cast_arg(const Analyzer::Expr *expr)
Definition: Execute.h:149
Definition: sqldefs.h:41
std::vector< JoinCondition > JoinQualsPerNestingLevel
Analyzer::ExpressionPtr rewrite_expr(const Analyzer::Expr *expr)
int32_t getExprIndex() const
Definition: Analyzer.h:1381
static const std::unordered_set< std::string > overlaps_supported_functions
std::shared_ptr< Analyzer::InValues > aggregateResult(const std::shared_ptr< Analyzer::InValues > &lhs, const std::shared_ptr< Analyzer::InValues > &rhs) const override
std::list< std::shared_ptr< Analyzer::Expr > > strip_join_covered_filter_quals(const std::list< std::shared_ptr< Analyzer::Expr >> &quals, const JoinQualsPerNestingLevel &join_quals)
std::shared_ptr< Analyzer::InValues > visitUOper(const Analyzer::UOper *uoper) const override
std::shared_ptr< Analyzer::Expr > visitBinOper(const Analyzer::BinOper *bin_oper) const override
std::shared_ptr< Analyzer::Expr > RetType
int32_t intval
Definition: sqltypes.h:125
SQLOps get_optype() const
Definition: Analyzer.h:432
std::shared_ptr< Analyzer::InValues > visitCharLength(const Analyzer::CharLengthExpr *) const override
std::shared_ptr< Analyzer::InValues > visitInIntegerSet(const Analyzer::InIntegerSet *) const override
float floatval
Definition: sqltypes.h:127
Datum get_constval() const
Definition: Analyzer.h:328
std::shared_ptr< Analyzer::Expr > visitUOper(const Analyzer::UOper *uoper) const override
const Analyzer::Expr * getElement(const size_t i) const
Definition: Analyzer.h:1384
std::shared_ptr< Analyzer::InValues > visitRegexpExpr(const Analyzer::RegexpExpr *) const override
bool is_decimal() const
Definition: sqltypes.h:453
int64_t bigintval
Definition: sqltypes.h:126
std::shared_ptr< Analyzer::InValues > visitAggExpr(const Analyzer::AggExpr *) const override
bool visitFunctionOper(const Analyzer::FunctionOper *func_oper) const override
Definition: sqldefs.h:37
int16_t smallintval
Definition: sqltypes.h:124
bool defaultResult() const override
boost::optional< OverlapsJoinConjunction > rewrite_overlaps_conjunction(const std::shared_ptr< Analyzer::Expr > expr)
RetType visitArrayOper(const Analyzer::ArrayExpr *array_expr) const override
SQLTypeInfoCore< ArrayContextTypeSizer, ExecutorTypePackaging, DateTimeFacilities > SQLTypeInfo
Definition: sqltypes.h:823
const std::shared_ptr< Analyzer::Expr > get_own_left_operand() const
Definition: Analyzer.h:437
std::shared_ptr< Analyzer::Expr > visitBinOper(const Analyzer::BinOper *bin_oper) const override
SQLTypes decimal_to_int_type(const SQLTypeInfo &ti)
Definition: Datum.cpp:268
Definition: sqldefs.h:34
const std::shared_ptr< Analyzer::Expr > get_own_right_operand() const
Definition: Analyzer.h:440
Definition: sqldefs.h:40
std::vector< std::pair< const Analyzer::Expr *, const Analyzer::Expr * > > join_qual_pairs
Definition: sqldefs.h:69
#define TRANSIENT_DICT_ID
Definition: sqltypes.h:186
#define IS_ARITHMETIC(X)
Definition: sqldefs.h:61
Definition: sqldefs.h:32
T visit(const Analyzer::Expr *expr) const
std::unordered_map< const Analyzer::Expr *, const SQLTypeInfo > casts_
std::shared_ptr< Analyzer::InValues > visitDatediffExpr(const Analyzer::DatediffExpr *) const override
std::shared_ptr< Analyzer::InValues > visitLikelihood(const Analyzer::LikelihoodExpr *) const override
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:77
std::shared_ptr< Analyzer::InValues > visitInValues(const Analyzer::InValues *) const override
#define CHECK(condition)
Definition: Logger.h:187
Definition: sqldefs.h:33
bool get_contains_agg() const
Definition: Analyzer.h:79
std::shared_ptr< Analyzer::InValues > visitCaseExpr(const Analyzer::CaseExpr *) const override
Definition: sqltypes.h:47
JoinCoveredQualVisitor(const JoinQualsPerNestingLevel &join_quals)
std::shared_ptr< Analyzer::InValues > visitDatetruncExpr(const Analyzer::DatetruncExpr *) const override
Definition: sqldefs.h:39
std::shared_ptr< Analyzer::Expr > fold_expr(const Analyzer::Expr *expr)
SQLOps get_optype() const
Definition: Analyzer.h:363
#define IS_COMPARISON(X)
Definition: sqldefs.h:57
double doubleval
Definition: sqltypes.h:128
std::shared_ptr< Analyzer::InValues > visitExtractExpr(const Analyzer::ExtractExpr *) const override
const Analyzer::Expr * strip_likelihood(const Analyzer::Expr *expr)
const Expr * get_right_operand() const
Definition: Analyzer.h:436
std::shared_ptr< Analyzer::InValues > visitCardinality(const Analyzer::CardinalityExpr *) const override
const Expr * get_left_operand() const
Definition: Analyzer.h:435
const Expr * get_operand() const
Definition: Analyzer.h:364