OmniSciDB  0b528656ed
ExpressionRewrite.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ExpressionRewrite.h"
18 #include "../Analyzer/Analyzer.h"
19 #include "../Parser/ParserNode.h"
20 #include "../Shared/sqldefs.h"
21 #include "DeepCopyVisitor.h"
22 #include "Execute.h"
23 #include "RelAlgTranslator.h"
24 #include "ScalarExprVisitor.h"
25 #include "Shared/Logger.h"
27 
28 #include <boost/locale/conversion.hpp>
29 #include <unordered_set>
30 
31 namespace {
32 
33 class OrToInVisitor : public ScalarExprVisitor<std::shared_ptr<Analyzer::InValues>> {
34  protected:
35  std::shared_ptr<Analyzer::InValues> visitBinOper(
36  const Analyzer::BinOper* bin_oper) const override {
37  switch (bin_oper->get_optype()) {
38  case kEQ: {
39  const auto rhs_owned = bin_oper->get_own_right_operand();
40  auto rhs_no_cast = extract_cast_arg(rhs_owned.get());
41  if (!dynamic_cast<const Analyzer::Constant*>(rhs_no_cast)) {
42  return nullptr;
43  }
44  const auto arg = bin_oper->get_own_left_operand();
45  const auto& arg_ti = arg->get_type_info();
46  auto rhs = rhs_no_cast->deep_copy()->add_cast(arg_ti);
47  return makeExpr<Analyzer::InValues>(
48  arg, std::list<std::shared_ptr<Analyzer::Expr>>{rhs});
49  }
50  case kOR: {
51  return aggregateResult(visit(bin_oper->get_left_operand()),
52  visit(bin_oper->get_right_operand()));
53  }
54  default:
55  break;
56  }
57  return nullptr;
58  }
59 
60  std::shared_ptr<Analyzer::InValues> visitUOper(
61  const Analyzer::UOper* uoper) const override {
62  return nullptr;
63  }
64 
65  std::shared_ptr<Analyzer::InValues> visitInValues(
66  const Analyzer::InValues*) const override {
67  return nullptr;
68  }
69 
70  std::shared_ptr<Analyzer::InValues> visitInIntegerSet(
71  const Analyzer::InIntegerSet*) const override {
72  return nullptr;
73  }
74 
75  std::shared_ptr<Analyzer::InValues> visitCharLength(
76  const Analyzer::CharLengthExpr*) const override {
77  return nullptr;
78  }
79 
80  std::shared_ptr<Analyzer::InValues> visitKeyForString(
81  const Analyzer::KeyForStringExpr*) const override {
82  return nullptr;
83  }
84 
85  std::shared_ptr<Analyzer::InValues> visitCardinality(
86  const Analyzer::CardinalityExpr*) const override {
87  return nullptr;
88  }
89 
90  std::shared_ptr<Analyzer::InValues> visitLikeExpr(
91  const Analyzer::LikeExpr*) const override {
92  return nullptr;
93  }
94 
95  std::shared_ptr<Analyzer::InValues> visitRegexpExpr(
96  const Analyzer::RegexpExpr*) const override {
97  return nullptr;
98  }
99 
100  std::shared_ptr<Analyzer::InValues> visitCaseExpr(
101  const Analyzer::CaseExpr*) const override {
102  return nullptr;
103  }
104 
105  std::shared_ptr<Analyzer::InValues> visitDatetruncExpr(
106  const Analyzer::DatetruncExpr*) const override {
107  return nullptr;
108  }
109 
110  std::shared_ptr<Analyzer::InValues> visitDatediffExpr(
111  const Analyzer::DatediffExpr*) const override {
112  return nullptr;
113  }
114 
115  std::shared_ptr<Analyzer::InValues> visitDateaddExpr(
116  const Analyzer::DateaddExpr*) const override {
117  return nullptr;
118  }
119 
120  std::shared_ptr<Analyzer::InValues> visitExtractExpr(
121  const Analyzer::ExtractExpr*) const override {
122  return nullptr;
123  }
124 
125  std::shared_ptr<Analyzer::InValues> visitLikelihood(
126  const Analyzer::LikelihoodExpr*) const override {
127  return nullptr;
128  }
129 
130  std::shared_ptr<Analyzer::InValues> visitAggExpr(
131  const Analyzer::AggExpr*) const override {
132  return nullptr;
133  }
134 
135  std::shared_ptr<Analyzer::InValues> aggregateResult(
136  const std::shared_ptr<Analyzer::InValues>& lhs,
137  const std::shared_ptr<Analyzer::InValues>& rhs) const override {
138  if (!lhs || !rhs) {
139  return nullptr;
140  }
141 
142  if (lhs->get_arg()->get_type_info() == rhs->get_arg()->get_type_info() &&
143  (*lhs->get_arg() == *rhs->get_arg())) {
144  auto union_values = lhs->get_value_list();
145  const auto& rhs_values = rhs->get_value_list();
146  union_values.insert(union_values.end(), rhs_values.begin(), rhs_values.end());
147  return makeExpr<Analyzer::InValues>(lhs->get_own_arg(), union_values);
148  }
149  return nullptr;
150  }
151 };
152 
154  protected:
155  std::shared_ptr<Analyzer::Expr> visitBinOper(
156  const Analyzer::BinOper* bin_oper) const override {
157  OrToInVisitor simple_visitor;
158  if (bin_oper->get_optype() == kOR) {
159  auto rewritten = simple_visitor.visit(bin_oper);
160  if (rewritten) {
161  return rewritten;
162  }
163  }
164  auto lhs = bin_oper->get_own_left_operand();
165  auto rhs = bin_oper->get_own_right_operand();
166  auto rewritten_lhs = visit(lhs.get());
167  auto rewritten_rhs = visit(rhs.get());
168  return makeExpr<Analyzer::BinOper>(bin_oper->get_type_info(),
169  bin_oper->get_contains_agg(),
170  bin_oper->get_optype(),
171  bin_oper->get_qualifier(),
172  rewritten_lhs ? rewritten_lhs : lhs,
173  rewritten_rhs ? rewritten_rhs : rhs);
174  }
175 };
176 
178  protected:
180 
181  RetType visitArrayOper(const Analyzer::ArrayExpr* array_expr) const override {
182  std::vector<std::shared_ptr<Analyzer::Expr>> args_copy;
183  for (size_t i = 0; i < array_expr->getElementCount(); ++i) {
184  auto const element_expr_ptr = visit(array_expr->getElement(i));
185  auto const& element_expr_type_info = element_expr_ptr->get_type_info();
186 
187  if (!element_expr_type_info.is_string() ||
188  element_expr_type_info.get_compression() != kENCODING_NONE) {
189  args_copy.push_back(element_expr_ptr);
190  } else {
191  auto transient_dict_type_info = element_expr_type_info;
192 
193  transient_dict_type_info.set_compression(kENCODING_DICT);
194  transient_dict_type_info.set_comp_param(TRANSIENT_DICT_ID);
195  transient_dict_type_info.set_fixed_size();
196  args_copy.push_back(element_expr_ptr->add_cast(transient_dict_type_info));
197  }
198  }
199 
200  const auto& type_info = array_expr->get_type_info();
201  return makeExpr<Analyzer::ArrayExpr>(
202  type_info, args_copy, array_expr->isNull(), array_expr->isLocalAlloc());
203  }
204 };
205 
207  template <typename T>
208  bool foldComparison(SQLOps optype, T t1, T t2) const {
209  switch (optype) {
210  case kEQ:
211  return t1 == t2;
212  case kNE:
213  return t1 != t2;
214  case kLT:
215  return t1 < t2;
216  case kLE:
217  return t1 <= t2;
218  case kGT:
219  return t1 > t2;
220  case kGE:
221  return t1 >= t2;
222  default:
223  break;
224  }
225  throw std::runtime_error("Unable to fold");
226  return false;
227  }
228 
229  template <typename T>
230  bool foldLogic(SQLOps optype, T t1, T t2) const {
231  switch (optype) {
232  case kAND:
233  return t1 && t2;
234  case kOR:
235  return t1 || t2;
236  case kNOT:
237  return !t1;
238  default:
239  break;
240  }
241  throw std::runtime_error("Unable to fold");
242  return false;
243  }
244 
245  template <typename T>
246  T foldArithmetic(SQLOps optype, T t1, T t2) const {
247  bool t2_is_zero = (t2 == (t2 - t2));
248  bool t2_is_negative = (t2 < (t2 - t2));
249  switch (optype) {
250  case kPLUS:
251  // The MIN limit for float and double is the smallest representable value,
252  // not the lowest negative value! Switching to C++11 lowest.
253  if ((t2_is_negative && t1 < std::numeric_limits<T>::lowest() - t2) ||
254  (!t2_is_negative && t1 > std::numeric_limits<T>::max() - t2)) {
255  num_overflows_++;
256  throw std::runtime_error("Plus overflow");
257  }
258  return t1 + t2;
259  case kMINUS:
260  if ((t2_is_negative && t1 > std::numeric_limits<T>::max() + t2) ||
261  (!t2_is_negative && t1 < std::numeric_limits<T>::lowest() + t2)) {
262  num_overflows_++;
263  throw std::runtime_error("Minus overflow");
264  }
265  return t1 - t2;
266  case kMULTIPLY: {
267  if (t2_is_zero) {
268  return t2;
269  }
270  auto ct1 = t1;
271  auto ct2 = t2;
272  // Need to keep t2's sign on the left
273  if (t2_is_negative) {
274  if (t1 == std::numeric_limits<T>::lowest() ||
275  t2 == std::numeric_limits<T>::lowest()) {
276  // negation could overflow - bail
277  num_overflows_++;
278  throw std::runtime_error("Mul neg overflow");
279  }
280  ct1 = -t1; // ct1 gets t2's negativity
281  ct2 = -t2; // ct2 is now positive
282  }
283  // Don't check overlow if we are folding FP mul by a fraction
284  bool ct2_is_fraction = (ct2 < (ct2 / ct2));
285  if (!ct2_is_fraction) {
286  if (ct1 > std::numeric_limits<T>::max() / ct2 ||
287  ct1 < std::numeric_limits<T>::lowest() / ct2) {
288  num_overflows_++;
289  throw std::runtime_error("Mul overflow");
290  }
291  }
292  return t1 * t2;
293  }
294  case kDIVIDE:
295  if (t2_is_zero) {
296  throw std::runtime_error("Will not fold division by zero");
297  }
298  return t1 / t2;
299  default:
300  break;
301  }
302  throw std::runtime_error("Unable to fold");
303  }
304 
305  bool foldOper(SQLOps optype,
306  SQLTypes type,
307  Datum lhs,
308  Datum rhs,
309  Datum& result,
310  SQLTypes& result_type) const {
311  result_type = type;
312 
313  try {
314  switch (type) {
315  case kBOOLEAN:
316  if (IS_COMPARISON(optype)) {
317  result.boolval = foldComparison<bool>(optype, lhs.boolval, rhs.boolval);
318  result_type = kBOOLEAN;
319  return true;
320  }
321  if (IS_LOGIC(optype)) {
322  result.boolval = foldLogic<bool>(optype, lhs.boolval, rhs.boolval);
323  result_type = kBOOLEAN;
324  return true;
325  }
326  CHECK(!IS_ARITHMETIC(optype));
327  break;
328  case kTINYINT:
329  if (IS_COMPARISON(optype)) {
330  result.boolval =
331  foldComparison<int8_t>(optype, lhs.tinyintval, rhs.tinyintval);
332  result_type = kBOOLEAN;
333  return true;
334  }
335  if (IS_ARITHMETIC(optype)) {
336  result.tinyintval =
337  foldArithmetic<int8_t>(optype, lhs.tinyintval, rhs.tinyintval);
338  result_type = kTINYINT;
339  return true;
340  }
341  CHECK(!IS_LOGIC(optype));
342  break;
343  case kSMALLINT:
344  if (IS_COMPARISON(optype)) {
345  result.boolval =
346  foldComparison<int16_t>(optype, lhs.smallintval, rhs.smallintval);
347  result_type = kBOOLEAN;
348  return true;
349  }
350  if (IS_ARITHMETIC(optype)) {
351  result.smallintval =
352  foldArithmetic<int16_t>(optype, lhs.smallintval, rhs.smallintval);
353  result_type = kSMALLINT;
354  return true;
355  }
356  CHECK(!IS_LOGIC(optype));
357  break;
358  case kINT:
359  if (IS_COMPARISON(optype)) {
360  result.boolval = foldComparison<int32_t>(optype, lhs.intval, rhs.intval);
361  result_type = kBOOLEAN;
362  return true;
363  }
364  if (IS_ARITHMETIC(optype)) {
365  result.intval = foldArithmetic<int32_t>(optype, lhs.intval, rhs.intval);
366  result_type = kINT;
367  return true;
368  }
369  CHECK(!IS_LOGIC(optype));
370  break;
371  case kBIGINT:
372  if (IS_COMPARISON(optype)) {
373  result.boolval =
374  foldComparison<int64_t>(optype, lhs.bigintval, rhs.bigintval);
375  result_type = kBOOLEAN;
376  return true;
377  }
378  if (IS_ARITHMETIC(optype)) {
379  result.bigintval =
380  foldArithmetic<int64_t>(optype, lhs.bigintval, rhs.bigintval);
381  result_type = kBIGINT;
382  return true;
383  }
384  CHECK(!IS_LOGIC(optype));
385  break;
386  case kFLOAT:
387  if (IS_COMPARISON(optype)) {
388  result.boolval = foldComparison<float>(optype, lhs.floatval, rhs.floatval);
389  result_type = kBOOLEAN;
390  return true;
391  }
392  if (IS_ARITHMETIC(optype)) {
393  result.floatval = foldArithmetic<float>(optype, lhs.floatval, rhs.floatval);
394  result_type = kFLOAT;
395  return true;
396  }
397  CHECK(!IS_LOGIC(optype));
398  break;
399  case kDOUBLE:
400  if (IS_COMPARISON(optype)) {
401  result.boolval = foldComparison<double>(optype, lhs.doubleval, rhs.doubleval);
402  result_type = kBOOLEAN;
403  return true;
404  }
405  if (IS_ARITHMETIC(optype)) {
406  result.doubleval =
407  foldArithmetic<double>(optype, lhs.doubleval, rhs.doubleval);
408  result_type = kDOUBLE;
409  return true;
410  }
411  CHECK(!IS_LOGIC(optype));
412  break;
413  default:
414  break;
415  }
416  } catch (...) {
417  return false;
418  }
419  return false;
420  }
421 
422  std::shared_ptr<Analyzer::Expr> visitUOper(
423  const Analyzer::UOper* uoper) const override {
424  const auto unvisited_operand = uoper->get_operand();
425  const auto optype = uoper->get_optype();
426  const auto& ti = uoper->get_type_info();
427  if (optype == kCAST) {
428  // Cache the cast type so it could be used in operand rewriting/folding
429  casts_.insert({unvisited_operand, ti});
430  }
431  const auto operand = visit(unvisited_operand);
432  const auto& operand_ti = operand->get_type_info();
433  const auto operand_type =
434  operand_ti.is_decimal() ? decimal_to_int_type(operand_ti) : operand_ti.get_type();
435  const auto const_operand =
436  std::dynamic_pointer_cast<const Analyzer::Constant>(operand);
437 
438  if (const_operand) {
439  const auto operand_datum = const_operand->get_constval();
440  Datum zero_datum = {};
441  Datum result_datum = {};
442  SQLTypes result_type;
443  switch (optype) {
444  case kNOT: {
445  if (foldOper(kEQ,
446  operand_type,
447  zero_datum,
448  operand_datum,
449  result_datum,
450  result_type)) {
451  CHECK_EQ(result_type, kBOOLEAN);
452  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
453  }
454  break;
455  }
456  case kUMINUS: {
457  if (foldOper(kMINUS,
458  operand_type,
459  zero_datum,
460  operand_datum,
461  result_datum,
462  result_type)) {
463  if (!operand_ti.is_decimal()) {
464  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
465  }
466  return makeExpr<Analyzer::Constant>(ti, false, result_datum);
467  }
468  break;
469  }
470  case kCAST: {
471  // Trying to fold number to number casts only
472  if (!ti.is_number() || !operand_ti.is_number()) {
473  break;
474  }
475  // Disallowing folding of FP to DECIMAL casts for now:
476  // allowing them would make this test pass:
477  // update dectest set d=cast( 1234.0 as float );
478  // which is expected to throw in Update.ImplicitCastToNumericTypes
479  // due to cast codegen currently not supporting these casts
480  if (ti.is_decimal() && operand_ti.is_fp()) {
481  break;
482  }
483  auto operand_copy = const_operand->deep_copy();
484  auto cast_operand = operand_copy->add_cast(ti);
485  auto const_cast_operand =
486  std::dynamic_pointer_cast<const Analyzer::Constant>(cast_operand);
487  if (const_cast_operand) {
488  auto const_cast_datum = const_cast_operand->get_constval();
489  return makeExpr<Analyzer::Constant>(ti, false, const_cast_datum);
490  }
491  }
492  default:
493  break;
494  }
495  }
496 
497  return makeExpr<Analyzer::UOper>(
498  uoper->get_type_info(), uoper->get_contains_agg(), optype, operand);
499  }
500 
501  std::shared_ptr<Analyzer::Expr> visitBinOper(
502  const Analyzer::BinOper* bin_oper) const override {
503  const auto optype = bin_oper->get_optype();
504  auto ti = bin_oper->get_type_info();
505  auto left_operand = bin_oper->get_own_left_operand();
506  auto right_operand = bin_oper->get_own_right_operand();
507 
508  // Check if bin_oper result is cast to a larger int or fp type
509  if (casts_.find(bin_oper) != casts_.end()) {
510  const auto cast_ti = casts_[bin_oper];
511  const auto& lhs_ti = bin_oper->get_left_operand()->get_type_info();
512  // Propagate cast down to the operands for folding
513  if ((cast_ti.is_integer() || cast_ti.is_fp()) && lhs_ti.is_integer() &&
514  cast_ti.get_size() > lhs_ti.get_size() &&
515  (optype == kMINUS || optype == kPLUS || optype == kMULTIPLY)) {
516  // Before folding, cast the operands to the bigger type to avoid overflows.
517  // Currently upcasting smaller integer types to larger integers or double.
518  left_operand = left_operand->deep_copy()->add_cast(cast_ti);
519  right_operand = right_operand->deep_copy()->add_cast(cast_ti);
520  ti = cast_ti;
521  }
522  }
523 
524  const auto lhs = visit(left_operand.get());
525  const auto rhs = visit(right_operand.get());
526 
527  auto const_lhs = std::dynamic_pointer_cast<Analyzer::Constant>(lhs);
528  auto const_rhs = std::dynamic_pointer_cast<Analyzer::Constant>(rhs);
529  const auto& lhs_ti = lhs->get_type_info();
530  const auto& rhs_ti = rhs->get_type_info();
531  auto lhs_type = lhs_ti.is_decimal() ? decimal_to_int_type(lhs_ti) : lhs_ti.get_type();
532  auto rhs_type = rhs_ti.is_decimal() ? decimal_to_int_type(rhs_ti) : rhs_ti.get_type();
533 
534  if (const_lhs && const_rhs && lhs_type == rhs_type) {
535  auto lhs_datum = const_lhs->get_constval();
536  auto rhs_datum = const_rhs->get_constval();
537  Datum result_datum = {};
538  SQLTypes result_type;
539  if (foldOper(optype, lhs_type, lhs_datum, rhs_datum, result_datum, result_type)) {
540  // Fold all ops that don't take in decimal operands, and also decimal comparisons
541  if (!lhs_ti.is_decimal() || IS_COMPARISON(optype)) {
542  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
543  }
544  // Decimal arithmetic has been done as kBIGINT. Selectively fold some decimal ops,
545  // using result_datum and BinOper expr typeinfo which was adjusted for these ops.
546  if (optype == kMINUS || optype == kPLUS || optype == kMULTIPLY) {
547  return makeExpr<Analyzer::Constant>(ti, false, result_datum);
548  }
549  }
550  }
551 
552  if (optype == kAND && lhs_type == rhs_type && lhs_type == kBOOLEAN) {
553  if (const_rhs && !const_rhs->get_is_null()) {
554  auto rhs_datum = const_rhs->get_constval();
555  if (rhs_datum.boolval == false) {
556  Datum d;
557  d.boolval = false;
558  // lhs && false --> false
559  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
560  }
561  // lhs && true --> lhs
562  return lhs;
563  }
564  if (const_lhs && !const_lhs->get_is_null()) {
565  auto lhs_datum = const_lhs->get_constval();
566  if (lhs_datum.boolval == false) {
567  Datum d;
568  d.boolval = false;
569  // false && rhs --> false
570  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
571  }
572  // true && rhs --> rhs
573  return rhs;
574  }
575  }
576  if (optype == kOR && lhs_type == rhs_type && lhs_type == kBOOLEAN) {
577  if (const_rhs && !const_rhs->get_is_null()) {
578  auto rhs_datum = const_rhs->get_constval();
579  if (rhs_datum.boolval == true) {
580  Datum d;
581  d.boolval = true;
582  // lhs || true --> true
583  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
584  }
585  // lhs || false --> lhs
586  return lhs;
587  }
588  if (const_lhs && !const_lhs->get_is_null()) {
589  auto lhs_datum = const_lhs->get_constval();
590  if (lhs_datum.boolval == true) {
591  Datum d;
592  d.boolval = true;
593  // true || rhs --> true
594  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
595  }
596  // false || rhs --> rhs
597  return rhs;
598  }
599  }
600  if (*lhs == *rhs) {
601  // Tautologies: v=v; v<=v; v>=v
602  if (optype == kEQ || optype == kLE || optype == kGE) {
603  Datum d;
604  d.boolval = true;
605  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
606  }
607  // Contradictions: v!=v; v<v; v>v
608  if (optype == kNE || optype == kLT || optype == kGT) {
609  Datum d;
610  d.boolval = false;
611  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
612  }
613  // v-v
614  if (optype == kMINUS) {
615  Datum d = {};
616  return makeExpr<Analyzer::Constant>(lhs_type, false, d);
617  }
618  }
619  // Convert fp division by a constant to multiplication by 1/constant
620  if (optype == kDIVIDE && const_rhs && rhs_ti.is_fp()) {
621  auto rhs_datum = const_rhs->get_constval();
622  std::shared_ptr<Analyzer::Expr> recip_rhs = nullptr;
623  if (rhs_ti.get_type() == kFLOAT) {
624  if (rhs_datum.floatval == 1.0) {
625  return lhs;
626  }
627  auto f = std::fabs(rhs_datum.floatval);
628  if (f > 1.0 || (f != 0.0 && 1.0 < f * std::numeric_limits<float>::max())) {
629  rhs_datum.floatval = 1.0 / rhs_datum.floatval;
630  recip_rhs = makeExpr<Analyzer::Constant>(rhs_type, false, rhs_datum);
631  }
632  } else if (rhs_ti.get_type() == kDOUBLE) {
633  if (rhs_datum.doubleval == 1.0) {
634  return lhs;
635  }
636  auto d = std::fabs(rhs_datum.doubleval);
637  if (d > 1.0 || (d != 0.0 && 1.0 < d * std::numeric_limits<double>::max())) {
638  rhs_datum.doubleval = 1.0 / rhs_datum.doubleval;
639  recip_rhs = makeExpr<Analyzer::Constant>(rhs_type, false, rhs_datum);
640  }
641  }
642  if (recip_rhs) {
643  return makeExpr<Analyzer::BinOper>(ti,
644  bin_oper->get_contains_agg(),
645  kMULTIPLY,
646  bin_oper->get_qualifier(),
647  lhs,
648  recip_rhs);
649  }
650  }
651 
652  return makeExpr<Analyzer::BinOper>(ti,
653  bin_oper->get_contains_agg(),
654  bin_oper->get_optype(),
655  bin_oper->get_qualifier(),
656  lhs,
657  rhs);
658  }
659 
660  std::shared_ptr<Analyzer::Expr> visitLower(
661  const Analyzer::LowerExpr* lower_expr) const override {
662  const auto constant_arg_expr =
663  dynamic_cast<const Analyzer::Constant*>(lower_expr->get_arg());
664  if (constant_arg_expr) {
666  boost::locale::to_lower(*constant_arg_expr->get_constval().stringval));
667  }
668  return makeExpr<Analyzer::LowerExpr>(lower_expr->get_own_arg());
669  }
670 
671  protected:
672  mutable std::unordered_map<const Analyzer::Expr*, const SQLTypeInfo> casts_;
673  mutable int32_t num_overflows_;
674 
675  public:
676  ConstantFoldingVisitor() : num_overflows_(0) {}
677  int32_t get_num_overflows() { return num_overflows_; }
678  void reset_num_overflows() { num_overflows_ = 0; }
679 };
680 
682  const auto with_likelihood = dynamic_cast<const Analyzer::LikelihoodExpr*>(expr);
683  if (!with_likelihood) {
684  return expr;
685  }
686  return with_likelihood->get_arg();
687 }
688 
689 } // namespace
690 
692  return ArrayElementStringLiteralEncodingVisitor().visit(expr);
693 }
694 
696  const auto sum_window = rewrite_sum_window(expr);
697  if (sum_window) {
698  return sum_window;
699  }
700  const auto avg_window = rewrite_avg_window(expr);
701  if (avg_window) {
702  return avg_window;
703  }
704  const auto expr_no_likelihood = strip_likelihood(expr);
705  // The following check is not strictly needed, but seems silly to transform a
706  // simple string comparison to an IN just to codegen the same thing anyway.
707 
708  RecursiveOrToInVisitor visitor;
709  auto rewritten_expr = visitor.visit(expr_no_likelihood);
710  const auto expr_with_likelihood =
711  std::dynamic_pointer_cast<const Analyzer::LikelihoodExpr>(rewritten_expr);
712  if (expr_with_likelihood) {
713  // Add back likelihood
714  return std::make_shared<Analyzer::LikelihoodExpr>(
715  rewritten_expr, expr_with_likelihood->get_likelihood());
716  }
717  return rewritten_expr;
718 }
719 
720 namespace {
721 
722 static const std::unordered_set<std::string> overlaps_supported_functions = {
723  "ST_Contains_MultiPolygon_Point",
724  "ST_Contains_Polygon_Point",
725  "ST_Contains_Polygon_Polygon",
726  "ST_Contains_Polygon_MultiPolygon",
727  "ST_Contains_MultiPolygon_MultiPolygon",
728  "ST_Contains_MultiPolygon_Polygon",
729  "ST_Intersects_Polygon_Point",
730  "ST_Intersects_Polygon_Polygon",
731  "ST_Intersects_Polygon_MultiPolygon",
732  "ST_Intersects_MultiPolygon_MultiPolygon",
733  "ST_Intersects_MultiPolygon_Polygon",
734  "ST_Overlaps"};
735 
736 static const std::unordered_set<std::string> requires_many_to_many = {
737  "ST_Contains_Polygon_Polygon",
738  "ST_Contains_Polygon_MultiPolygon",
739  "ST_Contains_MultiPolygon_MultiPolygon",
740  "ST_Contains_MultiPolygon_Polygon",
741  "ST_Intersects_Polygon_Polygon",
742  "ST_Intersects_Polygon_MultiPolygon",
743  "ST_Intersects_MultiPolygon_MultiPolygon",
744  "ST_Intersects_MultiPolygon_Polygon"};
745 
746 } // namespace
747 
748 boost::optional<OverlapsJoinConjunction> rewrite_overlaps_conjunction(
749  const std::shared_ptr<Analyzer::Expr> expr) {
750  auto func_oper = dynamic_cast<Analyzer::FunctionOper*>(expr.get());
751  if (func_oper) {
752  const auto needs_many_many = [func_oper]() {
753  return requires_many_to_many.find(func_oper->getName()) !=
754  requires_many_to_many.end();
755  };
756  // TODO(adb): consider converting unordered set to an unordered map, potentially
757  // storing the rewrite function we want to apply in the map
758  if (overlaps_supported_functions.find(func_oper->getName()) !=
760  if (!g_enable_hashjoin_many_to_many && needs_many_many()) {
761  LOG(WARNING) << "Many-to-many hashjoin support is disabled, unable to rewrite "
762  << func_oper->toString() << " to use accelerated geo join.";
763  return boost::none;
764  }
765 
766  DeepCopyVisitor deep_copy_visitor;
767  if (func_oper->getName() == "ST_Overlaps") {
768  CHECK_GE(func_oper->getArity(), size_t(2));
769  // return empty quals, overlaps join quals
770  // TODO(adb): we will likely want to actually check for true overlaps, but this
771  // works for now
772 
773  auto lhs = func_oper->getOwnArg(0);
774  auto rewritten_lhs = deep_copy_visitor.visit(lhs.get());
775  CHECK(rewritten_lhs);
776 
777  auto rhs = func_oper->getOwnArg(1);
778  auto rewritten_rhs = deep_copy_visitor.visit(rhs.get());
779  CHECK(rewritten_rhs);
780 
781  auto overlaps_oper = makeExpr<Analyzer::BinOper>(
782  kBOOLEAN, kOVERLAPS, kONE, rewritten_lhs, rewritten_rhs);
783  return OverlapsJoinConjunction{{}, {overlaps_oper}};
784  }
785 
786  // TODO(jclay): This will work for Poly_Poly,but needs to change for others.
787  CHECK_GE(func_oper->getArity(), size_t(4));
788  if (func_oper->getName() == "ST_Contains_Polygon_Polygon" ||
789  func_oper->getName() == "ST_Intersects_Polygon_Polygon" ||
790  func_oper->getName() == "ST_Intersects_MultiPolygon_MultiPolygon" ||
791  func_oper->getName() == "ST_Intersects_MultiPolygon_Polygon" ||
792  func_oper->getName() == "ST_Intersects_Polygon_MultiPolygon") {
793  auto lhs = func_oper->getOwnArg(3);
794  auto rewritten_lhs = deep_copy_visitor.visit(lhs.get());
795  CHECK(rewritten_lhs);
796  auto rhs = func_oper->getOwnArg(1);
797  auto rewritten_rhs = deep_copy_visitor.visit(rhs.get());
798  CHECK(rewritten_rhs);
799 
800  auto overlaps_oper = makeExpr<Analyzer::BinOper>(
801  kBOOLEAN, kOVERLAPS, kONE, rewritten_lhs, rewritten_rhs);
802 
803  VLOG(1) << "Successfully converted to overlaps join";
804  return OverlapsJoinConjunction{{expr}, {overlaps_oper}};
805  }
806 
807  auto lhs = func_oper->getOwnArg(2);
808  auto rewritten_lhs = deep_copy_visitor.visit(lhs.get());
809  CHECK(rewritten_lhs);
810  const auto& lhs_ti = rewritten_lhs->get_type_info();
811 
812  if (!lhs_ti.is_geometry()) {
813  // TODO(adb): If ST_Contains is passed geospatial literals instead of columns, the
814  // function will be expanded during translation rather than during code
815  // generation. While this scenario does not make sense for the overlaps join, we
816  // need to detect and abort the overlaps rewrite. Adding a GeospatialConstant
817  // dervied class to the Analyzer may prove to be a better way to handle geo
818  // literals, but for now we ensure the LHS type is a geospatial type, which would
819  // mean the function has not been expanded to the physical types, yet.
820 
821  LOG(INFO) << "Unable to rewrite " << func_oper->getName()
822  << " to overlaps conjunction. LHS input type is not a geospatial "
823  "type. Are both inputs geospatial columns?\n"
824  << func_oper->toString();
825 
826  return boost::none;
827  }
828 
829  // Read the bounds arg from the ST_Contains FuncOper (second argument)instead of the
830  // poly column (first argument)
831  auto rhs = func_oper->getOwnArg(1);
832  auto rewritten_rhs = deep_copy_visitor.visit(rhs.get());
833  CHECK(rewritten_rhs);
834 
835  // Check for compatible join ordering. If the join ordering does not match expected
836  // ordering for overlaps, the join builder will fail.
837  std::set<int> lhs_rte_idx;
838  lhs->collect_rte_idx(lhs_rte_idx);
839  CHECK(!lhs_rte_idx.empty());
840  std::set<int> rhs_rte_idx;
841  rhs->collect_rte_idx(rhs_rte_idx);
842  CHECK(!rhs_rte_idx.empty());
843 
844  if (lhs_rte_idx.size() > 1 || rhs_rte_idx.size() > 1 || lhs_rte_idx > rhs_rte_idx) {
845  LOG(INFO) << "Unable to rewrite " << func_oper->getName()
846  << " to overlaps conjunction. Cannot build hash table over LHS type. "
847  "Check join order.\n"
848  << func_oper->toString();
849  return boost::none;
850  }
851 
852  VLOG(1) << "Rewritten to use overlaps join with lhs as "
853  << rewritten_lhs->toString() << " and rhs as " << rewritten_rhs->toString();
854 
855  auto overlaps_oper = makeExpr<Analyzer::BinOper>(
856  kBOOLEAN, kOVERLAPS, kONE, rewritten_lhs, rewritten_rhs);
857 
858  VLOG(1) << "Successfully converted to overlaps join";
859  return OverlapsJoinConjunction{{expr}, {overlaps_oper}};
860  } else {
861  VLOG(1) << "Overlaps join not enabled for " << func_oper->getName();
862  }
863  }
864  return boost::none;
865 }
866 
876  public:
878  for (const auto& join_condition : join_quals) {
879  for (const auto& qual : join_condition.quals) {
880  auto qual_bin_oper = dynamic_cast<Analyzer::BinOper*>(qual.get());
881  if (qual_bin_oper) {
882  join_qual_pairs.emplace_back(qual_bin_oper->get_left_operand(),
883  qual_bin_oper->get_right_operand());
884  }
885  }
886  }
887  }
888 
889  bool visitFunctionOper(const Analyzer::FunctionOper* func_oper) const override {
890  if (overlaps_supported_functions.find(func_oper->getName()) !=
892  const auto lhs = func_oper->getArg(2);
893  const auto rhs = func_oper->getArg(1);
894  for (const auto& qual_pair : join_qual_pairs) {
895  if (*lhs == *qual_pair.first && *rhs == *qual_pair.second) {
896  return true;
897  }
898  }
899  }
900  return false;
901  }
902 
903  bool defaultResult() const override { return false; }
904 
905  private:
906  std::vector<std::pair<const Analyzer::Expr*, const Analyzer::Expr*>> join_qual_pairs;
907 };
908 
909 std::list<std::shared_ptr<Analyzer::Expr>> strip_join_covered_filter_quals(
910  const std::list<std::shared_ptr<Analyzer::Expr>>& quals,
911  const JoinQualsPerNestingLevel& join_quals) {
913  return quals;
914  }
915 
916  if (join_quals.empty()) {
917  return quals;
918  }
919 
920  std::list<std::shared_ptr<Analyzer::Expr>> quals_to_return;
921 
922  JoinCoveredQualVisitor visitor(join_quals);
923  for (const auto& qual : quals) {
924  if (!visitor.visit(qual.get())) {
925  // Not a covered qual, don't elide it from the filtered count
926  quals_to_return.push_back(qual);
927  }
928  }
929 
930  return quals_to_return;
931 }
932 
933 std::shared_ptr<Analyzer::Expr> fold_expr(const Analyzer::Expr* expr) {
934  if (!expr) {
935  return nullptr;
936  }
937  const auto expr_no_likelihood = strip_likelihood(expr);
938  ConstantFoldingVisitor visitor;
939  auto rewritten_expr = visitor.visit(expr_no_likelihood);
940  if (visitor.get_num_overflows() > 0 && rewritten_expr->get_type_info().is_integer() &&
941  rewritten_expr->get_type_info().get_type() != kBIGINT) {
942  auto rewritten_expr_const =
943  std::dynamic_pointer_cast<const Analyzer::Constant>(rewritten_expr);
944  if (!rewritten_expr_const) {
945  // Integer expression didn't fold completely the first time due to
946  // overflows in smaller type subexpressions, trying again with a cast
947  const auto& ti = SQLTypeInfo(kBIGINT, false);
948  auto bigint_expr_no_likelihood = expr_no_likelihood->deep_copy()->add_cast(ti);
949  auto rewritten_expr_take2 = visitor.visit(bigint_expr_no_likelihood.get());
950  auto rewritten_expr_take2_const =
951  std::dynamic_pointer_cast<Analyzer::Constant>(rewritten_expr_take2);
952  if (rewritten_expr_take2_const) {
953  // Managed to fold, switch to the new constant
954  rewritten_expr = rewritten_expr_take2_const;
955  }
956  }
957  }
958  const auto expr_with_likelihood = dynamic_cast<const Analyzer::LikelihoodExpr*>(expr);
959  if (expr_with_likelihood) {
960  // Add back likelihood
961  return std::make_shared<Analyzer::LikelihoodExpr>(
962  rewritten_expr, expr_with_likelihood->get_likelihood());
963  }
964  return rewritten_expr;
965 }
int8_t tinyintval
Definition: sqltypes.h:133
SQLQualifier get_qualifier() const
Definition: Analyzer.h:442
const std::shared_ptr< Analyzer::Expr > get_own_arg() const
Definition: Analyzer.h:753
Analyzer::ExpressionPtr rewrite_array_elements(Analyzer::Expr const *expr)
std::string to_lower(const std::string &str)
#define CHECK_EQ(x, y)
Definition: Logger.h:205
#define IS_LOGIC(X)
Definition: sqldefs.h:60
std::shared_ptr< Analyzer::InValues > visitLikeExpr(const Analyzer::LikeExpr *) const override
std::shared_ptr< Analyzer::InValues > visitDateaddExpr(const Analyzer::DateaddExpr *) const override
bool foldOper(SQLOps optype, SQLTypes type, Datum lhs, Datum rhs, Datum &result, SQLTypes &result_type) const
size_t getElementCount() const
Definition: Analyzer.h:1444
SQLTypes
Definition: sqltypes.h:39
std::shared_ptr< Analyzer::WindowFunction > rewrite_avg_window(const Analyzer::Expr *expr)
bool g_strip_join_covered_quals
Definition: Execute.cpp:95
std::string getName() const
Definition: Analyzer.h:1314
std::shared_ptr< Analyzer::InValues > visitKeyForString(const Analyzer::KeyForStringExpr *) const override
std::shared_ptr< Analyzer::WindowFunction > rewrite_sum_window(const Analyzer::Expr *expr)
const Analyzer::Expr * getArg(const size_t i) const
Definition: Analyzer.h:1318
std::shared_ptr< Analyzer::InValues > visitBinOper(const Analyzer::BinOper *bin_oper) const override
#define LOG(tag)
Definition: Logger.h:188
bool boolval
Definition: sqltypes.h:132
SQLOps
Definition: sqldefs.h:29
static const std::unordered_set< std::string > requires_many_to_many
Definition: sqldefs.h:35
Definition: sqldefs.h:36
Definition: sqldefs.h:38
bool isLocalAlloc() const
Definition: Analyzer.h:1445
#define CHECK_GE(x, y)
Definition: Logger.h:210
Definition: sqldefs.h:49
Definition: sqldefs.h:30
std::shared_ptr< Analyzer::Expr > ExpressionPtr
Definition: Analyzer.h:181
const Analyzer::Expr * extract_cast_arg(const Analyzer::Expr *expr)
Definition: Execute.h:129
Definition: sqldefs.h:41
std::vector< JoinCondition > JoinQualsPerNestingLevel
Analyzer::ExpressionPtr rewrite_expr(const Analyzer::Expr *expr)
static const std::unordered_set< std::string > overlaps_supported_functions
std::shared_ptr< Analyzer::InValues > aggregateResult(const std::shared_ptr< Analyzer::InValues > &lhs, const std::shared_ptr< Analyzer::InValues > &rhs) const override
std::list< std::shared_ptr< Analyzer::Expr > > strip_join_covered_filter_quals(const std::list< std::shared_ptr< Analyzer::Expr >> &quals, const JoinQualsPerNestingLevel &join_quals)
std::shared_ptr< Analyzer::InValues > visitUOper(const Analyzer::UOper *uoper) const override
std::shared_ptr< Analyzer::Expr > visitBinOper(const Analyzer::BinOper *bin_oper) const override
bool is_decimal() const
Definition: sqltypes.h:411
std::shared_ptr< Analyzer::Expr > RetType
int32_t intval
Definition: sqltypes.h:135
SQLOps get_optype() const
Definition: Analyzer.h:440
std::shared_ptr< Analyzer::InValues > visitCharLength(const Analyzer::CharLengthExpr *) const override
std::shared_ptr< Analyzer::InValues > visitInIntegerSet(const Analyzer::InIntegerSet *) const override
float floatval
Definition: sqltypes.h:137
Datum get_constval() const
Definition: Analyzer.h:336
std::shared_ptr< Analyzer::Expr > visitUOper(const Analyzer::UOper *uoper) const override
const Analyzer::Expr * getElement(const size_t i) const
Definition: Analyzer.h:1448
bool g_enable_hashjoin_many_to_many
Definition: Execute.cpp:92
std::shared_ptr< Analyzer::InValues > visitRegexpExpr(const Analyzer::RegexpExpr *) const override
int64_t bigintval
Definition: sqltypes.h:136
std::shared_ptr< Analyzer::InValues > visitAggExpr(const Analyzer::AggExpr *) const override
bool visitFunctionOper(const Analyzer::FunctionOper *func_oper) const override
Definition: sqldefs.h:37
int16_t smallintval
Definition: sqltypes.h:134
bool defaultResult() const override
static std::shared_ptr< Analyzer::Expr > analyzeValue(const std::string &)
Definition: ParserNode.cpp:102
boost::optional< OverlapsJoinConjunction > rewrite_overlaps_conjunction(const std::shared_ptr< Analyzer::Expr > expr)
RetType visitArrayOper(const Analyzer::ArrayExpr *array_expr) const override
const std::shared_ptr< Analyzer::Expr > get_own_left_operand() const
Definition: Analyzer.h:445
std::shared_ptr< Analyzer::Expr > visitBinOper(const Analyzer::BinOper *bin_oper) const override
SQLTypes decimal_to_int_type(const SQLTypeInfo &ti)
Definition: Datum.cpp:303
Definition: sqldefs.h:34
const std::shared_ptr< Analyzer::Expr > get_own_right_operand() const
Definition: Analyzer.h:448
Definition: sqldefs.h:40
std::vector< std::pair< const Analyzer::Expr *, const Analyzer::Expr * > > join_qual_pairs
Definition: sqldefs.h:69
#define TRANSIENT_DICT_ID
Definition: sqltypes.h:196
#define IS_ARITHMETIC(X)
Definition: sqldefs.h:61
Definition: sqldefs.h:32
Expression class for the LOWER (lowercase) string function. The "arg" constructor parameter must be a...
Definition: Analyzer.h:747
T visit(const Analyzer::Expr *expr) const
std::unordered_map< const Analyzer::Expr *, const SQLTypeInfo > casts_
std::shared_ptr< Analyzer::InValues > visitDatediffExpr(const Analyzer::DatediffExpr *) const override
std::shared_ptr< Analyzer::InValues > visitLikelihood(const Analyzer::LikelihoodExpr *) const override
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:79
std::shared_ptr< Analyzer::Expr > visitLower(const Analyzer::LowerExpr *lower_expr) const override
const Expr * get_arg() const
Definition: Analyzer.h:751
std::shared_ptr< Analyzer::InValues > visitInValues(const Analyzer::InValues *) const override
#define CHECK(condition)
Definition: Logger.h:197
Definition: sqldefs.h:33
bool get_contains_agg() const
Definition: Analyzer.h:81
std::shared_ptr< Analyzer::InValues > visitCaseExpr(const Analyzer::CaseExpr *) const override
Definition: sqltypes.h:46
JoinCoveredQualVisitor(const JoinQualsPerNestingLevel &join_quals)
std::shared_ptr< Analyzer::InValues > visitDatetruncExpr(const Analyzer::DatetruncExpr *) const override
bool isNull() const
Definition: Analyzer.h:1446
Definition: sqldefs.h:39
#define VLOG(n)
Definition: Logger.h:291
std::shared_ptr< Analyzer::Expr > fold_expr(const Analyzer::Expr *expr)
SQLOps get_optype() const
Definition: Analyzer.h:371
#define IS_COMPARISON(X)
Definition: sqldefs.h:57
double doubleval
Definition: sqltypes.h:138
std::shared_ptr< Analyzer::InValues > visitExtractExpr(const Analyzer::ExtractExpr *) const override
const Analyzer::Expr * strip_likelihood(const Analyzer::Expr *expr)
const Expr * get_right_operand() const
Definition: Analyzer.h:444
std::shared_ptr< Analyzer::InValues > visitCardinality(const Analyzer::CardinalityExpr *) const override
const Expr * get_left_operand() const
Definition: Analyzer.h:443
const Expr * get_operand() const
Definition: Analyzer.h:372