OmniSciDB  8fa3bf436f
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ExpressionRewrite.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ExpressionRewrite.h"
18 #include "../Analyzer/Analyzer.h"
19 #include "../Parser/ParserNode.h"
20 #include "../Shared/sqldefs.h"
21 #include "DeepCopyVisitor.h"
22 #include "Execute.h"
23 #include "Logger/Logger.h"
24 #include "RelAlgTranslator.h"
25 #include "ScalarExprVisitor.h"
27 
28 #include <boost/locale/conversion.hpp>
29 #include <unordered_set>
30 
31 namespace {
32 
33 class OrToInVisitor : public ScalarExprVisitor<std::shared_ptr<Analyzer::InValues>> {
34  protected:
35  std::shared_ptr<Analyzer::InValues> visitBinOper(
36  const Analyzer::BinOper* bin_oper) const override {
37  switch (bin_oper->get_optype()) {
38  case kEQ: {
39  const auto rhs_owned = bin_oper->get_own_right_operand();
40  auto rhs_no_cast = extract_cast_arg(rhs_owned.get());
41  if (!dynamic_cast<const Analyzer::Constant*>(rhs_no_cast)) {
42  return nullptr;
43  }
44  const auto arg = bin_oper->get_own_left_operand();
45  const auto& arg_ti = arg->get_type_info();
46  auto rhs = rhs_no_cast->deep_copy()->add_cast(arg_ti);
47  return makeExpr<Analyzer::InValues>(
48  arg, std::list<std::shared_ptr<Analyzer::Expr>>{rhs});
49  }
50  case kOR: {
51  return aggregateResult(visit(bin_oper->get_left_operand()),
52  visit(bin_oper->get_right_operand()));
53  }
54  default:
55  break;
56  }
57  return nullptr;
58  }
59 
60  std::shared_ptr<Analyzer::InValues> visitUOper(
61  const Analyzer::UOper* uoper) const override {
62  return nullptr;
63  }
64 
65  std::shared_ptr<Analyzer::InValues> visitInValues(
66  const Analyzer::InValues*) const override {
67  return nullptr;
68  }
69 
70  std::shared_ptr<Analyzer::InValues> visitInIntegerSet(
71  const Analyzer::InIntegerSet*) const override {
72  return nullptr;
73  }
74 
75  std::shared_ptr<Analyzer::InValues> visitCharLength(
76  const Analyzer::CharLengthExpr*) const override {
77  return nullptr;
78  }
79 
80  std::shared_ptr<Analyzer::InValues> visitKeyForString(
81  const Analyzer::KeyForStringExpr*) const override {
82  return nullptr;
83  }
84 
85  std::shared_ptr<Analyzer::InValues> visitSampleRatio(
86  const Analyzer::SampleRatioExpr*) const override {
87  return nullptr;
88  }
89 
90  std::shared_ptr<Analyzer::InValues> visitCardinality(
91  const Analyzer::CardinalityExpr*) const override {
92  return nullptr;
93  }
94 
95  std::shared_ptr<Analyzer::InValues> visitLikeExpr(
96  const Analyzer::LikeExpr*) const override {
97  return nullptr;
98  }
99 
100  std::shared_ptr<Analyzer::InValues> visitRegexpExpr(
101  const Analyzer::RegexpExpr*) const override {
102  return nullptr;
103  }
104 
105  std::shared_ptr<Analyzer::InValues> visitCaseExpr(
106  const Analyzer::CaseExpr*) const override {
107  return nullptr;
108  }
109 
110  std::shared_ptr<Analyzer::InValues> visitDatetruncExpr(
111  const Analyzer::DatetruncExpr*) const override {
112  return nullptr;
113  }
114 
115  std::shared_ptr<Analyzer::InValues> visitDatediffExpr(
116  const Analyzer::DatediffExpr*) const override {
117  return nullptr;
118  }
119 
120  std::shared_ptr<Analyzer::InValues> visitDateaddExpr(
121  const Analyzer::DateaddExpr*) const override {
122  return nullptr;
123  }
124 
125  std::shared_ptr<Analyzer::InValues> visitExtractExpr(
126  const Analyzer::ExtractExpr*) const override {
127  return nullptr;
128  }
129 
130  std::shared_ptr<Analyzer::InValues> visitLikelihood(
131  const Analyzer::LikelihoodExpr*) const override {
132  return nullptr;
133  }
134 
135  std::shared_ptr<Analyzer::InValues> visitAggExpr(
136  const Analyzer::AggExpr*) const override {
137  return nullptr;
138  }
139 
140  std::shared_ptr<Analyzer::InValues> aggregateResult(
141  const std::shared_ptr<Analyzer::InValues>& lhs,
142  const std::shared_ptr<Analyzer::InValues>& rhs) const override {
143  if (!lhs || !rhs) {
144  return nullptr;
145  }
146 
147  if (lhs->get_arg()->get_type_info() == rhs->get_arg()->get_type_info() &&
148  (*lhs->get_arg() == *rhs->get_arg())) {
149  auto union_values = lhs->get_value_list();
150  const auto& rhs_values = rhs->get_value_list();
151  union_values.insert(union_values.end(), rhs_values.begin(), rhs_values.end());
152  return makeExpr<Analyzer::InValues>(lhs->get_own_arg(), union_values);
153  }
154  return nullptr;
155  }
156 };
157 
159  protected:
160  std::shared_ptr<Analyzer::Expr> visitBinOper(
161  const Analyzer::BinOper* bin_oper) const override {
162  OrToInVisitor simple_visitor;
163  if (bin_oper->get_optype() == kOR) {
164  auto rewritten = simple_visitor.visit(bin_oper);
165  if (rewritten) {
166  return rewritten;
167  }
168  }
169  auto lhs = bin_oper->get_own_left_operand();
170  auto rhs = bin_oper->get_own_right_operand();
171  auto rewritten_lhs = visit(lhs.get());
172  auto rewritten_rhs = visit(rhs.get());
173  return makeExpr<Analyzer::BinOper>(bin_oper->get_type_info(),
174  bin_oper->get_contains_agg(),
175  bin_oper->get_optype(),
176  bin_oper->get_qualifier(),
177  rewritten_lhs ? rewritten_lhs : lhs,
178  rewritten_rhs ? rewritten_rhs : rhs);
179  }
180 };
181 
183  protected:
185 
186  RetType visitArrayOper(const Analyzer::ArrayExpr* array_expr) const override {
187  std::vector<std::shared_ptr<Analyzer::Expr>> args_copy;
188  for (size_t i = 0; i < array_expr->getElementCount(); ++i) {
189  auto const element_expr_ptr = visit(array_expr->getElement(i));
190  auto const& element_expr_type_info = element_expr_ptr->get_type_info();
191 
192  if (!element_expr_type_info.is_string() ||
193  element_expr_type_info.get_compression() != kENCODING_NONE) {
194  args_copy.push_back(element_expr_ptr);
195  } else {
196  auto transient_dict_type_info = element_expr_type_info;
197 
198  transient_dict_type_info.set_compression(kENCODING_DICT);
199  transient_dict_type_info.set_comp_param(TRANSIENT_DICT_ID);
200  transient_dict_type_info.set_fixed_size();
201  args_copy.push_back(element_expr_ptr->add_cast(transient_dict_type_info));
202  }
203  }
204 
205  const auto& type_info = array_expr->get_type_info();
206  return makeExpr<Analyzer::ArrayExpr>(
207  type_info, args_copy, array_expr->isNull(), array_expr->isLocalAlloc());
208  }
209 };
210 
212  template <typename T>
213  bool foldComparison(SQLOps optype, T t1, T t2) const {
214  switch (optype) {
215  case kEQ:
216  return t1 == t2;
217  case kNE:
218  return t1 != t2;
219  case kLT:
220  return t1 < t2;
221  case kLE:
222  return t1 <= t2;
223  case kGT:
224  return t1 > t2;
225  case kGE:
226  return t1 >= t2;
227  default:
228  break;
229  }
230  throw std::runtime_error("Unable to fold");
231  return false;
232  }
233 
234  template <typename T>
235  bool foldLogic(SQLOps optype, T t1, T t2) const {
236  switch (optype) {
237  case kAND:
238  return t1 && t2;
239  case kOR:
240  return t1 || t2;
241  case kNOT:
242  return !t1;
243  default:
244  break;
245  }
246  throw std::runtime_error("Unable to fold");
247  return false;
248  }
249 
250  template <typename T>
251  T foldArithmetic(SQLOps optype, T t1, T t2) const {
252  bool t2_is_zero = (t2 == (t2 - t2));
253  bool t2_is_negative = (t2 < (t2 - t2));
254  switch (optype) {
255  case kPLUS:
256  // The MIN limit for float and double is the smallest representable value,
257  // not the lowest negative value! Switching to C++11 lowest.
258  if ((t2_is_negative && t1 < std::numeric_limits<T>::lowest() - t2) ||
259  (!t2_is_negative && t1 > std::numeric_limits<T>::max() - t2)) {
260  num_overflows_++;
261  throw std::runtime_error("Plus overflow");
262  }
263  return t1 + t2;
264  case kMINUS:
265  if ((t2_is_negative && t1 > std::numeric_limits<T>::max() + t2) ||
266  (!t2_is_negative && t1 < std::numeric_limits<T>::lowest() + t2)) {
267  num_overflows_++;
268  throw std::runtime_error("Minus overflow");
269  }
270  return t1 - t2;
271  case kMULTIPLY: {
272  if (t2_is_zero) {
273  return t2;
274  }
275  auto ct1 = t1;
276  auto ct2 = t2;
277  // Need to keep t2's sign on the left
278  if (t2_is_negative) {
279  if (t1 == std::numeric_limits<T>::lowest() ||
280  t2 == std::numeric_limits<T>::lowest()) {
281  // negation could overflow - bail
282  num_overflows_++;
283  throw std::runtime_error("Mul neg overflow");
284  }
285  ct1 = -t1; // ct1 gets t2's negativity
286  ct2 = -t2; // ct2 is now positive
287  }
288  // Don't check overlow if we are folding FP mul by a fraction
289  bool ct2_is_fraction = (ct2 < (ct2 / ct2));
290  if (!ct2_is_fraction) {
291  if (ct1 > std::numeric_limits<T>::max() / ct2 ||
292  ct1 < std::numeric_limits<T>::lowest() / ct2) {
293  num_overflows_++;
294  throw std::runtime_error("Mul overflow");
295  }
296  }
297  return t1 * t2;
298  }
299  case kDIVIDE:
300  if (t2_is_zero) {
301  throw std::runtime_error("Will not fold division by zero");
302  }
303  return t1 / t2;
304  default:
305  break;
306  }
307  throw std::runtime_error("Unable to fold");
308  }
309 
310  bool foldOper(SQLOps optype,
311  SQLTypes type,
312  Datum lhs,
313  Datum rhs,
314  Datum& result,
315  SQLTypes& result_type) const {
316  result_type = type;
317 
318  try {
319  switch (type) {
320  case kBOOLEAN:
321  if (IS_COMPARISON(optype)) {
322  result.boolval = foldComparison<bool>(optype, lhs.boolval, rhs.boolval);
323  result_type = kBOOLEAN;
324  return true;
325  }
326  if (IS_LOGIC(optype)) {
327  result.boolval = foldLogic<bool>(optype, lhs.boolval, rhs.boolval);
328  result_type = kBOOLEAN;
329  return true;
330  }
331  CHECK(!IS_ARITHMETIC(optype));
332  break;
333  case kTINYINT:
334  if (IS_COMPARISON(optype)) {
335  result.boolval =
336  foldComparison<int8_t>(optype, lhs.tinyintval, rhs.tinyintval);
337  result_type = kBOOLEAN;
338  return true;
339  }
340  if (IS_ARITHMETIC(optype)) {
341  result.tinyintval =
342  foldArithmetic<int8_t>(optype, lhs.tinyintval, rhs.tinyintval);
343  result_type = kTINYINT;
344  return true;
345  }
346  CHECK(!IS_LOGIC(optype));
347  break;
348  case kSMALLINT:
349  if (IS_COMPARISON(optype)) {
350  result.boolval =
351  foldComparison<int16_t>(optype, lhs.smallintval, rhs.smallintval);
352  result_type = kBOOLEAN;
353  return true;
354  }
355  if (IS_ARITHMETIC(optype)) {
356  result.smallintval =
357  foldArithmetic<int16_t>(optype, lhs.smallintval, rhs.smallintval);
358  result_type = kSMALLINT;
359  return true;
360  }
361  CHECK(!IS_LOGIC(optype));
362  break;
363  case kINT:
364  if (IS_COMPARISON(optype)) {
365  result.boolval = foldComparison<int32_t>(optype, lhs.intval, rhs.intval);
366  result_type = kBOOLEAN;
367  return true;
368  }
369  if (IS_ARITHMETIC(optype)) {
370  result.intval = foldArithmetic<int32_t>(optype, lhs.intval, rhs.intval);
371  result_type = kINT;
372  return true;
373  }
374  CHECK(!IS_LOGIC(optype));
375  break;
376  case kBIGINT:
377  if (IS_COMPARISON(optype)) {
378  result.boolval =
379  foldComparison<int64_t>(optype, lhs.bigintval, rhs.bigintval);
380  result_type = kBOOLEAN;
381  return true;
382  }
383  if (IS_ARITHMETIC(optype)) {
384  result.bigintval =
385  foldArithmetic<int64_t>(optype, lhs.bigintval, rhs.bigintval);
386  result_type = kBIGINT;
387  return true;
388  }
389  CHECK(!IS_LOGIC(optype));
390  break;
391  case kFLOAT:
392  if (IS_COMPARISON(optype)) {
393  result.boolval = foldComparison<float>(optype, lhs.floatval, rhs.floatval);
394  result_type = kBOOLEAN;
395  return true;
396  }
397  if (IS_ARITHMETIC(optype)) {
398  result.floatval = foldArithmetic<float>(optype, lhs.floatval, rhs.floatval);
399  result_type = kFLOAT;
400  return true;
401  }
402  CHECK(!IS_LOGIC(optype));
403  break;
404  case kDOUBLE:
405  if (IS_COMPARISON(optype)) {
406  result.boolval = foldComparison<double>(optype, lhs.doubleval, rhs.doubleval);
407  result_type = kBOOLEAN;
408  return true;
409  }
410  if (IS_ARITHMETIC(optype)) {
411  result.doubleval =
412  foldArithmetic<double>(optype, lhs.doubleval, rhs.doubleval);
413  result_type = kDOUBLE;
414  return true;
415  }
416  CHECK(!IS_LOGIC(optype));
417  break;
418  default:
419  break;
420  }
421  } catch (...) {
422  return false;
423  }
424  return false;
425  }
426 
427  std::shared_ptr<Analyzer::Expr> visitUOper(
428  const Analyzer::UOper* uoper) const override {
429  const auto unvisited_operand = uoper->get_operand();
430  const auto optype = uoper->get_optype();
431  const auto& ti = uoper->get_type_info();
432  if (optype == kCAST) {
433  // Cache the cast type so it could be used in operand rewriting/folding
434  casts_.insert({unvisited_operand, ti});
435  }
436  const auto operand = visit(unvisited_operand);
437  const auto& operand_ti = operand->get_type_info();
438  const auto operand_type =
439  operand_ti.is_decimal() ? decimal_to_int_type(operand_ti) : operand_ti.get_type();
440  const auto const_operand =
441  std::dynamic_pointer_cast<const Analyzer::Constant>(operand);
442 
443  if (const_operand) {
444  const auto operand_datum = const_operand->get_constval();
445  Datum zero_datum = {};
446  Datum result_datum = {};
447  SQLTypes result_type;
448  switch (optype) {
449  case kNOT: {
450  if (foldOper(kEQ,
451  operand_type,
452  zero_datum,
453  operand_datum,
454  result_datum,
455  result_type)) {
456  CHECK_EQ(result_type, kBOOLEAN);
457  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
458  }
459  break;
460  }
461  case kUMINUS: {
462  if (foldOper(kMINUS,
463  operand_type,
464  zero_datum,
465  operand_datum,
466  result_datum,
467  result_type)) {
468  if (!operand_ti.is_decimal()) {
469  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
470  }
471  return makeExpr<Analyzer::Constant>(ti, false, result_datum);
472  }
473  break;
474  }
475  case kCAST: {
476  // Trying to fold number to number casts only
477  if (!ti.is_number() || !operand_ti.is_number()) {
478  break;
479  }
480  // Disallowing folding of FP to DECIMAL casts for now:
481  // allowing them would make this test pass:
482  // update dectest set d=cast( 1234.0 as float );
483  // which is expected to throw in Update.ImplicitCastToNumericTypes
484  // due to cast codegen currently not supporting these casts
485  if (ti.is_decimal() && operand_ti.is_fp()) {
486  break;
487  }
488  auto operand_copy = const_operand->deep_copy();
489  auto cast_operand = operand_copy->add_cast(ti);
490  auto const_cast_operand =
491  std::dynamic_pointer_cast<const Analyzer::Constant>(cast_operand);
492  if (const_cast_operand) {
493  auto const_cast_datum = const_cast_operand->get_constval();
494  return makeExpr<Analyzer::Constant>(ti, false, const_cast_datum);
495  }
496  }
497  default:
498  break;
499  }
500  }
501 
502  return makeExpr<Analyzer::UOper>(
503  uoper->get_type_info(), uoper->get_contains_agg(), optype, operand);
504  }
505 
506  std::shared_ptr<Analyzer::Expr> visitBinOper(
507  const Analyzer::BinOper* bin_oper) const override {
508  const auto optype = bin_oper->get_optype();
509  auto ti = bin_oper->get_type_info();
510  auto left_operand = bin_oper->get_own_left_operand();
511  auto right_operand = bin_oper->get_own_right_operand();
512 
513  // Check if bin_oper result is cast to a larger int or fp type
514  if (casts_.find(bin_oper) != casts_.end()) {
515  const auto cast_ti = casts_[bin_oper];
516  const auto& lhs_ti = bin_oper->get_left_operand()->get_type_info();
517  // Propagate cast down to the operands for folding
518  if ((cast_ti.is_integer() || cast_ti.is_fp()) && lhs_ti.is_integer() &&
519  cast_ti.get_size() > lhs_ti.get_size() &&
520  (optype == kMINUS || optype == kPLUS || optype == kMULTIPLY)) {
521  // Before folding, cast the operands to the bigger type to avoid overflows.
522  // Currently upcasting smaller integer types to larger integers or double.
523  left_operand = left_operand->deep_copy()->add_cast(cast_ti);
524  right_operand = right_operand->deep_copy()->add_cast(cast_ti);
525  ti = cast_ti;
526  }
527  }
528 
529  const auto lhs = visit(left_operand.get());
530  const auto rhs = visit(right_operand.get());
531 
532  auto const_lhs = std::dynamic_pointer_cast<Analyzer::Constant>(lhs);
533  auto const_rhs = std::dynamic_pointer_cast<Analyzer::Constant>(rhs);
534  const auto& lhs_ti = lhs->get_type_info();
535  const auto& rhs_ti = rhs->get_type_info();
536  auto lhs_type = lhs_ti.is_decimal() ? decimal_to_int_type(lhs_ti) : lhs_ti.get_type();
537  auto rhs_type = rhs_ti.is_decimal() ? decimal_to_int_type(rhs_ti) : rhs_ti.get_type();
538 
539  if (const_lhs && const_rhs && lhs_type == rhs_type) {
540  auto lhs_datum = const_lhs->get_constval();
541  auto rhs_datum = const_rhs->get_constval();
542  Datum result_datum = {};
543  SQLTypes result_type;
544  if (foldOper(optype, lhs_type, lhs_datum, rhs_datum, result_datum, result_type)) {
545  // Fold all ops that don't take in decimal operands, and also decimal comparisons
546  if (!lhs_ti.is_decimal() || IS_COMPARISON(optype)) {
547  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
548  }
549  // Decimal arithmetic has been done as kBIGINT. Selectively fold some decimal ops,
550  // using result_datum and BinOper expr typeinfo which was adjusted for these ops.
551  if (optype == kMINUS || optype == kPLUS || optype == kMULTIPLY) {
552  return makeExpr<Analyzer::Constant>(ti, false, result_datum);
553  }
554  }
555  }
556 
557  if (optype == kAND && lhs_type == rhs_type && lhs_type == kBOOLEAN) {
558  if (const_rhs && !const_rhs->get_is_null()) {
559  auto rhs_datum = const_rhs->get_constval();
560  if (rhs_datum.boolval == false) {
561  Datum d;
562  d.boolval = false;
563  // lhs && false --> false
564  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
565  }
566  // lhs && true --> lhs
567  return lhs;
568  }
569  if (const_lhs && !const_lhs->get_is_null()) {
570  auto lhs_datum = const_lhs->get_constval();
571  if (lhs_datum.boolval == false) {
572  Datum d;
573  d.boolval = false;
574  // false && rhs --> false
575  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
576  }
577  // true && rhs --> rhs
578  return rhs;
579  }
580  }
581  if (optype == kOR && lhs_type == rhs_type && lhs_type == kBOOLEAN) {
582  if (const_rhs && !const_rhs->get_is_null()) {
583  auto rhs_datum = const_rhs->get_constval();
584  if (rhs_datum.boolval == true) {
585  Datum d;
586  d.boolval = true;
587  // lhs || true --> true
588  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
589  }
590  // lhs || false --> lhs
591  return lhs;
592  }
593  if (const_lhs && !const_lhs->get_is_null()) {
594  auto lhs_datum = const_lhs->get_constval();
595  if (lhs_datum.boolval == true) {
596  Datum d;
597  d.boolval = true;
598  // true || rhs --> true
599  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
600  }
601  // false || rhs --> rhs
602  return rhs;
603  }
604  }
605  if (*lhs == *rhs) {
606  // Tautologies: v=v; v<=v; v>=v
607  if (optype == kEQ || optype == kLE || optype == kGE) {
608  Datum d;
609  d.boolval = true;
610  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
611  }
612  // Contradictions: v!=v; v<v; v>v
613  if (optype == kNE || optype == kLT || optype == kGT) {
614  Datum d;
615  d.boolval = false;
616  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
617  }
618  // v-v
619  if (optype == kMINUS) {
620  Datum d = {};
621  return makeExpr<Analyzer::Constant>(lhs_type, false, d);
622  }
623  }
624  // Convert fp division by a constant to multiplication by 1/constant
625  if (optype == kDIVIDE && const_rhs && rhs_ti.is_fp()) {
626  auto rhs_datum = const_rhs->get_constval();
627  std::shared_ptr<Analyzer::Expr> recip_rhs = nullptr;
628  if (rhs_ti.get_type() == kFLOAT) {
629  if (rhs_datum.floatval == 1.0) {
630  return lhs;
631  }
632  auto f = std::fabs(rhs_datum.floatval);
633  if (f > 1.0 || (f != 0.0 && 1.0 < f * std::numeric_limits<float>::max())) {
634  rhs_datum.floatval = 1.0 / rhs_datum.floatval;
635  recip_rhs = makeExpr<Analyzer::Constant>(rhs_type, false, rhs_datum);
636  }
637  } else if (rhs_ti.get_type() == kDOUBLE) {
638  if (rhs_datum.doubleval == 1.0) {
639  return lhs;
640  }
641  auto d = std::fabs(rhs_datum.doubleval);
642  if (d > 1.0 || (d != 0.0 && 1.0 < d * std::numeric_limits<double>::max())) {
643  rhs_datum.doubleval = 1.0 / rhs_datum.doubleval;
644  recip_rhs = makeExpr<Analyzer::Constant>(rhs_type, false, rhs_datum);
645  }
646  }
647  if (recip_rhs) {
648  return makeExpr<Analyzer::BinOper>(ti,
649  bin_oper->get_contains_agg(),
650  kMULTIPLY,
651  bin_oper->get_qualifier(),
652  lhs,
653  recip_rhs);
654  }
655  }
656 
657  return makeExpr<Analyzer::BinOper>(ti,
658  bin_oper->get_contains_agg(),
659  bin_oper->get_optype(),
660  bin_oper->get_qualifier(),
661  lhs,
662  rhs);
663  }
664 
665  std::shared_ptr<Analyzer::Expr> visitLower(
666  const Analyzer::LowerExpr* lower_expr) const override {
667  const auto constant_arg_expr =
668  dynamic_cast<const Analyzer::Constant*>(lower_expr->get_arg());
669  if (constant_arg_expr) {
671  boost::locale::to_lower(*constant_arg_expr->get_constval().stringval));
672  }
673  return makeExpr<Analyzer::LowerExpr>(lower_expr->get_own_arg());
674  }
675 
676  protected:
677  mutable std::unordered_map<const Analyzer::Expr*, const SQLTypeInfo> casts_;
678  mutable int32_t num_overflows_;
679 
680  public:
681  ConstantFoldingVisitor() : num_overflows_(0) {}
682  int32_t get_num_overflows() { return num_overflows_; }
683  void reset_num_overflows() { num_overflows_ = 0; }
684 };
685 
687  const auto with_likelihood = dynamic_cast<const Analyzer::LikelihoodExpr*>(expr);
688  if (!with_likelihood) {
689  return expr;
690  }
691  return with_likelihood->get_arg();
692 }
693 
694 } // namespace
695 
697  return ArrayElementStringLiteralEncodingVisitor().visit(expr);
698 }
699 
701  const auto sum_window = rewrite_sum_window(expr);
702  if (sum_window) {
703  return sum_window;
704  }
705  const auto avg_window = rewrite_avg_window(expr);
706  if (avg_window) {
707  return avg_window;
708  }
709  const auto expr_no_likelihood = strip_likelihood(expr);
710  // The following check is not strictly needed, but seems silly to transform a
711  // simple string comparison to an IN just to codegen the same thing anyway.
712 
713  RecursiveOrToInVisitor visitor;
714  auto rewritten_expr = visitor.visit(expr_no_likelihood);
715  const auto expr_with_likelihood =
716  std::dynamic_pointer_cast<const Analyzer::LikelihoodExpr>(rewritten_expr);
717  if (expr_with_likelihood) {
718  // Add back likelihood
719  return std::make_shared<Analyzer::LikelihoodExpr>(
720  rewritten_expr, expr_with_likelihood->get_likelihood());
721  }
722  return rewritten_expr;
723 }
724 
725 namespace {
726 
727 static const std::unordered_set<std::string> overlaps_supported_functions = {
728  "ST_Contains_MultiPolygon_Point",
729  "ST_Contains_Polygon_Point",
730  "ST_cContains_MultiPolygon_Point", // compressed coords version
731  "ST_cContains_Polygon_Point",
732  "ST_Contains_Polygon_Polygon",
733  "ST_Contains_Polygon_MultiPolygon",
734  "ST_Contains_MultiPolygon_MultiPolygon",
735  "ST_Contains_MultiPolygon_Polygon",
736  "ST_Intersects_Polygon_Point",
737  "ST_Intersects_Polygon_Polygon",
738  "ST_Intersects_Polygon_MultiPolygon",
739  "ST_Intersects_MultiPolygon_MultiPolygon",
740  "ST_Intersects_MultiPolygon_Polygon",
741  "ST_Intersects_MultiPolygon_Point",
742  "ST_Approx_Overlaps_MultiPolygon_Point",
743  "ST_Overlaps"};
744 
745 static const std::unordered_set<std::string> requires_many_to_many = {
746  "ST_Contains_Polygon_Polygon",
747  "ST_Contains_Polygon_MultiPolygon",
748  "ST_Contains_MultiPolygon_MultiPolygon",
749  "ST_Contains_MultiPolygon_Polygon",
750  "ST_Intersects_Polygon_Polygon",
751  "ST_Intersects_Polygon_MultiPolygon",
752  "ST_Intersects_MultiPolygon_MultiPolygon",
753  "ST_Intersects_MultiPolygon_Polygon"};
754 
755 } // namespace
756 
757 boost::optional<OverlapsJoinConjunction> rewrite_overlaps_conjunction(
758  const std::shared_ptr<Analyzer::Expr> expr) {
759  auto func_oper = dynamic_cast<Analyzer::FunctionOper*>(expr.get());
760  if (func_oper) {
761  const auto needs_many_many = [func_oper]() {
762  return requires_many_to_many.find(func_oper->getName()) !=
763  requires_many_to_many.end();
764  };
765  // TODO(adb): consider converting unordered set to an unordered map, potentially
766  // storing the rewrite function we want to apply in the map
767  if (overlaps_supported_functions.find(func_oper->getName()) !=
769  if (!g_enable_hashjoin_many_to_many && needs_many_many()) {
770  LOG(WARNING) << "Many-to-many hashjoin support is disabled, unable to rewrite "
771  << func_oper->toString() << " to use accelerated geo join.";
772  return boost::none;
773  }
774 
775  DeepCopyVisitor deep_copy_visitor;
776  if (func_oper->getName() == "ST_Overlaps") {
777  CHECK_GE(func_oper->getArity(), size_t(2));
778  // return empty quals, overlaps join quals
779  // TODO(adb): we will likely want to actually check for true overlaps, but this
780  // works for now
781 
782  auto lhs = func_oper->getOwnArg(0);
783  auto rewritten_lhs = deep_copy_visitor.visit(lhs.get());
784  CHECK(rewritten_lhs);
785 
786  auto rhs = func_oper->getOwnArg(1);
787  auto rewritten_rhs = deep_copy_visitor.visit(rhs.get());
788  CHECK(rewritten_rhs);
789 
790  auto overlaps_oper = makeExpr<Analyzer::BinOper>(
791  kBOOLEAN, kOVERLAPS, kONE, rewritten_lhs, rewritten_rhs);
792  return OverlapsJoinConjunction{{}, {overlaps_oper}};
793  }
794 
795  // TODO(jclay): This will work for Poly_Poly,but needs to change for others.
796  CHECK_GE(func_oper->getArity(), size_t(4));
797  if (func_oper->getName() == "ST_Contains_Polygon_Polygon" ||
798  func_oper->getName() == "ST_Intersects_Polygon_Polygon" ||
799  func_oper->getName() == "ST_Intersects_MultiPolygon_MultiPolygon" ||
800  func_oper->getName() == "ST_Intersects_MultiPolygon_Polygon" ||
801  func_oper->getName() == "ST_Intersects_Polygon_MultiPolygon") {
802  auto lhs = func_oper->getOwnArg(3);
803  auto rewritten_lhs = deep_copy_visitor.visit(lhs.get());
804  CHECK(rewritten_lhs);
805  auto rhs = func_oper->getOwnArg(1);
806  auto rewritten_rhs = deep_copy_visitor.visit(rhs.get());
807  CHECK(rewritten_rhs);
808 
809  auto overlaps_oper = makeExpr<Analyzer::BinOper>(
810  kBOOLEAN, kOVERLAPS, kONE, rewritten_lhs, rewritten_rhs);
811 
812  VLOG(1) << "Successfully converted to overlaps join";
813  return OverlapsJoinConjunction{{expr}, {overlaps_oper}};
814  }
815 
816  auto lhs = func_oper->getOwnArg(2);
817  auto rewritten_lhs = deep_copy_visitor.visit(lhs.get());
818  CHECK(rewritten_lhs);
819  const auto& lhs_ti = rewritten_lhs->get_type_info();
820 
821  if (!lhs_ti.is_geometry() && !is_constructed_point(rewritten_lhs.get())) {
822  // TODO(adb): If ST_Contains is passed geospatial literals instead of columns, the
823  // function will be expanded during translation rather than during code
824  // generation. While this scenario does not make sense for the overlaps join, we
825  // need to detect and abort the overlaps rewrite. Adding a GeospatialConstant
826  // dervied class to the Analyzer may prove to be a better way to handle geo
827  // literals, but for now we ensure the LHS type is a geospatial type, which would
828  // mean the function has not been expanded to the physical types, yet.
829 
830  LOG(INFO) << "Unable to rewrite " << func_oper->getName()
831  << " to overlaps conjunction. LHS input type is neither a geospatial "
832  "column nor a constructed point\n"
833  << func_oper->toString();
834 
835  return boost::none;
836  }
837 
838  // Read the bounds arg from the ST_Contains FuncOper (second argument)instead of the
839  // poly column (first argument)
840  auto rhs = func_oper->getOwnArg(1);
841  auto rewritten_rhs = deep_copy_visitor.visit(rhs.get());
842  CHECK(rewritten_rhs);
843 
844  // Check for compatible join ordering. If the join ordering does not match expected
845  // ordering for overlaps, the join builder will fail.
846  std::set<int> lhs_rte_idx;
847  lhs->collect_rte_idx(lhs_rte_idx);
848  CHECK(!lhs_rte_idx.empty());
849  std::set<int> rhs_rte_idx;
850  rhs->collect_rte_idx(rhs_rte_idx);
851  CHECK(!rhs_rte_idx.empty());
852 
853  if (lhs_rte_idx.size() > 1 || rhs_rte_idx.size() > 1 || lhs_rte_idx > rhs_rte_idx) {
854  LOG(INFO) << "Unable to rewrite " << func_oper->getName()
855  << " to overlaps conjunction. Cannot build hash table over LHS type. "
856  "Check join order.\n"
857  << func_oper->toString();
858  return boost::none;
859  }
860 
861  VLOG(1) << "Rewritten to use overlaps join with lhs as "
862  << rewritten_lhs->toString() << " and rhs as " << rewritten_rhs->toString();
863 
864  auto overlaps_oper = makeExpr<Analyzer::BinOper>(
865  kBOOLEAN, kOVERLAPS, kONE, rewritten_lhs, rewritten_rhs);
866 
867  VLOG(1) << "Successfully converted to overlaps join";
868  if (func_oper->getName() == "ST_Approx_Overlaps_MultiPolygon_Point"sv) {
869  return OverlapsJoinConjunction{{}, {overlaps_oper}};
870  } else {
871  return OverlapsJoinConjunction{{expr}, {overlaps_oper}};
872  }
873  } else {
874  VLOG(1) << "Overlaps join not enabled for " << func_oper->getName();
875  }
876  }
877  return boost::none;
878 }
879 
889  public:
891  for (const auto& join_condition : join_quals) {
892  for (const auto& qual : join_condition.quals) {
893  auto qual_bin_oper = dynamic_cast<Analyzer::BinOper*>(qual.get());
894  if (qual_bin_oper) {
895  join_qual_pairs.emplace_back(qual_bin_oper->get_left_operand(),
896  qual_bin_oper->get_right_operand());
897  }
898  }
899  }
900  }
901 
902  bool visitFunctionOper(const Analyzer::FunctionOper* func_oper) const override {
903  if (overlaps_supported_functions.find(func_oper->getName()) !=
905  const auto lhs = func_oper->getArg(2);
906  const auto rhs = func_oper->getArg(1);
907  for (const auto& qual_pair : join_qual_pairs) {
908  if (*lhs == *qual_pair.first && *rhs == *qual_pair.second) {
909  return true;
910  }
911  }
912  }
913  return false;
914  }
915 
916  bool defaultResult() const override { return false; }
917 
918  private:
919  std::vector<std::pair<const Analyzer::Expr*, const Analyzer::Expr*>> join_qual_pairs;
920 };
921 
922 std::list<std::shared_ptr<Analyzer::Expr>> strip_join_covered_filter_quals(
923  const std::list<std::shared_ptr<Analyzer::Expr>>& quals,
924  const JoinQualsPerNestingLevel& join_quals) {
926  return quals;
927  }
928 
929  if (join_quals.empty()) {
930  return quals;
931  }
932 
933  std::list<std::shared_ptr<Analyzer::Expr>> quals_to_return;
934 
935  JoinCoveredQualVisitor visitor(join_quals);
936  for (const auto& qual : quals) {
937  if (!visitor.visit(qual.get())) {
938  // Not a covered qual, don't elide it from the filtered count
939  quals_to_return.push_back(qual);
940  }
941  }
942 
943  return quals_to_return;
944 }
945 
946 std::shared_ptr<Analyzer::Expr> fold_expr(const Analyzer::Expr* expr) {
947  if (!expr) {
948  return nullptr;
949  }
950  const auto expr_no_likelihood = strip_likelihood(expr);
951  ConstantFoldingVisitor visitor;
952  auto rewritten_expr = visitor.visit(expr_no_likelihood);
953  if (visitor.get_num_overflows() > 0 && rewritten_expr->get_type_info().is_integer() &&
954  rewritten_expr->get_type_info().get_type() != kBIGINT) {
955  auto rewritten_expr_const =
956  std::dynamic_pointer_cast<const Analyzer::Constant>(rewritten_expr);
957  if (!rewritten_expr_const) {
958  // Integer expression didn't fold completely the first time due to
959  // overflows in smaller type subexpressions, trying again with a cast
960  const auto& ti = SQLTypeInfo(kBIGINT, false);
961  auto bigint_expr_no_likelihood = expr_no_likelihood->deep_copy()->add_cast(ti);
962  auto rewritten_expr_take2 = visitor.visit(bigint_expr_no_likelihood.get());
963  auto rewritten_expr_take2_const =
964  std::dynamic_pointer_cast<Analyzer::Constant>(rewritten_expr_take2);
965  if (rewritten_expr_take2_const) {
966  // Managed to fold, switch to the new constant
967  rewritten_expr = rewritten_expr_take2_const;
968  }
969  }
970  }
971  const auto expr_with_likelihood = dynamic_cast<const Analyzer::LikelihoodExpr*>(expr);
972  if (expr_with_likelihood) {
973  // Add back likelihood
974  return std::make_shared<Analyzer::LikelihoodExpr>(
975  rewritten_expr, expr_with_likelihood->get_likelihood());
976  }
977  return rewritten_expr;
978 }
979 
981  const Analyzer::ColumnVar* val_side,
982  const int max_rte_covered) {
983  if (key_side->get_table_id() == val_side->get_table_id() &&
984  key_side->get_rte_idx() == val_side->get_rte_idx() &&
985  key_side->get_rte_idx() > max_rte_covered) {
986  return true;
987  }
988  return false;
989 }
990 
992  std::unordered_map<int, llvm::Value*>& scan_idx_to_hash_pos) {
993  int ret = INT32_MIN;
994  for (auto& kv : scan_idx_to_hash_pos) {
995  if (kv.first > ret) {
996  ret = kv.first;
997  }
998  }
999  return ret;
1000 }
int8_t tinyintval
Definition: sqltypes.h:206
int get_table_id() const
Definition: Analyzer.h:194
Analyzer::ExpressionPtr rewrite_array_elements(Analyzer::Expr const *expr)
std::string to_lower(const std::string &str)
#define CHECK_EQ(x, y)
Definition: Logger.h:211
#define IS_LOGIC(X)
Definition: sqldefs.h:60
std::shared_ptr< Analyzer::InValues > visitLikeExpr(const Analyzer::LikeExpr *) const override
std::shared_ptr< Analyzer::InValues > visitDateaddExpr(const Analyzer::DateaddExpr *) const override
const std::shared_ptr< Analyzer::Expr > get_own_arg() const
Definition: Analyzer.h:797
bool self_join_not_covered_by_left_deep_tree(const Analyzer::ColumnVar *key_side, const Analyzer::ColumnVar *val_side, const int max_rte_covered)
SQLTypes
Definition: sqltypes.h:37
std::shared_ptr< Analyzer::WindowFunction > rewrite_avg_window(const Analyzer::Expr *expr)
bool g_strip_join_covered_quals
Definition: Execute.cpp:100
std::shared_ptr< Analyzer::InValues > visitKeyForString(const Analyzer::KeyForStringExpr *) const override
tuple d
Definition: test_fsi.py:9
std::shared_ptr< Analyzer::WindowFunction > rewrite_sum_window(const Analyzer::Expr *expr)
std::shared_ptr< Analyzer::InValues > visitBinOper(const Analyzer::BinOper *bin_oper) const override
#define LOG(tag)
Definition: Logger.h:194
bool boolval
Definition: sqltypes.h:205
const Expr * get_right_operand() const
Definition: Analyzer.h:443
bool is_constructed_point(const Analyzer::Expr *expr)
Definition: Execute.h:1176
SQLOps
Definition: sqldefs.h:29
static const std::unordered_set< std::string > requires_many_to_many
Definition: sqldefs.h:35
Definition: sqldefs.h:36
Definition: sqldefs.h:38
#define CHECK_GE(x, y)
Definition: Logger.h:216
bool get_contains_agg() const
Definition: Analyzer.h:80
Definition: sqldefs.h:49
Definition: sqldefs.h:30
std::shared_ptr< Analyzer::Expr > ExpressionPtr
Definition: Analyzer.h:180
const Analyzer::Expr * extract_cast_arg(const Analyzer::Expr *expr)
Definition: Execute.h:202
Definition: sqldefs.h:41
std::vector< JoinCondition > JoinQualsPerNestingLevel
T visit(const Analyzer::Expr *expr) const
Analyzer::ExpressionPtr rewrite_expr(const Analyzer::Expr *expr)
static const std::unordered_set< std::string > overlaps_supported_functions
bool isNull() const
Definition: Analyzer.h:1490
std::shared_ptr< Analyzer::InValues > aggregateResult(const std::shared_ptr< Analyzer::InValues > &lhs, const std::shared_ptr< Analyzer::InValues > &rhs) const override
std::list< std::shared_ptr< Analyzer::Expr > > strip_join_covered_filter_quals(const std::list< std::shared_ptr< Analyzer::Expr >> &quals, const JoinQualsPerNestingLevel &join_quals)
std::shared_ptr< Analyzer::InValues > visitUOper(const Analyzer::UOper *uoper) const override
const int get_max_rte_scan_table(std::unordered_map< int, llvm::Value * > &scan_idx_to_hash_pos)
std::shared_ptr< Analyzer::Expr > visitBinOper(const Analyzer::BinOper *bin_oper) const override
std::shared_ptr< Analyzer::Expr > RetType
int32_t intval
Definition: sqltypes.h:208
std::shared_ptr< Analyzer::InValues > visitCharLength(const Analyzer::CharLengthExpr *) const override
std::shared_ptr< Analyzer::InValues > visitInIntegerSet(const Analyzer::InIntegerSet *) const override
SQLOps get_optype() const
Definition: Analyzer.h:439
float floatval
Definition: sqltypes.h:210
std::shared_ptr< Analyzer::Expr > visitUOper(const Analyzer::UOper *uoper) const override
#define INT32_MIN
bool g_enable_hashjoin_many_to_many
Definition: Execute.cpp:97
std::shared_ptr< Analyzer::InValues > visitRegexpExpr(const Analyzer::RegexpExpr *) const override
int64_t bigintval
Definition: sqltypes.h:209
std::shared_ptr< Analyzer::InValues > visitAggExpr(const Analyzer::AggExpr *) const override
bool visitFunctionOper(const Analyzer::FunctionOper *func_oper) const override
Definition: sqldefs.h:37
bool foldOper(SQLOps optype, SQLTypes type, Datum lhs, Datum rhs, Datum &result, SQLTypes &result_type) const
int16_t smallintval
Definition: sqltypes.h:207
bool defaultResult() const override
static std::shared_ptr< Analyzer::Expr > analyzeValue(const std::string &)
Definition: ParserNode.cpp:118
boost::optional< OverlapsJoinConjunction > rewrite_overlaps_conjunction(const std::shared_ptr< Analyzer::Expr > expr)
RetType visitArrayOper(const Analyzer::ArrayExpr *array_expr) const override
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:78
std::shared_ptr< Analyzer::Expr > visitBinOper(const Analyzer::BinOper *bin_oper) const override
size_t getElementCount() const
Definition: Analyzer.h:1488
SQLTypes decimal_to_int_type(const SQLTypeInfo &ti)
Definition: Datum.cpp:419
Definition: sqldefs.h:34
Definition: sqldefs.h:40
std::vector< std::pair< const Analyzer::Expr *, const Analyzer::Expr * > > join_qual_pairs
Definition: sqldefs.h:69
#define TRANSIENT_DICT_ID
Definition: sqltypes.h:253
bool isLocalAlloc() const
Definition: Analyzer.h:1489
const Analyzer::Expr * getArg(const size_t i) const
Definition: Analyzer.h:1362
#define IS_ARITHMETIC(X)
Definition: sqldefs.h:61
int get_rte_idx() const
Definition: Analyzer.h:196
const Expr * get_operand() const
Definition: Analyzer.h:371
Datum get_constval() const
Definition: Analyzer.h:335
Definition: sqldefs.h:32
Expression class for the LOWER (lowercase) string function. The &quot;arg&quot; constructor parameter must be a...
Definition: Analyzer.h:791
std::unordered_map< const Analyzer::Expr *, const SQLTypeInfo > casts_
std::shared_ptr< Analyzer::InValues > visitDatediffExpr(const Analyzer::DatediffExpr *) const override
std::shared_ptr< Analyzer::InValues > visitLikelihood(const Analyzer::LikelihoodExpr *) const override
std::shared_ptr< Analyzer::Expr > visitLower(const Analyzer::LowerExpr *lower_expr) const override
std::shared_ptr< Analyzer::InValues > visitInValues(const Analyzer::InValues *) const override
#define CHECK(condition)
Definition: Logger.h:203
Definition: sqldefs.h:33
char * f
const Expr * get_left_operand() const
Definition: Analyzer.h:442
std::shared_ptr< Analyzer::InValues > visitCaseExpr(const Analyzer::CaseExpr *) const override
Definition: sqltypes.h:44
JoinCoveredQualVisitor(const JoinQualsPerNestingLevel &join_quals)
const std::shared_ptr< Analyzer::Expr > get_own_right_operand() const
Definition: Analyzer.h:447
std::string getName() const
Definition: Analyzer.h:1358
bool is_decimal() const
Definition: sqltypes.h:492
std::shared_ptr< Analyzer::InValues > visitDatetruncExpr(const Analyzer::DatetruncExpr *) const override
Definition: sqldefs.h:39
const std::shared_ptr< Analyzer::Expr > get_own_left_operand() const
Definition: Analyzer.h:444
SQLOps get_optype() const
Definition: Analyzer.h:370
#define VLOG(n)
Definition: Logger.h:297
const Expr * get_arg() const
Definition: Analyzer.h:795
std::shared_ptr< Analyzer::Expr > fold_expr(const Analyzer::Expr *expr)
#define IS_COMPARISON(X)
Definition: sqldefs.h:57
double doubleval
Definition: sqltypes.h:211
std::shared_ptr< Analyzer::InValues > visitExtractExpr(const Analyzer::ExtractExpr *) const override
const Analyzer::Expr * getElement(const size_t i) const
Definition: Analyzer.h:1492
SQLQualifier get_qualifier() const
Definition: Analyzer.h:441
const Analyzer::Expr * strip_likelihood(const Analyzer::Expr *expr)
std::shared_ptr< Analyzer::InValues > visitSampleRatio(const Analyzer::SampleRatioExpr *) const override
std::shared_ptr< Analyzer::InValues > visitCardinality(const Analyzer::CardinalityExpr *) const override