OmniSciDB  ab4938a6a3
RelAlgDagBuilder.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RelAlgDagBuilder.h"
18 #include "../Shared/sqldefs.h"
20 #include "Catalog/Catalog.h"
22 #include "JsonAccessors.h"
23 #include "RelAlgOptimizer.h"
24 #include "RelLeftDeepInnerJoin.h"
26 #include "RexVisitor.h"
27 
28 #include <rapidjson/error/en.h>
29 #include <rapidjson/error/error.h>
30 #include <rapidjson/stringbuffer.h>
31 #include <rapidjson/writer.h>
32 
33 #include <string>
34 #include <unordered_set>
35 
36 extern bool g_cluster;
37 extern bool g_enable_union;
38 
39 namespace {
40 
41 const unsigned FIRST_RA_NODE_ID = 1;
42 
43 } // namespace
44 
45 thread_local unsigned RelAlgNode::crt_id_ = FIRST_RA_NODE_ID;
46 
49 }
50 
52  const std::shared_ptr<const ExecutionResult> result) {
53  auto row_set = result->getRows();
54  CHECK(row_set);
55  CHECK_EQ(size_t(1), row_set->colCount());
56  *(type_.get()) = row_set->getColType(0);
57  (*(result_.get())) = result;
58 }
59 
60 std::unique_ptr<RexSubQuery> RexSubQuery::deepCopy() const {
61  return std::make_unique<RexSubQuery>(type_, result_, ra_->deepCopy());
62 }
63 
64 namespace {
65 
66 class RexRebindInputsVisitor : public RexVisitor<void*> {
67  public:
68  RexRebindInputsVisitor(const RelAlgNode* old_input, const RelAlgNode* new_input)
69  : old_input_(old_input), new_input_(new_input) {}
70 
71  virtual ~RexRebindInputsVisitor() = default;
72 
73  void* visitInput(const RexInput* rex_input) const override {
74  const auto old_source = rex_input->getSourceNode();
75  if (old_source == old_input_) {
76  const auto left_deep_join = dynamic_cast<const RelLeftDeepInnerJoin*>(new_input_);
77  if (left_deep_join) {
78  rebind_inputs_from_left_deep_join(rex_input, left_deep_join);
79  return nullptr;
80  }
81  rex_input->setSourceNode(new_input_);
82  }
83  return nullptr;
84  };
85 
86  private:
87  const RelAlgNode* old_input_;
89 };
90 
91 // Creates an output with n columns.
92 std::vector<RexInput> n_outputs(const RelAlgNode* node, const size_t n) {
93  std::vector<RexInput> outputs;
94  outputs.reserve(n);
95  for (size_t i = 0; i < n; ++i) {
96  outputs.emplace_back(node, i);
97  }
98  return outputs;
99 }
100 
102  public:
104  const RelAlgNode* old_input,
105  const RelAlgNode* new_input,
106  std::unordered_map<unsigned, unsigned> old_to_new_index_map)
107  : RexRebindInputsVisitor(old_input, new_input), mapping_(old_to_new_index_map) {}
108 
109  void* visitInput(const RexInput* rex_input) const override {
110  RexRebindInputsVisitor::visitInput(rex_input);
111  auto mapping_itr = mapping_.find(rex_input->getIndex());
112  CHECK(mapping_itr != mapping_.end());
113  rex_input->setIndex(mapping_itr->second);
114  return nullptr;
115  }
116 
117  private:
118  const std::unordered_map<unsigned, unsigned> mapping_;
119 };
120 
121 } // namespace
122 
124  std::shared_ptr<const RelAlgNode> old_input,
125  std::shared_ptr<const RelAlgNode> input,
126  std::optional<std::unordered_map<unsigned, unsigned>> old_to_new_index_map) {
127  RelAlgNode::replaceInput(old_input, input);
128  std::unique_ptr<RexRebindInputsVisitor> rebind_inputs;
129  if (old_to_new_index_map) {
130  rebind_inputs = std::make_unique<RexRebindReindexInputsVisitor>(
131  old_input.get(), input.get(), *old_to_new_index_map);
132  } else {
133  rebind_inputs =
134  std::make_unique<RexRebindInputsVisitor>(old_input.get(), input.get());
135  }
136  CHECK(rebind_inputs);
137  for (const auto& scalar_expr : scalar_exprs_) {
138  rebind_inputs->visit(scalar_expr.get());
139  }
140 }
141 
142 void RelProject::appendInput(std::string new_field_name,
143  std::unique_ptr<const RexScalar> new_input) {
144  fields_.emplace_back(std::move(new_field_name));
145  scalar_exprs_.emplace_back(std::move(new_input));
146 }
147 
149  const auto scan_node = dynamic_cast<const RelScan*>(ra_node);
150  if (scan_node) {
151  // Scan node has no inputs, output contains all columns in the table.
152  CHECK_EQ(size_t(0), scan_node->inputCount());
153  return n_outputs(scan_node, scan_node->size());
154  }
155  const auto project_node = dynamic_cast<const RelProject*>(ra_node);
156  if (project_node) {
157  // Project output count doesn't depend on the input
158  CHECK_EQ(size_t(1), project_node->inputCount());
159  return n_outputs(project_node, project_node->size());
160  }
161  const auto filter_node = dynamic_cast<const RelFilter*>(ra_node);
162  if (filter_node) {
163  // Filter preserves shape
164  CHECK_EQ(size_t(1), filter_node->inputCount());
165  const auto prev_out = get_node_output(filter_node->getInput(0));
166  return n_outputs(filter_node, prev_out.size());
167  }
168  const auto aggregate_node = dynamic_cast<const RelAggregate*>(ra_node);
169  if (aggregate_node) {
170  // Aggregate output count doesn't depend on the input
171  CHECK_EQ(size_t(1), aggregate_node->inputCount());
172  return n_outputs(aggregate_node, aggregate_node->size());
173  }
174  const auto compound_node = dynamic_cast<const RelCompound*>(ra_node);
175  if (compound_node) {
176  // Compound output count doesn't depend on the input
177  CHECK_EQ(size_t(1), compound_node->inputCount());
178  return n_outputs(compound_node, compound_node->size());
179  }
180  const auto join_node = dynamic_cast<const RelJoin*>(ra_node);
181  if (join_node) {
182  // Join concatenates the outputs from the inputs and the output
183  // directly references the nodes in the input.
184  CHECK_EQ(size_t(2), join_node->inputCount());
185  auto lhs_out =
186  n_outputs(join_node->getInput(0), get_node_output(join_node->getInput(0)).size());
187  const auto rhs_out =
188  n_outputs(join_node->getInput(1), get_node_output(join_node->getInput(1)).size());
189  lhs_out.insert(lhs_out.end(), rhs_out.begin(), rhs_out.end());
190  return lhs_out;
191  }
192  const auto table_func_node = dynamic_cast<const RelTableFunction*>(ra_node);
193  if (table_func_node) {
194  // Table Function output count doesn't depend on the input
195  CHECK_EQ(size_t(1), table_func_node->inputCount());
196  return n_outputs(table_func_node, table_func_node->size());
197  }
198  const auto sort_node = dynamic_cast<const RelSort*>(ra_node);
199  if (sort_node) {
200  // Sort preserves shape
201  CHECK_EQ(size_t(1), sort_node->inputCount());
202  const auto prev_out = get_node_output(sort_node->getInput(0));
203  return n_outputs(sort_node, prev_out.size());
204  }
205  const auto logical_values_node = dynamic_cast<const RelLogicalValues*>(ra_node);
206  if (logical_values_node) {
207  CHECK_EQ(size_t(0), logical_values_node->inputCount());
208  return n_outputs(logical_values_node, logical_values_node->size());
209  }
210  const auto logical_union_node = dynamic_cast<const RelLogicalUnion*>(ra_node);
211  if (logical_union_node) {
212  return n_outputs(logical_union_node, logical_union_node->size());
213  }
214  LOG(FATAL) << "Unhandled ra_node type: " << ra_node->toString();
215  return {};
216 }
217 
219  if (!isSimple()) {
220  return false;
221  }
222  CHECK_EQ(size_t(1), inputCount());
223  const auto source = getInput(0);
224  if (dynamic_cast<const RelJoin*>(source)) {
225  return false;
226  }
227  const auto source_shape = get_node_output(source);
228  if (source_shape.size() != scalar_exprs_.size()) {
229  return false;
230  }
231  for (size_t i = 0; i < scalar_exprs_.size(); ++i) {
232  const auto& scalar_expr = scalar_exprs_[i];
233  const auto input = dynamic_cast<const RexInput*>(scalar_expr.get());
234  CHECK(input);
235  CHECK_EQ(source, input->getSourceNode());
236  // We should add the additional check that input->getIndex() !=
237  // source_shape[i].getIndex(), but Calcite doesn't generate the right
238  // Sort-Project-Sort sequence when joins are involved.
239  if (input->getSourceNode() != source_shape[i].getSourceNode()) {
240  return false;
241  }
242  }
243  return true;
244 }
245 
246 namespace {
247 
248 bool isRenamedInput(const RelAlgNode* node,
249  const size_t index,
250  const std::string& new_name) {
251  CHECK_LT(index, node->size());
252  if (auto join = dynamic_cast<const RelJoin*>(node)) {
253  CHECK_EQ(size_t(2), join->inputCount());
254  const auto lhs_size = join->getInput(0)->size();
255  if (index < lhs_size) {
256  return isRenamedInput(join->getInput(0), index, new_name);
257  }
258  CHECK_GE(index, lhs_size);
259  return isRenamedInput(join->getInput(1), index - lhs_size, new_name);
260  }
261 
262  if (auto scan = dynamic_cast<const RelScan*>(node)) {
263  return new_name != scan->getFieldName(index);
264  }
265 
266  if (auto aggregate = dynamic_cast<const RelAggregate*>(node)) {
267  return new_name != aggregate->getFieldName(index);
268  }
269 
270  if (auto project = dynamic_cast<const RelProject*>(node)) {
271  return new_name != project->getFieldName(index);
272  }
273 
274  if (auto table_func = dynamic_cast<const RelTableFunction*>(node)) {
275  return new_name != table_func->getFieldName(index);
276  }
277 
278  if (auto logical_values = dynamic_cast<const RelLogicalValues*>(node)) {
279  const auto& tuple_type = logical_values->getTupleType();
280  CHECK_LT(index, tuple_type.size());
281  return new_name != tuple_type[index].get_resname();
282  }
283 
284  CHECK(dynamic_cast<const RelSort*>(node) || dynamic_cast<const RelFilter*>(node) ||
285  dynamic_cast<const RelLogicalUnion*>(node));
286  return isRenamedInput(node->getInput(0), index, new_name);
287 }
288 
289 } // namespace
290 
292  if (!isSimple()) {
293  return false;
294  }
295  CHECK_EQ(scalar_exprs_.size(), fields_.size());
296  for (size_t i = 0; i < fields_.size(); ++i) {
297  auto rex_in = dynamic_cast<const RexInput*>(scalar_exprs_[i].get());
298  CHECK(rex_in);
299  if (isRenamedInput(rex_in->getSourceNode(), rex_in->getIndex(), fields_[i])) {
300  return true;
301  }
302  }
303  return false;
304 }
305 
306 void RelJoin::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
307  std::shared_ptr<const RelAlgNode> input) {
308  RelAlgNode::replaceInput(old_input, input);
309  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
310  if (condition_) {
311  rebind_inputs.visit(condition_.get());
312  }
313 }
314 
315 void RelFilter::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
316  std::shared_ptr<const RelAlgNode> input) {
317  RelAlgNode::replaceInput(old_input, input);
318  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
319  rebind_inputs.visit(filter_.get());
320 }
321 
322 void RelCompound::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
323  std::shared_ptr<const RelAlgNode> input) {
324  RelAlgNode::replaceInput(old_input, input);
325  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
326  for (const auto& scalar_source : scalar_sources_) {
327  rebind_inputs.visit(scalar_source.get());
328  }
329  if (filter_expr_) {
330  rebind_inputs.visit(filter_expr_.get());
331  }
332 }
333 
334 std::shared_ptr<RelAlgNode> RelProject::deepCopy() const {
335  RexDeepCopyVisitor copier;
336  std::vector<std::unique_ptr<const RexScalar>> exprs_copy;
337  for (auto& expr : scalar_exprs_) {
338  exprs_copy.push_back(copier.visit(expr.get()));
339  }
340  return std::make_shared<RelProject>(exprs_copy, fields_, inputs_[0]);
341 }
342 
343 std::shared_ptr<RelAlgNode> RelLogicalValues::deepCopy() const {
344  RexDeepCopyVisitor copier;
345  std::vector<RelLogicalValues::RowValues> values_copy;
346  for (auto& row : values_) {
347  values_copy.emplace_back(RelLogicalValues::RowValues{});
348  for (auto& value : row) {
349  values_copy.back().push_back(copier.visit(value.get()));
350  }
351  }
352  return std::make_shared<RelLogicalValues>(tuple_type_, values_copy);
353 }
354 
355 std::shared_ptr<RelAlgNode> RelFilter::deepCopy() const {
356  RexDeepCopyVisitor copier;
357  auto filter_copy = copier.visit(filter_.get());
358  return std::make_shared<RelFilter>(filter_copy, inputs_[0]);
359 }
360 
361 std::shared_ptr<RelAlgNode> RelAggregate::deepCopy() const {
362  std::vector<std::unique_ptr<const RexAgg>> aggs_copy;
363  for (auto& agg : agg_exprs_) {
364  auto copy = agg->deepCopy();
365  aggs_copy.push_back(std::move(copy));
366  }
367  return std::make_shared<RelAggregate>(groupby_count_, aggs_copy, fields_, inputs_[0]);
368 }
369 
370 std::shared_ptr<RelAlgNode> RelJoin::deepCopy() const {
371  RexDeepCopyVisitor copier;
372  auto condition_copy = copier.visit(condition_.get());
373  return std::make_shared<RelJoin>(inputs_[0], inputs_[1], condition_copy, join_type_);
374 }
375 
376 std::shared_ptr<RelAlgNode> RelCompound::deepCopy() const {
377  RexDeepCopyVisitor copier;
378  auto filter_copy = filter_expr_ ? copier.visit(filter_expr_.get()) : nullptr;
379  std::unordered_map<const Rex*, const Rex*> old_to_new_target;
380  std::vector<const RexAgg*> aggs_copy;
381  for (auto& agg : agg_exprs_) {
382  auto copy = agg->deepCopy();
383  old_to_new_target.insert(std::make_pair(agg.get(), copy.get()));
384  aggs_copy.push_back(copy.release());
385  }
386  std::vector<std::unique_ptr<const RexScalar>> sources_copy;
387  for (size_t i = 0; i < scalar_sources_.size(); ++i) {
388  auto copy = copier.visit(scalar_sources_[i].get());
389  old_to_new_target.insert(std::make_pair(scalar_sources_[i].get(), copy.get()));
390  sources_copy.push_back(std::move(copy));
391  }
392  std::vector<const Rex*> target_exprs_copy;
393  for (auto target : target_exprs_) {
394  auto target_it = old_to_new_target.find(target);
395  CHECK(target_it != old_to_new_target.end());
396  target_exprs_copy.push_back(target_it->second);
397  }
398  auto new_compound = std::make_shared<RelCompound>(filter_copy,
399  target_exprs_copy,
400  groupby_count_,
401  aggs_copy,
402  fields_,
403  sources_copy,
404  is_agg_);
405  new_compound->addManagedInput(inputs_[0]);
406  return new_compound;
407 }
408 
409 std::shared_ptr<RelAlgNode> RelSort::deepCopy() const {
410  auto ret = std::make_shared<RelSort>(collation_, limit_, offset_, inputs_[0]);
411  ret->setEmptyResult(isEmptyResult());
412  return ret;
413 }
414 
415 void RelTableFunction::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
416  std::shared_ptr<const RelAlgNode> input) {
417  RelAlgNode::replaceInput(old_input, input);
418  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
419  for (const auto& target_expr : target_exprs_) {
420  rebind_inputs.visit(target_expr.get());
421  }
422  for (const auto& func_input : table_func_inputs_) {
423  rebind_inputs.visit(func_input.get());
424  }
425 }
426 
427 std::shared_ptr<RelAlgNode> RelTableFunction::deepCopy() const {
428  RexDeepCopyVisitor copier;
429 
430  std::unordered_map<const Rex*, const Rex*> old_to_new_input;
431 
432  std::vector<std::unique_ptr<const RexScalar>> table_func_inputs_copy;
433  for (auto& expr : table_func_inputs_) {
434  table_func_inputs_copy.push_back(copier.visit(expr.get()));
435  old_to_new_input.insert(
436  std::make_pair(expr.get(), table_func_inputs_copy.back().get()));
437  }
438 
439  std::vector<const Rex*> col_inputs_copy;
440  for (auto target : col_inputs_) {
441  auto target_it = old_to_new_input.find(target);
442  CHECK(target_it != old_to_new_input.end());
443  col_inputs_copy.push_back(target_it->second);
444  }
445  auto fields_copy = fields_;
446 
447  std::vector<std::unique_ptr<const RexScalar>> target_exprs_copy;
448  for (auto& expr : target_exprs_) {
449  target_exprs_copy.push_back(copier.visit(expr.get()));
450  }
451 
452  return std::make_shared<RelTableFunction>(function_name_,
453  inputs_[0],
454  fields_copy,
455  col_inputs_copy,
456  table_func_inputs_copy,
457  target_exprs_copy);
458 }
459 
460 namespace std {
461 template <>
462 struct hash<std::pair<const RelAlgNode*, int>> {
463  size_t operator()(const std::pair<const RelAlgNode*, int>& input_col) const {
464  auto ptr_val = reinterpret_cast<const int64_t*>(&input_col.first);
465  return static_cast<int64_t>(*ptr_val) ^ input_col.second;
466  }
467 };
468 } // namespace std
469 
470 namespace {
471 
472 std::set<std::pair<const RelAlgNode*, int>> get_equiv_cols(const RelAlgNode* node,
473  const size_t which_col) {
474  std::set<std::pair<const RelAlgNode*, int>> work_set;
475  auto walker = node;
476  auto curr_col = which_col;
477  while (true) {
478  work_set.insert(std::make_pair(walker, curr_col));
479  if (dynamic_cast<const RelScan*>(walker) || dynamic_cast<const RelJoin*>(walker)) {
480  break;
481  }
482  CHECK_EQ(size_t(1), walker->inputCount());
483  auto only_source = walker->getInput(0);
484  if (auto project = dynamic_cast<const RelProject*>(walker)) {
485  if (auto input = dynamic_cast<const RexInput*>(project->getProjectAt(curr_col))) {
486  const auto join_source = dynamic_cast<const RelJoin*>(only_source);
487  if (join_source) {
488  CHECK_EQ(size_t(2), join_source->inputCount());
489  auto lhs = join_source->getInput(0);
490  CHECK((input->getIndex() < lhs->size() && lhs == input->getSourceNode()) ||
491  join_source->getInput(1) == input->getSourceNode());
492  } else {
493  CHECK_EQ(input->getSourceNode(), only_source);
494  }
495  curr_col = input->getIndex();
496  } else {
497  break;
498  }
499  } else if (auto aggregate = dynamic_cast<const RelAggregate*>(walker)) {
500  if (curr_col >= aggregate->getGroupByCount()) {
501  break;
502  }
503  }
504  walker = only_source;
505  }
506  return work_set;
507 }
508 
509 } // namespace
510 
511 bool RelSort::hasEquivCollationOf(const RelSort& that) const {
512  if (collation_.size() != that.collation_.size()) {
513  return false;
514  }
515 
516  for (size_t i = 0, e = collation_.size(); i < e; ++i) {
517  auto this_sort_key = collation_[i];
518  auto that_sort_key = that.collation_[i];
519  if (this_sort_key.getSortDir() != that_sort_key.getSortDir()) {
520  return false;
521  }
522  if (this_sort_key.getNullsPosition() != that_sort_key.getNullsPosition()) {
523  return false;
524  }
525  auto this_equiv_keys = get_equiv_cols(this, this_sort_key.getField());
526  auto that_equiv_keys = get_equiv_cols(&that, that_sort_key.getField());
527  std::vector<std::pair<const RelAlgNode*, int>> intersect;
528  std::set_intersection(this_equiv_keys.begin(),
529  this_equiv_keys.end(),
530  that_equiv_keys.begin(),
531  that_equiv_keys.end(),
532  std::back_inserter(intersect));
533  if (intersect.empty()) {
534  return false;
535  }
536  }
537  return true;
538 }
539 
540 // class RelLogicalUnion methods
541 
543  : RelAlgNode(std::move(inputs)), is_all_(is_all) {
544  if (!g_enable_union) {
545  throw QueryNotSupported(
546  "UNION is not supported yet. There is an experimental enable-union option "
547  "available to enable UNION ALL queries.");
548  }
549  CHECK_LE(2u, inputs_.size());
550  if (!is_all_) {
551  throw QueryNotSupported("UNION without ALL is not supported yet.");
552  }
553 }
554 
555 std::shared_ptr<RelAlgNode> RelLogicalUnion::deepCopy() const {
556  return std::make_shared<RelLogicalUnion>(*this);
557 }
558 
559 size_t RelLogicalUnion::size() const {
560  return inputs_.at(0)->size();
561 }
562 
563 std::string RelLogicalUnion::toString() const {
564  return cat("(RelLogicalUnion<", this, ">(is_all(", is_all_, ")))");
565 }
566 
567 std::string RelLogicalUnion::getFieldName(const size_t i) const {
568  if (auto const* input = dynamic_cast<RelCompound const*>(inputs_[0].get())) {
569  return input->getFieldName(i);
570  } else if (auto const* input = dynamic_cast<RelProject const*>(inputs_[0].get())) {
571  return input->getFieldName(i);
572  } else if (auto const* input = dynamic_cast<RelLogicalUnion const*>(inputs_[0].get())) {
573  return input->getFieldName(i);
574  } else if (auto const* input = dynamic_cast<RelAggregate const*>(inputs_[0].get())) {
575  return input->getFieldName(i);
576  } else if (auto const* input = dynamic_cast<RelScan const*>(inputs_[0].get())) {
577  return input->getFieldName(i);
578  } else if (auto const* input =
579  dynamic_cast<RelTableFunction const*>(inputs_[0].get())) {
580  return input->getFieldName(i);
581  }
582  UNREACHABLE() << "Unhandled input type: " << inputs_.front()->toString();
583  return {};
584 }
585 
587  std::vector<TargetMetaInfo> const& tmis0 = inputs_[0]->getOutputMetainfo();
588  std::vector<TargetMetaInfo> const& tmis1 = inputs_[1]->getOutputMetainfo();
589  if (tmis0.size() != tmis1.size()) {
590  VLOG(2) << "tmis0.size() = " << tmis0.size() << " != " << tmis1.size()
591  << " = tmis1.size()";
592  return false;
593  }
594  for (size_t i = 0; i < tmis0.size(); ++i) {
595  if (tmis0[i].get_type_info() != tmis1[i].get_type_info()) {
596  VLOG(2) << "Types do not match for UNION:\n tmis0[" << i
597  << "].get_type_info().to_string() = "
598  << tmis0[i].get_type_info().to_string() << "\n tmis1[" << i
599  << "].get_type_info().to_string() = "
600  << tmis1[i].get_type_info().to_string();
601  return false;
602  }
603  }
604  return true;
605 }
606 
607 // Rest of code requires a raw pointer, but RexInput object needs to live somewhere.
609  size_t input_idx) const {
610  if (auto const* rex_input_ptr = dynamic_cast<RexInput const*>(rex_scalar)) {
611  RexInput rex_input(*rex_input_ptr);
612  rex_input.setSourceNode(getInput(input_idx));
613  scalar_exprs_.emplace_back(std::make_shared<RexInput const>(std::move(rex_input)));
614  return scalar_exprs_.back().get();
615  }
616  return rex_scalar;
617 }
618 
619 namespace {
620 
621 unsigned node_id(const rapidjson::Value& ra_node) noexcept {
622  const auto& id = field(ra_node, "id");
623  return std::stoi(json_str(id));
624 }
625 
626 std::string json_node_to_string(const rapidjson::Value& node) noexcept {
627  rapidjson::StringBuffer buffer;
628  rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
629  node.Accept(writer);
630  return buffer.GetString();
631 }
632 
633 // The parse_* functions below de-serialize expressions as they come from Calcite.
634 // RelAlgDagBuilder will take care of making the representation easy to
635 // navigate for lower layers, for example by replacing RexAbstractInput with RexInput.
636 
637 std::unique_ptr<RexAbstractInput> parse_abstract_input(
638  const rapidjson::Value& expr) noexcept {
639  const auto& input = field(expr, "input");
640  return std::unique_ptr<RexAbstractInput>(new RexAbstractInput(json_i64(input)));
641 }
642 
643 std::unique_ptr<RexLiteral> parse_literal(const rapidjson::Value& expr) {
644  CHECK(expr.IsObject());
645  const auto& literal = field(expr, "literal");
646  const auto type = to_sql_type(json_str(field(expr, "type")));
647  const auto target_type = to_sql_type(json_str(field(expr, "target_type")));
648  const auto scale = json_i64(field(expr, "scale"));
649  const auto precision = json_i64(field(expr, "precision"));
650  const auto type_scale = json_i64(field(expr, "type_scale"));
651  const auto type_precision = json_i64(field(expr, "type_precision"));
652  if (literal.IsNull()) {
653  return std::unique_ptr<RexLiteral>(new RexLiteral(target_type));
654  }
655  switch (type) {
656  case kDECIMAL:
657  case kINTERVAL_DAY_TIME:
659  case kTIME:
660  case kTIMESTAMP:
661  case kDATE:
662  return std::unique_ptr<RexLiteral>(new RexLiteral(json_i64(literal),
663  type,
664  target_type,
665  scale,
666  precision,
667  type_scale,
668  type_precision));
669  case kDOUBLE: {
670  if (literal.IsDouble()) {
671  return std::unique_ptr<RexLiteral>(new RexLiteral(json_double(literal),
672  type,
673  target_type,
674  scale,
675  precision,
676  type_scale,
677  type_precision));
678  }
679  CHECK(literal.IsInt64());
680  return std::unique_ptr<RexLiteral>(
681  new RexLiteral(static_cast<double>(json_i64(literal)),
682  type,
683  target_type,
684  scale,
685  precision,
686  type_scale,
687  type_precision));
688  }
689  case kTEXT:
690  return std::unique_ptr<RexLiteral>(new RexLiteral(json_str(literal),
691  type,
692  target_type,
693  scale,
694  precision,
695  type_scale,
696  type_precision));
697  case kBOOLEAN:
698  return std::unique_ptr<RexLiteral>(new RexLiteral(json_bool(literal),
699  type,
700  target_type,
701  scale,
702  precision,
703  type_scale,
704  type_precision));
705  case kNULLT:
706  return std::unique_ptr<RexLiteral>(new RexLiteral(target_type));
707  default:
708  CHECK(false);
709  }
710  CHECK(false);
711  return nullptr;
712 }
713 
714 std::unique_ptr<const RexScalar> parse_scalar_expr(const rapidjson::Value& expr,
716  RelAlgDagBuilder& root_dag_builder);
717 
718 SQLTypeInfo parse_type(const rapidjson::Value& type_obj) {
719  if (type_obj.IsArray()) {
720  throw QueryNotSupported("Composite types are not currently supported.");
721  }
722  CHECK(type_obj.IsObject() && type_obj.MemberCount() >= 2)
723  << json_node_to_string(type_obj);
724  const auto type = to_sql_type(json_str(field(type_obj, "type")));
725  const auto nullable = json_bool(field(type_obj, "nullable"));
726  const auto precision_it = type_obj.FindMember("precision");
727  const int precision =
728  precision_it != type_obj.MemberEnd() ? json_i64(precision_it->value) : 0;
729  const auto scale_it = type_obj.FindMember("scale");
730  const int scale = scale_it != type_obj.MemberEnd() ? json_i64(scale_it->value) : 0;
731  SQLTypeInfo ti(type, !nullable);
732  ti.set_precision(precision);
733  ti.set_scale(scale);
734  return ti;
735 }
736 
737 std::vector<std::unique_ptr<const RexScalar>> parse_expr_array(
738  const rapidjson::Value& arr,
739  const Catalog_Namespace::Catalog& cat,
740  RelAlgDagBuilder& root_dag_builder) {
741  std::vector<std::unique_ptr<const RexScalar>> exprs;
742  for (auto it = arr.Begin(); it != arr.End(); ++it) {
743  exprs.emplace_back(parse_scalar_expr(*it, cat, root_dag_builder));
744  }
745  return exprs;
746 }
747 
749  if (name == "ROW_NUMBER") {
751  }
752  if (name == "RANK") {
754  }
755  if (name == "DENSE_RANK") {
757  }
758  if (name == "PERCENT_RANK") {
760  }
761  if (name == "CUME_DIST") {
763  }
764  if (name == "NTILE") {
766  }
767  if (name == "LAG") {
769  }
770  if (name == "LEAD") {
772  }
773  if (name == "FIRST_VALUE") {
775  }
776  if (name == "LAST_VALUE") {
778  }
779  if (name == "AVG") {
781  }
782  if (name == "MIN") {
784  }
785  if (name == "MAX") {
787  }
788  if (name == "SUM") {
790  }
791  if (name == "COUNT") {
793  }
794  if (name == "$SUM0") {
796  }
797  throw std::runtime_error("Unsupported window function: " + name);
798 }
799 
800 std::vector<std::unique_ptr<const RexScalar>> parse_window_order_exprs(
801  const rapidjson::Value& arr,
802  const Catalog_Namespace::Catalog& cat,
803  RelAlgDagBuilder& root_dag_builder) {
804  std::vector<std::unique_ptr<const RexScalar>> exprs;
805  for (auto it = arr.Begin(); it != arr.End(); ++it) {
806  exprs.emplace_back(parse_scalar_expr(field(*it, "field"), cat, root_dag_builder));
807  }
808  return exprs;
809 }
810 
811 SortDirection parse_sort_direction(const rapidjson::Value& collation) {
812  return json_str(field(collation, "direction")) == std::string("DESCENDING")
815 }
816 
817 NullSortedPosition parse_nulls_position(const rapidjson::Value& collation) {
818  return json_str(field(collation, "nulls")) == std::string("FIRST")
821 }
822 
823 std::vector<SortField> parse_window_order_collation(const rapidjson::Value& arr,
824  const Catalog_Namespace::Catalog& cat,
825  RelAlgDagBuilder& root_dag_builder) {
826  std::vector<SortField> collation;
827  size_t field_idx = 0;
828  for (auto it = arr.Begin(); it != arr.End(); ++it, ++field_idx) {
829  const auto sort_dir = parse_sort_direction(*it);
830  const auto null_pos = parse_nulls_position(*it);
831  collation.emplace_back(field_idx, sort_dir, null_pos);
832  }
833  return collation;
834 }
835 
837  const rapidjson::Value& window_bound_obj,
838  const Catalog_Namespace::Catalog& cat,
839  RelAlgDagBuilder& root_dag_builder) {
840  CHECK(window_bound_obj.IsObject());
842  window_bound.unbounded = json_bool(field(window_bound_obj, "unbounded"));
843  window_bound.preceding = json_bool(field(window_bound_obj, "preceding"));
844  window_bound.following = json_bool(field(window_bound_obj, "following"));
845  window_bound.is_current_row = json_bool(field(window_bound_obj, "is_current_row"));
846  const auto& offset_field = field(window_bound_obj, "offset");
847  if (offset_field.IsObject()) {
848  window_bound.offset = parse_scalar_expr(offset_field, cat, root_dag_builder);
849  } else {
850  CHECK(offset_field.IsNull());
851  }
852  window_bound.order_key = json_i64(field(window_bound_obj, "order_key"));
853  return window_bound;
854 }
855 
856 std::unique_ptr<const RexSubQuery> parse_subquery(const rapidjson::Value& expr,
857  const Catalog_Namespace::Catalog& cat,
858  RelAlgDagBuilder& root_dag_builder) {
859  const auto& operands = field(expr, "operands");
860  CHECK(operands.IsArray());
861  CHECK_GE(operands.Size(), unsigned(0));
862  const auto& subquery_ast = field(expr, "subquery");
863 
864  RelAlgDagBuilder subquery_dag(root_dag_builder, subquery_ast, cat, nullptr);
865  auto subquery = std::make_shared<RexSubQuery>(subquery_dag.getRootNodeShPtr());
866  root_dag_builder.registerSubquery(subquery);
867  return subquery->deepCopy();
868 }
869 
870 std::unique_ptr<RexOperator> parse_operator(const rapidjson::Value& expr,
871  const Catalog_Namespace::Catalog& cat,
872  RelAlgDagBuilder& root_dag_builder) {
873  const auto op_name = json_str(field(expr, "op"));
874  const bool is_quantifier =
875  op_name == std::string("PG_ANY") || op_name == std::string("PG_ALL");
876  const auto op = is_quantifier ? kFUNCTION : to_sql_op(op_name);
877  const auto& operators_json_arr = field(expr, "operands");
878  CHECK(operators_json_arr.IsArray());
879  auto operands = parse_expr_array(operators_json_arr, cat, root_dag_builder);
880  const auto type_it = expr.FindMember("type");
881  CHECK(type_it != expr.MemberEnd());
882  auto ti = parse_type(type_it->value);
883  if (op == kIN && expr.HasMember("subquery")) {
884  auto subquery = parse_subquery(expr, cat, root_dag_builder);
885  operands.emplace_back(std::move(subquery));
886  }
887  if (expr.FindMember("partition_keys") != expr.MemberEnd()) {
888  const auto& partition_keys_arr = field(expr, "partition_keys");
889  auto partition_keys = parse_expr_array(partition_keys_arr, cat, root_dag_builder);
890  const auto& order_keys_arr = field(expr, "order_keys");
891  auto order_keys = parse_window_order_exprs(order_keys_arr, cat, root_dag_builder);
892  const auto collation =
893  parse_window_order_collation(order_keys_arr, cat, root_dag_builder);
894  const auto kind = parse_window_function_kind(op_name);
895  const auto lower_bound =
896  parse_window_bound(field(expr, "lower_bound"), cat, root_dag_builder);
897  const auto upper_bound =
898  parse_window_bound(field(expr, "upper_bound"), cat, root_dag_builder);
899  bool is_rows = json_bool(field(expr, "is_rows"));
900  ti.set_notnull(false);
901  return std::make_unique<RexWindowFunctionOperator>(kind,
902  operands,
903  partition_keys,
904  order_keys,
905  collation,
906  lower_bound,
907  upper_bound,
908  is_rows,
909  ti);
910  }
911  return std::unique_ptr<RexOperator>(op == kFUNCTION
912  ? new RexFunctionOperator(op_name, operands, ti)
913  : new RexOperator(op, operands, ti));
914 }
915 
916 std::unique_ptr<RexCase> parse_case(const rapidjson::Value& expr,
917  const Catalog_Namespace::Catalog& cat,
918  RelAlgDagBuilder& root_dag_builder) {
919  const auto& operands = field(expr, "operands");
920  CHECK(operands.IsArray());
921  CHECK_GE(operands.Size(), unsigned(2));
922  std::unique_ptr<const RexScalar> else_expr;
923  std::vector<
924  std::pair<std::unique_ptr<const RexScalar>, std::unique_ptr<const RexScalar>>>
925  expr_pair_list;
926  for (auto operands_it = operands.Begin(); operands_it != operands.End();) {
927  auto when_expr = parse_scalar_expr(*operands_it++, cat, root_dag_builder);
928  if (operands_it == operands.End()) {
929  else_expr = std::move(when_expr);
930  break;
931  }
932  auto then_expr = parse_scalar_expr(*operands_it++, cat, root_dag_builder);
933  expr_pair_list.emplace_back(std::move(when_expr), std::move(then_expr));
934  }
935  return std::unique_ptr<RexCase>(new RexCase(expr_pair_list, else_expr));
936 }
937 
938 std::vector<std::string> strings_from_json_array(
939  const rapidjson::Value& json_str_arr) noexcept {
940  CHECK(json_str_arr.IsArray());
941  std::vector<std::string> fields;
942  for (auto json_str_arr_it = json_str_arr.Begin(); json_str_arr_it != json_str_arr.End();
943  ++json_str_arr_it) {
944  CHECK(json_str_arr_it->IsString());
945  fields.emplace_back(json_str_arr_it->GetString());
946  }
947  return fields;
948 }
949 
950 std::vector<size_t> indices_from_json_array(
951  const rapidjson::Value& json_idx_arr) noexcept {
952  CHECK(json_idx_arr.IsArray());
953  std::vector<size_t> indices;
954  for (auto json_idx_arr_it = json_idx_arr.Begin(); json_idx_arr_it != json_idx_arr.End();
955  ++json_idx_arr_it) {
956  CHECK(json_idx_arr_it->IsInt());
957  CHECK_GE(json_idx_arr_it->GetInt(), 0);
958  indices.emplace_back(json_idx_arr_it->GetInt());
959  }
960  return indices;
961 }
962 
963 std::unique_ptr<const RexAgg> parse_aggregate_expr(const rapidjson::Value& expr) {
964  const auto agg = to_agg_kind(json_str(field(expr, "agg")));
965  const auto distinct = json_bool(field(expr, "distinct"));
966  const auto agg_ti = parse_type(field(expr, "type"));
967  const auto operands = indices_from_json_array(field(expr, "operands"));
968  if (operands.size() > 1 && (operands.size() != 2 || agg != kAPPROX_COUNT_DISTINCT)) {
969  throw QueryNotSupported("Multiple arguments for aggregates aren't supported");
970  }
971  return std::unique_ptr<const RexAgg>(new RexAgg(agg, distinct, agg_ti, operands));
972 }
973 
974 std::unique_ptr<const RexScalar> parse_scalar_expr(const rapidjson::Value& expr,
975  const Catalog_Namespace::Catalog& cat,
976  RelAlgDagBuilder& root_dag_builder) {
977  CHECK(expr.IsObject());
978  if (expr.IsObject() && expr.HasMember("input")) {
979  return std::unique_ptr<const RexScalar>(parse_abstract_input(expr));
980  }
981  if (expr.IsObject() && expr.HasMember("literal")) {
982  return std::unique_ptr<const RexScalar>(parse_literal(expr));
983  }
984  if (expr.IsObject() && expr.HasMember("op")) {
985  const auto op_str = json_str(field(expr, "op"));
986  if (op_str == std::string("CASE")) {
987  return std::unique_ptr<const RexScalar>(parse_case(expr, cat, root_dag_builder));
988  }
989  if (op_str == std::string("$SCALAR_QUERY")) {
990  return std::unique_ptr<const RexScalar>(
991  parse_subquery(expr, cat, root_dag_builder));
992  }
993  return std::unique_ptr<const RexScalar>(parse_operator(expr, cat, root_dag_builder));
994  }
995  throw QueryNotSupported("Expression node " + json_node_to_string(expr) +
996  " not supported");
997 }
998 
999 JoinType to_join_type(const std::string& join_type_name) {
1000  if (join_type_name == "inner") {
1001  return JoinType::INNER;
1002  }
1003  if (join_type_name == "left") {
1004  return JoinType::LEFT;
1005  }
1006  throw QueryNotSupported("Join type (" + join_type_name + ") not supported");
1007 }
1008 
1009 std::unique_ptr<const RexScalar> disambiguate_rex(const RexScalar*, const RANodeOutput&);
1010 
1011 std::unique_ptr<const RexOperator> disambiguate_operator(
1012  const RexOperator* rex_operator,
1013  const RANodeOutput& ra_output) noexcept {
1014  std::vector<std::unique_ptr<const RexScalar>> disambiguated_operands;
1015  for (size_t i = 0; i < rex_operator->size(); ++i) {
1016  auto operand = rex_operator->getOperand(i);
1017  if (dynamic_cast<const RexSubQuery*>(operand)) {
1018  disambiguated_operands.emplace_back(rex_operator->getOperandAndRelease(i));
1019  } else {
1020  disambiguated_operands.emplace_back(disambiguate_rex(operand, ra_output));
1021  }
1022  }
1023  const auto rex_window_function_operator =
1024  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
1025  if (rex_window_function_operator) {
1026  const auto& partition_keys = rex_window_function_operator->getPartitionKeys();
1027  std::vector<std::unique_ptr<const RexScalar>> disambiguated_partition_keys;
1028  for (const auto& partition_key : partition_keys) {
1029  disambiguated_partition_keys.emplace_back(
1030  disambiguate_rex(partition_key.get(), ra_output));
1031  }
1032  std::vector<std::unique_ptr<const RexScalar>> disambiguated_order_keys;
1033  const auto& order_keys = rex_window_function_operator->getOrderKeys();
1034  for (const auto& order_key : order_keys) {
1035  disambiguated_order_keys.emplace_back(disambiguate_rex(order_key.get(), ra_output));
1036  }
1037  return rex_window_function_operator->disambiguatedOperands(
1038  disambiguated_operands,
1039  disambiguated_partition_keys,
1040  disambiguated_order_keys,
1041  rex_window_function_operator->getCollation());
1042  }
1043  return rex_operator->getDisambiguated(disambiguated_operands);
1044 }
1045 
1046 std::unique_ptr<const RexCase> disambiguate_case(const RexCase* rex_case,
1047  const RANodeOutput& ra_output) {
1048  std::vector<
1049  std::pair<std::unique_ptr<const RexScalar>, std::unique_ptr<const RexScalar>>>
1050  disambiguated_expr_pair_list;
1051  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
1052  auto disambiguated_when = disambiguate_rex(rex_case->getWhen(i), ra_output);
1053  auto disambiguated_then = disambiguate_rex(rex_case->getThen(i), ra_output);
1054  disambiguated_expr_pair_list.emplace_back(std::move(disambiguated_when),
1055  std::move(disambiguated_then));
1056  }
1057  std::unique_ptr<const RexScalar> disambiguated_else{
1058  disambiguate_rex(rex_case->getElse(), ra_output)};
1059  return std::unique_ptr<const RexCase>(
1060  new RexCase(disambiguated_expr_pair_list, disambiguated_else));
1061 }
1062 
1063 // The inputs used by scalar expressions are given as indices in the serialized
1064 // representation of the query. This is hard to navigate; make the relationship
1065 // explicit by creating RexInput expressions which hold a pointer to the source
1066 // relational algebra node and the index relative to the output of that node.
1067 std::unique_ptr<const RexScalar> disambiguate_rex(const RexScalar* rex_scalar,
1068  const RANodeOutput& ra_output) {
1069  const auto rex_abstract_input = dynamic_cast<const RexAbstractInput*>(rex_scalar);
1070  if (rex_abstract_input) {
1071  CHECK_LT(static_cast<size_t>(rex_abstract_input->getIndex()), ra_output.size());
1072  return std::unique_ptr<const RexInput>(
1073  new RexInput(ra_output[rex_abstract_input->getIndex()]));
1074  }
1075  const auto rex_operator = dynamic_cast<const RexOperator*>(rex_scalar);
1076  if (rex_operator) {
1077  return disambiguate_operator(rex_operator, ra_output);
1078  }
1079  const auto rex_case = dynamic_cast<const RexCase*>(rex_scalar);
1080  if (rex_case) {
1081  return disambiguate_case(rex_case, ra_output);
1082  }
1083  const auto rex_literal = dynamic_cast<const RexLiteral*>(rex_scalar);
1084  CHECK(rex_literal);
1085  return std::unique_ptr<const RexLiteral>(new RexLiteral(*rex_literal));
1086 }
1087 
1088 void bind_project_to_input(RelProject* project_node, const RANodeOutput& input) noexcept {
1089  CHECK_EQ(size_t(1), project_node->inputCount());
1090  std::vector<std::unique_ptr<const RexScalar>> disambiguated_exprs;
1091  for (size_t i = 0; i < project_node->size(); ++i) {
1092  const auto projected_expr = project_node->getProjectAt(i);
1093  if (dynamic_cast<const RexSubQuery*>(projected_expr)) {
1094  disambiguated_exprs.emplace_back(project_node->getProjectAtAndRelease(i));
1095  } else {
1096  disambiguated_exprs.emplace_back(disambiguate_rex(projected_expr, input));
1097  }
1098  }
1099  project_node->setExpressions(disambiguated_exprs);
1100 }
1101 
1103  const RANodeOutput& input) noexcept {
1104  CHECK_EQ(size_t(1), table_func_node->inputCount());
1105  std::vector<std::unique_ptr<const RexScalar>> disambiguated_exprs;
1106  for (size_t i = 0; i < table_func_node->getTableFuncInputsSize(); ++i) {
1107  const auto target_expr = table_func_node->getTableFuncInputAt(i);
1108  if (dynamic_cast<const RexSubQuery*>(target_expr)) {
1109  disambiguated_exprs.emplace_back(table_func_node->getTableFuncInputAtAndRelease(i));
1110  } else {
1111  disambiguated_exprs.emplace_back(disambiguate_rex(target_expr, input));
1112  }
1113  }
1114  table_func_node->setTableFuncInputs(disambiguated_exprs);
1115 }
1116 
1117 void bind_inputs(const std::vector<std::shared_ptr<RelAlgNode>>& nodes) noexcept {
1118  for (auto ra_node : nodes) {
1119  const auto filter_node = std::dynamic_pointer_cast<RelFilter>(ra_node);
1120  if (filter_node) {
1121  CHECK_EQ(size_t(1), filter_node->inputCount());
1122  auto disambiguated_condition = disambiguate_rex(
1123  filter_node->getCondition(), get_node_output(filter_node->getInput(0)));
1124  filter_node->setCondition(disambiguated_condition);
1125  continue;
1126  }
1127  const auto join_node = std::dynamic_pointer_cast<RelJoin>(ra_node);
1128  if (join_node) {
1129  CHECK_EQ(size_t(2), join_node->inputCount());
1130  auto disambiguated_condition =
1131  disambiguate_rex(join_node->getCondition(), get_node_output(join_node.get()));
1132  join_node->setCondition(disambiguated_condition);
1133  continue;
1134  }
1135  const auto project_node = std::dynamic_pointer_cast<RelProject>(ra_node);
1136  if (project_node) {
1137  bind_project_to_input(project_node.get(),
1138  get_node_output(project_node->getInput(0)));
1139  continue;
1140  }
1141  const auto table_func_node = std::dynamic_pointer_cast<RelTableFunction>(ra_node);
1142  if (table_func_node) {
1143  bind_table_func_to_input(table_func_node.get(),
1144  get_node_output(table_func_node->getInput(0)));
1145  }
1146  }
1147 }
1148 
1149 void mark_nops(const std::vector<std::shared_ptr<RelAlgNode>>& nodes) noexcept {
1150  for (auto node : nodes) {
1151  const auto agg_node = std::dynamic_pointer_cast<RelAggregate>(node);
1152  if (!agg_node || agg_node->getAggExprsCount()) {
1153  continue;
1154  }
1155  CHECK_EQ(size_t(1), node->inputCount());
1156  const auto agg_input_node = dynamic_cast<const RelAggregate*>(node->getInput(0));
1157  if (agg_input_node && !agg_input_node->getAggExprsCount() &&
1158  agg_node->getGroupByCount() == agg_input_node->getGroupByCount()) {
1159  agg_node->markAsNop();
1160  }
1161  }
1162 }
1163 
1164 namespace {
1165 
1166 std::vector<const Rex*> reproject_targets(
1167  const RelProject* simple_project,
1168  const std::vector<const Rex*>& target_exprs) noexcept {
1169  std::vector<const Rex*> result;
1170  for (size_t i = 0; i < simple_project->size(); ++i) {
1171  const auto input_rex = dynamic_cast<const RexInput*>(simple_project->getProjectAt(i));
1172  CHECK(input_rex);
1173  CHECK_LT(static_cast<size_t>(input_rex->getIndex()), target_exprs.size());
1174  result.push_back(target_exprs[input_rex->getIndex()]);
1175  }
1176  return result;
1177 }
1178 
1185  public:
1187  const RelAlgNode* node_to_keep,
1188  const std::vector<std::unique_ptr<const RexScalar>>& scalar_sources)
1189  : node_to_keep_(node_to_keep), scalar_sources_(scalar_sources) {}
1190 
1191  // Reproject the RexInput from its current RA Node to the RA Node we intend to keep
1192  RetType visitInput(const RexInput* input) const final {
1193  if (input->getSourceNode() == node_to_keep_) {
1194  const auto index = input->getIndex();
1195  CHECK_LT(index, scalar_sources_.size());
1196  return visit(scalar_sources_[index].get());
1197  } else {
1198  return input->deepCopy();
1199  }
1200  }
1201 
1202  private:
1204  const std::vector<std::unique_ptr<const RexScalar>>& scalar_sources_;
1205 };
1206 
1207 } // namespace
1208 
1209 void create_compound(std::vector<std::shared_ptr<RelAlgNode>>& nodes,
1210  const std::vector<size_t>& pattern) noexcept {
1211  CHECK_GE(pattern.size(), size_t(2));
1212  CHECK_LE(pattern.size(), size_t(4));
1213 
1214  std::unique_ptr<const RexScalar> filter_rex;
1215  std::vector<std::unique_ptr<const RexScalar>> scalar_sources;
1216  size_t groupby_count{0};
1217  std::vector<std::string> fields;
1218  std::vector<const RexAgg*> agg_exprs;
1219  std::vector<const Rex*> target_exprs;
1220  bool first_project{true};
1221  bool is_agg{false};
1222  RelAlgNode* last_node{nullptr};
1223 
1224  std::shared_ptr<ModifyManipulationTarget> manipulation_target;
1225 
1226  for (const auto node_idx : pattern) {
1227  const auto ra_node = nodes[node_idx];
1228  const auto ra_filter = std::dynamic_pointer_cast<RelFilter>(ra_node);
1229  if (ra_filter) {
1230  CHECK(!filter_rex);
1231  filter_rex.reset(ra_filter->getAndReleaseCondition());
1232  CHECK(filter_rex);
1233  last_node = ra_node.get();
1234  continue;
1235  }
1236  const auto ra_project = std::dynamic_pointer_cast<RelProject>(ra_node);
1237  if (ra_project) {
1238  fields = ra_project->getFields();
1239  manipulation_target = ra_project;
1240 
1241  if (first_project) {
1242  CHECK_EQ(size_t(1), ra_project->inputCount());
1243  // Rebind the input of the project to the input of the filter itself
1244  // since we know that we'll evaluate the filter on the fly, with no
1245  // intermediate buffer.
1246  const auto filter_input = dynamic_cast<const RelFilter*>(ra_project->getInput(0));
1247  if (filter_input) {
1248  CHECK_EQ(size_t(1), filter_input->inputCount());
1249  bind_project_to_input(ra_project.get(),
1250  get_node_output(filter_input->getInput(0)));
1251  }
1252  scalar_sources = ra_project->getExpressionsAndRelease();
1253  for (const auto& scalar_expr : scalar_sources) {
1254  target_exprs.push_back(scalar_expr.get());
1255  }
1256  first_project = false;
1257  } else {
1258  if (ra_project->isSimple()) {
1259  target_exprs = reproject_targets(ra_project.get(), target_exprs);
1260  } else {
1261  // TODO(adb): This is essentially a more general case of simple project, we
1262  // could likely merge the two
1263  std::vector<const Rex*> result;
1264  RexInputReplacementVisitor visitor(last_node, scalar_sources);
1265  for (size_t i = 0; i < ra_project->size(); ++i) {
1266  const auto rex = ra_project->getProjectAt(i);
1267  if (auto rex_input = dynamic_cast<const RexInput*>(rex)) {
1268  const auto index = rex_input->getIndex();
1269  CHECK_LT(index, target_exprs.size());
1270  result.push_back(target_exprs[index]);
1271  } else {
1272  scalar_sources.push_back(visitor.visit(rex));
1273  result.push_back(scalar_sources.back().get());
1274  }
1275  }
1276  target_exprs = result;
1277  }
1278  }
1279  last_node = ra_node.get();
1280  continue;
1281  }
1282  const auto ra_aggregate = std::dynamic_pointer_cast<RelAggregate>(ra_node);
1283  if (ra_aggregate) {
1284  is_agg = true;
1285  fields = ra_aggregate->getFields();
1286  agg_exprs = ra_aggregate->getAggregatesAndRelease();
1287  groupby_count = ra_aggregate->getGroupByCount();
1288  decltype(target_exprs){}.swap(target_exprs);
1289  CHECK_LE(groupby_count, scalar_sources.size());
1290  for (size_t group_idx = 0; group_idx < groupby_count; ++group_idx) {
1291  const auto rex_ref = new RexRef(group_idx + 1);
1292  target_exprs.push_back(rex_ref);
1293  scalar_sources.emplace_back(rex_ref);
1294  }
1295  for (const auto rex_agg : agg_exprs) {
1296  target_exprs.push_back(rex_agg);
1297  }
1298  last_node = ra_node.get();
1299  continue;
1300  }
1301  }
1302 
1303  auto compound_node =
1304  std::make_shared<RelCompound>(filter_rex,
1305  target_exprs,
1306  groupby_count,
1307  agg_exprs,
1308  fields,
1309  scalar_sources,
1310  is_agg,
1311  manipulation_target->isUpdateViaSelect(),
1312  manipulation_target->isDeleteViaSelect(),
1313  manipulation_target->isVarlenUpdateRequired(),
1314  manipulation_target->getModifiedTableDescriptor(),
1315  manipulation_target->getTargetColumns());
1316  auto old_node = nodes[pattern.back()];
1317  nodes[pattern.back()] = compound_node;
1318  auto first_node = nodes[pattern.front()];
1319  CHECK_EQ(size_t(1), first_node->inputCount());
1320  compound_node->addManagedInput(first_node->getAndOwnInput(0));
1321  for (size_t i = 0; i < pattern.size() - 1; ++i) {
1322  nodes[pattern[i]].reset();
1323  }
1324  for (auto node : nodes) {
1325  if (!node) {
1326  continue;
1327  }
1328  node->replaceInput(old_node, compound_node);
1329  }
1330 }
1331 
1332 class RANodeIterator : public std::vector<std::shared_ptr<RelAlgNode>>::const_iterator {
1333  using ElementType = std::shared_ptr<RelAlgNode>;
1334  using Super = std::vector<ElementType>::const_iterator;
1335  using Container = std::vector<ElementType>;
1336 
1337  public:
1338  enum class AdvancingMode { DUChain, InOrder };
1339 
1340  explicit RANodeIterator(const Container& nodes)
1341  : Super(nodes.begin()), owner_(nodes), nodeCount_([&nodes]() -> size_t {
1342  size_t non_zero_count = 0;
1343  for (const auto& node : nodes) {
1344  if (node) {
1345  ++non_zero_count;
1346  }
1347  }
1348  return non_zero_count;
1349  }()) {}
1350 
1351  explicit operator size_t() {
1352  return std::distance(owner_.begin(), *static_cast<Super*>(this));
1353  }
1354 
1355  RANodeIterator operator++() = delete;
1356 
1357  void advance(AdvancingMode mode) {
1358  Super& super = *this;
1359  switch (mode) {
1360  case AdvancingMode::DUChain: {
1361  size_t use_count = 0;
1362  Super only_use = owner_.end();
1363  for (Super nodeIt = std::next(super); nodeIt != owner_.end(); ++nodeIt) {
1364  if (!*nodeIt) {
1365  continue;
1366  }
1367  for (size_t i = 0; i < (*nodeIt)->inputCount(); ++i) {
1368  if ((*super) == (*nodeIt)->getAndOwnInput(i)) {
1369  ++use_count;
1370  if (1 == use_count) {
1371  only_use = nodeIt;
1372  } else {
1373  super = owner_.end();
1374  return;
1375  }
1376  }
1377  }
1378  }
1379  super = only_use;
1380  break;
1381  }
1382  case AdvancingMode::InOrder:
1383  for (size_t i = 0; i != owner_.size(); ++i) {
1384  if (!visited_.count(i)) {
1385  super = owner_.begin();
1386  std::advance(super, i);
1387  return;
1388  }
1389  }
1390  super = owner_.end();
1391  break;
1392  default:
1393  CHECK(false);
1394  }
1395  }
1396 
1397  bool allVisited() { return visited_.size() == nodeCount_; }
1398 
1400  visited_.insert(size_t(*this));
1401  Super& super = *this;
1402  return *super;
1403  }
1404 
1405  const ElementType* operator->() { return &(operator*()); }
1406 
1407  private:
1409  const size_t nodeCount_;
1410  std::unordered_set<size_t> visited_;
1411 };
1412 
1413 namespace {
1414 
1415 bool input_can_be_coalesced(const RelAlgNode* parent_node,
1416  const size_t index,
1417  const bool first_rex_is_input) {
1418  if (auto agg_node = dynamic_cast<const RelAggregate*>(parent_node)) {
1419  if (index == 0 && agg_node->getGroupByCount() > 0) {
1420  return true;
1421  } else {
1422  // Is an aggregated target, only allow the project to be elided if the aggregate
1423  // target is simply passed through (i.e. if the top level expression attached to
1424  // the project node is a RexInput expression)
1425  return first_rex_is_input;
1426  }
1427  }
1428  return first_rex_is_input;
1429 }
1430 
1437  public:
1438  bool visitInput(const RexInput* input) const final {
1439  // The top level expression node is checked before we apply the visitor. If we get
1440  // here, this input rex is a child of another rex node, and we handle the can be
1441  // coalesced check slightly differently
1442  return input_can_be_coalesced(input->getSourceNode(), input->getIndex(), false);
1443  }
1444 
1445  bool visitLiteral(const RexLiteral*) const final { return false; }
1446 
1447  bool visitSubQuery(const RexSubQuery*) const final { return false; }
1448 
1449  bool visitRef(const RexRef*) const final { return false; }
1450 
1451  protected:
1452  bool aggregateResult(const bool& aggregate, const bool& next_result) const final {
1453  return aggregate && next_result;
1454  }
1455 
1456  bool defaultResult() const final { return true; }
1457 };
1458 
1459 // Detect the window function SUM pattern: CASE WHEN COUNT() > 0 THEN SUM ELSE 0
1461  const auto case_operator = dynamic_cast<const RexCase*>(rex);
1462  if (case_operator && case_operator->branchCount() == 1) {
1463  const auto then_window =
1464  dynamic_cast<const RexWindowFunctionOperator*>(case_operator->getThen(0));
1465  if (then_window && then_window->getKind() == SqlWindowFunctionKind::SUM_INTERNAL) {
1466  return true;
1467  }
1468  }
1469  return false;
1470 }
1471 
1472 // Detect both window function operators and window function operators embedded in case
1473 // statements (for null handling)
1475  if (dynamic_cast<const RexWindowFunctionOperator*>(rex)) {
1476  return true;
1477  }
1478 
1479  // unwrap from casts, if they exist
1480  const auto rex_cast = dynamic_cast<const RexOperator*>(rex);
1481  if (rex_cast && rex_cast->getOperator() == kCAST) {
1482  CHECK_EQ(rex_cast->size(), size_t(1));
1483  return is_window_function_operator(rex_cast->getOperand(0));
1484  }
1485 
1486  if (is_window_function_sum(rex)) {
1487  return true;
1488  }
1489  // Check for Window Function AVG:
1490  // (CASE WHEN count > 0 THEN sum ELSE 0) / COUNT
1491  const RexOperator* divide_operator = dynamic_cast<const RexOperator*>(rex);
1492  if (divide_operator && divide_operator->getOperator() == kDIVIDE) {
1493  CHECK_EQ(divide_operator->size(), size_t(2));
1494  const auto case_operator =
1495  dynamic_cast<const RexCase*>(divide_operator->getOperand(0));
1496  const auto second_window =
1497  dynamic_cast<const RexWindowFunctionOperator*>(divide_operator->getOperand(1));
1498  if (case_operator && second_window &&
1499  second_window->getKind() == SqlWindowFunctionKind::COUNT) {
1500  if (is_window_function_sum(case_operator)) {
1501  return true;
1502  }
1503  }
1504  }
1505  return false;
1506 }
1507 
1508 } // namespace
1509 
1510 void coalesce_nodes(std::vector<std::shared_ptr<RelAlgNode>>& nodes,
1511  const std::vector<const RelAlgNode*>& left_deep_joins) {
1512  enum class CoalesceState { Initial, Filter, FirstProject, Aggregate };
1513  std::vector<size_t> crt_pattern;
1514  CoalesceState crt_state{CoalesceState::Initial};
1515 
1516  auto reset_state = [&crt_pattern, &crt_state]() {
1517  crt_state = CoalesceState::Initial;
1518  decltype(crt_pattern)().swap(crt_pattern);
1519  };
1520 
1521  for (RANodeIterator nodeIt(nodes); !nodeIt.allVisited();) {
1522  const auto ra_node = nodeIt != nodes.end() ? *nodeIt : nullptr;
1523  switch (crt_state) {
1524  case CoalesceState::Initial: {
1525  if (std::dynamic_pointer_cast<const RelFilter>(ra_node) &&
1526  std::find(left_deep_joins.begin(), left_deep_joins.end(), ra_node.get()) ==
1527  left_deep_joins.end()) {
1528  crt_pattern.push_back(size_t(nodeIt));
1529  crt_state = CoalesceState::Filter;
1530  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1531  } else if (std::dynamic_pointer_cast<const RelProject>(ra_node)) {
1532  crt_pattern.push_back(size_t(nodeIt));
1533  crt_state = CoalesceState::FirstProject;
1534  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1535  } else {
1536  nodeIt.advance(RANodeIterator::AdvancingMode::InOrder);
1537  }
1538  break;
1539  }
1540  case CoalesceState::Filter: {
1541  if (auto project_node = std::dynamic_pointer_cast<const RelProject>(ra_node)) {
1542  if (project_node->hasWindowFunctionExpr()) {
1543  reset_state();
1544  break;
1545  }
1546  crt_pattern.push_back(size_t(nodeIt));
1547  crt_state = CoalesceState::FirstProject;
1548  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1549  } else {
1550  reset_state();
1551  }
1552  break;
1553  }
1554  case CoalesceState::FirstProject: {
1555  if (std::dynamic_pointer_cast<const RelAggregate>(ra_node)) {
1556  crt_pattern.push_back(size_t(nodeIt));
1557  crt_state = CoalesceState::Aggregate;
1558  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1559  } else {
1560  if (crt_pattern.size() >= 2) {
1561  create_compound(nodes, crt_pattern);
1562  }
1563  reset_state();
1564  }
1565  break;
1566  }
1567  case CoalesceState::Aggregate: {
1568  if (auto project_node = std::dynamic_pointer_cast<const RelProject>(ra_node)) {
1569  // TODO(adb): overloading the simple project terminology again here
1570  bool is_simple_project{true};
1571  for (size_t i = 0; i < project_node->size(); i++) {
1572  const auto scalar_rex = project_node->getProjectAt(i);
1573  // If the top level scalar rex is an input node, we can bypass the visitor
1574  if (auto input_rex = dynamic_cast<const RexInput*>(scalar_rex)) {
1576  input_rex->getSourceNode(), input_rex->getIndex(), true)) {
1577  is_simple_project = false;
1578  break;
1579  }
1580  continue;
1581  }
1582  CoalesceSecondaryProjectVisitor visitor;
1583  if (!visitor.visit(project_node->getProjectAt(i))) {
1584  is_simple_project = false;
1585  break;
1586  }
1587  }
1588  if (is_simple_project) {
1589  crt_pattern.push_back(size_t(nodeIt));
1590  nodeIt.advance(RANodeIterator::AdvancingMode::InOrder);
1591  }
1592  }
1593  CHECK_GE(crt_pattern.size(), size_t(2));
1594  create_compound(nodes, crt_pattern);
1595  reset_state();
1596  break;
1597  }
1598  default:
1599  CHECK(false);
1600  }
1601  }
1602  if (crt_state == CoalesceState::FirstProject || crt_state == CoalesceState::Aggregate) {
1603  if (crt_pattern.size() >= 2) {
1604  create_compound(nodes, crt_pattern);
1605  }
1606  CHECK(!crt_pattern.empty());
1607  }
1608 }
1609 
1617 class WindowFunctionDetectionVisitor : public RexVisitor<const RexScalar*> {
1618  protected:
1619  // Detect embedded window function expressions in operators
1620  const RexScalar* visitOperator(const RexOperator* rex_operator) const final {
1621  if (is_window_function_operator(rex_operator)) {
1622  return rex_operator;
1623  }
1624 
1625  const size_t operand_count = rex_operator->size();
1626  for (size_t i = 0; i < operand_count; ++i) {
1627  const auto operand = rex_operator->getOperand(i);
1628  if (is_window_function_operator(operand)) {
1629  // Handle both RexWindowFunctionOperators and window functions built up from
1630  // multiple RexScalar objects (e.g. AVG)
1631  return operand;
1632  }
1633  const auto operandResult = visit(operand);
1634  if (operandResult) {
1635  return operandResult;
1636  }
1637  }
1638 
1639  return defaultResult();
1640  }
1641 
1642  // Detect embedded window function expressions in case statements. Note that this may
1643  // manifest as a nested case statement inside a top level case statement, as some
1644  // window functions (sum, avg) are represented as a case statement. Use the
1645  // is_window_function_operator helper to detect complete window function expressions.
1646  const RexScalar* visitCase(const RexCase* rex_case) const final {
1647  if (is_window_function_operator(rex_case)) {
1648  return rex_case;
1649  }
1650 
1651  auto result = defaultResult();
1652  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
1653  const auto when = rex_case->getWhen(i);
1654  result = is_window_function_operator(when) ? when : visit(when);
1655  if (result) {
1656  return result;
1657  }
1658  const auto then = rex_case->getThen(i);
1659  result = is_window_function_operator(then) ? then : visit(then);
1660  if (result) {
1661  return result;
1662  }
1663  }
1664  if (rex_case->getElse()) {
1665  auto else_expr = rex_case->getElse();
1666  result = is_window_function_operator(else_expr) ? else_expr : visit(else_expr);
1667  }
1668  return result;
1669  }
1670 
1671  const RexScalar* aggregateResult(const RexScalar* const& aggregate,
1672  const RexScalar* const& next_result) const final {
1673  // all methods calling aggregate result should be overriden
1674  UNREACHABLE();
1675  return nullptr;
1676  }
1677 
1678  const RexScalar* defaultResult() const final { return nullptr; }
1679 };
1680 
1690  public:
1691  RexWindowFuncReplacementVisitor(std::unique_ptr<const RexScalar> replacement_rex)
1692  : replacement_rex_(std::move(replacement_rex)) {}
1693 
1694  ~RexWindowFuncReplacementVisitor() { CHECK(replacement_rex_ == nullptr); }
1695 
1696  protected:
1697  RetType visitOperator(const RexOperator* rex_operator) const final {
1698  if (should_replace_operand(rex_operator)) {
1699  return std::move(replacement_rex_);
1700  }
1701 
1702  const auto rex_window_function_operator =
1703  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
1704  if (rex_window_function_operator) {
1705  // Deep copy the embedded window function operator
1706  return visitWindowFunctionOperator(rex_window_function_operator);
1707  }
1708 
1709  const size_t operand_count = rex_operator->size();
1710  std::vector<RetType> new_opnds;
1711  for (size_t i = 0; i < operand_count; ++i) {
1712  const auto operand = rex_operator->getOperand(i);
1713  if (should_replace_operand(operand)) {
1714  new_opnds.push_back(std::move(replacement_rex_));
1715  } else {
1716  new_opnds.emplace_back(visit(rex_operator->getOperand(i)));
1717  }
1718  }
1719  return rex_operator->getDisambiguated(new_opnds);
1720  }
1721 
1722  RetType visitCase(const RexCase* rex_case) const final {
1723  if (should_replace_operand(rex_case)) {
1724  return std::move(replacement_rex_);
1725  }
1726 
1727  std::vector<std::pair<RetType, RetType>> new_pair_list;
1728  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
1729  auto when_operand = rex_case->getWhen(i);
1730  auto then_operand = rex_case->getThen(i);
1731  new_pair_list.emplace_back(
1732  should_replace_operand(when_operand) ? std::move(replacement_rex_)
1733  : visit(when_operand),
1734  should_replace_operand(then_operand) ? std::move(replacement_rex_)
1735  : visit(then_operand));
1736  }
1737  auto new_else = should_replace_operand(rex_case->getElse())
1738  ? std::move(replacement_rex_)
1739  : visit(rex_case->getElse());
1740  return std::make_unique<RexCase>(new_pair_list, new_else);
1741  }
1742 
1743  private:
1744  bool should_replace_operand(const RexScalar* rex) const {
1745  return replacement_rex_ && is_window_function_operator(rex);
1746  }
1747 
1748  mutable std::unique_ptr<const RexScalar> replacement_rex_;
1749 };
1750 
1761  public:
1762  RexInputBackpropagationVisitor(RelProject* node) : node_(node) { CHECK(node_); }
1763 
1764  protected:
1765  RetType visitInput(const RexInput* rex_input) const final {
1766  if (rex_input->getSourceNode() != node_) {
1767  const auto cur_index = rex_input->getIndex();
1768  auto cur_source_node = rex_input->getSourceNode();
1769  std::string field_name = "";
1770  if (auto cur_project_node = dynamic_cast<const RelProject*>(cur_source_node)) {
1771  field_name = cur_project_node->getFieldName(cur_index);
1772  }
1773  node_->appendInput(field_name, rex_input->deepCopy());
1774  return std::make_unique<RexInput>(node_, node_->size() - 1);
1775  } else {
1776  return rex_input->deepCopy();
1777  }
1778  }
1779 
1780  private:
1781  mutable RelProject* node_;
1782 };
1783 
1800  std::vector<std::shared_ptr<RelAlgNode>>& nodes) {
1801  std::list<std::shared_ptr<RelAlgNode>> node_list(nodes.begin(), nodes.end());
1802 
1804  for (auto node_itr = node_list.begin(); node_itr != node_list.end(); ++node_itr) {
1805  const auto node = *node_itr;
1806  auto window_func_project_node = std::dynamic_pointer_cast<RelProject>(node);
1807  if (!window_func_project_node) {
1808  continue;
1809  }
1810 
1811  // map scalar expression index in the project node to wiondow function ptr
1812  std::unordered_map<size_t, const RexScalar*> embedded_window_function_expressions;
1813 
1814  // Iterate the target exprs of the project node and check for window function
1815  // expressions. If an embedded expression exists, save it in the
1816  // embedded_window_function_expressions map and split the expression into a window
1817  // function expression and a parent expression in a subsequent project node
1818  for (size_t i = 0; i < window_func_project_node->size(); i++) {
1819  const auto scalar_rex = window_func_project_node->getProjectAt(i);
1820  if (is_window_function_operator(scalar_rex)) {
1821  // top level window function exprs are fine
1822  continue;
1823  }
1824 
1825  if (const auto window_func_rex = visitor.visit(scalar_rex)) {
1826  const auto ret = embedded_window_function_expressions.insert(
1827  std::make_pair(i, window_func_rex));
1828  CHECK(ret.second);
1829  }
1830  }
1831 
1832  if (!embedded_window_function_expressions.empty()) {
1833  std::vector<std::unique_ptr<const RexScalar>> new_scalar_exprs;
1834 
1835  auto window_func_scalar_exprs =
1836  window_func_project_node->getExpressionsAndRelease();
1837  for (size_t rex_idx = 0; rex_idx < window_func_scalar_exprs.size(); ++rex_idx) {
1838  const auto embedded_window_func_expr_pair =
1839  embedded_window_function_expressions.find(rex_idx);
1840  if (embedded_window_func_expr_pair ==
1841  embedded_window_function_expressions.end()) {
1842  new_scalar_exprs.emplace_back(
1843  std::make_unique<const RexInput>(window_func_project_node.get(), rex_idx));
1844  } else {
1845  const auto window_func_rex_idx = embedded_window_func_expr_pair->first;
1846  CHECK_LT(window_func_rex_idx, window_func_scalar_exprs.size());
1847 
1848  const auto& window_func_rex = embedded_window_func_expr_pair->second;
1849 
1850  RexDeepCopyVisitor copier;
1851  auto window_func_rex_copy = copier.visit(window_func_rex);
1852 
1853  auto window_func_parent_expr =
1854  window_func_scalar_exprs[window_func_rex_idx].get();
1855 
1856  // Replace window func rex with an input rex
1857  auto window_func_result_input = std::make_unique<const RexInput>(
1858  window_func_project_node.get(), window_func_rex_idx);
1859  RexWindowFuncReplacementVisitor replacer(std::move(window_func_result_input));
1860  auto new_parent_rex = replacer.visit(window_func_parent_expr);
1861 
1862  // Put the parent expr in the new scalar exprs
1863  new_scalar_exprs.emplace_back(std::move(new_parent_rex));
1864 
1865  // Put the window func expr in cur scalar exprs
1866  window_func_scalar_exprs[window_func_rex_idx] = std::move(window_func_rex_copy);
1867  }
1868  }
1869 
1870  CHECK_EQ(window_func_scalar_exprs.size(), new_scalar_exprs.size());
1871  window_func_project_node->setExpressions(window_func_scalar_exprs);
1872 
1873  // Ensure any inputs from the node containing the expression (the "new" node)
1874  // exist on the window function project node, e.g. if we had a binary operation
1875  // involving an aggregate value or column not included in the top level
1876  // projection list.
1877  RexInputBackpropagationVisitor input_visitor(window_func_project_node.get());
1878  for (size_t i = 0; i < new_scalar_exprs.size(); i++) {
1879  if (dynamic_cast<const RexInput*>(new_scalar_exprs[i].get())) {
1880  // ignore top level inputs, these were copied directly from the previous
1881  // node
1882  continue;
1883  }
1884  new_scalar_exprs[i] = input_visitor.visit(new_scalar_exprs[i].get());
1885  }
1886 
1887  // Build the new project node and insert it into the list after the project node
1888  // containing the window function
1889  auto new_project =
1890  std::make_shared<RelProject>(new_scalar_exprs,
1891  window_func_project_node->getFields(),
1892  window_func_project_node);
1893  node_list.insert(std::next(node_itr), new_project);
1894 
1895  // Rebind all the following inputs
1896  for (auto rebind_itr = std::next(node_itr, 2); rebind_itr != node_list.end();
1897  rebind_itr++) {
1898  (*rebind_itr)->replaceInput(window_func_project_node, new_project);
1899  }
1900  }
1901  }
1902  nodes.assign(node_list.begin(), node_list.end());
1903 }
1904 
1905 using RexInputSet = std::unordered_set<RexInput>;
1906 
1907 class RexInputCollector : public RexVisitor<RexInputSet> {
1908  public:
1909  RexInputSet visitInput(const RexInput* input) const override {
1910  return RexInputSet{*input};
1911  }
1912 
1913  protected:
1915  const RexInputSet& next_result) const override {
1916  auto result = aggregate;
1917  result.insert(next_result.begin(), next_result.end());
1918  return result;
1919  }
1920 };
1921 
1929 void add_window_function_pre_project(std::vector<std::shared_ptr<RelAlgNode>>& nodes) {
1930  std::list<std::shared_ptr<RelAlgNode>> node_list(nodes.begin(), nodes.end());
1931 
1932  for (auto node_itr = node_list.begin(); node_itr != node_list.end(); ++node_itr) {
1933  const auto node = *node_itr;
1934  auto window_func_project_node = std::dynamic_pointer_cast<RelProject>(node);
1935  if (!window_func_project_node) {
1936  continue;
1937  }
1938  if (!window_func_project_node->hasWindowFunctionExpr()) {
1939  // the first projection node in the query plan does not have a window function
1940  // expression -- this step is not requierd.
1941  return;
1942  }
1943 
1944  const auto prev_node_itr = std::prev(node_itr);
1945  const auto prev_node = *prev_node_itr;
1946  CHECK(prev_node);
1947 
1948  RexInputSet inputs;
1949  RexInputCollector input_collector;
1950  for (size_t i = 0; i < window_func_project_node->size(); i++) {
1951  auto new_inputs = input_collector.visit(window_func_project_node->getProjectAt(i));
1952  inputs.insert(new_inputs.begin(), new_inputs.end());
1953  }
1954 
1955  // Note: Technically not required since we are mapping old inputs to new input
1956  // indices, but makes the re-mapping of inputs easier to follow.
1957  std::vector<RexInput> sorted_inputs(inputs.begin(), inputs.end());
1958  std::sort(sorted_inputs.begin(),
1959  sorted_inputs.end(),
1960  [](const auto& a, const auto& b) { return a.getIndex() < b.getIndex(); });
1961 
1962  std::vector<std::unique_ptr<const RexScalar>> scalar_exprs;
1963  std::vector<std::string> fields;
1964  std::unordered_map<unsigned, unsigned> old_index_to_new_index;
1965  for (auto& input : sorted_inputs) {
1966  CHECK_EQ(input.getSourceNode(), prev_node.get());
1967  CHECK(old_index_to_new_index
1968  .insert(std::make_pair(input.getIndex(), scalar_exprs.size()))
1969  .second);
1970  scalar_exprs.emplace_back(input.deepCopy());
1971  fields.emplace_back("");
1972  }
1973 
1974  auto new_project = std::make_shared<RelProject>(scalar_exprs, fields, prev_node);
1975  node_list.insert(node_itr, new_project);
1976  window_func_project_node->replaceInput(
1977  prev_node, new_project, old_index_to_new_index);
1978 
1979  break;
1980  }
1981 
1982  nodes.assign(node_list.begin(), node_list.end());
1983 }
1984 
1985 int64_t get_int_literal_field(const rapidjson::Value& obj,
1986  const char field[],
1987  const int64_t default_val) noexcept {
1988  const auto it = obj.FindMember(field);
1989  if (it == obj.MemberEnd()) {
1990  return default_val;
1991  }
1992  std::unique_ptr<RexLiteral> lit(parse_literal(it->value));
1993  CHECK_EQ(kDECIMAL, lit->getType());
1994  CHECK_EQ(unsigned(0), lit->getScale());
1995  CHECK_EQ(unsigned(0), lit->getTypeScale());
1996  return lit->getVal<int64_t>();
1997 }
1998 
1999 void check_empty_inputs_field(const rapidjson::Value& node) noexcept {
2000  const auto& inputs_json = field(node, "inputs");
2001  CHECK(inputs_json.IsArray() && !inputs_json.Size());
2002 }
2003 
2005  const rapidjson::Value& scan_ra) {
2006  const auto& table_json = field(scan_ra, "table");
2007  CHECK(table_json.IsArray());
2008  CHECK_EQ(unsigned(2), table_json.Size());
2009  const auto td = cat.getMetadataForTable(table_json[1].GetString());
2010  CHECK(td);
2011  return td;
2012 }
2013 
2014 std::vector<std::string> getFieldNamesFromScanNode(const rapidjson::Value& scan_ra) {
2015  const auto& fields_json = field(scan_ra, "fieldNames");
2016  return strings_from_json_array(fields_json);
2017 }
2018 
2019 } // namespace
2020 
2022  for (const auto& expr : scalar_exprs_) {
2023  if (is_window_function_operator(expr.get())) {
2024  return true;
2025  }
2026  }
2027  return false;
2028 }
2029 namespace details {
2030 
2032  public:
2034 
2035  std::vector<std::shared_ptr<RelAlgNode>> run(const rapidjson::Value& rels,
2036  RelAlgDagBuilder& root_dag_builder) {
2037  for (auto rels_it = rels.Begin(); rels_it != rels.End(); ++rels_it) {
2038  const auto& crt_node = *rels_it;
2039  const auto id = node_id(crt_node);
2040  CHECK_EQ(static_cast<size_t>(id), nodes_.size());
2041  CHECK(crt_node.IsObject());
2042  std::shared_ptr<RelAlgNode> ra_node = nullptr;
2043  const auto rel_op = json_str(field(crt_node, "relOp"));
2044  if (rel_op == std::string("EnumerableTableScan") ||
2045  rel_op == std::string("LogicalTableScan")) {
2046  ra_node = dispatchTableScan(crt_node);
2047  } else if (rel_op == std::string("LogicalProject")) {
2048  ra_node = dispatchProject(crt_node, root_dag_builder);
2049  } else if (rel_op == std::string("LogicalFilter")) {
2050  ra_node = dispatchFilter(crt_node, root_dag_builder);
2051  } else if (rel_op == std::string("LogicalAggregate")) {
2052  ra_node = dispatchAggregate(crt_node);
2053  } else if (rel_op == std::string("LogicalJoin")) {
2054  ra_node = dispatchJoin(crt_node, root_dag_builder);
2055  } else if (rel_op == std::string("LogicalSort")) {
2056  ra_node = dispatchSort(crt_node);
2057  } else if (rel_op == std::string("LogicalValues")) {
2058  ra_node = dispatchLogicalValues(crt_node);
2059  } else if (rel_op == std::string("LogicalTableModify")) {
2060  ra_node = dispatchModify(crt_node);
2061  } else if (rel_op == std::string("LogicalTableFunctionScan")) {
2062  ra_node = dispatchTableFunction(crt_node, root_dag_builder);
2063  } else if (rel_op == std::string("LogicalUnion")) {
2064  ra_node = dispatchUnion(crt_node);
2065  } else {
2066  throw QueryNotSupported(std::string("Node ") + rel_op + " not supported yet");
2067  }
2068  nodes_.push_back(ra_node);
2069  }
2070 
2071  return std::move(nodes_);
2072  }
2073 
2074  private:
2075  std::shared_ptr<RelScan> dispatchTableScan(const rapidjson::Value& scan_ra) {
2076  check_empty_inputs_field(scan_ra);
2077  CHECK(scan_ra.IsObject());
2078  const auto td = getTableFromScanNode(cat_, scan_ra);
2079  const auto field_names = getFieldNamesFromScanNode(scan_ra);
2080  return std::make_shared<RelScan>(td, field_names);
2081  }
2082 
2083  std::shared_ptr<RelProject> dispatchProject(const rapidjson::Value& proj_ra,
2084  RelAlgDagBuilder& root_dag_builder) {
2085  const auto inputs = getRelAlgInputs(proj_ra);
2086  CHECK_EQ(size_t(1), inputs.size());
2087  const auto& exprs_json = field(proj_ra, "exprs");
2088  CHECK(exprs_json.IsArray());
2089  std::vector<std::unique_ptr<const RexScalar>> exprs;
2090  for (auto exprs_json_it = exprs_json.Begin(); exprs_json_it != exprs_json.End();
2091  ++exprs_json_it) {
2092  exprs.emplace_back(parse_scalar_expr(*exprs_json_it, cat_, root_dag_builder));
2093  }
2094  const auto& fields = field(proj_ra, "fields");
2095  return std::make_shared<RelProject>(
2096  exprs, strings_from_json_array(fields), inputs.front());
2097  }
2098 
2099  std::shared_ptr<RelFilter> dispatchFilter(const rapidjson::Value& filter_ra,
2100  RelAlgDagBuilder& root_dag_builder) {
2101  const auto inputs = getRelAlgInputs(filter_ra);
2102  CHECK_EQ(size_t(1), inputs.size());
2103  const auto id = node_id(filter_ra);
2104  CHECK(id);
2105  auto condition =
2106  parse_scalar_expr(field(filter_ra, "condition"), cat_, root_dag_builder);
2107  return std::make_shared<RelFilter>(condition, inputs.front());
2108  }
2109 
2110  std::shared_ptr<RelAggregate> dispatchAggregate(const rapidjson::Value& agg_ra) {
2111  const auto inputs = getRelAlgInputs(agg_ra);
2112  CHECK_EQ(size_t(1), inputs.size());
2113  const auto fields = strings_from_json_array(field(agg_ra, "fields"));
2114  const auto group = indices_from_json_array(field(agg_ra, "group"));
2115  for (size_t i = 0; i < group.size(); ++i) {
2116  CHECK_EQ(i, group[i]);
2117  }
2118  if (agg_ra.HasMember("groups") || agg_ra.HasMember("indicator")) {
2119  throw QueryNotSupported("GROUP BY extensions not supported");
2120  }
2121  const auto& aggs_json_arr = field(agg_ra, "aggs");
2122  CHECK(aggs_json_arr.IsArray());
2123  std::vector<std::unique_ptr<const RexAgg>> aggs;
2124  for (auto aggs_json_arr_it = aggs_json_arr.Begin();
2125  aggs_json_arr_it != aggs_json_arr.End();
2126  ++aggs_json_arr_it) {
2127  aggs.emplace_back(parse_aggregate_expr(*aggs_json_arr_it));
2128  }
2129  return std::make_shared<RelAggregate>(group.size(), aggs, fields, inputs.front());
2130  }
2131 
2132  std::shared_ptr<RelJoin> dispatchJoin(const rapidjson::Value& join_ra,
2133  RelAlgDagBuilder& root_dag_builder) {
2134  const auto inputs = getRelAlgInputs(join_ra);
2135  CHECK_EQ(size_t(2), inputs.size());
2136  const auto join_type = to_join_type(json_str(field(join_ra, "joinType")));
2137  auto filter_rex =
2138  parse_scalar_expr(field(join_ra, "condition"), cat_, root_dag_builder);
2139  return std::make_shared<RelJoin>(inputs[0], inputs[1], filter_rex, join_type);
2140  }
2141 
2142  std::shared_ptr<RelSort> dispatchSort(const rapidjson::Value& sort_ra) {
2143  const auto inputs = getRelAlgInputs(sort_ra);
2144  CHECK_EQ(size_t(1), inputs.size());
2145  std::vector<SortField> collation;
2146  const auto& collation_arr = field(sort_ra, "collation");
2147  CHECK(collation_arr.IsArray());
2148  for (auto collation_arr_it = collation_arr.Begin();
2149  collation_arr_it != collation_arr.End();
2150  ++collation_arr_it) {
2151  const size_t field_idx = json_i64(field(*collation_arr_it, "field"));
2152  const auto sort_dir = parse_sort_direction(*collation_arr_it);
2153  const auto null_pos = parse_nulls_position(*collation_arr_it);
2154  collation.emplace_back(field_idx, sort_dir, null_pos);
2155  }
2156  auto limit = get_int_literal_field(sort_ra, "fetch", -1);
2157  const auto offset = get_int_literal_field(sort_ra, "offset", 0);
2158  auto ret = std::make_shared<RelSort>(
2159  collation, limit > 0 ? limit : 0, offset, inputs.front());
2160  ret->setEmptyResult(limit == 0);
2161  return ret;
2162  }
2163 
2164  std::shared_ptr<RelModify> dispatchModify(const rapidjson::Value& logical_modify_ra) {
2165  const auto inputs = getRelAlgInputs(logical_modify_ra);
2166  CHECK_EQ(size_t(1), inputs.size());
2167 
2168  const auto table_descriptor = getTableFromScanNode(cat_, logical_modify_ra);
2169  if (table_descriptor->isView) {
2170  throw std::runtime_error("UPDATE of a view is unsupported.");
2171  }
2172 
2173  bool flattened = json_bool(field(logical_modify_ra, "flattened"));
2174  std::string op = json_str(field(logical_modify_ra, "operation"));
2175  RelModify::TargetColumnList target_column_list;
2176 
2177  if (op == "UPDATE") {
2178  const auto& update_columns = field(logical_modify_ra, "updateColumnList");
2179  CHECK(update_columns.IsArray());
2180 
2181  for (auto column_arr_it = update_columns.Begin();
2182  column_arr_it != update_columns.End();
2183  ++column_arr_it) {
2184  target_column_list.push_back(column_arr_it->GetString());
2185  }
2186  }
2187 
2188  auto modify_node = std::make_shared<RelModify>(
2189  cat_, table_descriptor, flattened, op, target_column_list, inputs[0]);
2190  switch (modify_node->getOperation()) {
2192  modify_node->applyDeleteModificationsToInputNode();
2193  break;
2194  }
2196  modify_node->applyUpdateModificationsToInputNode();
2197  break;
2198  }
2199  default:
2200  throw std::runtime_error("Unsupported RelModify operation: " +
2201  json_node_to_string(logical_modify_ra));
2202  }
2203 
2204  return modify_node;
2205  }
2206 
2207  std::shared_ptr<RelTableFunction> dispatchTableFunction(
2208  const rapidjson::Value& table_func_ra,
2209  RelAlgDagBuilder& root_dag_builder) {
2210  const auto inputs = getRelAlgInputs(table_func_ra);
2211  CHECK_EQ(size_t(1), inputs.size());
2212 
2213  const auto& invocation = field(table_func_ra, "invocation");
2214  CHECK(invocation.IsObject());
2215 
2216  const auto& operands = field(invocation, "operands");
2217  CHECK(operands.IsArray());
2218  CHECK_GE(operands.Size(), unsigned(0));
2219 
2220  std::vector<const Rex*> col_inputs;
2221  std::vector<std::unique_ptr<const RexScalar>> table_func_inputs;
2222  std::vector<std::string> fields;
2223 
2224  for (auto exprs_json_it = operands.Begin(); exprs_json_it != operands.End();
2225  ++exprs_json_it) {
2226  const auto& expr_json = *exprs_json_it;
2227  CHECK(expr_json.IsObject());
2228 
2229  if (expr_json.HasMember("op")) {
2230  const auto op_str = json_str(field(expr_json, "op"));
2231  if (op_str == "CAST" && expr_json.HasMember("type")) {
2232  const auto& expr_type = field(expr_json, "type");
2233  CHECK(expr_type.IsObject());
2234  CHECK(expr_type.HasMember("type"));
2235  const auto& expr_type_name = json_str(field(expr_type, "type"));
2236  if (expr_type_name == "CURSOR") {
2237  CHECK(expr_json.HasMember("operands"));
2238  const auto& expr_operands = field(expr_json, "operands");
2239  CHECK(expr_operands.IsArray());
2240  if (expr_operands.Size() != 1) {
2241  throw std::runtime_error(
2242  "Table functions currently only support one ResultSet input");
2243  }
2244 
2245  CHECK(expr_json.HasMember("type"));
2246  const auto& expr_types = field(invocation, "type");
2247  CHECK(expr_types.IsArray());
2248 
2249  const auto prior_node = prev(table_func_ra);
2250  CHECK(prior_node);
2251  CHECK_EQ(prior_node->size(), expr_types.Size());
2252 
2253  // Forward the values from the prior node as RexInputs
2254  for (size_t i = 0; i < prior_node->size(); i++) {
2255  table_func_inputs.emplace_back(std::make_unique<RexAbstractInput>(i));
2256  col_inputs.emplace_back(table_func_inputs.back().get());
2257  }
2258  continue;
2259  }
2260  }
2261  }
2262  table_func_inputs.emplace_back(
2263  parse_scalar_expr(*exprs_json_it, cat_, root_dag_builder));
2264  }
2265 
2266  const auto& op_name = field(invocation, "op");
2267  CHECK(op_name.IsString());
2268 
2269  std::vector<std::unique_ptr<const RexScalar>> table_function_projected_outputs;
2270  const auto& row_types = field(table_func_ra, "rowType");
2271  CHECK(row_types.IsArray());
2272  CHECK_GE(row_types.Size(), unsigned(0));
2273  const auto& row_types_array = row_types.GetArray();
2274 
2275  for (size_t i = 0; i < row_types_array.Size(); i++) {
2276  // We don't care about the type information in rowType -- replace each output with
2277  // a reference to be resolved later in the translator
2278  table_function_projected_outputs.emplace_back(std::make_unique<RexRef>(i));
2279  fields.emplace_back("");
2280  }
2281 
2282  return std::make_shared<RelTableFunction>(op_name.GetString(),
2283  inputs[0],
2284  fields,
2285  col_inputs,
2286  table_func_inputs,
2287  table_function_projected_outputs);
2288  }
2289 
2290  std::shared_ptr<RelLogicalValues> dispatchLogicalValues(
2291  const rapidjson::Value& logical_values_ra) {
2292  const auto& tuple_type_arr = field(logical_values_ra, "type");
2293  CHECK(tuple_type_arr.IsArray());
2294  std::vector<TargetMetaInfo> tuple_type;
2295  for (auto tuple_type_arr_it = tuple_type_arr.Begin();
2296  tuple_type_arr_it != tuple_type_arr.End();
2297  ++tuple_type_arr_it) {
2298  const auto component_type = parse_type(*tuple_type_arr_it);
2299  const auto component_name = json_str(field(*tuple_type_arr_it, "name"));
2300  tuple_type.emplace_back(component_name, component_type);
2301  }
2302  const auto& inputs_arr = field(logical_values_ra, "inputs");
2303  CHECK(inputs_arr.IsArray());
2304  const auto& tuples_arr = field(logical_values_ra, "tuples");
2305  CHECK(tuples_arr.IsArray());
2306 
2307  if (inputs_arr.Size()) {
2308  throw QueryNotSupported("Inputs not supported in logical values yet.");
2309  }
2310 
2311  std::vector<RelLogicalValues::RowValues> values;
2312  if (tuples_arr.Size()) {
2313  for (const auto& row : tuples_arr.GetArray()) {
2314  CHECK(row.IsArray());
2315  const auto values_json = row.GetArray();
2316  if (!values.empty()) {
2317  CHECK_EQ(values[0].size(), values_json.Size());
2318  }
2319  values.emplace_back(RelLogicalValues::RowValues{});
2320  for (const auto& value : values_json) {
2321  CHECK(value.IsObject());
2322  CHECK(value.HasMember("literal"));
2323  values.back().emplace_back(parse_literal(value));
2324  }
2325  }
2326  }
2327 
2328  return std::make_shared<RelLogicalValues>(tuple_type, values);
2329  }
2330 
2331  std::shared_ptr<RelLogicalUnion> dispatchUnion(
2332  const rapidjson::Value& logical_union_ra) {
2333  auto inputs = getRelAlgInputs(logical_union_ra);
2334  auto const& all_type_bool = field(logical_union_ra, "all");
2335  CHECK(all_type_bool.IsBool());
2336  return std::make_shared<RelLogicalUnion>(std::move(inputs), all_type_bool.GetBool());
2337  }
2338 
2339  RelAlgInputs getRelAlgInputs(const rapidjson::Value& node) {
2340  if (node.HasMember("inputs")) {
2341  const auto str_input_ids = strings_from_json_array(field(node, "inputs"));
2342  RelAlgInputs ra_inputs;
2343  for (const auto& str_id : str_input_ids) {
2344  ra_inputs.push_back(nodes_[std::stoi(str_id)]);
2345  }
2346  return ra_inputs;
2347  }
2348  return {prev(node)};
2349  }
2350 
2351  std::shared_ptr<const RelAlgNode> prev(const rapidjson::Value& crt_node) {
2352  const auto id = node_id(crt_node);
2353  CHECK(id);
2354  CHECK_EQ(static_cast<size_t>(id), nodes_.size());
2355  return nodes_.back();
2356  }
2357 
2359  std::vector<std::shared_ptr<RelAlgNode>> nodes_;
2360 };
2361 
2362 } // namespace details
2363 
2364 RelAlgDagBuilder::RelAlgDagBuilder(const std::string& query_ra,
2366  const RenderInfo* render_info)
2367  : cat_(cat), render_info_(render_info) {
2368  rapidjson::Document query_ast;
2369  query_ast.Parse(query_ra.c_str());
2370  VLOG(2) << "Parsing query RA JSON: " << query_ra;
2371  if (query_ast.HasParseError()) {
2372  query_ast.GetParseError();
2373  LOG(ERROR) << "Failed to parse RA tree from Calcite (offset "
2374  << query_ast.GetErrorOffset() << "):\n"
2375  << rapidjson::GetParseError_En(query_ast.GetParseError());
2376  VLOG(1) << "Failed to parse query RA: " << query_ra;
2377  throw std::runtime_error(
2378  "Failed to parse relational algebra tree. Possible query syntax error.");
2379  }
2380  CHECK(query_ast.IsObject());
2382  build(query_ast, *this);
2383 }
2384 
2386  const rapidjson::Value& query_ast,
2388  const RenderInfo* render_info)
2389  : cat_(cat), render_info_(render_info) {
2390  build(query_ast, root_dag_builder);
2391 }
2392 
2393 void RelAlgDagBuilder::build(const rapidjson::Value& query_ast,
2394  RelAlgDagBuilder& lead_dag_builder) {
2395  const auto& rels = field(query_ast, "rels");
2396  CHECK(rels.IsArray());
2397  try {
2398  nodes_ = details::RelAlgDispatcher(cat_).run(rels, lead_dag_builder);
2399  } catch (const QueryNotSupported&) {
2400  throw;
2401  }
2402  CHECK(!nodes_.empty());
2404 
2405  if (render_info_) {
2406  // Alter the RA for render. Do this before any flattening/optimizations are done to
2407  // the tree.
2409  }
2410 
2411  mark_nops(nodes_);
2416  std::vector<const RelAlgNode*> filtered_left_deep_joins;
2417  std::vector<const RelAlgNode*> left_deep_joins;
2418  for (const auto& node : nodes_) {
2419  const auto left_deep_join_root = get_left_deep_join_root(node);
2420  // The filter which starts a left-deep join pattern must not be coalesced
2421  // since it contains (part of) the join condition.
2422  if (left_deep_join_root) {
2423  left_deep_joins.push_back(left_deep_join_root.get());
2424  if (std::dynamic_pointer_cast<const RelFilter>(left_deep_join_root)) {
2425  filtered_left_deep_joins.push_back(left_deep_join_root.get());
2426  }
2427  }
2428  }
2429  if (filtered_left_deep_joins.empty()) {
2431  }
2432  eliminate_dead_columns(nodes_);
2434  if (g_cluster) {
2436  }
2437  coalesce_nodes(nodes_, left_deep_joins);
2438  CHECK(nodes_.back().unique());
2439  create_left_deep_join(nodes_);
2440 }
2441 
2443  std::function<void(RelAlgNode const*)> const& callback) const {
2444  for (auto const& node : nodes_) {
2445  if (node) {
2446  callback(node.get());
2447  }
2448  }
2449 }
2450 
2452  for (auto& node : nodes_) {
2453  if (node) {
2454  node->resetQueryExecutionState();
2455  }
2456  }
2457 }
2458 
2459 // Return tree with depth represented by indentations.
2460 std::string tree_string(const RelAlgNode* ra, const size_t depth) {
2461  std::string result = std::string(2 * depth, ' ') + ra->toString() + '\n';
2462  for (size_t i = 0; i < ra->inputCount(); ++i) {
2463  result += tree_string(ra->getInput(i), depth + 1);
2464  }
2465  return result;
2466 }
std::vector< std::shared_ptr< const RexScalar > > scalar_exprs_
SQLTypes to_sql_type(const std::string &type_name)
bool is_agg(const Analyzer::Expr *expr)
SQLOps getOperator() const
SQLTypeInfo parse_type(const rapidjson::Value &type_obj)
std::vector< std::string > strings_from_json_array(const rapidjson::Value &json_str_arr) noexcept
std::unique_ptr< RexCase > parse_case(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
std::shared_ptr< RelAggregate > dispatchAggregate(const rapidjson::Value &agg_ra)
#define CHECK_EQ(x, y)
Definition: Logger.h:205
const ConstRexScalarPtrVector & getPartitionKeys() const
JoinType to_join_type(const std::string &join_type_name)
const Catalog_Namespace::Catalog & cat_
std::shared_ptr< RelProject > dispatchProject(const rapidjson::Value &proj_ra, RelAlgDagBuilder &root_dag_builder)
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
JoinType
Definition: sqldefs.h:107
void setSourceNode(const RelAlgNode *node) const
void setIndex(const unsigned in_index) const
void bind_inputs(const std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
void hoist_filter_cond_to_cross_join(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
class for a per-database catalog. also includes metadata for the current database and the current use...
Definition: Catalog.h:86
Definition: sqltypes.h:50
std::shared_ptr< const RelAlgNode > get_left_deep_join_root(const std::shared_ptr< RelAlgNode > &node)
void sink_projected_boolean_expr_to_join(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
std::unique_ptr< const RexAgg > parse_aggregate_expr(const rapidjson::Value &expr)
void eliminate_identical_copy(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
const TableDescriptor * getMetadataForTable(const std::string &tableName, const bool populateFragmenter=true) const
Returns a pointer to a const TableDescriptor struct matching the provided tableName.
static thread_local unsigned crt_id_
std::shared_ptr< RelScan > dispatchTableScan(const rapidjson::Value &scan_ra)
std::pair< std::shared_ptr< RelLeftDeepInnerJoin >, std::shared_ptr< const RelAlgNode > > create_left_deep_join(const std::shared_ptr< RelAlgNode > &left_deep_join_root)
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
SQLAgg to_agg_kind(const std::string &agg_name)
std::shared_ptr< RelLogicalUnion > dispatchUnion(const rapidjson::Value &logical_union_ra)
#define LOG(tag)
Definition: Logger.h:188
NullSortedPosition
std::vector< std::string > TargetColumnList
bool g_enable_union
const bool json_bool(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:49
void eachNode(std::function< void(RelAlgNode const *)> const &) const
const RexScalar * getProjectAt(const size_t idx) const
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:44
size_t branchCount() const
void build(const rapidjson::Value &query_ast, RelAlgDagBuilder &root_dag_builder)
RexWindowFunctionOperator::RexWindowBound parse_window_bound(const rapidjson::Value &window_bound_obj, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
std::string join(T const &container, std::string const &delim)
void bind_table_func_to_input(RelTableFunction *table_func_node, const RANodeOutput &input) noexcept
std::unique_ptr< RexOperator > parse_operator(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
bool isIdentity() const
#define UNREACHABLE()
Definition: Logger.h:241
bool isRenamedInput(const RelAlgNode *node, const size_t index, const std::string &new_name)
#define CHECK_GE(x, y)
Definition: Logger.h:210
Definition: sqldefs.h:49
void coalesce_nodes(std::vector< std::shared_ptr< RelAlgNode >> &nodes, const std::vector< const RelAlgNode *> &left_deep_joins)
bool hasWindowFunctionExpr() const
std::shared_ptr< const RelAlgNode > getRootNodeShPtr() const
void appendInput(std::string new_field_name, std::unique_ptr< const RexScalar > new_input)
RetType visitOperator(const RexOperator *rex_operator) const final
std::unique_ptr< RexSubQuery > deepCopy() const
void bind_project_to_input(RelProject *project_node, const RANodeOutput &input) noexcept
const Catalog_Namespace::Catalog & cat_
std::vector< std::unique_ptr< const RexScalar > > parse_window_order_exprs(const rapidjson::Value &arr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
SqlWindowFunctionKind parse_window_function_kind(const std::string &name)
void simplify_sort(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
std::vector< SortField > collation_
RexRebindInputsVisitor(const RelAlgNode *old_input, const RelAlgNode *new_input)
bool hasEquivCollationOf(const RelSort &that) const
int64_t get_int_literal_field(const rapidjson::Value &obj, const char field[], const int64_t default_val) noexcept
const RexScalar * getWhen(const size_t idx) const
std::vector< std::unique_ptr< const RexScalar > > parse_expr_array(const rapidjson::Value &arr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
std::vector< const Rex * > reproject_targets(const RelProject *simple_project, const std::vector< const Rex *> &target_exprs) noexcept
This file contains the class specification and related data structures for Catalog.
bool isRenaming() const
RexInputSet visitInput(const RexInput *input) const override
RexWindowFuncReplacementVisitor(std::unique_ptr< const RexScalar > replacement_rex)
const RenderInfo * render_info_
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
const RexScalar * visitOperator(const RexOperator *rex_operator) const final
std::string cat(Ts &&... args)
std::vector< std::shared_ptr< RelAlgNode > > nodes_
const RexScalar * getThen(const size_t idx) const
const TableDescriptor * getTableFromScanNode(const Catalog_Namespace::Catalog &cat, const rapidjson::Value &scan_ra)
SortDirection parse_sort_direction(const rapidjson::Value &collation)
std::shared_ptr< RelAlgNode > deepCopy() const override
const RelAlgNode * getSourceNode() const
std::unique_ptr< const RexScalar > disambiguate_rex(const RexScalar *, const RANodeOutput &)
std::string getFieldName(const size_t i) const
std::vector< SortField > parse_window_order_collation(const rapidjson::Value &arr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
SQLOps to_sql_op(const std::string &op_str)
void add_window_function_pre_project(std::vector< std::shared_ptr< RelAlgNode >> &nodes)
void set_scale(int s)
Definition: sqltypes.h:353
const int64_t json_i64(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:39
void * visitInput(const RexInput *rex_input) const override
std::shared_ptr< RelAlgNode > deepCopy() const override
std::shared_ptr< RelAlgNode > deepCopy() const override
const double json_double(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:54
std::shared_ptr< RelAlgNode > deepCopy() const override
std::unordered_set< RexInput > RexInputSet
std::shared_ptr< RelAlgNode > deepCopy() const override
std::vector< std::shared_ptr< RelAlgNode > > nodes_
RexInputReplacementVisitor(const RelAlgNode *node_to_keep, const std::vector< std::unique_ptr< const RexScalar >> &scalar_sources)
RexRebindReindexInputsVisitor(const RelAlgNode *old_input, const RelAlgNode *new_input, std::unordered_map< unsigned, unsigned > old_to_new_index_map)
const RexScalar * getOperand(const size_t idx) const
virtual void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input)
NullSortedPosition parse_nulls_position(const rapidjson::Value &collation)
std::shared_ptr< RelFilter > dispatchFilter(const rapidjson::Value &filter_ra, RelAlgDagBuilder &root_dag_builder)
std::unique_ptr< RexAbstractInput > parse_abstract_input(const rapidjson::Value &expr) noexcept
std::set< std::pair< const RelAlgNode *, int > > get_equiv_cols(const RelAlgNode *node, const size_t which_col)
bool inputMetainfoTypesMatch() const
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
std::shared_ptr< RelAlgNode > deepCopy() const override
std::unique_ptr< RexLiteral > parse_literal(const rapidjson::Value &expr)
virtual T visit(const RexScalar *rex_scalar) const
Definition: RexVisitor.h:27
SortDirection
std::vector< std::shared_ptr< const RelAlgNode > > RelAlgInputs
#define CHECK_LT(x, y)
Definition: Logger.h:207
Definition: sqltypes.h:53
Definition: sqltypes.h:54
std::unique_ptr< const RexOperator > disambiguate_operator(const RexOperator *rex_operator, const RANodeOutput &ra_output) noexcept
std::string tree_string(const RelAlgNode *ra, const size_t depth)
#define CHECK_LE(x, y)
Definition: Logger.h:208
unsigned getIndex() const
std::vector< ElementType >::const_iterator Super
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
RelLogicalUnion(RelAlgInputs, bool is_all)
std::vector< std::unique_ptr< const RexScalar > > RowValues
std::unique_ptr< const RexCase > disambiguate_case(const RexCase *rex_case, const RANodeOutput &ra_output)
std::shared_ptr< RelJoin > dispatchJoin(const rapidjson::Value &join_ra, RelAlgDagBuilder &root_dag_builder)
const size_t inputCount() const
void separate_window_function_expressions(std::vector< std::shared_ptr< RelAlgNode >> &nodes)
const RexScalar * aggregateResult(const RexScalar *const &aggregate, const RexScalar *const &next_result) const final
std::shared_ptr< RelModify > dispatchModify(const rapidjson::Value &logical_modify_ra)
virtual size_t size() const =0
void registerSubquery(std::shared_ptr< RexSubQuery > subquery)
void setExecutionResult(const std::shared_ptr< const ExecutionResult > result)
RelAlgDagBuilder()=delete
SqlWindowFunctionKind
Definition: sqldefs.h:82
void mark_nops(const std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
RexInputSet aggregateResult(const RexInputSet &aggregate, const RexInputSet &next_result) const override
Definition: sqldefs.h:53
std::shared_ptr< RelSort > dispatchSort(const rapidjson::Value &sort_ra)
std::vector< size_t > indices_from_json_array(const rapidjson::Value &json_idx_arr) noexcept
virtual std::string toString() const =0
#define CHECK(condition)
Definition: Logger.h:197
std::unique_ptr< const RexScalar > parse_scalar_expr(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
unsigned node_id(const rapidjson::Value &ra_node) noexcept
RANodeOutput get_node_output(const RelAlgNode *ra_node)
size_t operator()(const std::pair< const RelAlgNode *, int > &input_col) const
std::vector< std::string > getFieldNamesFromScanNode(const rapidjson::Value &scan_ra)
void create_compound(std::vector< std::shared_ptr< RelAlgNode >> &nodes, const std::vector< size_t > &pattern) noexcept
const RelAlgNode * getInput(const size_t idx) const
void alterRAForRender(std::vector< std::shared_ptr< RelAlgNode >> &nodes, const RenderInfo &render_info)
const std::vector< std::string > & getFields() const
std::shared_ptr< const RelAlgNode > prev(const rapidjson::Value &crt_node)
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
std::vector< std::shared_ptr< RelAlgNode > > run(const rapidjson::Value &rels, RelAlgDagBuilder &root_dag_builder)
RelAlgDispatcher(const Catalog_Namespace::Catalog &cat)
RexScalar const * copyAndRedirectSource(RexScalar const *, size_t input_idx) const
bool g_cluster
void fold_filters(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
std::vector< RexInput > RANodeOutput
specifies the content in-memory of a row in the table metadata table
void rebind_inputs_from_left_deep_join(const RexScalar *rex, const RelLeftDeepInnerJoin *left_deep_join)
std::unique_ptr< const RexSubQuery > parse_subquery(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
size_t size() const override
bool input_can_be_coalesced(const RelAlgNode *parent_node, const size_t index, const bool first_rex_is_input)
std::shared_ptr< RelAlgNode > deepCopy() const override
RelAlgInputs getRelAlgInputs(const rapidjson::Value &node)
std::shared_ptr< RelLogicalValues > dispatchLogicalValues(const rapidjson::Value &logical_values_ra)
std::shared_ptr< RelTableFunction > dispatchTableFunction(const rapidjson::Value &table_func_ra, RelAlgDagBuilder &root_dag_builder)
std::unique_ptr< const RexScalar > RetType
Definition: RexVisitor.h:139
std::shared_ptr< RelAlgNode > deepCopy() const override
#define VLOG(n)
Definition: Logger.h:291
size_t size() const
RelAlgInputs inputs_
void set_precision(int d)
Definition: sqltypes.h:351
std::string toString() const override
void eliminate_dead_columns(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
std::shared_ptr< RelAlgNode > deepCopy() const override
void check_empty_inputs_field(const rapidjson::Value &node) noexcept
std::vector< RexInput > n_outputs(const RelAlgNode *node, const size_t n)
const RexScalar * visitCase(const RexCase *rex_case) const final
const RexScalar * getElse() const
static void resetRelAlgFirstId() noexcept
std::string json_node_to_string(const rapidjson::Value &node) noexcept