OmniSciDB  addbbd5075
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ExpressionRewrite.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ExpressionRewrite.h"
18 #include "../Analyzer/Analyzer.h"
19 #include "../Parser/ParserNode.h"
20 #include "../Shared/sqldefs.h"
21 #include "DeepCopyVisitor.h"
22 #include "Execute.h"
23 #include "RelAlgTranslator.h"
24 #include "ScalarExprVisitor.h"
25 #include "Shared/Logger.h"
27 
28 #include <boost/locale/conversion.hpp>
29 #include <unordered_set>
30 
31 namespace {
32 
33 class OrToInVisitor : public ScalarExprVisitor<std::shared_ptr<Analyzer::InValues>> {
34  protected:
35  std::shared_ptr<Analyzer::InValues> visitBinOper(
36  const Analyzer::BinOper* bin_oper) const override {
37  switch (bin_oper->get_optype()) {
38  case kEQ: {
39  const auto rhs_owned = bin_oper->get_own_right_operand();
40  auto rhs_no_cast = extract_cast_arg(rhs_owned.get());
41  if (!dynamic_cast<const Analyzer::Constant*>(rhs_no_cast)) {
42  return nullptr;
43  }
44  const auto arg = bin_oper->get_own_left_operand();
45  const auto& arg_ti = arg->get_type_info();
46  auto rhs = rhs_no_cast->deep_copy()->add_cast(arg_ti);
47  return makeExpr<Analyzer::InValues>(
48  arg, std::list<std::shared_ptr<Analyzer::Expr>>{rhs});
49  }
50  case kOR: {
51  return aggregateResult(visit(bin_oper->get_left_operand()),
52  visit(bin_oper->get_right_operand()));
53  }
54  default:
55  break;
56  }
57  return nullptr;
58  }
59 
60  std::shared_ptr<Analyzer::InValues> visitUOper(
61  const Analyzer::UOper* uoper) const override {
62  return nullptr;
63  }
64 
65  std::shared_ptr<Analyzer::InValues> visitInValues(
66  const Analyzer::InValues*) const override {
67  return nullptr;
68  }
69 
70  std::shared_ptr<Analyzer::InValues> visitInIntegerSet(
71  const Analyzer::InIntegerSet*) const override {
72  return nullptr;
73  }
74 
75  std::shared_ptr<Analyzer::InValues> visitCharLength(
76  const Analyzer::CharLengthExpr*) const override {
77  return nullptr;
78  }
79 
80  std::shared_ptr<Analyzer::InValues> visitKeyForString(
81  const Analyzer::KeyForStringExpr*) const override {
82  return nullptr;
83  }
84 
85  std::shared_ptr<Analyzer::InValues> visitCardinality(
86  const Analyzer::CardinalityExpr*) const override {
87  return nullptr;
88  }
89 
90  std::shared_ptr<Analyzer::InValues> visitLikeExpr(
91  const Analyzer::LikeExpr*) const override {
92  return nullptr;
93  }
94 
95  std::shared_ptr<Analyzer::InValues> visitRegexpExpr(
96  const Analyzer::RegexpExpr*) const override {
97  return nullptr;
98  }
99 
100  std::shared_ptr<Analyzer::InValues> visitCaseExpr(
101  const Analyzer::CaseExpr*) const override {
102  return nullptr;
103  }
104 
105  std::shared_ptr<Analyzer::InValues> visitDatetruncExpr(
106  const Analyzer::DatetruncExpr*) const override {
107  return nullptr;
108  }
109 
110  std::shared_ptr<Analyzer::InValues> visitDatediffExpr(
111  const Analyzer::DatediffExpr*) const override {
112  return nullptr;
113  }
114 
115  std::shared_ptr<Analyzer::InValues> visitDateaddExpr(
116  const Analyzer::DateaddExpr*) const override {
117  return nullptr;
118  }
119 
120  std::shared_ptr<Analyzer::InValues> visitExtractExpr(
121  const Analyzer::ExtractExpr*) const override {
122  return nullptr;
123  }
124 
125  std::shared_ptr<Analyzer::InValues> visitLikelihood(
126  const Analyzer::LikelihoodExpr*) const override {
127  return nullptr;
128  }
129 
130  std::shared_ptr<Analyzer::InValues> visitAggExpr(
131  const Analyzer::AggExpr*) const override {
132  return nullptr;
133  }
134 
135  std::shared_ptr<Analyzer::InValues> aggregateResult(
136  const std::shared_ptr<Analyzer::InValues>& lhs,
137  const std::shared_ptr<Analyzer::InValues>& rhs) const override {
138  if (!lhs || !rhs) {
139  return nullptr;
140  }
141  if (!(*lhs->get_arg() == *rhs->get_arg())) {
142  return nullptr;
143  }
144  auto union_values = lhs->get_value_list();
145  const auto& rhs_values = rhs->get_value_list();
146  union_values.insert(union_values.end(), rhs_values.begin(), rhs_values.end());
147  return makeExpr<Analyzer::InValues>(lhs->get_own_arg(), union_values);
148  }
149 };
150 
152  protected:
153  std::shared_ptr<Analyzer::Expr> visitBinOper(
154  const Analyzer::BinOper* bin_oper) const override {
155  OrToInVisitor simple_visitor;
156  if (bin_oper->get_optype() == kOR) {
157  auto rewritten = simple_visitor.visit(bin_oper);
158  if (rewritten) {
159  return rewritten;
160  }
161  }
162  auto lhs = bin_oper->get_own_left_operand();
163  auto rhs = bin_oper->get_own_right_operand();
164  auto rewritten_lhs = visit(lhs.get());
165  auto rewritten_rhs = visit(rhs.get());
166  return makeExpr<Analyzer::BinOper>(bin_oper->get_type_info(),
167  bin_oper->get_contains_agg(),
168  bin_oper->get_optype(),
169  bin_oper->get_qualifier(),
170  rewritten_lhs ? rewritten_lhs : lhs,
171  rewritten_rhs ? rewritten_rhs : rhs);
172  }
173 };
174 
176  protected:
178 
179  RetType visitArrayOper(const Analyzer::ArrayExpr* array_expr) const override {
180  std::vector<std::shared_ptr<Analyzer::Expr>> args_copy;
181  for (size_t i = 0; i < array_expr->getElementCount(); ++i) {
182  auto const element_expr_ptr = visit(array_expr->getElement(i));
183  auto const& element_expr_type_info = element_expr_ptr->get_type_info();
184 
185  if (!element_expr_type_info.is_string() ||
186  element_expr_type_info.get_compression() != kENCODING_NONE) {
187  args_copy.push_back(element_expr_ptr);
188  } else {
189  auto transient_dict_type_info = element_expr_type_info;
190 
191  transient_dict_type_info.set_compression(kENCODING_DICT);
192  transient_dict_type_info.set_comp_param(TRANSIENT_DICT_ID);
193  transient_dict_type_info.set_fixed_size();
194  args_copy.push_back(element_expr_ptr->add_cast(transient_dict_type_info));
195  }
196  }
197 
198  const auto& type_info = array_expr->get_type_info();
199  return makeExpr<Analyzer::ArrayExpr>(type_info,
200  args_copy,
201  array_expr->getExprIndex(),
202  array_expr->isNull(),
203  array_expr->isLocalAlloc());
204  }
205 };
206 
208  template <typename T>
209  bool foldComparison(SQLOps optype, T t1, T t2) const {
210  switch (optype) {
211  case kEQ:
212  return t1 == t2;
213  case kNE:
214  return t1 != t2;
215  case kLT:
216  return t1 < t2;
217  case kLE:
218  return t1 <= t2;
219  case kGT:
220  return t1 > t2;
221  case kGE:
222  return t1 >= t2;
223  default:
224  break;
225  }
226  throw std::runtime_error("Unable to fold");
227  return false;
228  }
229 
230  template <typename T>
231  bool foldLogic(SQLOps optype, T t1, T t2) const {
232  switch (optype) {
233  case kAND:
234  return t1 && t2;
235  case kOR:
236  return t1 || t2;
237  case kNOT:
238  return !t1;
239  default:
240  break;
241  }
242  throw std::runtime_error("Unable to fold");
243  return false;
244  }
245 
246  template <typename T>
247  T foldArithmetic(SQLOps optype, T t1, T t2) const {
248  bool t2_is_zero = (t2 == (t2 - t2));
249  bool t2_is_negative = (t2 < (t2 - t2));
250  switch (optype) {
251  case kPLUS:
252  // The MIN limit for float and double is the smallest representable value,
253  // not the lowest negative value! Switching to C++11 lowest.
254  if ((t2_is_negative && t1 < std::numeric_limits<T>::lowest() - t2) ||
255  (!t2_is_negative && t1 > std::numeric_limits<T>::max() - t2)) {
256  num_overflows_++;
257  throw std::runtime_error("Plus overflow");
258  }
259  return t1 + t2;
260  case kMINUS:
261  if ((t2_is_negative && t1 > std::numeric_limits<T>::max() + t2) ||
262  (!t2_is_negative && t1 < std::numeric_limits<T>::lowest() + t2)) {
263  num_overflows_++;
264  throw std::runtime_error("Minus overflow");
265  }
266  return t1 - t2;
267  case kMULTIPLY: {
268  if (t2_is_zero) {
269  return t2;
270  }
271  auto ct1 = t1;
272  auto ct2 = t2;
273  // Need to keep t2's sign on the left
274  if (t2_is_negative) {
275  if (t1 == std::numeric_limits<T>::lowest() ||
276  t2 == std::numeric_limits<T>::lowest()) {
277  // negation could overflow - bail
278  num_overflows_++;
279  throw std::runtime_error("Mul neg overflow");
280  }
281  ct1 = -t1; // ct1 gets t2's negativity
282  ct2 = -t2; // ct2 is now positive
283  }
284  // Don't check overlow if we are folding FP mul by a fraction
285  bool ct2_is_fraction = (ct2 < (ct2 / ct2));
286  if (!ct2_is_fraction) {
287  if (ct1 > std::numeric_limits<T>::max() / ct2 ||
288  ct1 < std::numeric_limits<T>::lowest() / ct2) {
289  num_overflows_++;
290  throw std::runtime_error("Mul overflow");
291  }
292  }
293  return t1 * t2;
294  }
295  case kDIVIDE:
296  if (t2_is_zero) {
297  throw std::runtime_error("Will not fold division by zero");
298  }
299  return t1 / t2;
300  default:
301  break;
302  }
303  throw std::runtime_error("Unable to fold");
304  }
305 
306  bool foldOper(SQLOps optype,
307  SQLTypes type,
308  Datum lhs,
309  Datum rhs,
310  Datum& result,
311  SQLTypes& result_type) const {
312  result_type = type;
313 
314  try {
315  switch (type) {
316  case kBOOLEAN:
317  if (IS_COMPARISON(optype)) {
318  result.boolval = foldComparison<bool>(optype, lhs.boolval, rhs.boolval);
319  result_type = kBOOLEAN;
320  return true;
321  }
322  if (IS_LOGIC(optype)) {
323  result.boolval = foldLogic<bool>(optype, lhs.boolval, rhs.boolval);
324  result_type = kBOOLEAN;
325  return true;
326  }
327  CHECK(!IS_ARITHMETIC(optype));
328  break;
329  case kTINYINT:
330  if (IS_COMPARISON(optype)) {
331  result.boolval =
332  foldComparison<int8_t>(optype, lhs.tinyintval, rhs.tinyintval);
333  result_type = kBOOLEAN;
334  return true;
335  }
336  if (IS_ARITHMETIC(optype)) {
337  result.tinyintval =
338  foldArithmetic<int8_t>(optype, lhs.tinyintval, rhs.tinyintval);
339  result_type = kTINYINT;
340  return true;
341  }
342  CHECK(!IS_LOGIC(optype));
343  break;
344  case kSMALLINT:
345  if (IS_COMPARISON(optype)) {
346  result.boolval =
347  foldComparison<int16_t>(optype, lhs.smallintval, rhs.smallintval);
348  result_type = kBOOLEAN;
349  return true;
350  }
351  if (IS_ARITHMETIC(optype)) {
352  result.smallintval =
353  foldArithmetic<int16_t>(optype, lhs.smallintval, rhs.smallintval);
354  result_type = kSMALLINT;
355  return true;
356  }
357  CHECK(!IS_LOGIC(optype));
358  break;
359  case kINT:
360  if (IS_COMPARISON(optype)) {
361  result.boolval = foldComparison<int32_t>(optype, lhs.intval, rhs.intval);
362  result_type = kBOOLEAN;
363  return true;
364  }
365  if (IS_ARITHMETIC(optype)) {
366  result.intval = foldArithmetic<int32_t>(optype, lhs.intval, rhs.intval);
367  result_type = kINT;
368  return true;
369  }
370  CHECK(!IS_LOGIC(optype));
371  break;
372  case kBIGINT:
373  if (IS_COMPARISON(optype)) {
374  result.boolval =
375  foldComparison<int64_t>(optype, lhs.bigintval, rhs.bigintval);
376  result_type = kBOOLEAN;
377  return true;
378  }
379  if (IS_ARITHMETIC(optype)) {
380  result.bigintval =
381  foldArithmetic<int64_t>(optype, lhs.bigintval, rhs.bigintval);
382  result_type = kBIGINT;
383  return true;
384  }
385  CHECK(!IS_LOGIC(optype));
386  break;
387  case kFLOAT:
388  if (IS_COMPARISON(optype)) {
389  result.boolval = foldComparison<float>(optype, lhs.floatval, rhs.floatval);
390  result_type = kBOOLEAN;
391  return true;
392  }
393  if (IS_ARITHMETIC(optype)) {
394  result.floatval = foldArithmetic<float>(optype, lhs.floatval, rhs.floatval);
395  result_type = kFLOAT;
396  return true;
397  }
398  CHECK(!IS_LOGIC(optype));
399  break;
400  case kDOUBLE:
401  if (IS_COMPARISON(optype)) {
402  result.boolval = foldComparison<double>(optype, lhs.doubleval, rhs.doubleval);
403  result_type = kBOOLEAN;
404  return true;
405  }
406  if (IS_ARITHMETIC(optype)) {
407  result.doubleval =
408  foldArithmetic<double>(optype, lhs.doubleval, rhs.doubleval);
409  result_type = kDOUBLE;
410  return true;
411  }
412  CHECK(!IS_LOGIC(optype));
413  break;
414  default:
415  break;
416  }
417  } catch (...) {
418  return false;
419  }
420  return false;
421  }
422 
423  std::shared_ptr<Analyzer::Expr> visitUOper(
424  const Analyzer::UOper* uoper) const override {
425  const auto unvisited_operand = uoper->get_operand();
426  const auto optype = uoper->get_optype();
427  const auto& ti = uoper->get_type_info();
428  if (optype == kCAST) {
429  // Cache the cast type so it could be used in operand rewriting/folding
430  casts_.insert({unvisited_operand, ti});
431  }
432  const auto operand = visit(unvisited_operand);
433  const auto& operand_ti = operand->get_type_info();
434  const auto operand_type =
435  operand_ti.is_decimal() ? decimal_to_int_type(operand_ti) : operand_ti.get_type();
436  const auto const_operand =
437  std::dynamic_pointer_cast<const Analyzer::Constant>(operand);
438 
439  if (const_operand) {
440  const auto operand_datum = const_operand->get_constval();
441  Datum zero_datum = {};
442  Datum result_datum = {};
443  SQLTypes result_type;
444  switch (optype) {
445  case kNOT: {
446  if (foldOper(kEQ,
447  operand_type,
448  zero_datum,
449  operand_datum,
450  result_datum,
451  result_type)) {
452  CHECK_EQ(result_type, kBOOLEAN);
453  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
454  }
455  break;
456  }
457  case kUMINUS: {
458  if (foldOper(kMINUS,
459  operand_type,
460  zero_datum,
461  operand_datum,
462  result_datum,
463  result_type)) {
464  if (!operand_ti.is_decimal()) {
465  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
466  }
467  return makeExpr<Analyzer::Constant>(ti, false, result_datum);
468  }
469  break;
470  }
471  case kCAST: {
472  // Trying to fold number to number casts only
473  if (!ti.is_number() || !operand_ti.is_number()) {
474  break;
475  }
476  // Disallowing folding of FP to DECIMAL casts for now:
477  // allowing them would make this test pass:
478  // update dectest set d=cast( 1234.0 as float );
479  // which is expected to throw in Update.ImplicitCastToNumericTypes
480  // due to cast codegen currently not supporting these casts
481  if (ti.is_decimal() && operand_ti.is_fp()) {
482  break;
483  }
484  auto operand_copy = const_operand->deep_copy();
485  auto cast_operand = operand_copy->add_cast(ti);
486  auto const_cast_operand =
487  std::dynamic_pointer_cast<const Analyzer::Constant>(cast_operand);
488  if (const_cast_operand) {
489  auto const_cast_datum = const_cast_operand->get_constval();
490  return makeExpr<Analyzer::Constant>(ti, false, const_cast_datum);
491  }
492  }
493  default:
494  break;
495  }
496  }
497 
498  return makeExpr<Analyzer::UOper>(
499  uoper->get_type_info(), uoper->get_contains_agg(), optype, operand);
500  }
501 
502  std::shared_ptr<Analyzer::Expr> visitBinOper(
503  const Analyzer::BinOper* bin_oper) const override {
504  const auto optype = bin_oper->get_optype();
505  auto ti = bin_oper->get_type_info();
506  auto left_operand = bin_oper->get_own_left_operand();
507  auto right_operand = bin_oper->get_own_right_operand();
508 
509  // Check if bin_oper result is cast to a larger int or fp type
510  if (casts_.find(bin_oper) != casts_.end()) {
511  const auto cast_ti = casts_[bin_oper];
512  const auto& lhs_ti = bin_oper->get_left_operand()->get_type_info();
513  // Propagate cast down to the operands for folding
514  if ((cast_ti.is_integer() || cast_ti.is_fp()) && lhs_ti.is_integer() &&
515  cast_ti.get_size() > lhs_ti.get_size() &&
516  (optype == kMINUS || optype == kPLUS || optype == kMULTIPLY)) {
517  // Before folding, cast the operands to the bigger type to avoid overflows.
518  // Currently upcasting smaller integer types to larger integers or double.
519  left_operand = left_operand->deep_copy()->add_cast(cast_ti);
520  right_operand = right_operand->deep_copy()->add_cast(cast_ti);
521  ti = cast_ti;
522  }
523  }
524 
525  const auto lhs = visit(left_operand.get());
526  const auto rhs = visit(right_operand.get());
527 
528  auto const_lhs = std::dynamic_pointer_cast<Analyzer::Constant>(lhs);
529  auto const_rhs = std::dynamic_pointer_cast<Analyzer::Constant>(rhs);
530  const auto& lhs_ti = lhs->get_type_info();
531  const auto& rhs_ti = rhs->get_type_info();
532  auto lhs_type = lhs_ti.is_decimal() ? decimal_to_int_type(lhs_ti) : lhs_ti.get_type();
533  auto rhs_type = rhs_ti.is_decimal() ? decimal_to_int_type(rhs_ti) : rhs_ti.get_type();
534 
535  if (const_lhs && const_rhs && lhs_type == rhs_type) {
536  auto lhs_datum = const_lhs->get_constval();
537  auto rhs_datum = const_rhs->get_constval();
538  Datum result_datum = {};
539  SQLTypes result_type;
540  if (foldOper(optype, lhs_type, lhs_datum, rhs_datum, result_datum, result_type)) {
541  // Fold all ops that don't take in decimal operands, and also decimal comparisons
542  if (!lhs_ti.is_decimal() || IS_COMPARISON(optype)) {
543  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
544  }
545  // Decimal arithmetic has been done as kBIGINT. Selectively fold some decimal ops,
546  // using result_datum and BinOper expr typeinfo which was adjusted for these ops.
547  if (optype == kMINUS || optype == kPLUS || optype == kMULTIPLY) {
548  return makeExpr<Analyzer::Constant>(ti, false, result_datum);
549  }
550  }
551  }
552 
553  if (optype == kAND && lhs_type == rhs_type && lhs_type == kBOOLEAN) {
554  if (const_rhs && !const_rhs->get_is_null()) {
555  auto rhs_datum = const_rhs->get_constval();
556  if (rhs_datum.boolval == false) {
557  Datum d;
558  d.boolval = false;
559  // lhs && false --> false
560  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
561  }
562  // lhs && true --> lhs
563  return lhs;
564  }
565  if (const_lhs && !const_lhs->get_is_null()) {
566  auto lhs_datum = const_lhs->get_constval();
567  if (lhs_datum.boolval == false) {
568  Datum d;
569  d.boolval = false;
570  // false && rhs --> false
571  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
572  }
573  // true && rhs --> rhs
574  return rhs;
575  }
576  }
577  if (optype == kOR && lhs_type == rhs_type && lhs_type == kBOOLEAN) {
578  if (const_rhs && !const_rhs->get_is_null()) {
579  auto rhs_datum = const_rhs->get_constval();
580  if (rhs_datum.boolval == true) {
581  Datum d;
582  d.boolval = true;
583  // lhs || true --> true
584  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
585  }
586  // lhs || false --> lhs
587  return lhs;
588  }
589  if (const_lhs && !const_lhs->get_is_null()) {
590  auto lhs_datum = const_lhs->get_constval();
591  if (lhs_datum.boolval == true) {
592  Datum d;
593  d.boolval = true;
594  // true || rhs --> true
595  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
596  }
597  // false || rhs --> rhs
598  return rhs;
599  }
600  }
601  if (*lhs == *rhs) {
602  // Tautologies: v=v; v<=v; v>=v
603  if (optype == kEQ || optype == kLE || optype == kGE) {
604  Datum d;
605  d.boolval = true;
606  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
607  }
608  // Contradictions: v!=v; v<v; v>v
609  if (optype == kNE || optype == kLT || optype == kGT) {
610  Datum d;
611  d.boolval = false;
612  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
613  }
614  // v-v
615  if (optype == kMINUS) {
616  Datum d = {};
617  return makeExpr<Analyzer::Constant>(lhs_type, false, d);
618  }
619  }
620  // Convert fp division by a constant to multiplication by 1/constant
621  if (optype == kDIVIDE && const_rhs && rhs_ti.is_fp()) {
622  auto rhs_datum = const_rhs->get_constval();
623  std::shared_ptr<Analyzer::Expr> recip_rhs = nullptr;
624  if (rhs_ti.get_type() == kFLOAT) {
625  if (rhs_datum.floatval == 1.0) {
626  return lhs;
627  }
628  auto f = std::fabs(rhs_datum.floatval);
629  if (f > 1.0 || (f != 0.0 && 1.0 < f * std::numeric_limits<float>::max())) {
630  rhs_datum.floatval = 1.0 / rhs_datum.floatval;
631  recip_rhs = makeExpr<Analyzer::Constant>(rhs_type, false, rhs_datum);
632  }
633  } else if (rhs_ti.get_type() == kDOUBLE) {
634  if (rhs_datum.doubleval == 1.0) {
635  return lhs;
636  }
637  auto d = std::fabs(rhs_datum.doubleval);
638  if (d > 1.0 || (d != 0.0 && 1.0 < d * std::numeric_limits<double>::max())) {
639  rhs_datum.doubleval = 1.0 / rhs_datum.doubleval;
640  recip_rhs = makeExpr<Analyzer::Constant>(rhs_type, false, rhs_datum);
641  }
642  }
643  if (recip_rhs) {
644  return makeExpr<Analyzer::BinOper>(ti,
645  bin_oper->get_contains_agg(),
646  kMULTIPLY,
647  bin_oper->get_qualifier(),
648  lhs,
649  recip_rhs);
650  }
651  }
652 
653  return makeExpr<Analyzer::BinOper>(ti,
654  bin_oper->get_contains_agg(),
655  bin_oper->get_optype(),
656  bin_oper->get_qualifier(),
657  lhs,
658  rhs);
659  }
660 
661  std::shared_ptr<Analyzer::Expr> visitLower(
662  const Analyzer::LowerExpr* lower_expr) const override {
663  const auto constant_arg_expr =
664  dynamic_cast<const Analyzer::Constant*>(lower_expr->get_arg());
665  if (constant_arg_expr) {
667  boost::locale::to_lower(*constant_arg_expr->get_constval().stringval));
668  }
669  return makeExpr<Analyzer::LowerExpr>(lower_expr->get_own_arg());
670  }
671 
672  protected:
673  mutable std::unordered_map<const Analyzer::Expr*, const SQLTypeInfo> casts_;
674  mutable int32_t num_overflows_;
675 
676  public:
677  ConstantFoldingVisitor() : num_overflows_(0) {}
678  int32_t get_num_overflows() { return num_overflows_; }
679  void reset_num_overflows() { num_overflows_ = 0; }
680 };
681 
683  const auto with_likelihood = dynamic_cast<const Analyzer::LikelihoodExpr*>(expr);
684  if (!with_likelihood) {
685  return expr;
686  }
687  return with_likelihood->get_arg();
688 }
689 
690 } // namespace
691 
693  return ArrayElementStringLiteralEncodingVisitor().visit(expr);
694 }
695 
697  const auto sum_window = rewrite_sum_window(expr);
698  if (sum_window) {
699  return sum_window;
700  }
701  const auto avg_window = rewrite_avg_window(expr);
702  if (avg_window) {
703  return avg_window;
704  }
705  const auto expr_no_likelihood = strip_likelihood(expr);
706  // The following check is not strictly needed, but seems silly to transform a
707  // simple string comparison to an IN just to codegen the same thing anyway.
708 
709  RecursiveOrToInVisitor visitor;
710  auto rewritten_expr = visitor.visit(expr_no_likelihood);
711  const auto expr_with_likelihood =
712  std::dynamic_pointer_cast<const Analyzer::LikelihoodExpr>(rewritten_expr);
713  if (expr_with_likelihood) {
714  // Add back likelihood
715  return std::make_shared<Analyzer::LikelihoodExpr>(
716  rewritten_expr, expr_with_likelihood->get_likelihood());
717  }
718  return rewritten_expr;
719 }
720 
721 namespace {
722 
723 static const std::unordered_set<std::string> overlaps_supported_functions = {
724  "ST_Contains_MultiPolygon_Point",
725  "ST_Contains_Polygon_Point"};
726 
727 } // namespace
728 
729 boost::optional<OverlapsJoinConjunction> rewrite_overlaps_conjunction(
730  const std::shared_ptr<Analyzer::Expr> expr) {
731  auto func_oper = dynamic_cast<Analyzer::FunctionOper*>(expr.get());
732  if (func_oper) {
733  // TODO(adb): consider converting unordered set to an unordered map, potentially
734  // storing the rewrite function we want to apply in the map
735  if (overlaps_supported_functions.find(func_oper->getName()) !=
737  CHECK_GE(func_oper->getArity(), size_t(3));
738 
739  DeepCopyVisitor deep_copy_visitor;
740  auto lhs = func_oper->getOwnArg(2);
741  auto rewritten_lhs = deep_copy_visitor.visit(lhs.get());
742  CHECK(rewritten_lhs);
743  const auto& lhs_ti = rewritten_lhs->get_type_info();
744  if (!lhs_ti.is_geometry()) {
745  // TODO(adb): If ST_Contains is passed geospatial literals instead of columns, the
746  // function will be expanded during translation rather than during code
747  // generation. While this scenario does not make sense for the overlaps join, we
748  // need to detect and abort the overlaps rewrite. Adding a GeospatialConstant
749  // dervied class to the Analyzer may prove to be a better way to handle geo
750  // literals, but for now we ensure the LHS type is a geospatial type, which would
751  // mean the function has not been expanded to the physical types, yet.
752  LOG(WARNING) << "Failed to rewrite " << func_oper->getName()
753  << " to overlaps conjunction. LHS input type is not a geospatial "
754  "type. Are both inputs geospatial columns?";
755  return boost::none;
756  }
757 
758  // Read the bounds arg from the ST_Contains FuncOper (second argument)instead of the
759  // poly column (first argument)
760  auto rhs = func_oper->getOwnArg(1);
761  auto rewritten_rhs = deep_copy_visitor.visit(rhs.get());
762  CHECK(rewritten_rhs);
763 
764  auto overlaps_oper = makeExpr<Analyzer::BinOper>(
765  kBOOLEAN, kOVERLAPS, kONE, rewritten_lhs, rewritten_rhs);
766 
767  return OverlapsJoinConjunction{{expr}, {overlaps_oper}};
768  }
769  }
770  return boost::none;
771 }
772 
782  public:
784  for (const auto& join_condition : join_quals) {
785  for (const auto& qual : join_condition.quals) {
786  auto qual_bin_oper = dynamic_cast<Analyzer::BinOper*>(qual.get());
787  if (qual_bin_oper) {
788  join_qual_pairs.emplace_back(qual_bin_oper->get_left_operand(),
789  qual_bin_oper->get_right_operand());
790  }
791  }
792  }
793  }
794 
795  bool visitFunctionOper(const Analyzer::FunctionOper* func_oper) const override {
796  if (overlaps_supported_functions.find(func_oper->getName()) !=
798  const auto lhs = func_oper->getArg(2);
799  const auto rhs = func_oper->getArg(1);
800  for (const auto qual_pair : join_qual_pairs) {
801  if (*lhs == *qual_pair.first && *rhs == *qual_pair.second) {
802  return true;
803  }
804  }
805  }
806  return false;
807  }
808 
809  bool defaultResult() const override { return false; }
810 
811  private:
812  std::vector<std::pair<const Analyzer::Expr*, const Analyzer::Expr*>> join_qual_pairs;
813 };
814 
815 std::list<std::shared_ptr<Analyzer::Expr>> strip_join_covered_filter_quals(
816  const std::list<std::shared_ptr<Analyzer::Expr>>& quals,
817  const JoinQualsPerNestingLevel& join_quals) {
819  return quals;
820  }
821 
822  if (join_quals.empty()) {
823  return quals;
824  }
825 
826  std::list<std::shared_ptr<Analyzer::Expr>> quals_to_return;
827 
828  JoinCoveredQualVisitor visitor(join_quals);
829  for (const auto& qual : quals) {
830  if (!visitor.visit(qual.get())) {
831  // Not a covered qual, don't elide it from the filtered count
832  quals_to_return.push_back(qual);
833  }
834  }
835 
836  return quals_to_return;
837 }
838 
839 std::shared_ptr<Analyzer::Expr> fold_expr(const Analyzer::Expr* expr) {
840  if (!expr) {
841  return nullptr;
842  }
843  const auto expr_no_likelihood = strip_likelihood(expr);
844  ConstantFoldingVisitor visitor;
845  auto rewritten_expr = visitor.visit(expr_no_likelihood);
846  if (visitor.get_num_overflows() > 0 && rewritten_expr->get_type_info().is_integer() &&
847  rewritten_expr->get_type_info().get_type() != kBIGINT) {
848  auto rewritten_expr_const =
849  std::dynamic_pointer_cast<const Analyzer::Constant>(rewritten_expr);
850  if (!rewritten_expr_const) {
851  // Integer expression didn't fold completely the first time due to
852  // overflows in smaller type subexpressions, trying again with a cast
853  const auto& ti = SQLTypeInfo(kBIGINT, false);
854  auto bigint_expr_no_likelihood = expr_no_likelihood->deep_copy()->add_cast(ti);
855  auto rewritten_expr_take2 = visitor.visit(bigint_expr_no_likelihood.get());
856  auto rewritten_expr_take2_const =
857  std::dynamic_pointer_cast<Analyzer::Constant>(rewritten_expr_take2);
858  if (rewritten_expr_take2_const) {
859  // Managed to fold, switch to the new constant
860  rewritten_expr = rewritten_expr_take2_const;
861  }
862  }
863  }
864  const auto expr_with_likelihood = dynamic_cast<const Analyzer::LikelihoodExpr*>(expr);
865  if (expr_with_likelihood) {
866  // Add back likelihood
867  return std::make_shared<Analyzer::LikelihoodExpr>(
868  rewritten_expr, expr_with_likelihood->get_likelihood());
869  }
870  return rewritten_expr;
871 }
int8_t tinyintval
Definition: sqltypes.h:126
Analyzer::ExpressionPtr rewrite_array_elements(Analyzer::Expr const *expr)
#define CHECK_EQ(x, y)
Definition: Logger.h:201
#define IS_LOGIC(X)
Definition: sqldefs.h:60
std::shared_ptr< Analyzer::InValues > visitLikeExpr(const Analyzer::LikeExpr *) const override
std::shared_ptr< Analyzer::InValues > visitDateaddExpr(const Analyzer::DateaddExpr *) const override
const std::shared_ptr< Analyzer::Expr > get_own_arg() const
Definition: Analyzer.h:746
SQLTypes
Definition: sqltypes.h:41
std::shared_ptr< Analyzer::WindowFunction > rewrite_avg_window(const Analyzer::Expr *expr)
int32_t getExprIndex() const
Definition: Analyzer.h:1440
bool g_strip_join_covered_quals
Definition: Execute.cpp:90
std::shared_ptr< Analyzer::InValues > visitKeyForString(const Analyzer::KeyForStringExpr *) const override
std::shared_ptr< Analyzer::WindowFunction > rewrite_sum_window(const Analyzer::Expr *expr)
std::shared_ptr< Analyzer::InValues > visitBinOper(const Analyzer::BinOper *bin_oper) const override
#define LOG(tag)
Definition: Logger.h:188
bool boolval
Definition: sqltypes.h:125
const Expr * get_right_operand() const
Definition: Analyzer.h:437
SQLOps
Definition: sqldefs.h:29
Definition: sqldefs.h:35
Definition: sqldefs.h:36
Definition: sqldefs.h:38
#define CHECK_GE(x, y)
Definition: Logger.h:206
bool get_contains_agg() const
Definition: Analyzer.h:80
Definition: sqldefs.h:49
Definition: sqldefs.h:30
std::shared_ptr< Analyzer::Expr > ExpressionPtr
Definition: Analyzer.h:180
const Analyzer::Expr * extract_cast_arg(const Analyzer::Expr *expr)
Definition: Execute.h:151
Definition: sqldefs.h:41
std::vector< JoinCondition > JoinQualsPerNestingLevel
T visit(const Analyzer::Expr *expr) const
Analyzer::ExpressionPtr rewrite_expr(const Analyzer::Expr *expr)
static const std::unordered_set< std::string > overlaps_supported_functions
bool isNull() const
Definition: Analyzer.h:1442
std::shared_ptr< Analyzer::InValues > aggregateResult(const std::shared_ptr< Analyzer::InValues > &lhs, const std::shared_ptr< Analyzer::InValues > &rhs) const override
std::list< std::shared_ptr< Analyzer::Expr > > strip_join_covered_filter_quals(const std::list< std::shared_ptr< Analyzer::Expr >> &quals, const JoinQualsPerNestingLevel &join_quals)
std::shared_ptr< Analyzer::InValues > visitUOper(const Analyzer::UOper *uoper) const override
std::shared_ptr< Analyzer::Expr > visitBinOper(const Analyzer::BinOper *bin_oper) const override
std::shared_ptr< Analyzer::Expr > RetType
int32_t intval
Definition: sqltypes.h:128
std::shared_ptr< Analyzer::InValues > visitCharLength(const Analyzer::CharLengthExpr *) const override
std::shared_ptr< Analyzer::InValues > visitInIntegerSet(const Analyzer::InIntegerSet *) const override
SQLOps get_optype() const
Definition: Analyzer.h:433
float floatval
Definition: sqltypes.h:130
std::shared_ptr< Analyzer::Expr > visitUOper(const Analyzer::UOper *uoper) const override
CHECK(cgen_state)
std::shared_ptr< Analyzer::InValues > visitRegexpExpr(const Analyzer::RegexpExpr *) const override
int64_t bigintval
Definition: sqltypes.h:129
std::shared_ptr< Analyzer::InValues > visitAggExpr(const Analyzer::AggExpr *) const override
bool visitFunctionOper(const Analyzer::FunctionOper *func_oper) const override
Definition: sqldefs.h:37
bool foldOper(SQLOps optype, SQLTypes type, Datum lhs, Datum rhs, Datum &result, SQLTypes &result_type) const
int16_t smallintval
Definition: sqltypes.h:127
bool defaultResult() const override
static std::shared_ptr< Analyzer::Expr > analyzeValue(const std::string &)
Definition: ParserNode.cpp:97
boost::optional< OverlapsJoinConjunction > rewrite_overlaps_conjunction(const std::shared_ptr< Analyzer::Expr > expr)
RetType visitArrayOper(const Analyzer::ArrayExpr *array_expr) const override
SQLTypeInfoCore< ArrayContextTypeSizer, ExecutorTypePackaging, DateTimeFacilities > SQLTypeInfo
Definition: sqltypes.h:852
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:78
std::shared_ptr< Analyzer::Expr > visitBinOper(const Analyzer::BinOper *bin_oper) const override
size_t getElementCount() const
Definition: Analyzer.h:1439
SQLTypes decimal_to_int_type(const SQLTypeInfo &ti)
Definition: Datum.cpp:268
Definition: sqldefs.h:34
Definition: sqldefs.h:40
std::vector< std::pair< const Analyzer::Expr *, const Analyzer::Expr * > > join_qual_pairs
Definition: sqldefs.h:69
#define TRANSIENT_DICT_ID
Definition: sqltypes.h:189
bool isLocalAlloc() const
Definition: Analyzer.h:1441
const Analyzer::Expr * getArg(const size_t i) const
Definition: Analyzer.h:1311
#define IS_ARITHMETIC(X)
Definition: sqldefs.h:61
const Expr * get_operand() const
Definition: Analyzer.h:365
Datum get_constval() const
Definition: Analyzer.h:329
Definition: sqldefs.h:32
Expression class for the LOWER (lowercase) string function. The &quot;arg&quot; constructor parameter must be a...
Definition: Analyzer.h:740
std::unordered_map< const Analyzer::Expr *, const SQLTypeInfo > casts_
std::shared_ptr< Analyzer::InValues > visitDatediffExpr(const Analyzer::DatediffExpr *) const override
std::shared_ptr< Analyzer::InValues > visitLikelihood(const Analyzer::LikelihoodExpr *) const override
std::shared_ptr< Analyzer::Expr > visitLower(const Analyzer::LowerExpr *lower_expr) const override
std::shared_ptr< Analyzer::InValues > visitInValues(const Analyzer::InValues *) const override
Definition: sqldefs.h:33
const Expr * get_left_operand() const
Definition: Analyzer.h:436
std::shared_ptr< Analyzer::InValues > visitCaseExpr(const Analyzer::CaseExpr *) const override
Definition: sqltypes.h:48
JoinCoveredQualVisitor(const JoinQualsPerNestingLevel &join_quals)
const std::shared_ptr< Analyzer::Expr > get_own_right_operand() const
Definition: Analyzer.h:441
std::string getName() const
Definition: Analyzer.h:1307
bool is_decimal() const
Definition: sqltypes.h:480
std::shared_ptr< Analyzer::InValues > visitDatetruncExpr(const Analyzer::DatetruncExpr *) const override
Definition: sqldefs.h:39
const std::shared_ptr< Analyzer::Expr > get_own_left_operand() const
Definition: Analyzer.h:438
SQLOps get_optype() const
Definition: Analyzer.h:364
const Expr * get_arg() const
Definition: Analyzer.h:744
std::shared_ptr< Analyzer::Expr > fold_expr(const Analyzer::Expr *expr)
#define IS_COMPARISON(X)
Definition: sqldefs.h:57
double doubleval
Definition: sqltypes.h:131
std::shared_ptr< Analyzer::InValues > visitExtractExpr(const Analyzer::ExtractExpr *) const override
const Analyzer::Expr * getElement(const size_t i) const
Definition: Analyzer.h:1444
SQLQualifier get_qualifier() const
Definition: Analyzer.h:435
const Analyzer::Expr * strip_likelihood(const Analyzer::Expr *expr)
std::shared_ptr< Analyzer::InValues > visitCardinality(const Analyzer::CardinalityExpr *) const override