OmniSciDB  cde582ebc3
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
RelAlgDag.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RelAlgDag.h"
19 #include "Catalog/Catalog.h"
21 #include "JsonAccessors.h"
22 #include "RelAlgOptimizer.h"
23 #include "RelLeftDeepInnerJoin.h"
24 #include "RexVisitor.h"
25 #include "Shared/sqldefs.h"
26 
27 #include <rapidjson/error/en.h>
28 #include <rapidjson/error/error.h>
29 #include <rapidjson/stringbuffer.h>
30 #include <rapidjson/writer.h>
31 
32 #include <string>
33 #include <unordered_set>
34 
35 extern bool g_cluster;
36 extern bool g_enable_union;
37 
38 namespace {
39 
40 const unsigned FIRST_RA_NODE_ID = 1;
41 
42 } // namespace
43 
44 thread_local unsigned RelAlgNode::crt_id_ = FIRST_RA_NODE_ID;
45 
48 }
49 
51  const std::shared_ptr<const ExecutionResult> result) {
52  auto row_set = result->getRows();
53  CHECK(row_set);
54  CHECK_EQ(size_t(1), row_set->colCount());
55  *(type_.get()) = row_set->getColType(0);
56  (*(result_.get())) = result;
57 }
58 
59 std::unique_ptr<RexSubQuery> RexSubQuery::deepCopy() const {
60  return std::make_unique<RexSubQuery>(type_, result_, ra_->deepCopy());
61 }
62 
63 unsigned RexSubQuery::getId() const {
64  return ra_->getId();
65 }
66 
67 namespace {
68 
69 class RexRebindInputsVisitor : public RexVisitor<void*> {
70  public:
71  RexRebindInputsVisitor(const RelAlgNode* old_input, const RelAlgNode* new_input)
72  : old_input_(old_input), new_input_(new_input) {}
73 
74  virtual ~RexRebindInputsVisitor() = default;
75 
76  void* visitInput(const RexInput* rex_input) const override {
77  const auto old_source = rex_input->getSourceNode();
78  if (old_source == old_input_) {
79  const auto left_deep_join = dynamic_cast<const RelLeftDeepInnerJoin*>(new_input_);
80  if (left_deep_join) {
81  rebind_inputs_from_left_deep_join(rex_input, left_deep_join);
82  return nullptr;
83  }
84  rex_input->setSourceNode(new_input_);
85  }
86  return nullptr;
87  };
88 
89  private:
90  const RelAlgNode* old_input_;
92 };
93 
94 // Creates an output with n columns.
95 std::vector<RexInput> n_outputs(const RelAlgNode* node, const size_t n) {
96  std::vector<RexInput> outputs;
97  outputs.reserve(n);
98  for (size_t i = 0; i < n; ++i) {
99  outputs.emplace_back(node, i);
100  }
101  return outputs;
102 }
103 
105  public:
107  const RelAlgNode* old_input,
108  const RelAlgNode* new_input,
109  std::unordered_map<unsigned, unsigned> old_to_new_index_map)
110  : RexRebindInputsVisitor(old_input, new_input), mapping_(old_to_new_index_map) {}
111 
112  void* visitInput(const RexInput* rex_input) const override {
113  RexRebindInputsVisitor::visitInput(rex_input);
114  auto mapping_itr = mapping_.find(rex_input->getIndex());
115  CHECK(mapping_itr != mapping_.end());
116  rex_input->setIndex(mapping_itr->second);
117  return nullptr;
118  }
119 
120  private:
121  const std::unordered_map<unsigned, unsigned> mapping_;
122 };
123 
125  : public RexVisitorBase<std::unique_ptr<const RexScalar>> {
126  public:
127  enum class WindowExprType { PARTITION_KEY, ORDER_KEY };
129  std::shared_ptr<RelProject> new_project,
130  std::vector<std::unique_ptr<const RexScalar>>& scalar_exprs_for_new_project,
131  std::vector<std::string>& fields_for_new_project,
132  std::unordered_map<size_t, size_t>& expr_offset_cache)
133  : new_project_(new_project)
134  , scalar_exprs_for_new_project_(scalar_exprs_for_new_project)
135  , fields_for_new_project_(fields_for_new_project)
136  , expr_offset_cache_(expr_offset_cache)
137  , found_case_expr_window_operand_(false)
138  , has_partition_expr_(false) {}
139 
140  size_t pushDownExpressionImpl(const RexScalar* expr) const {
141  auto hash = expr->toHash();
142  auto it = expr_offset_cache_.find(hash);
143  auto new_offset = -1;
144  if (it == expr_offset_cache_.end()) {
145  CHECK(
146  expr_offset_cache_.emplace(hash, scalar_exprs_for_new_project_.size()).second);
147  new_offset = scalar_exprs_for_new_project_.size();
148  fields_for_new_project_.emplace_back("");
149  scalar_exprs_for_new_project_.emplace_back(deep_copier_.visit(expr));
150  } else {
151  // we already pushed down the same expression, so reuse it
152  new_offset = it->second;
153  }
154  return new_offset;
155  }
156 
158  size_t expr_offset) const {
159  // given window expr offset and inner expr's offset,
160  // return a (push-downed) expression's offset from the new projection node
161  switch (type) {
162  case WindowExprType::PARTITION_KEY: {
163  auto it = pushed_down_partition_key_offset_.find(expr_offset);
164  CHECK(it != pushed_down_partition_key_offset_.end());
165  return it->second;
166  }
167  case WindowExprType::ORDER_KEY: {
168  auto it = pushed_down_order_key_offset_.find(expr_offset);
169  CHECK(it != pushed_down_order_key_offset_.end());
170  return it->second;
171  }
172  default:
173  UNREACHABLE();
174  return std::nullopt;
175  }
176  }
177 
179  const RexWindowFunctionOperator* window_expr) const {
180  // step 1. push "all" target expressions of the window_func_project_node down to the
181  // new child projection
182  // each window expr is a separate target expression of the projection node
183  // and they have their own inner expression related to partition / order clauses
184  // so we capture their offsets to correctly rebind their input
185  pushed_down_window_operands_offset_.clear();
186  pushed_down_partition_key_offset_.clear();
187  pushed_down_order_key_offset_.clear();
188  for (size_t offset = 0; offset < window_expr->size(); ++offset) {
189  auto expr = window_expr->getOperand(offset);
190  auto literal_expr = dynamic_cast<const RexLiteral*>(expr);
191  auto case_expr = dynamic_cast<const RexCase*>(expr);
192  if (case_expr) {
193  // when columnar output is enabled, pushdown case expr can incur an issue
194  // during columnarization, so we record this and try to force rowwise-output
195  // until we fix the issue
196  // todo (yoonmin) : relax this
197  found_case_expr_window_operand_ = true;
198  }
199  if (!literal_expr) {
200  auto new_offset = pushDownExpressionImpl(expr);
201  pushed_down_window_operands_offset_.emplace(offset, new_offset);
202  }
203  }
204  size_t offset = 0;
205  for (const auto& partition_key : window_expr->getPartitionKeys()) {
206  auto new_offset = pushDownExpressionImpl(partition_key.get());
207  pushed_down_partition_key_offset_.emplace(offset, new_offset);
208  ++offset;
209  }
210  has_partition_expr_ = !window_expr->getPartitionKeys().empty();
211  offset = 0;
212  for (const auto& order_key : window_expr->getOrderKeys()) {
213  auto new_offset = pushDownExpressionImpl(order_key.get());
214  pushed_down_order_key_offset_.emplace(offset, new_offset);
215  ++offset;
216  }
217 
218  // step 2. rebind projected targets of the window_func_project_node with the new
219  // project node
220  std::vector<std::unique_ptr<const RexScalar>> window_operands;
221  auto deconst_window_expr = const_cast<RexWindowFunctionOperator*>(window_expr);
222  for (size_t idx = 0; idx < window_expr->size(); ++idx) {
223  auto it = pushed_down_window_operands_offset_.find(idx);
224  if (it != pushed_down_window_operands_offset_.end()) {
225  auto new_input = std::make_unique<const RexInput>(new_project_.get(), it->second);
226  CHECK(new_input);
227  window_operands.emplace_back(std::move(new_input));
228  } else {
229  auto copied_expr = deep_copier_.visit(window_expr->getOperand(idx));
230  window_operands.emplace_back(std::move(copied_expr));
231  }
232  }
233  deconst_window_expr->replaceOperands(std::move(window_operands));
234 
235  for (size_t idx = 0; idx < window_expr->getPartitionKeys().size(); ++idx) {
236  auto new_offset = getOffsetForPushedDownExpr(WindowExprType::PARTITION_KEY, idx);
237  CHECK(new_offset);
238  auto new_input = std::make_unique<const RexInput>(new_project_.get(), *new_offset);
239  CHECK(new_input);
240  deconst_window_expr->replacePartitionKey(idx, std::move(new_input));
241  }
242 
243  for (size_t idx = 0; idx < window_expr->getOrderKeys().size(); ++idx) {
244  auto new_offset = getOffsetForPushedDownExpr(WindowExprType::ORDER_KEY, idx);
245  CHECK(new_offset);
246  auto new_input = std::make_unique<const RexInput>(new_project_.get(), *new_offset);
247  CHECK(new_input);
248  deconst_window_expr->replaceOrderKey(idx, std::move(new_input));
249  }
250  }
251 
252  std::unique_ptr<const RexScalar> visitInput(const RexInput* rex_input) const override {
253  auto new_offset = pushDownExpressionImpl(rex_input);
254  CHECK_LT(new_offset, scalar_exprs_for_new_project_.size());
255  auto hash = rex_input->toHash();
256  auto it = expr_offset_cache_.find(hash);
257  CHECK(it != expr_offset_cache_.end());
258  CHECK_EQ(new_offset, it->second);
259  auto new_input = std::make_unique<const RexInput>(new_project_.get(), new_offset);
260  CHECK(new_input);
261  return new_input;
262  }
263 
264  std::unique_ptr<const RexScalar> visitLiteral(
265  const RexLiteral* rex_literal) const override {
266  return deep_copier_.visit(rex_literal);
267  }
268 
269  std::unique_ptr<const RexScalar> visitRef(const RexRef* rex_ref) const override {
270  return deep_copier_.visit(rex_ref);
271  }
272 
273  std::unique_ptr<const RexScalar> visitSubQuery(
274  const RexSubQuery* rex_subquery) const override {
275  return deep_copier_.visit(rex_subquery);
276  }
277 
278  std::unique_ptr<const RexScalar> visitCase(const RexCase* rex_case) const override {
279  std::vector<
280  std::pair<std::unique_ptr<const RexScalar>, std::unique_ptr<const RexScalar>>>
281  new_expr_pair_list;
282  std::unique_ptr<const RexScalar> new_else_expr;
283  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
284  const auto when = rex_case->getWhen(i);
285  auto new_when = PushDownGenericExpressionInWindowFunction::visit(when);
286  const auto then = rex_case->getThen(i);
287  auto new_then = PushDownGenericExpressionInWindowFunction::visit(then);
288  new_expr_pair_list.emplace_back(std::move(new_when), std::move(new_then));
289  }
290  if (rex_case->getElse()) {
291  new_else_expr = deep_copier_.visit(rex_case->getElse());
292  }
293  auto new_case = std::make_unique<const RexCase>(new_expr_pair_list, new_else_expr);
294  return new_case;
295  }
296 
297  std::unique_ptr<const RexScalar> visitOperator(
298  const RexOperator* rex_operator) const override {
299  const auto rex_window_func_operator =
300  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
301  if (rex_window_func_operator) {
302  pushDownExpressionInWindowFunction(rex_window_func_operator);
303  return deep_copier_.visit(rex_operator);
304  } else {
305  std::unique_ptr<const RexOperator> new_operator{nullptr};
306  std::vector<std::unique_ptr<const RexScalar>> new_operands;
307  for (size_t i = 0; i < rex_operator->size(); ++i) {
308  const auto operand = rex_operator->getOperand(i);
309  auto new_operand = PushDownGenericExpressionInWindowFunction::visit(operand);
310  new_operands.emplace_back(std::move(new_operand));
311  }
312  if (auto function_op = dynamic_cast<const RexFunctionOperator*>(rex_operator)) {
313  new_operator = std::make_unique<const RexFunctionOperator>(
314  function_op->getName(), new_operands, rex_operator->getType());
315  } else {
316  new_operator = std::make_unique<const RexOperator>(
317  rex_operator->getOperator(), new_operands, rex_operator->getType());
318  }
319  CHECK(new_operator);
320  return new_operator;
321  }
322  }
323 
324  bool hasCaseExprAsWindowOperand() { return found_case_expr_window_operand_; }
325 
326  bool hasPartitionExpression() { return has_partition_expr_; }
327 
328  private:
329  std::unique_ptr<const RexScalar> defaultResult() const override { return nullptr; }
330 
331  std::shared_ptr<RelProject> new_project_;
332  std::vector<std::unique_ptr<const RexScalar>>& scalar_exprs_for_new_project_;
333  std::vector<std::string>& fields_for_new_project_;
334  std::unordered_map<size_t, size_t>& expr_offset_cache_;
336  mutable bool has_partition_expr_;
337  mutable std::unordered_map<size_t, size_t> pushed_down_window_operands_offset_;
338  mutable std::unordered_map<size_t, size_t> pushed_down_partition_key_offset_;
339  mutable std::unordered_map<size_t, size_t> pushed_down_order_key_offset_;
341 };
342 
343 } // namespace
344 
346  std::shared_ptr<const RelAlgNode> old_input,
347  std::shared_ptr<const RelAlgNode> input,
348  std::optional<std::unordered_map<unsigned, unsigned>> old_to_new_index_map) {
349  RelAlgNode::replaceInput(old_input, input);
350  std::unique_ptr<RexRebindInputsVisitor> rebind_inputs;
351  if (old_to_new_index_map) {
352  rebind_inputs = std::make_unique<RexRebindReindexInputsVisitor>(
353  old_input.get(), input.get(), *old_to_new_index_map);
354  } else {
355  rebind_inputs =
356  std::make_unique<RexRebindInputsVisitor>(old_input.get(), input.get());
357  }
358  CHECK(rebind_inputs);
359  for (const auto& scalar_expr : scalar_exprs_) {
360  rebind_inputs->visit(scalar_expr.get());
361  }
362 }
363 
364 void RelProject::appendInput(std::string new_field_name,
365  std::unique_ptr<const RexScalar> new_input) {
366  fields_.emplace_back(std::move(new_field_name));
367  scalar_exprs_.emplace_back(std::move(new_input));
368 }
369 
371  const auto scan_node = dynamic_cast<const RelScan*>(ra_node);
372  if (scan_node) {
373  // Scan node has no inputs, output contains all columns in the table.
374  CHECK_EQ(size_t(0), scan_node->inputCount());
375  return n_outputs(scan_node, scan_node->size());
376  }
377  const auto project_node = dynamic_cast<const RelProject*>(ra_node);
378  if (project_node) {
379  // Project output count doesn't depend on the input
380  CHECK_EQ(size_t(1), project_node->inputCount());
381  return n_outputs(project_node, project_node->size());
382  }
383  const auto filter_node = dynamic_cast<const RelFilter*>(ra_node);
384  if (filter_node) {
385  // Filter preserves shape
386  CHECK_EQ(size_t(1), filter_node->inputCount());
387  const auto prev_out = get_node_output(filter_node->getInput(0));
388  return n_outputs(filter_node, prev_out.size());
389  }
390  const auto aggregate_node = dynamic_cast<const RelAggregate*>(ra_node);
391  if (aggregate_node) {
392  // Aggregate output count doesn't depend on the input
393  CHECK_EQ(size_t(1), aggregate_node->inputCount());
394  return n_outputs(aggregate_node, aggregate_node->size());
395  }
396  const auto compound_node = dynamic_cast<const RelCompound*>(ra_node);
397  if (compound_node) {
398  // Compound output count doesn't depend on the input
399  CHECK_EQ(size_t(1), compound_node->inputCount());
400  return n_outputs(compound_node, compound_node->size());
401  }
402  const auto join_node = dynamic_cast<const RelJoin*>(ra_node);
403  if (join_node) {
404  // Join concatenates the outputs from the inputs and the output
405  // directly references the nodes in the input.
406  CHECK_EQ(size_t(2), join_node->inputCount());
407  auto lhs_out =
408  n_outputs(join_node->getInput(0), get_node_output(join_node->getInput(0)).size());
409  const auto rhs_out =
410  n_outputs(join_node->getInput(1), get_node_output(join_node->getInput(1)).size());
411  lhs_out.insert(lhs_out.end(), rhs_out.begin(), rhs_out.end());
412  return lhs_out;
413  }
414  const auto table_func_node = dynamic_cast<const RelTableFunction*>(ra_node);
415  if (table_func_node) {
416  // Table Function output count doesn't depend on the input
417  return n_outputs(table_func_node, table_func_node->size());
418  }
419  const auto sort_node = dynamic_cast<const RelSort*>(ra_node);
420  if (sort_node) {
421  // Sort preserves shape
422  CHECK_EQ(size_t(1), sort_node->inputCount());
423  const auto prev_out = get_node_output(sort_node->getInput(0));
424  return n_outputs(sort_node, prev_out.size());
425  }
426  const auto logical_values_node = dynamic_cast<const RelLogicalValues*>(ra_node);
427  if (logical_values_node) {
428  CHECK_EQ(size_t(0), logical_values_node->inputCount());
429  return n_outputs(logical_values_node, logical_values_node->size());
430  }
431  const auto logical_union_node = dynamic_cast<const RelLogicalUnion*>(ra_node);
432  if (logical_union_node) {
433  return n_outputs(logical_union_node, logical_union_node->size());
434  }
435  LOG(FATAL) << "Unhandled ra_node type: " << ::toString(ra_node);
436  return {};
437 }
438 
440  if (!isSimple()) {
441  return false;
442  }
443  CHECK_EQ(size_t(1), inputCount());
444  const auto source = getInput(0);
445  if (dynamic_cast<const RelJoin*>(source)) {
446  return false;
447  }
448  const auto source_shape = get_node_output(source);
449  if (source_shape.size() != scalar_exprs_.size()) {
450  return false;
451  }
452  for (size_t i = 0; i < scalar_exprs_.size(); ++i) {
453  const auto& scalar_expr = scalar_exprs_[i];
454  const auto input = dynamic_cast<const RexInput*>(scalar_expr.get());
455  CHECK(input);
456  CHECK_EQ(source, input->getSourceNode());
457  // We should add the additional check that input->getIndex() !=
458  // source_shape[i].getIndex(), but Calcite doesn't generate the right
459  // Sort-Project-Sort sequence when joins are involved.
460  if (input->getSourceNode() != source_shape[i].getSourceNode()) {
461  return false;
462  }
463  }
464  return true;
465 }
466 
467 namespace {
468 
469 bool isRenamedInput(const RelAlgNode* node,
470  const size_t index,
471  const std::string& new_name) {
472  CHECK_LT(index, node->size());
473  if (auto join = dynamic_cast<const RelJoin*>(node)) {
474  CHECK_EQ(size_t(2), join->inputCount());
475  const auto lhs_size = join->getInput(0)->size();
476  if (index < lhs_size) {
477  return isRenamedInput(join->getInput(0), index, new_name);
478  }
479  CHECK_GE(index, lhs_size);
480  return isRenamedInput(join->getInput(1), index - lhs_size, new_name);
481  }
482 
483  if (auto scan = dynamic_cast<const RelScan*>(node)) {
484  return new_name != scan->getFieldName(index);
485  }
486 
487  if (auto aggregate = dynamic_cast<const RelAggregate*>(node)) {
488  return new_name != aggregate->getFieldName(index);
489  }
490 
491  if (auto project = dynamic_cast<const RelProject*>(node)) {
492  return new_name != project->getFieldName(index);
493  }
494 
495  if (auto table_func = dynamic_cast<const RelTableFunction*>(node)) {
496  return new_name != table_func->getFieldName(index);
497  }
498 
499  if (auto logical_values = dynamic_cast<const RelLogicalValues*>(node)) {
500  const auto& tuple_type = logical_values->getTupleType();
501  CHECK_LT(index, tuple_type.size());
502  return new_name != tuple_type[index].get_resname();
503  }
504 
505  CHECK(dynamic_cast<const RelSort*>(node) || dynamic_cast<const RelFilter*>(node) ||
506  dynamic_cast<const RelLogicalUnion*>(node));
507  return isRenamedInput(node->getInput(0), index, new_name);
508 }
509 
510 } // namespace
511 
513  if (!isSimple()) {
514  return false;
515  }
516  CHECK_EQ(scalar_exprs_.size(), fields_.size());
517  for (size_t i = 0; i < fields_.size(); ++i) {
518  auto rex_in = dynamic_cast<const RexInput*>(scalar_exprs_[i].get());
519  CHECK(rex_in);
520  if (isRenamedInput(rex_in->getSourceNode(), rex_in->getIndex(), fields_[i])) {
521  return true;
522  }
523  }
524  return false;
525 }
526 
527 void RelJoin::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
528  std::shared_ptr<const RelAlgNode> input) {
529  RelAlgNode::replaceInput(old_input, input);
530  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
531  if (condition_) {
532  rebind_inputs.visit(condition_.get());
533  }
534 }
535 
536 void RelFilter::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
537  std::shared_ptr<const RelAlgNode> input) {
538  RelAlgNode::replaceInput(old_input, input);
539  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
540  rebind_inputs.visit(filter_.get());
541 }
542 
543 void RelCompound::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
544  std::shared_ptr<const RelAlgNode> input) {
545  RelAlgNode::replaceInput(old_input, input);
546  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
547  for (const auto& scalar_source : scalar_sources_) {
548  rebind_inputs.visit(scalar_source.get());
549  }
550  if (filter_expr_) {
551  rebind_inputs.visit(filter_expr_.get());
552  }
553 }
554 
556  : RelAlgNode(rhs)
558  , fields_(rhs.fields_)
559  , hint_applied_(false)
560  , hints_(std::make_unique<Hints>()) {
561  RexDeepCopyVisitor copier;
562  for (auto const& expr : rhs.scalar_exprs_) {
563  scalar_exprs_.push_back(copier.visit(expr.get()));
564  }
565  if (rhs.hint_applied_) {
566  for (auto const& kv : *rhs.hints_) {
567  addHint(kv.second);
568  }
569  }
570 }
571 
573  : RelAlgNode(rhs)
574  , tuple_type_(rhs.tuple_type_)
575  , values_(RexDeepCopyVisitor::copy(rhs.values_)) {}
576 
578  RexDeepCopyVisitor copier;
579  filter_ = copier.visit(rhs.filter_.get());
580 }
581 
583  : RelAlgNode(rhs)
584  , groupby_count_(rhs.groupby_count_)
585  , fields_(rhs.fields_)
586  , hint_applied_(false)
587  , hints_(std::make_unique<Hints>()) {
588  agg_exprs_.reserve(rhs.agg_exprs_.size());
589  for (auto const& agg : rhs.agg_exprs_) {
590  agg_exprs_.push_back(agg->deepCopy());
591  }
592  if (rhs.hint_applied_) {
593  for (auto const& kv : *rhs.hints_) {
594  addHint(kv.second);
595  }
596  }
597 }
598 
600  : RelAlgNode(rhs)
601  , join_type_(rhs.join_type_)
602  , hint_applied_(false)
603  , hints_(std::make_unique<Hints>()) {
604  RexDeepCopyVisitor copier;
605  condition_ = copier.visit(rhs.condition_.get());
606  if (rhs.hint_applied_) {
607  for (auto const& kv : *rhs.hints_) {
608  addHint(kv.second);
609  }
610  }
611 }
612 
613 namespace {
614 
615 std::vector<std::unique_ptr<const RexAgg>> copyAggExprs(
616  std::vector<std::unique_ptr<const RexAgg>> const& agg_exprs) {
617  std::vector<std::unique_ptr<const RexAgg>> agg_exprs_copy;
618  agg_exprs_copy.reserve(agg_exprs.size());
619  for (auto const& agg_expr : agg_exprs) {
620  agg_exprs_copy.push_back(agg_expr->deepCopy());
621  }
622  return agg_exprs_copy;
623 }
624 
625 std::vector<std::unique_ptr<const RexScalar>> copyRexScalars(
626  std::vector<std::unique_ptr<const RexScalar>> const& scalar_sources) {
627  std::vector<std::unique_ptr<const RexScalar>> scalar_sources_copy;
628  scalar_sources_copy.reserve(scalar_sources.size());
629  RexDeepCopyVisitor copier;
630  for (auto const& scalar_source : scalar_sources) {
631  scalar_sources_copy.push_back(copier.visit(scalar_source.get()));
632  }
633  return scalar_sources_copy;
634 }
635 
636 std::vector<const Rex*> remapTargetPointers(
637  std::vector<std::unique_ptr<const RexAgg>> const& agg_exprs_new,
638  std::vector<std::unique_ptr<const RexScalar>> const& scalar_sources_new,
639  std::vector<std::unique_ptr<const RexAgg>> const& agg_exprs_old,
640  std::vector<std::unique_ptr<const RexScalar>> const& scalar_sources_old,
641  std::vector<const Rex*> const& target_exprs_old) {
642  std::vector<const Rex*> target_exprs(target_exprs_old);
643  std::unordered_map<const Rex*, const Rex*> old_to_new_target(target_exprs.size());
644  for (size_t i = 0; i < agg_exprs_new.size(); ++i) {
645  old_to_new_target.emplace(agg_exprs_old[i].get(), agg_exprs_new[i].get());
646  }
647  for (size_t i = 0; i < scalar_sources_new.size(); ++i) {
648  old_to_new_target.emplace(scalar_sources_old[i].get(), scalar_sources_new[i].get());
649  }
650  for (auto& target : target_exprs) {
651  auto target_it = old_to_new_target.find(target);
652  CHECK(target_it != old_to_new_target.end());
653  target = target_it->second;
654  }
655  return target_exprs;
656 }
657 
658 } // namespace
659 
661  : RelAlgNode(rhs)
663  , groupby_count_(rhs.groupby_count_)
664  , agg_exprs_(copyAggExprs(rhs.agg_exprs_))
665  , fields_(rhs.fields_)
666  , is_agg_(rhs.is_agg_)
667  , scalar_sources_(copyRexScalars(rhs.scalar_sources_))
668  , target_exprs_(remapTargetPointers(agg_exprs_,
669  scalar_sources_,
670  rhs.agg_exprs_,
671  rhs.scalar_sources_,
672  rhs.target_exprs_))
673  , hint_applied_(false)
674  , hints_(std::make_unique<Hints>()) {
675  RexDeepCopyVisitor copier;
676  filter_expr_ = rhs.filter_expr_ ? copier.visit(rhs.filter_expr_.get()) : nullptr;
677  if (rhs.hint_applied_) {
678  for (auto const& kv : *rhs.hints_) {
679  addHint(kv.second);
680  }
681  }
682 }
683 
684 void RelTableFunction::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
685  std::shared_ptr<const RelAlgNode> input) {
686  RelAlgNode::replaceInput(old_input, input);
687  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
688  for (const auto& target_expr : target_exprs_) {
689  rebind_inputs.visit(target_expr.get());
690  }
691  for (const auto& func_input : table_func_inputs_) {
692  rebind_inputs.visit(func_input.get());
693  }
694 }
695 
697  int32_t literal_args = 0;
698  for (const auto& arg : table_func_inputs_) {
699  const auto rex_literal = dynamic_cast<const RexLiteral*>(arg.get());
700  if (rex_literal) {
701  literal_args += 1;
702  }
703  }
704  return literal_args;
705 }
706 
708  : RelAlgNode(rhs)
709  , function_name_(rhs.function_name_)
710  , fields_(rhs.fields_)
711  , col_inputs_(rhs.col_inputs_)
712  , table_func_inputs_(copyRexScalars(rhs.table_func_inputs_))
713  , target_exprs_(copyRexScalars(rhs.target_exprs_)) {
714  std::unordered_map<const Rex*, const Rex*> old_to_new_input;
715  for (size_t i = 0; i < table_func_inputs_.size(); ++i) {
716  old_to_new_input.emplace(rhs.table_func_inputs_[i].get(),
717  table_func_inputs_[i].get());
718  }
719  for (auto& target : col_inputs_) {
720  auto target_it = old_to_new_input.find(target);
721  CHECK(target_it != old_to_new_input.end());
722  target = target_it->second;
723  }
724 }
725 
726 namespace std {
727 template <>
728 struct hash<std::pair<const RelAlgNode*, int>> {
729  size_t operator()(const std::pair<const RelAlgNode*, int>& input_col) const {
730  auto ptr_val = reinterpret_cast<const int64_t*>(&input_col.first);
731  auto h = static_cast<size_t>(*ptr_val);
732  boost::hash_combine(h, input_col.second);
733  return h;
734  }
735 };
736 } // namespace std
737 
738 namespace {
739 
740 std::set<std::pair<const RelAlgNode*, int>> get_equiv_cols(const RelAlgNode* node,
741  const size_t which_col) {
742  std::set<std::pair<const RelAlgNode*, int>> work_set;
743  auto walker = node;
744  auto curr_col = which_col;
745  while (true) {
746  work_set.insert(std::make_pair(walker, curr_col));
747  if (dynamic_cast<const RelScan*>(walker) || dynamic_cast<const RelJoin*>(walker)) {
748  break;
749  }
750  CHECK_EQ(size_t(1), walker->inputCount());
751  auto only_source = walker->getInput(0);
752  if (auto project = dynamic_cast<const RelProject*>(walker)) {
753  if (auto input = dynamic_cast<const RexInput*>(project->getProjectAt(curr_col))) {
754  const auto join_source = dynamic_cast<const RelJoin*>(only_source);
755  if (join_source) {
756  CHECK_EQ(size_t(2), join_source->inputCount());
757  auto lhs = join_source->getInput(0);
758  CHECK((input->getIndex() < lhs->size() && lhs == input->getSourceNode()) ||
759  join_source->getInput(1) == input->getSourceNode());
760  } else {
761  CHECK_EQ(input->getSourceNode(), only_source);
762  }
763  curr_col = input->getIndex();
764  } else {
765  break;
766  }
767  } else if (auto aggregate = dynamic_cast<const RelAggregate*>(walker)) {
768  if (curr_col >= aggregate->getGroupByCount()) {
769  break;
770  }
771  }
772  walker = only_source;
773  }
774  return work_set;
775 }
776 
777 } // namespace
778 
779 bool RelSort::hasEquivCollationOf(const RelSort& that) const {
780  if (collation_.size() != that.collation_.size()) {
781  return false;
782  }
783 
784  for (size_t i = 0, e = collation_.size(); i < e; ++i) {
785  auto this_sort_key = collation_[i];
786  auto that_sort_key = that.collation_[i];
787  if (this_sort_key.getSortDir() != that_sort_key.getSortDir()) {
788  return false;
789  }
790  if (this_sort_key.getNullsPosition() != that_sort_key.getNullsPosition()) {
791  return false;
792  }
793  auto this_equiv_keys = get_equiv_cols(this, this_sort_key.getField());
794  auto that_equiv_keys = get_equiv_cols(&that, that_sort_key.getField());
795  std::vector<std::pair<const RelAlgNode*, int>> intersect;
796  std::set_intersection(this_equiv_keys.begin(),
797  this_equiv_keys.end(),
798  that_equiv_keys.begin(),
799  that_equiv_keys.end(),
800  std::back_inserter(intersect));
801  if (intersect.empty()) {
802  return false;
803  }
804  }
805  return true;
806 }
807 
808 // class RelLogicalUnion methods
809 
811  : RelAlgNode(std::move(inputs)), is_all_(is_all) {
812  if (!g_enable_union) {
813  throw QueryNotSupported(
814  "The DEPRECATED enable-union option is set to off. Please remove this option as "
815  "it may be disabled in the future.");
816  }
817  CHECK_LE(2u, inputs_.size());
818  if (!is_all_) {
819  throw QueryNotSupported("UNION without ALL is not supported yet.");
820  }
821 }
822 
823 size_t RelLogicalUnion::size() const {
824  return inputs_.front()->size();
825 }
826 
828  return cat(::typeName(this), "(is_all(", is_all_, "))");
829 }
830 
831 size_t RelLogicalUnion::toHash() const {
832  if (!hash_) {
833  hash_ = typeid(RelLogicalUnion).hash_code();
834  boost::hash_combine(*hash_, is_all_);
835  }
836  return *hash_;
837 }
838 
839 std::string RelLogicalUnion::getFieldName(const size_t i) const {
840  if (auto const* input = dynamic_cast<RelCompound const*>(inputs_[0].get())) {
841  return input->getFieldName(i);
842  } else if (auto const* input = dynamic_cast<RelProject const*>(inputs_[0].get())) {
843  return input->getFieldName(i);
844  } else if (auto const* input = dynamic_cast<RelLogicalUnion const*>(inputs_[0].get())) {
845  return input->getFieldName(i);
846  } else if (auto const* input = dynamic_cast<RelAggregate const*>(inputs_[0].get())) {
847  return input->getFieldName(i);
848  } else if (auto const* input = dynamic_cast<RelScan const*>(inputs_[0].get())) {
849  return input->getFieldName(i);
850  } else if (auto const* input =
851  dynamic_cast<RelTableFunction const*>(inputs_[0].get())) {
852  return input->getFieldName(i);
853  }
854  UNREACHABLE() << "Unhandled input type: " << ::toString(inputs_.front());
855  return {};
856 }
857 
859  std::vector<TargetMetaInfo> const& tmis0 = inputs_[0]->getOutputMetainfo();
860  for (size_t i = 1; i < inputs_.size(); ++i) {
861  std::vector<TargetMetaInfo> const& tmisi = inputs_[i]->getOutputMetainfo();
862  if (tmis0.size() != tmisi.size()) {
863  LOG(INFO) << "tmis0.size()=" << tmis0.size() << " != " << tmisi.size()
864  << "=tmisi.size() for i=" << i;
865  throw std::runtime_error("Subqueries of a UNION must have matching data types.");
866  }
867  for (size_t j = 0; j < tmis0.size(); ++j) {
868  if (tmis0[j].get_type_info() != tmisi[j].get_type_info()) {
869  SQLTypeInfo const& ti0 = tmis0[j].get_type_info();
870  SQLTypeInfo const& ti1 = tmisi[j].get_type_info();
871  LOG(INFO) << "Types do not match for UNION:\n tmis0[" << j
872  << "].get_type_info().to_string() = " << ti0.to_string() << "\n tmis"
873  << i << '[' << j
874  << "].get_type_info().to_string() = " << ti1.to_string();
875  // The only permitted difference is when both columns are dictionary-encoded.
876  if (!(ti0.is_dict_encoded_string() && ti1.is_dict_encoded_string())) {
877  throw std::runtime_error(
878  "Subqueries of a UNION must have the exact same data types.");
879  }
880  }
881  }
882  }
883 }
884 
885 // Rest of code requires a raw pointer, but RexInput object needs to live somewhere.
887  size_t input_idx) const {
888  if (auto const* rex_input_ptr = dynamic_cast<RexInput const*>(rex_scalar)) {
889  RexInput rex_input(*rex_input_ptr);
890  rex_input.setSourceNode(getInput(input_idx));
891  scalar_exprs_.emplace_back(std::make_shared<RexInput const>(std::move(rex_input)));
892  return scalar_exprs_.back().get();
893  }
894  return rex_scalar;
895 }
896 
897 namespace {
898 
899 unsigned node_id(const rapidjson::Value& ra_node) noexcept {
900  const auto& id = field(ra_node, "id");
901  return std::stoi(json_str(id));
902 }
903 
904 std::string json_node_to_string(const rapidjson::Value& node) noexcept {
905  rapidjson::StringBuffer buffer;
906  rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
907  node.Accept(writer);
908  return buffer.GetString();
909 }
910 
911 // The parse_* functions below de-serialize expressions as they come from Calcite.
912 // RelAlgDagBuilder will take care of making the representation easy to
913 // navigate for lower layers, for example by replacing RexAbstractInput with RexInput.
914 
915 std::unique_ptr<RexAbstractInput> parse_abstract_input(
916  const rapidjson::Value& expr) noexcept {
917  const auto& input = field(expr, "input");
918  return std::unique_ptr<RexAbstractInput>(new RexAbstractInput(json_i64(input)));
919 }
920 
921 std::unique_ptr<RexLiteral> parse_literal(const rapidjson::Value& expr) {
922  CHECK(expr.IsObject());
923  const auto& literal = field(expr, "literal");
924  const auto type = to_sql_type(json_str(field(expr, "type")));
925  const auto target_type = to_sql_type(json_str(field(expr, "target_type")));
926  const auto scale = json_i64(field(expr, "scale"));
927  const auto precision = json_i64(field(expr, "precision"));
928  const auto type_scale = json_i64(field(expr, "type_scale"));
929  const auto type_precision = json_i64(field(expr, "type_precision"));
930  if (literal.IsNull()) {
931  return std::unique_ptr<RexLiteral>(new RexLiteral(target_type));
932  }
933  switch (type) {
934  case kINT:
935  case kBIGINT:
936  case kDECIMAL:
937  case kINTERVAL_DAY_TIME:
939  case kTIME:
940  case kTIMESTAMP:
941  case kDATE:
942  return std::unique_ptr<RexLiteral>(new RexLiteral(json_i64(literal),
943  type,
944  target_type,
945  scale,
946  precision,
947  type_scale,
948  type_precision));
949  case kDOUBLE: {
950  if (literal.IsDouble()) {
951  return std::unique_ptr<RexLiteral>(new RexLiteral(json_double(literal),
952  type,
953  target_type,
954  scale,
955  precision,
956  type_scale,
957  type_precision));
958  } else if (literal.IsInt64()) {
959  return std::make_unique<RexLiteral>(static_cast<double>(literal.GetInt64()),
960  type,
961  target_type,
962  scale,
963  precision,
964  type_scale,
965  type_precision);
966 
967  } else if (literal.IsUint64()) {
968  return std::make_unique<RexLiteral>(static_cast<double>(literal.GetUint64()),
969  type,
970  target_type,
971  scale,
972  precision,
973  type_scale,
974  type_precision);
975  }
976  UNREACHABLE() << "Unhandled type: " << literal.GetType();
977  }
978  case kTEXT:
979  return std::unique_ptr<RexLiteral>(new RexLiteral(json_str(literal),
980  type,
981  target_type,
982  scale,
983  precision,
984  type_scale,
985  type_precision));
986  case kBOOLEAN:
987  return std::unique_ptr<RexLiteral>(new RexLiteral(json_bool(literal),
988  type,
989  target_type,
990  scale,
991  precision,
992  type_scale,
993  type_precision));
994  case kNULLT:
995  return std::unique_ptr<RexLiteral>(new RexLiteral(target_type));
996  default:
997  CHECK(false);
998  }
999  CHECK(false);
1000  return nullptr;
1001 }
1002 
1003 std::unique_ptr<const RexScalar> parse_scalar_expr(const rapidjson::Value& expr,
1005  RelAlgDag& root_dag);
1006 
1007 SQLTypeInfo parse_type(const rapidjson::Value& type_obj) {
1008  if (type_obj.IsArray()) {
1009  throw QueryNotSupported("Composite types are not currently supported.");
1010  }
1011  CHECK(type_obj.IsObject() && type_obj.MemberCount() >= 2)
1012  << json_node_to_string(type_obj);
1013  const auto type = to_sql_type(json_str(field(type_obj, "type")));
1014  const auto nullable = json_bool(field(type_obj, "nullable"));
1015  const auto precision_it = type_obj.FindMember("precision");
1016  const int precision =
1017  precision_it != type_obj.MemberEnd() ? json_i64(precision_it->value) : 0;
1018  const auto scale_it = type_obj.FindMember("scale");
1019  const int scale = scale_it != type_obj.MemberEnd() ? json_i64(scale_it->value) : 0;
1020  SQLTypeInfo ti(type, !nullable);
1021  ti.set_precision(precision);
1022  ti.set_scale(scale);
1023  return ti;
1024 }
1025 
1026 std::vector<std::unique_ptr<const RexScalar>> parse_expr_array(
1027  const rapidjson::Value& arr,
1029  RelAlgDag& root_dag) {
1030  std::vector<std::unique_ptr<const RexScalar>> exprs;
1031  for (auto it = arr.Begin(); it != arr.End(); ++it) {
1032  exprs.emplace_back(parse_scalar_expr(*it, cat, root_dag));
1033  }
1034  return exprs;
1035 }
1036 
1038  if (name == "ROW_NUMBER") {
1040  }
1041  if (name == "RANK") {
1043  }
1044  if (name == "DENSE_RANK") {
1046  }
1047  if (name == "PERCENT_RANK") {
1049  }
1050  if (name == "CUME_DIST") {
1052  }
1053  if (name == "NTILE") {
1055  }
1056  if (name == "LAG") {
1058  }
1059  if (name == "LEAD") {
1061  }
1062  if (name == "FIRST_VALUE") {
1064  }
1065  if (name == "LAST_VALUE") {
1067  }
1068  if (name == "AVG") {
1070  }
1071  if (name == "MIN") {
1073  }
1074  if (name == "MAX") {
1076  }
1077  if (name == "SUM") {
1079  }
1080  if (name == "COUNT") {
1082  }
1083  if (name == "$SUM0") {
1085  }
1086  throw std::runtime_error("Unsupported window function: " + name);
1087 }
1088 
1089 std::vector<std::unique_ptr<const RexScalar>> parse_window_order_exprs(
1090  const rapidjson::Value& arr,
1092  RelAlgDag& root_dag) {
1093  std::vector<std::unique_ptr<const RexScalar>> exprs;
1094  for (auto it = arr.Begin(); it != arr.End(); ++it) {
1095  exprs.emplace_back(parse_scalar_expr(field(*it, "field"), cat, root_dag));
1096  }
1097  return exprs;
1098 }
1099 
1100 SortDirection parse_sort_direction(const rapidjson::Value& collation) {
1101  return json_str(field(collation, "direction")) == std::string("DESCENDING")
1104 }
1105 
1106 NullSortedPosition parse_nulls_position(const rapidjson::Value& collation) {
1107  return json_str(field(collation, "nulls")) == std::string("FIRST")
1110 }
1111 
1112 std::vector<SortField> parse_window_order_collation(const rapidjson::Value& arr,
1114  RelAlgDag& root_dag) {
1115  std::vector<SortField> collation;
1116  size_t field_idx = 0;
1117  for (auto it = arr.Begin(); it != arr.End(); ++it, ++field_idx) {
1118  const auto sort_dir = parse_sort_direction(*it);
1119  const auto null_pos = parse_nulls_position(*it);
1120  collation.emplace_back(field_idx, sort_dir, null_pos);
1121  }
1122  return collation;
1123 }
1124 
1126  const rapidjson::Value& window_bound_obj,
1128  RelAlgDag& root_dag) {
1129  CHECK(window_bound_obj.IsObject());
1131  window_bound.unbounded = json_bool(field(window_bound_obj, "unbounded"));
1132  window_bound.preceding = json_bool(field(window_bound_obj, "preceding"));
1133  window_bound.following = json_bool(field(window_bound_obj, "following"));
1134  window_bound.is_current_row = json_bool(field(window_bound_obj, "is_current_row"));
1135  const auto& offset_field = field(window_bound_obj, "offset");
1136  if (offset_field.IsObject()) {
1137  window_bound.bound_expr = parse_scalar_expr(offset_field, cat, root_dag);
1138  } else {
1139  CHECK(offset_field.IsNull());
1140  }
1141  window_bound.order_key = json_i64(field(window_bound_obj, "order_key"));
1142  return window_bound;
1143 }
1144 
1145 std::unique_ptr<const RexSubQuery> parse_subquery(const rapidjson::Value& expr,
1147  RelAlgDag& root_dag) {
1148  const auto& operands = field(expr, "operands");
1149  CHECK(operands.IsArray());
1150  CHECK_GE(operands.Size(), unsigned(0));
1151  const auto& subquery_ast = field(expr, "subquery");
1152 
1153  auto subquery_dag = RelAlgDagBuilder::buildDagForSubquery(root_dag, subquery_ast, cat);
1154  auto subquery = std::make_shared<RexSubQuery>(subquery_dag->getRootNodeShPtr());
1155  root_dag.registerSubquery(subquery);
1156  return subquery->deepCopy();
1157 }
1158 
1159 std::unique_ptr<RexOperator> parse_operator(const rapidjson::Value& expr,
1161  RelAlgDag& root_dag) {
1162  const auto op_name = json_str(field(expr, "op"));
1163  const bool is_quantifier =
1164  op_name == std::string("PG_ANY") || op_name == std::string("PG_ALL");
1165  const auto op = is_quantifier ? kFUNCTION : to_sql_op(op_name);
1166  const auto& operators_json_arr = field(expr, "operands");
1167  CHECK(operators_json_arr.IsArray());
1168  auto operands = parse_expr_array(operators_json_arr, cat, root_dag);
1169  const auto type_it = expr.FindMember("type");
1170  CHECK(type_it != expr.MemberEnd());
1171  auto ti = parse_type(type_it->value);
1172  if (op == kIN && expr.HasMember("subquery")) {
1173  auto subquery = parse_subquery(expr, cat, root_dag);
1174  operands.emplace_back(std::move(subquery));
1175  }
1176  if (expr.FindMember("partition_keys") != expr.MemberEnd()) {
1177  const auto& partition_keys_arr = field(expr, "partition_keys");
1178  auto partition_keys = parse_expr_array(partition_keys_arr, cat, root_dag);
1179  const auto& order_keys_arr = field(expr, "order_keys");
1180  auto order_keys = parse_window_order_exprs(order_keys_arr, cat, root_dag);
1181  const auto collation = parse_window_order_collation(order_keys_arr, cat, root_dag);
1182  const auto kind = parse_window_function_kind(op_name);
1183  const auto lower_bound =
1184  parse_window_bound(field(expr, "lower_bound"), cat, root_dag);
1185  const auto upper_bound =
1186  parse_window_bound(field(expr, "upper_bound"), cat, root_dag);
1187  bool is_rows = json_bool(field(expr, "is_rows"));
1188  ti.set_notnull(false);
1189  return std::make_unique<RexWindowFunctionOperator>(kind,
1190  operands,
1191  partition_keys,
1192  order_keys,
1193  collation,
1194  lower_bound,
1195  upper_bound,
1196  is_rows,
1197  ti);
1198  }
1199  return std::unique_ptr<RexOperator>(op == kFUNCTION
1200  ? new RexFunctionOperator(op_name, operands, ti)
1201  : new RexOperator(op, operands, ti));
1202 }
1203 
1204 std::unique_ptr<RexCase> parse_case(const rapidjson::Value& expr,
1206  RelAlgDag& root_dag) {
1207  const auto& operands = field(expr, "operands");
1208  CHECK(operands.IsArray());
1209  CHECK_GE(operands.Size(), unsigned(2));
1210  std::unique_ptr<const RexScalar> else_expr;
1211  std::vector<
1212  std::pair<std::unique_ptr<const RexScalar>, std::unique_ptr<const RexScalar>>>
1213  expr_pair_list;
1214  for (auto operands_it = operands.Begin(); operands_it != operands.End();) {
1215  auto when_expr = parse_scalar_expr(*operands_it++, cat, root_dag);
1216  if (operands_it == operands.End()) {
1217  else_expr = std::move(when_expr);
1218  break;
1219  }
1220  auto then_expr = parse_scalar_expr(*operands_it++, cat, root_dag);
1221  expr_pair_list.emplace_back(std::move(when_expr), std::move(then_expr));
1222  }
1223  return std::unique_ptr<RexCase>(new RexCase(expr_pair_list, else_expr));
1224 }
1225 
1226 std::vector<std::string> strings_from_json_array(
1227  const rapidjson::Value& json_str_arr) noexcept {
1228  CHECK(json_str_arr.IsArray());
1229  std::vector<std::string> fields;
1230  for (auto json_str_arr_it = json_str_arr.Begin(); json_str_arr_it != json_str_arr.End();
1231  ++json_str_arr_it) {
1232  CHECK(json_str_arr_it->IsString());
1233  fields.emplace_back(json_str_arr_it->GetString());
1234  }
1235  return fields;
1236 }
1237 
1238 std::vector<size_t> indices_from_json_array(
1239  const rapidjson::Value& json_idx_arr) noexcept {
1240  CHECK(json_idx_arr.IsArray());
1241  std::vector<size_t> indices;
1242  for (auto json_idx_arr_it = json_idx_arr.Begin(); json_idx_arr_it != json_idx_arr.End();
1243  ++json_idx_arr_it) {
1244  CHECK(json_idx_arr_it->IsInt());
1245  CHECK_GE(json_idx_arr_it->GetInt(), 0);
1246  indices.emplace_back(json_idx_arr_it->GetInt());
1247  }
1248  return indices;
1249 }
1250 
1251 std::unique_ptr<const RexAgg> parse_aggregate_expr(const rapidjson::Value& expr) {
1252  const auto agg_str = json_str(field(expr, "agg"));
1253  if (agg_str == "APPROX_QUANTILE") {
1254  LOG(INFO) << "APPROX_QUANTILE is deprecated. Please use APPROX_PERCENTILE instead.";
1255  }
1256  const auto agg = to_agg_kind(agg_str);
1257  const auto distinct = json_bool(field(expr, "distinct"));
1258  const auto agg_ti = parse_type(field(expr, "type"));
1259  const auto operands = indices_from_json_array(field(expr, "operands"));
1260  if (operands.size() > 1 && (operands.size() != 2 || (agg != kAPPROX_COUNT_DISTINCT &&
1261  agg != kAPPROX_QUANTILE))) {
1262  throw QueryNotSupported("Multiple arguments for aggregates aren't supported");
1263  }
1264  return std::unique_ptr<const RexAgg>(new RexAgg(agg, distinct, agg_ti, operands));
1265 }
1266 
1267 std::unique_ptr<const RexScalar> parse_scalar_expr(const rapidjson::Value& expr,
1269  RelAlgDag& root_dag) {
1270  CHECK(expr.IsObject());
1271  if (expr.IsObject() && expr.HasMember("input")) {
1272  return std::unique_ptr<const RexScalar>(parse_abstract_input(expr));
1273  }
1274  if (expr.IsObject() && expr.HasMember("literal")) {
1275  return std::unique_ptr<const RexScalar>(parse_literal(expr));
1276  }
1277  if (expr.IsObject() && expr.HasMember("op")) {
1278  const auto op_str = json_str(field(expr, "op"));
1279  if (op_str == std::string("CASE")) {
1280  return std::unique_ptr<const RexScalar>(parse_case(expr, cat, root_dag));
1281  }
1282  if (op_str == std::string("$SCALAR_QUERY")) {
1283  return std::unique_ptr<const RexScalar>(parse_subquery(expr, cat, root_dag));
1284  }
1285  return std::unique_ptr<const RexScalar>(parse_operator(expr, cat, root_dag));
1286  }
1287  throw QueryNotSupported("Expression node " + json_node_to_string(expr) +
1288  " not supported");
1289 }
1290 
1291 JoinType to_join_type(const std::string& join_type_name) {
1292  if (join_type_name == "inner") {
1293  return JoinType::INNER;
1294  }
1295  if (join_type_name == "left") {
1296  return JoinType::LEFT;
1297  }
1298  if (join_type_name == "semi") {
1299  return JoinType::SEMI;
1300  }
1301  if (join_type_name == "anti") {
1302  return JoinType::ANTI;
1303  }
1304  throw QueryNotSupported("Join type (" + join_type_name + ") not supported");
1305 }
1306 
1307 std::unique_ptr<const RexScalar> disambiguate_rex(const RexScalar*, const RANodeOutput&);
1308 
1309 std::unique_ptr<const RexOperator> disambiguate_operator(
1310  const RexOperator* rex_operator,
1311  const RANodeOutput& ra_output) noexcept {
1312  std::vector<std::unique_ptr<const RexScalar>> disambiguated_operands;
1313  for (size_t i = 0; i < rex_operator->size(); ++i) {
1314  auto operand = rex_operator->getOperand(i);
1315  if (dynamic_cast<const RexSubQuery*>(operand)) {
1316  disambiguated_operands.emplace_back(rex_operator->getOperandAndRelease(i));
1317  } else {
1318  disambiguated_operands.emplace_back(disambiguate_rex(operand, ra_output));
1319  }
1320  }
1321  const auto rex_window_function_operator =
1322  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
1323  if (rex_window_function_operator) {
1324  const auto& partition_keys = rex_window_function_operator->getPartitionKeys();
1325  std::vector<std::unique_ptr<const RexScalar>> disambiguated_partition_keys;
1326  for (const auto& partition_key : partition_keys) {
1327  disambiguated_partition_keys.emplace_back(
1328  disambiguate_rex(partition_key.get(), ra_output));
1329  }
1330  std::vector<std::unique_ptr<const RexScalar>> disambiguated_order_keys;
1331  const auto& order_keys = rex_window_function_operator->getOrderKeys();
1332  for (const auto& order_key : order_keys) {
1333  disambiguated_order_keys.emplace_back(disambiguate_rex(order_key.get(), ra_output));
1334  }
1335  return rex_window_function_operator->disambiguatedOperands(
1336  disambiguated_operands,
1337  disambiguated_partition_keys,
1338  disambiguated_order_keys,
1339  rex_window_function_operator->getCollation());
1340  }
1341  return rex_operator->getDisambiguated(disambiguated_operands);
1342 }
1343 
1344 std::unique_ptr<const RexCase> disambiguate_case(const RexCase* rex_case,
1345  const RANodeOutput& ra_output) {
1346  std::vector<
1347  std::pair<std::unique_ptr<const RexScalar>, std::unique_ptr<const RexScalar>>>
1348  disambiguated_expr_pair_list;
1349  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
1350  auto disambiguated_when = disambiguate_rex(rex_case->getWhen(i), ra_output);
1351  auto disambiguated_then = disambiguate_rex(rex_case->getThen(i), ra_output);
1352  disambiguated_expr_pair_list.emplace_back(std::move(disambiguated_when),
1353  std::move(disambiguated_then));
1354  }
1355  std::unique_ptr<const RexScalar> disambiguated_else{
1356  disambiguate_rex(rex_case->getElse(), ra_output)};
1357  return std::unique_ptr<const RexCase>(
1358  new RexCase(disambiguated_expr_pair_list, disambiguated_else));
1359 }
1360 
1361 // The inputs used by scalar expressions are given as indices in the serialized
1362 // representation of the query. This is hard to navigate; make the relationship
1363 // explicit by creating RexInput expressions which hold a pointer to the source
1364 // relational algebra node and the index relative to the output of that node.
1365 std::unique_ptr<const RexScalar> disambiguate_rex(const RexScalar* rex_scalar,
1366  const RANodeOutput& ra_output) {
1367  const auto rex_abstract_input = dynamic_cast<const RexAbstractInput*>(rex_scalar);
1368  if (rex_abstract_input) {
1369  CHECK_LT(static_cast<size_t>(rex_abstract_input->getIndex()), ra_output.size());
1370  return std::unique_ptr<const RexInput>(
1371  new RexInput(ra_output[rex_abstract_input->getIndex()]));
1372  }
1373  const auto rex_operator = dynamic_cast<const RexOperator*>(rex_scalar);
1374  if (rex_operator) {
1375  return disambiguate_operator(rex_operator, ra_output);
1376  }
1377  const auto rex_case = dynamic_cast<const RexCase*>(rex_scalar);
1378  if (rex_case) {
1379  return disambiguate_case(rex_case, ra_output);
1380  }
1381  if (auto const rex_literal = dynamic_cast<const RexLiteral*>(rex_scalar)) {
1382  return rex_literal->deepCopy();
1383  } else if (auto const rex_subquery = dynamic_cast<const RexSubQuery*>(rex_scalar)) {
1384  return rex_subquery->deepCopy();
1385  } else {
1386  throw QueryNotSupported("Unable to disambiguate expression of type " +
1387  std::string(typeid(*rex_scalar).name()));
1388  }
1389 }
1390 
1391 void bind_project_to_input(RelProject* project_node, const RANodeOutput& input) noexcept {
1392  CHECK_EQ(size_t(1), project_node->inputCount());
1393  std::vector<std::unique_ptr<const RexScalar>> disambiguated_exprs;
1394  for (size_t i = 0; i < project_node->size(); ++i) {
1395  const auto projected_expr = project_node->getProjectAt(i);
1396  if (dynamic_cast<const RexSubQuery*>(projected_expr)) {
1397  disambiguated_exprs.emplace_back(project_node->getProjectAtAndRelease(i));
1398  } else {
1399  disambiguated_exprs.emplace_back(disambiguate_rex(projected_expr, input));
1400  }
1401  }
1402  project_node->setExpressions(disambiguated_exprs);
1403 }
1404 
1406  const RANodeOutput& input) noexcept {
1407  std::vector<std::unique_ptr<const RexScalar>> disambiguated_exprs;
1408  for (size_t i = 0; i < table_func_node->getTableFuncInputsSize(); ++i) {
1409  const auto target_expr = table_func_node->getTableFuncInputAt(i);
1410  if (dynamic_cast<const RexSubQuery*>(target_expr)) {
1411  disambiguated_exprs.emplace_back(table_func_node->getTableFuncInputAtAndRelease(i));
1412  } else {
1413  disambiguated_exprs.emplace_back(disambiguate_rex(target_expr, input));
1414  }
1415  }
1416  table_func_node->setTableFuncInputs(disambiguated_exprs);
1417 }
1418 
1419 void bind_inputs(const std::vector<std::shared_ptr<RelAlgNode>>& nodes) noexcept {
1420  for (auto ra_node : nodes) {
1421  const auto filter_node = std::dynamic_pointer_cast<RelFilter>(ra_node);
1422  if (filter_node) {
1423  CHECK_EQ(size_t(1), filter_node->inputCount());
1424  auto disambiguated_condition = disambiguate_rex(
1425  filter_node->getCondition(), get_node_output(filter_node->getInput(0)));
1426  filter_node->setCondition(disambiguated_condition);
1427  continue;
1428  }
1429  const auto join_node = std::dynamic_pointer_cast<RelJoin>(ra_node);
1430  if (join_node) {
1431  CHECK_EQ(size_t(2), join_node->inputCount());
1432  auto disambiguated_condition =
1433  disambiguate_rex(join_node->getCondition(), get_node_output(join_node.get()));
1434  join_node->setCondition(disambiguated_condition);
1435  continue;
1436  }
1437  const auto project_node = std::dynamic_pointer_cast<RelProject>(ra_node);
1438  if (project_node) {
1439  bind_project_to_input(project_node.get(),
1440  get_node_output(project_node->getInput(0)));
1441  continue;
1442  }
1443  const auto table_func_node = std::dynamic_pointer_cast<RelTableFunction>(ra_node);
1444  if (table_func_node) {
1445  /*
1446  Collect all inputs from table function input (non-literal)
1447  arguments.
1448  */
1449  RANodeOutput input;
1450  input.reserve(table_func_node->inputCount());
1451  for (size_t i = 0; i < table_func_node->inputCount(); i++) {
1452  auto node_output = get_node_output(table_func_node->getInput(i));
1453  input.insert(input.end(), node_output.begin(), node_output.end());
1454  }
1455  bind_table_func_to_input(table_func_node.get(), input);
1456  }
1457  }
1458 }
1459 
1460 void handle_query_hint(const std::vector<std::shared_ptr<RelAlgNode>>& nodes,
1461  RelAlgDag& rel_alg_dag) noexcept {
1462  // query hint is delivered by the above three nodes
1463  // when a query block has top-sort node, a hint is registered to
1464  // one of the node which locates at the nearest from the sort node
1465  RegisteredQueryHint global_query_hint;
1466  for (auto node : nodes) {
1467  Hints* hint_delivered = nullptr;
1468  const auto agg_node = std::dynamic_pointer_cast<RelAggregate>(node);
1469  if (agg_node) {
1470  if (agg_node->hasDeliveredHint()) {
1471  hint_delivered = agg_node->getDeliveredHints();
1472  }
1473  }
1474  const auto project_node = std::dynamic_pointer_cast<RelProject>(node);
1475  if (project_node) {
1476  if (project_node->hasDeliveredHint()) {
1477  hint_delivered = project_node->getDeliveredHints();
1478  }
1479  }
1480  const auto compound_node = std::dynamic_pointer_cast<RelCompound>(node);
1481  if (compound_node) {
1482  if (compound_node->hasDeliveredHint()) {
1483  hint_delivered = compound_node->getDeliveredHints();
1484  }
1485  }
1486  if (hint_delivered && !hint_delivered->empty()) {
1487  rel_alg_dag.registerQueryHints(node, hint_delivered, global_query_hint);
1488  }
1489  }
1490  rel_alg_dag.setGlobalQueryHints(global_query_hint);
1491 }
1492 
1493 void compute_node_hash(const std::vector<std::shared_ptr<RelAlgNode>>& nodes) {
1494  // compute each rel node's hash value in advance to avoid inconsistency of their hash
1495  // values depending on the toHash's caller
1496  // specifically, we manipulate our logical query plan before retrieving query step
1497  // sequence but once we compute a hash value we cached it so there is no way to update
1498  // it after the plan has been changed starting from the top node, we compute the hash
1499  // value (top-down manner)
1500  std::for_each(
1501  nodes.rbegin(), nodes.rend(), [](const std::shared_ptr<RelAlgNode>& node) {
1502  auto node_hash = node->toHash();
1503  CHECK_NE(node_hash, static_cast<size_t>(0));
1504  });
1505 }
1506 
1507 void mark_nops(const std::vector<std::shared_ptr<RelAlgNode>>& nodes) noexcept {
1508  for (auto node : nodes) {
1509  const auto agg_node = std::dynamic_pointer_cast<RelAggregate>(node);
1510  if (!agg_node || agg_node->getAggExprsCount()) {
1511  continue;
1512  }
1513  CHECK_EQ(size_t(1), node->inputCount());
1514  const auto agg_input_node = dynamic_cast<const RelAggregate*>(node->getInput(0));
1515  if (agg_input_node && !agg_input_node->getAggExprsCount() &&
1516  agg_node->getGroupByCount() == agg_input_node->getGroupByCount()) {
1517  agg_node->markAsNop();
1518  }
1519  }
1520 }
1521 
1522 namespace {
1523 
1524 std::vector<const Rex*> reproject_targets(
1525  const RelProject* simple_project,
1526  const std::vector<const Rex*>& target_exprs) noexcept {
1527  std::vector<const Rex*> result;
1528  for (size_t i = 0; i < simple_project->size(); ++i) {
1529  const auto input_rex = dynamic_cast<const RexInput*>(simple_project->getProjectAt(i));
1530  CHECK(input_rex);
1531  CHECK_LT(static_cast<size_t>(input_rex->getIndex()), target_exprs.size());
1532  result.push_back(target_exprs[input_rex->getIndex()]);
1533  }
1534  return result;
1535 }
1536 
1543  public:
1545  const RelAlgNode* node_to_keep,
1546  const std::vector<std::unique_ptr<const RexScalar>>& scalar_sources)
1547  : node_to_keep_(node_to_keep), scalar_sources_(scalar_sources) {}
1548 
1549  // Reproject the RexInput from its current RA Node to the RA Node we intend to keep
1550  RetType visitInput(const RexInput* input) const final {
1551  if (input->getSourceNode() == node_to_keep_) {
1552  const auto index = input->getIndex();
1553  CHECK_LT(index, scalar_sources_.size());
1554  return visit(scalar_sources_[index].get());
1555  } else {
1556  return input->deepCopy();
1557  }
1558  }
1559 
1560  private:
1562  const std::vector<std::unique_ptr<const RexScalar>>& scalar_sources_;
1563 };
1564 
1565 } // namespace
1566 
1568  std::vector<std::shared_ptr<RelAlgNode>>& nodes,
1569  const std::vector<size_t>& pattern,
1570  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
1571  query_hints) noexcept {
1572  CHECK_GE(pattern.size(), size_t(2));
1573  CHECK_LE(pattern.size(), size_t(4));
1574 
1575  std::unique_ptr<const RexScalar> filter_rex;
1576  std::vector<std::unique_ptr<const RexScalar>> scalar_sources;
1577  size_t groupby_count{0};
1578  std::vector<std::string> fields;
1579  std::vector<const RexAgg*> agg_exprs;
1580  std::vector<const Rex*> target_exprs;
1581  bool first_project{true};
1582  bool is_agg{false};
1583  RelAlgNode* last_node{nullptr};
1584 
1585  std::shared_ptr<ModifyManipulationTarget> manipulation_target;
1586  size_t node_hash{0};
1587  unsigned node_id{0};
1588  bool hint_registered{false};
1589  RegisteredQueryHint registered_query_hint = RegisteredQueryHint::defaults();
1590  for (const auto node_idx : pattern) {
1591  const auto ra_node = nodes[node_idx];
1592  auto registered_query_hint_map_it = query_hints.find(ra_node->toHash());
1593  if (registered_query_hint_map_it != query_hints.end()) {
1594  auto& registered_query_hint_map = registered_query_hint_map_it->second;
1595  auto registered_query_hint_it = registered_query_hint_map.find(ra_node->getId());
1596  if (registered_query_hint_it != registered_query_hint_map.end()) {
1597  hint_registered = true;
1598  node_hash = registered_query_hint_map_it->first;
1599  node_id = registered_query_hint_it->first;
1600  registered_query_hint = registered_query_hint_it->second;
1601  }
1602  }
1603  const auto ra_filter = std::dynamic_pointer_cast<RelFilter>(ra_node);
1604  if (ra_filter) {
1605  CHECK(!filter_rex);
1606  filter_rex.reset(ra_filter->getAndReleaseCondition());
1607  CHECK(filter_rex);
1608  last_node = ra_node.get();
1609  continue;
1610  }
1611  const auto ra_project = std::dynamic_pointer_cast<RelProject>(ra_node);
1612  if (ra_project) {
1613  fields = ra_project->getFields();
1614  manipulation_target = ra_project;
1615 
1616  if (first_project) {
1617  CHECK_EQ(size_t(1), ra_project->inputCount());
1618  // Rebind the input of the project to the input of the filter itself
1619  // since we know that we'll evaluate the filter on the fly, with no
1620  // intermediate buffer.
1621  const auto filter_input = dynamic_cast<const RelFilter*>(ra_project->getInput(0));
1622  if (filter_input) {
1623  CHECK_EQ(size_t(1), filter_input->inputCount());
1624  bind_project_to_input(ra_project.get(),
1625  get_node_output(filter_input->getInput(0)));
1626  }
1627  scalar_sources = ra_project->getExpressionsAndRelease();
1628  for (const auto& scalar_expr : scalar_sources) {
1629  target_exprs.push_back(scalar_expr.get());
1630  }
1631  first_project = false;
1632  } else {
1633  if (ra_project->isSimple()) {
1634  target_exprs = reproject_targets(ra_project.get(), target_exprs);
1635  } else {
1636  // TODO(adb): This is essentially a more general case of simple project, we
1637  // could likely merge the two
1638  std::vector<const Rex*> result;
1639  RexInputReplacementVisitor visitor(last_node, scalar_sources);
1640  for (size_t i = 0; i < ra_project->size(); ++i) {
1641  const auto rex = ra_project->getProjectAt(i);
1642  if (auto rex_input = dynamic_cast<const RexInput*>(rex)) {
1643  const auto index = rex_input->getIndex();
1644  CHECK_LT(index, target_exprs.size());
1645  result.push_back(target_exprs[index]);
1646  } else {
1647  scalar_sources.push_back(visitor.visit(rex));
1648  result.push_back(scalar_sources.back().get());
1649  }
1650  }
1651  target_exprs = result;
1652  }
1653  }
1654  last_node = ra_node.get();
1655  continue;
1656  }
1657  const auto ra_aggregate = std::dynamic_pointer_cast<RelAggregate>(ra_node);
1658  if (ra_aggregate) {
1659  is_agg = true;
1660  fields = ra_aggregate->getFields();
1661  agg_exprs = ra_aggregate->getAggregatesAndRelease();
1662  groupby_count = ra_aggregate->getGroupByCount();
1663  decltype(target_exprs){}.swap(target_exprs);
1664  CHECK_LE(groupby_count, scalar_sources.size());
1665  for (size_t group_idx = 0; group_idx < groupby_count; ++group_idx) {
1666  const auto rex_ref = new RexRef(group_idx + 1);
1667  target_exprs.push_back(rex_ref);
1668  scalar_sources.emplace_back(rex_ref);
1669  }
1670  for (const auto rex_agg : agg_exprs) {
1671  target_exprs.push_back(rex_agg);
1672  }
1673  last_node = ra_node.get();
1674  continue;
1675  }
1676  }
1677 
1678  auto compound_node =
1679  std::make_shared<RelCompound>(filter_rex,
1680  target_exprs,
1681  groupby_count,
1682  agg_exprs,
1683  fields,
1684  scalar_sources,
1685  is_agg,
1686  manipulation_target->isUpdateViaSelect(),
1687  manipulation_target->isDeleteViaSelect(),
1688  manipulation_target->isVarlenUpdateRequired(),
1689  manipulation_target->getModifiedTableDescriptor(),
1690  manipulation_target->getTargetColumns());
1691  auto old_node = nodes[pattern.back()];
1692  nodes[pattern.back()] = compound_node;
1693  auto first_node = nodes[pattern.front()];
1694  CHECK_EQ(size_t(1), first_node->inputCount());
1695  compound_node->addManagedInput(first_node->getAndOwnInput(0));
1696  if (hint_registered) {
1697  // pass the registered hint from the origin node to newly created compound node
1698  // where it is coalesced
1699  auto registered_query_hint_map_it = query_hints.find(node_hash);
1700  CHECK(registered_query_hint_map_it != query_hints.end());
1701  auto registered_query_hint_map = registered_query_hint_map_it->second;
1702  if (registered_query_hint_map.size() > 1) {
1703  registered_query_hint_map.erase(node_id);
1704  } else {
1705  CHECK_EQ(registered_query_hint_map.size(), static_cast<size_t>(1));
1706  query_hints.erase(node_hash);
1707  }
1708  std::unordered_map<unsigned, RegisteredQueryHint> hint_map;
1709  hint_map.emplace(compound_node->getId(), registered_query_hint);
1710  query_hints.emplace(compound_node->toHash(), hint_map);
1711  }
1712  for (size_t i = 0; i < pattern.size() - 1; ++i) {
1713  nodes[pattern[i]].reset();
1714  }
1715  for (auto node : nodes) {
1716  if (!node) {
1717  continue;
1718  }
1719  node->replaceInput(old_node, compound_node);
1720  }
1721 }
1722 
1723 class RANodeIterator : public std::vector<std::shared_ptr<RelAlgNode>>::const_iterator {
1724  using ElementType = std::shared_ptr<RelAlgNode>;
1725  using Super = std::vector<ElementType>::const_iterator;
1726  using Container = std::vector<ElementType>;
1727 
1728  public:
1729  enum class AdvancingMode { DUChain, InOrder };
1730 
1731  explicit RANodeIterator(const Container& nodes)
1732  : Super(nodes.begin()), owner_(nodes), nodeCount_([&nodes]() -> size_t {
1733  size_t non_zero_count = 0;
1734  for (const auto& node : nodes) {
1735  if (node) {
1736  ++non_zero_count;
1737  }
1738  }
1740  }()) {}
1741 
1742  explicit operator size_t() {
1743  return std::distance(owner_.begin(), *static_cast<Super*>(this));
1744  }
1745 
1746  RANodeIterator operator++() = delete;
1747 
1748  void advance(AdvancingMode mode) {
1749  Super& super = *this;
1750  switch (mode) {
1751  case AdvancingMode::DUChain: {
1752  size_t use_count = 0;
1753  Super only_use = owner_.end();
1754  for (Super nodeIt = std::next(super); nodeIt != owner_.end(); ++nodeIt) {
1755  if (!*nodeIt) {
1756  continue;
1757  }
1758  for (size_t i = 0; i < (*nodeIt)->inputCount(); ++i) {
1759  if ((*super) == (*nodeIt)->getAndOwnInput(i)) {
1760  ++use_count;
1761  if (1 == use_count) {
1762  only_use = nodeIt;
1763  } else {
1764  super = owner_.end();
1765  return;
1766  }
1767  }
1768  }
1769  }
1770  super = only_use;
1771  break;
1772  }
1773  case AdvancingMode::InOrder:
1774  for (size_t i = 0; i != owner_.size(); ++i) {
1775  if (!visited_.count(i)) {
1776  super = owner_.begin();
1777  std::advance(super, i);
1778  return;
1779  }
1780  }
1781  super = owner_.end();
1782  break;
1783  default:
1784  CHECK(false);
1785  }
1786  }
1787 
1788  bool allVisited() { return visited_.size() == nodeCount_; }
1789 
1791  visited_.insert(size_t(*this));
1792  Super& super = *this;
1793  return *super;
1794  }
1795 
1796  const ElementType* operator->() { return &(operator*()); }
1797 
1798  private:
1800  const size_t nodeCount_;
1801  std::unordered_set<size_t> visited_;
1802 };
1803 
1804 namespace {
1805 
1806 bool input_can_be_coalesced(const RelAlgNode* parent_node,
1807  const size_t index,
1808  const bool first_rex_is_input) {
1809  if (auto agg_node = dynamic_cast<const RelAggregate*>(parent_node)) {
1810  if (index == 0 && agg_node->getGroupByCount() > 0) {
1811  return true;
1812  } else {
1813  // Is an aggregated target, only allow the project to be elided if the aggregate
1814  // target is simply passed through (i.e. if the top level expression attached to
1815  // the project node is a RexInput expression)
1816  return first_rex_is_input;
1817  }
1818  }
1819  return first_rex_is_input;
1820 }
1821 
1828  public:
1829  bool visitInput(const RexInput* input) const final {
1830  // The top level expression node is checked before we apply the visitor. If we get
1831  // here, this input rex is a child of another rex node, and we handle the can be
1832  // coalesced check slightly differently
1833  return input_can_be_coalesced(input->getSourceNode(), input->getIndex(), false);
1834  }
1835 
1836  bool visitLiteral(const RexLiteral*) const final { return false; }
1837 
1838  bool visitSubQuery(const RexSubQuery*) const final { return false; }
1839 
1840  bool visitRef(const RexRef*) const final { return false; }
1841 
1842  protected:
1843  bool aggregateResult(const bool& aggregate, const bool& next_result) const final {
1844  return aggregate && next_result;
1845  }
1846 
1847  bool defaultResult() const final { return true; }
1848 };
1849 
1850 // Detect the window function SUM pattern: CASE WHEN COUNT() > 0 THEN SUM ELSE 0
1852  const auto case_operator = dynamic_cast<const RexCase*>(rex);
1853  if (case_operator && case_operator->branchCount() == 1) {
1854  const auto then_window =
1855  dynamic_cast<const RexWindowFunctionOperator*>(case_operator->getThen(0));
1856  if (then_window && then_window->getKind() == SqlWindowFunctionKind::SUM_INTERNAL) {
1857  return true;
1858  }
1859  }
1860  return false;
1861 }
1862 
1863 // Check for Window Function AVG:
1864 // (CASE WHEN count > 0 THEN sum ELSE 0) / COUNT
1866  const RexOperator* divide_operator = dynamic_cast<const RexOperator*>(rex);
1867  if (divide_operator && divide_operator->getOperator() == kDIVIDE) {
1868  CHECK_EQ(divide_operator->size(), size_t(2));
1869  const auto case_operator =
1870  dynamic_cast<const RexCase*>(divide_operator->getOperand(0));
1871  const auto second_window =
1872  dynamic_cast<const RexWindowFunctionOperator*>(divide_operator->getOperand(1));
1873  if (case_operator && second_window &&
1874  second_window->getKind() == SqlWindowFunctionKind::COUNT) {
1875  if (is_window_function_sum(case_operator)) {
1876  return true;
1877  }
1878  }
1879  }
1880  return false;
1881 }
1882 
1883 // Detect both window function operators and window function operators embedded in case
1884 // statements (for null handling)
1886  if (dynamic_cast<const RexWindowFunctionOperator*>(rex)) {
1887  return true;
1888  }
1889 
1890  // unwrap from casts, if they exist
1891  const auto rex_cast = dynamic_cast<const RexOperator*>(rex);
1892  if (rex_cast && rex_cast->getOperator() == kCAST) {
1893  CHECK_EQ(rex_cast->size(), size_t(1));
1894  return is_window_function_operator(rex_cast->getOperand(0));
1895  }
1896 
1898  return true;
1899  }
1900 
1901  return false;
1902 }
1903 
1904 } // namespace
1905 
1907  std::vector<std::shared_ptr<RelAlgNode>>& nodes,
1908  const std::vector<const RelAlgNode*>& left_deep_joins,
1909  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
1910  query_hints) {
1911  enum class CoalesceState { Initial, Filter, FirstProject, Aggregate };
1912  std::vector<size_t> crt_pattern;
1913  CoalesceState crt_state{CoalesceState::Initial};
1914 
1915  auto reset_state = [&crt_pattern, &crt_state]() {
1916  crt_state = CoalesceState::Initial;
1917  std::vector<size_t>().swap(crt_pattern);
1918  };
1919 
1920  for (RANodeIterator nodeIt(nodes); !nodeIt.allVisited();) {
1921  const auto ra_node = nodeIt != nodes.end() ? *nodeIt : nullptr;
1922  switch (crt_state) {
1923  case CoalesceState::Initial: {
1924  if (std::dynamic_pointer_cast<const RelFilter>(ra_node) &&
1925  std::find(left_deep_joins.begin(), left_deep_joins.end(), ra_node.get()) ==
1926  left_deep_joins.end()) {
1927  crt_pattern.push_back(size_t(nodeIt));
1928  crt_state = CoalesceState::Filter;
1929  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1930  } else if (auto project_node =
1931  std::dynamic_pointer_cast<const RelProject>(ra_node)) {
1932  if (project_node->hasWindowFunctionExpr()) {
1933  nodeIt.advance(RANodeIterator::AdvancingMode::InOrder);
1934  } else {
1935  crt_pattern.push_back(size_t(nodeIt));
1936  crt_state = CoalesceState::FirstProject;
1937  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1938  }
1939  } else {
1940  nodeIt.advance(RANodeIterator::AdvancingMode::InOrder);
1941  }
1942  break;
1943  }
1944  case CoalesceState::Filter: {
1945  if (auto project_node = std::dynamic_pointer_cast<const RelProject>(ra_node)) {
1946  // Given we now add preceding projects for all window functions following
1947  // RelFilter nodes, the following should never occur
1948  CHECK(!project_node->hasWindowFunctionExpr());
1949  crt_pattern.push_back(size_t(nodeIt));
1950  crt_state = CoalesceState::FirstProject;
1951  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1952  } else {
1953  reset_state();
1954  }
1955  break;
1956  }
1957  case CoalesceState::FirstProject: {
1958  if (std::dynamic_pointer_cast<const RelAggregate>(ra_node)) {
1959  crt_pattern.push_back(size_t(nodeIt));
1960  crt_state = CoalesceState::Aggregate;
1961  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1962  } else {
1963  if (crt_pattern.size() >= 2) {
1964  create_compound(nodes, crt_pattern, query_hints);
1965  }
1966  reset_state();
1967  }
1968  break;
1969  }
1970  case CoalesceState::Aggregate: {
1971  if (auto project_node = std::dynamic_pointer_cast<const RelProject>(ra_node)) {
1972  if (!project_node->hasWindowFunctionExpr()) {
1973  // TODO(adb): overloading the simple project terminology again here
1974  bool is_simple_project{true};
1975  for (size_t i = 0; i < project_node->size(); i++) {
1976  const auto scalar_rex = project_node->getProjectAt(i);
1977  // If the top level scalar rex is an input node, we can bypass the visitor
1978  if (auto input_rex = dynamic_cast<const RexInput*>(scalar_rex)) {
1980  input_rex->getSourceNode(), input_rex->getIndex(), true)) {
1981  is_simple_project = false;
1982  break;
1983  }
1984  continue;
1985  }
1986  CoalesceSecondaryProjectVisitor visitor;
1987  if (!visitor.visit(project_node->getProjectAt(i))) {
1988  is_simple_project = false;
1989  break;
1990  }
1991  }
1992  if (is_simple_project) {
1993  crt_pattern.push_back(size_t(nodeIt));
1994  nodeIt.advance(RANodeIterator::AdvancingMode::InOrder);
1995  }
1996  }
1997  }
1998  CHECK_GE(crt_pattern.size(), size_t(2));
1999  create_compound(nodes, crt_pattern, query_hints);
2000  reset_state();
2001  break;
2002  }
2003  default:
2004  CHECK(false);
2005  }
2006  }
2007  if (crt_state == CoalesceState::FirstProject || crt_state == CoalesceState::Aggregate) {
2008  if (crt_pattern.size() >= 2) {
2009  create_compound(nodes, crt_pattern, query_hints);
2010  }
2011  CHECK(!crt_pattern.empty());
2012  }
2013 }
2014 
2015 class WindowFunctionCollector : public RexVisitor<void*> {
2016  public:
2018  std::unordered_map<size_t, const RexScalar*>& collected_window_func)
2019  : collected_window_func_(collected_window_func) {}
2020 
2021  protected:
2022  // Detect embedded window function expressions in operators
2023  void* visitOperator(const RexOperator* rex_operator) const final {
2024  if (is_window_function_operator(rex_operator)) {
2025  collected_window_func_.emplace(rex_operator->toHash(), rex_operator);
2026  }
2027  const size_t operand_count = rex_operator->size();
2028  for (size_t i = 0; i < operand_count; ++i) {
2029  const auto operand = rex_operator->getOperand(i);
2030  if (is_window_function_operator(operand)) {
2031  // Handle both RexWindowFunctionOperators and window functions built up from
2032  // multiple RexScalar objects (e.g. AVG)
2033  collected_window_func_.emplace(operand->toHash(), operand);
2034  } else {
2035  visit(operand);
2036  }
2037  }
2038  return defaultResult();
2039  }
2040 
2041  // Detect embedded window function expressions in case statements. Note that this may
2042  // manifest as a nested case statement inside a top level case statement, as some
2043  // window functions (sum, avg) are represented as a case statement. Use the
2044  // is_window_function_operator helper to detect complete window function expressions.
2045  void* visitCase(const RexCase* rex_case) const final {
2046  if (is_window_function_operator(rex_case)) {
2047  collected_window_func_.emplace(rex_case->toHash(), rex_case);
2048  return nullptr;
2049  }
2050 
2051  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
2052  const auto when = rex_case->getWhen(i);
2053  if (is_window_function_operator(when)) {
2054  collected_window_func_.emplace(when->toHash(), when);
2055  } else {
2056  visit(when);
2057  }
2058  const auto then = rex_case->getThen(i);
2059  if (is_window_function_operator(then)) {
2060  collected_window_func_.emplace(then->toHash(), then);
2061  } else {
2062  visit(then);
2063  }
2064  }
2065  if (rex_case->getElse()) {
2066  auto else_expr = rex_case->getElse();
2067  if (is_window_function_operator(else_expr)) {
2068  collected_window_func_.emplace(else_expr->toHash(), else_expr);
2069  } else {
2070  visit(else_expr);
2071  }
2072  }
2073  return defaultResult();
2074  }
2075 
2076  void* defaultResult() const final { return nullptr; }
2077 
2078  private:
2079  std::unordered_map<size_t, const RexScalar*>& collected_window_func_;
2080 };
2081 
2083  public:
2085  std::unordered_set<size_t>& collected_window_func_hash,
2086  std::vector<std::unique_ptr<const RexScalar>>& new_rex_input_for_window_func,
2087  std::unordered_map<size_t, size_t>& window_func_to_new_rex_input_idx_map,
2088  RelProject* new_project,
2089  std::unordered_map<size_t, std::unique_ptr<const RexInput>>&
2090  new_rex_input_from_child_node)
2091  : collected_window_func_hash_(collected_window_func_hash)
2092  , new_rex_input_for_window_func_(new_rex_input_for_window_func)
2093  , window_func_to_new_rex_input_idx_map_(window_func_to_new_rex_input_idx_map)
2094  , new_project_(new_project)
2095  , new_rex_input_from_child_node_(new_rex_input_from_child_node) {
2096  CHECK_EQ(collected_window_func_hash_.size(),
2097  window_func_to_new_rex_input_idx_map_.size());
2098  for (auto hash : collected_window_func_hash_) {
2099  auto rex_it = window_func_to_new_rex_input_idx_map_.find(hash);
2100  CHECK(rex_it != window_func_to_new_rex_input_idx_map_.end());
2101  CHECK_LT(rex_it->second, new_rex_input_for_window_func_.size());
2102  }
2103  CHECK(new_project_);
2104  }
2105 
2106  protected:
2107  RetType visitInput(const RexInput* rex_input) const final {
2108  if (rex_input->getSourceNode() != new_project_) {
2109  const auto cur_index = rex_input->getIndex();
2110  auto cur_source_node = rex_input->getSourceNode();
2111  std::string field_name = "";
2112  if (auto cur_project_node = dynamic_cast<const RelProject*>(cur_source_node)) {
2113  field_name = cur_project_node->getFieldName(cur_index);
2114  }
2115  auto rex_input_hash = rex_input->toHash();
2116  auto rex_input_it = new_rex_input_from_child_node_.find(rex_input_hash);
2117  if (rex_input_it == new_rex_input_from_child_node_.end()) {
2118  auto new_rex_input =
2119  std::make_unique<RexInput>(new_project_, new_project_->size());
2120  new_project_->appendInput(field_name, rex_input->deepCopy());
2121  new_rex_input_from_child_node_.emplace(rex_input_hash, new_rex_input->deepCopy());
2122  return new_rex_input;
2123  } else {
2124  return rex_input_it->second->deepCopy();
2125  }
2126  } else {
2127  return rex_input->deepCopy();
2128  }
2129  }
2130 
2131  RetType visitOperator(const RexOperator* rex_operator) const final {
2132  auto new_rex_idx = is_collected_window_function(rex_operator->toHash());
2133  if (new_rex_idx) {
2134  return get_new_rex_input(*new_rex_idx);
2135  }
2136 
2137  const auto rex_window_function_operator =
2138  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
2139  if (rex_window_function_operator) {
2140  // Deep copy the embedded window function operator
2141  return visitWindowFunctionOperator(rex_window_function_operator);
2142  }
2143 
2144  const size_t operand_count = rex_operator->size();
2145  std::vector<RetType> new_opnds;
2146  for (size_t i = 0; i < operand_count; ++i) {
2147  const auto operand = rex_operator->getOperand(i);
2148  auto new_rex_idx_for_operand = is_collected_window_function(operand->toHash());
2149  if (new_rex_idx_for_operand) {
2150  new_opnds.push_back(get_new_rex_input(*new_rex_idx_for_operand));
2151  } else {
2152  new_opnds.emplace_back(visit(rex_operator->getOperand(i)));
2153  }
2154  }
2155  return rex_operator->getDisambiguated(new_opnds);
2156  }
2157 
2158  RetType visitCase(const RexCase* rex_case) const final {
2159  auto new_rex_idx = is_collected_window_function(rex_case->toHash());
2160  if (new_rex_idx) {
2161  return get_new_rex_input(*new_rex_idx);
2162  }
2163 
2164  std::vector<std::pair<RetType, RetType>> new_pair_list;
2165  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
2166  auto when_operand = rex_case->getWhen(i);
2167  auto new_rex_idx_for_when_operand =
2168  is_collected_window_function(when_operand->toHash());
2169 
2170  auto then_operand = rex_case->getThen(i);
2171  auto new_rex_idx_for_then_operand =
2172  is_collected_window_function(then_operand->toHash());
2173 
2174  new_pair_list.emplace_back(
2175  new_rex_idx_for_when_operand ? get_new_rex_input(*new_rex_idx_for_when_operand)
2176  : visit(when_operand),
2177  new_rex_idx_for_then_operand ? get_new_rex_input(*new_rex_idx_for_then_operand)
2178  : visit(then_operand));
2179  }
2180  auto new_rex_idx_for_else_operand =
2181  is_collected_window_function(rex_case->getElse()->toHash());
2182  auto new_else = new_rex_idx_for_else_operand
2183  ? get_new_rex_input(*new_rex_idx_for_else_operand)
2184  : visit(rex_case->getElse());
2185  return std::make_unique<RexCase>(new_pair_list, new_else);
2186  }
2187 
2188  private:
2189  std::optional<size_t> is_collected_window_function(size_t rex_hash) const {
2190  auto rex_it = window_func_to_new_rex_input_idx_map_.find(rex_hash);
2191  if (rex_it != window_func_to_new_rex_input_idx_map_.end()) {
2192  return rex_it->second;
2193  }
2194  return std::nullopt;
2195  }
2196 
2197  std::unique_ptr<const RexScalar> get_new_rex_input(size_t rex_idx) const {
2198  CHECK_GE(rex_idx, 0UL);
2199  CHECK_LT(rex_idx, new_rex_input_for_window_func_.size());
2200  auto& new_rex_input = new_rex_input_for_window_func_.at(rex_idx);
2201  CHECK(new_rex_input);
2202  auto copied_rex_input = copier_.visit(new_rex_input.get());
2203  return copied_rex_input;
2204  }
2205 
2206  std::unordered_set<size_t>& collected_window_func_hash_;
2207  // we should have new rex_input for each window function collected
2208  std::vector<std::unique_ptr<const RexScalar>>& new_rex_input_for_window_func_;
2209  // an index to get a new rex_input for the collected window function
2210  std::unordered_map<size_t, size_t>& window_func_to_new_rex_input_idx_map_;
2212  std::unordered_map<size_t, std::unique_ptr<const RexInput>>&
2215 };
2216 
2218  std::shared_ptr<RelProject> prev_node,
2219  std::shared_ptr<RelProject> new_node,
2220  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
2221  query_hints) {
2222  auto delivered_hints = prev_node->getDeliveredHints();
2223  bool needs_propagate_hints = !delivered_hints->empty();
2224  if (needs_propagate_hints) {
2225  for (auto& kv : *delivered_hints) {
2226  new_node->addHint(kv.second);
2227  }
2228  auto prev_it = query_hints.find(prev_node->toHash());
2229  // query hint for the prev projection node should be registered
2230  CHECK(prev_it != query_hints.end());
2231  auto prev_hint_it = prev_it->second.find(prev_node->getId());
2232  CHECK(prev_hint_it != prev_it->second.end());
2233  std::unordered_map<unsigned, RegisteredQueryHint> hint_map;
2234  hint_map.emplace(new_node->getId(), prev_hint_it->second);
2235  query_hints.emplace(new_node->toHash(), hint_map);
2236  }
2237 }
2238 
2262  std::vector<std::shared_ptr<RelAlgNode>>& nodes,
2263  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
2264  query_hints) {
2265  std::list<std::shared_ptr<RelAlgNode>> node_list(nodes.begin(), nodes.end());
2266  for (auto node_itr = node_list.begin(); node_itr != node_list.end(); ++node_itr) {
2267  const auto node = *node_itr;
2268  auto window_func_project_node = std::dynamic_pointer_cast<RelProject>(node);
2269  if (!window_func_project_node) {
2270  continue;
2271  }
2272 
2273  const auto prev_node_itr = std::prev(node_itr);
2274  const auto prev_node = *prev_node_itr;
2275  CHECK(prev_node);
2276 
2277  // map scalar expression index in the project node to window function ptr
2278  std::unordered_map<size_t, const RexScalar*> collected_window_func;
2279  WindowFunctionCollector collector(collected_window_func);
2280  // Iterate the target exprs of the project node and check for window function
2281  // expressions. If an embedded expression exists, collect it
2282  for (size_t i = 0; i < window_func_project_node->size(); i++) {
2283  const auto scalar_rex = window_func_project_node->getProjectAt(i);
2284  if (is_window_function_operator(scalar_rex)) {
2285  // top level window function exprs are fine
2286  continue;
2287  }
2288  collector.visit(scalar_rex);
2289  }
2290 
2291  if (!collected_window_func.empty()) {
2292  // we have a nested window function expression
2293  std::unordered_set<size_t> collected_window_func_hash;
2294  // the current window function needs a set of new rex input which references
2295  // expressions in the newly introduced projection node
2296  std::vector<std::unique_ptr<const RexScalar>> new_rex_input_for_window_func;
2297  // a target projection expression of the newly created projection node
2298  std::vector<std::unique_ptr<const RexScalar>> new_scalar_expr_for_window_project;
2299  // a map between nested window function (hash val) and
2300  // its rex index stored in the `new_rex_input_for_window_func`
2301  std::unordered_map<size_t, size_t> window_func_to_new_rex_input_idx_map;
2302  // a map between RexInput of the current window function projection node (hash val)
2303  // and its corresponding new RexInput which is pushed down to the new projection
2304  // node
2305  std::unordered_map<size_t, std::unique_ptr<const RexInput>>
2306  new_rex_input_from_child_node;
2307  RexDeepCopyVisitor copier;
2308 
2309  std::vector<std::unique_ptr<const RexScalar>> dummy_scalar_exprs;
2310  std::vector<std::string> dummy_fields;
2311  std::vector<std::string> new_project_field_names;
2312  // create a new project node, it will contain window function expressions
2313  auto new_project =
2314  std::make_shared<RelProject>(dummy_scalar_exprs, dummy_fields, prev_node);
2315  // insert this new project node between the current window project node and its
2316  // child node
2317  node_list.insert(node_itr, new_project);
2318 
2319  // retrieve various information to replace expressions in the current window
2320  // function project node w/ considering scalar expressions in the new project node
2321  std::for_each(collected_window_func.begin(),
2322  collected_window_func.end(),
2323  [&new_project_field_names,
2324  &collected_window_func_hash,
2325  &new_rex_input_for_window_func,
2326  &new_scalar_expr_for_window_project,
2327  &copier,
2328  &new_project,
2329  &window_func_to_new_rex_input_idx_map](const auto& kv) {
2330  // compute window function expr's hash, and create a new rex_input
2331  // for it
2332  collected_window_func_hash.insert(kv.first);
2333 
2334  // map an old expression in the window function project node
2335  // to an index of the corresponding new RexInput
2336  const auto rex_idx = new_rex_input_for_window_func.size();
2337  window_func_to_new_rex_input_idx_map.emplace(kv.first, rex_idx);
2338 
2339  // create a new RexInput and make it as one of new expression of the
2340  // newly created project node
2341  new_rex_input_for_window_func.emplace_back(
2342  std::make_unique<const RexInput>(new_project.get(), rex_idx));
2343  new_scalar_expr_for_window_project.push_back(
2344  std::move(copier.visit(kv.second)));
2345  new_project_field_names.emplace_back("");
2346  });
2347  new_project->setExpressions(new_scalar_expr_for_window_project);
2348  new_project->setFields(std::move(new_project_field_names));
2349 
2350  auto window_func_scalar_exprs =
2351  window_func_project_node->getExpressionsAndRelease();
2352  RexWindowFuncReplacementVisitor replacer(collected_window_func_hash,
2353  new_rex_input_for_window_func,
2354  window_func_to_new_rex_input_idx_map,
2355  new_project.get(),
2356  new_rex_input_from_child_node);
2357  size_t rex_idx = 0;
2358  for (auto& scalar_expr : window_func_scalar_exprs) {
2359  // try to replace the old expressions in the window function project node
2360  // with expressions of the newly created project node
2361  auto new_parent_rex = replacer.visit(scalar_expr.get());
2362  window_func_scalar_exprs[rex_idx] = std::move(new_parent_rex);
2363  rex_idx++;
2364  }
2365  // Update the previous window project node
2366  window_func_project_node->setExpressions(window_func_scalar_exprs);
2367  window_func_project_node->replaceInput(prev_node, new_project);
2368  propagate_hints_to_new_project(window_func_project_node, new_project, query_hints);
2369  }
2370  }
2371  nodes.assign(node_list.begin(), node_list.end());
2372 }
2373 
2374 using RexInputSet = std::unordered_set<RexInput>;
2375 
2376 class RexInputCollector : public RexVisitor<RexInputSet> {
2377  public:
2378  RexInputSet visitInput(const RexInput* input) const override {
2379  return RexInputSet{*input};
2380  }
2381 
2382  protected:
2384  const RexInputSet& next_result) const override {
2385  auto result = aggregate;
2386  result.insert(next_result.begin(), next_result.end());
2387  return result;
2388  }
2389 };
2390 
2404  std::vector<std::shared_ptr<RelAlgNode>>& nodes,
2405  const bool always_add_project_if_first_project_is_window_expr,
2406  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
2407  query_hints) {
2408  std::list<std::shared_ptr<RelAlgNode>> node_list(nodes.begin(), nodes.end());
2409  size_t project_node_counter{0};
2410  for (auto node_itr = node_list.begin(); node_itr != node_list.end(); ++node_itr) {
2411  const auto node = *node_itr;
2412 
2413  auto window_func_project_node = std::dynamic_pointer_cast<RelProject>(node);
2414  if (!window_func_project_node) {
2415  continue;
2416  }
2417  project_node_counter++;
2418  if (!window_func_project_node->hasWindowFunctionExpr()) {
2419  // this projection node does not have a window function
2420  // expression -- skip to the next node in the DAG.
2421  continue;
2422  }
2423 
2424  auto need_pushdown_generic_expr = [&window_func_project_node]() {
2425  for (size_t i = 0; i < window_func_project_node->size(); ++i) {
2426  const auto projected_target = window_func_project_node->getProjectAt(i);
2427  if (auto window_expr =
2428  dynamic_cast<const RexWindowFunctionOperator*>(projected_target)) {
2429  for (const auto& partition_key : window_expr->getPartitionKeys()) {
2430  auto partition_input = dynamic_cast<const RexInput*>(partition_key.get());
2431  if (!partition_input) {
2432  return true;
2433  }
2434  }
2435  for (const auto& order_key : window_expr->getOrderKeys()) {
2436  auto order_input = dynamic_cast<const RexInput*>(order_key.get());
2437  if (!order_input) {
2438  return true;
2439  }
2440  }
2441  }
2442  }
2443  return false;
2444  };
2445 
2446  const auto prev_node_itr = std::prev(node_itr);
2447  const auto prev_node = *prev_node_itr;
2448  CHECK(prev_node);
2449 
2450  auto filter_node = std::dynamic_pointer_cast<RelFilter>(prev_node);
2451  auto join_node = std::dynamic_pointer_cast<RelJoin>(prev_node);
2452 
2453  auto scan_node = std::dynamic_pointer_cast<RelScan>(prev_node);
2454  const bool has_multi_fragment_scan_input =
2455  (scan_node &&
2456  (scan_node->getNumShards() > 0 || scan_node->getNumFragments() > 1));
2457  const bool needs_expr_pushdown = need_pushdown_generic_expr();
2458 
2459  // We currently add a preceding project node in one of two conditions:
2460  // 1. always_add_project_if_first_project_is_window_expr = true, which
2461  // we currently only set for distributed, but could also be set to support
2462  // multi-frag window function inputs, either if we can detect that an input table
2463  // is multi-frag up front, or using a retry mechanism like we do for join filter
2464  // push down.
2465  // TODO(todd): Investigate a viable approach for the above.
2466  // 2. Regardless of #1, if the window function project node is preceded by a
2467  // filter node. This is required both for correctness and to avoid pulling
2468  // all source input columns into memory since non-coalesced filter node
2469  // inputs are currently not pruned or eliminated via dead column elimination.
2470  // Note that we expect any filter node followed by a project node to be coalesced
2471  // into a single compound node in RelAlgDag::coalesce_nodes, and that action
2472  // prunes unused inputs.
2473  // TODO(todd): Investigate whether the shotgun filter node issue affects other
2474  // query plans, i.e. filters before joins, and whether there is a more general
2475  // approach to solving this (will still need the preceding project node for
2476  // window functions preceded by filter nodes for correctness though)
2477  // 3. Similar to the above, when the window function project node is preceded
2478  // by a join node.
2479  // 4. when partition by / order by clauses have a general expression instead of
2480  // referencing column
2481 
2482  if (!((always_add_project_if_first_project_is_window_expr &&
2483  project_node_counter == 1) ||
2484  filter_node || join_node || has_multi_fragment_scan_input ||
2485  needs_expr_pushdown)) {
2486  continue;
2487  }
2488 
2489  if (needs_expr_pushdown || join_node) {
2490  // previous logic cannot cover join_node case well, so use the newly introduced
2491  // push-down expression logic to safely add pre_project node before processing
2492  // window function
2493  std::unordered_map<size_t, size_t> expr_offset_cache;
2494  std::vector<std::unique_ptr<const RexScalar>> scalar_exprs_for_new_project;
2495  std::vector<std::unique_ptr<const RexScalar>> scalar_exprs_for_window_project;
2496  std::vector<std::string> fields_for_window_project;
2497  std::vector<std::string> fields_for_new_project;
2498 
2499  // step 0. create new project node with an empty scalar expr to rebind target exprs
2500  std::vector<std::unique_ptr<const RexScalar>> dummy_scalar_exprs;
2501  std::vector<std::string> dummy_fields;
2502  auto new_project =
2503  std::make_shared<RelProject>(dummy_scalar_exprs, dummy_fields, prev_node);
2504 
2505  // step 1 - 2
2506  PushDownGenericExpressionInWindowFunction visitor(new_project,
2507  scalar_exprs_for_new_project,
2508  fields_for_new_project,
2509  expr_offset_cache);
2510  for (size_t i = 0; i < window_func_project_node->size(); ++i) {
2511  auto projected_target = window_func_project_node->getProjectAt(i);
2512  auto new_projection_target = visitor.visit(projected_target);
2513  scalar_exprs_for_window_project.emplace_back(
2514  std::move(new_projection_target.release()));
2515  }
2516  new_project->setExpressions(scalar_exprs_for_new_project);
2517  new_project->setFields(std::move(fields_for_new_project));
2518  bool has_groupby = false;
2519  auto aggregate = std::dynamic_pointer_cast<RelAggregate>(prev_node);
2520  if (aggregate) {
2521  has_groupby = aggregate->getGroupByCount() > 0;
2522  }
2523  if (has_groupby && visitor.hasPartitionExpression()) {
2524  // we currently may compute incorrect result with columnar output when
2525  // 1) the window function has partition expression, and
2526  // 2) a parent node of the window function projection node has group by expression
2527  // so we force rowwise output (only) for the newly injected projection node
2528  // to prevent computing incorrect query result
2529  // todo (yoonmin) : relax this
2530  VLOG(1)
2531  << "Query output overridden to row-wise format due to presence of a window "
2532  "function with partition expression and group-by expression.";
2533  new_project->forceRowwiseOutput();
2534  }
2535  if (visitor.hasCaseExprAsWindowOperand()) {
2536  // force rowwise output
2537  VLOG(1)
2538  << "Query output overridden to row-wise format due to presence of a window "
2539  "function with a case statement as its operand.";
2540  new_project->forceRowwiseOutput();
2541  }
2542 
2543  // step 3. finalize
2544  propagate_hints_to_new_project(window_func_project_node, new_project, query_hints);
2545  node_list.insert(node_itr, new_project);
2546  window_func_project_node->replaceInput(prev_node, new_project);
2547  window_func_project_node->setExpressions(scalar_exprs_for_window_project);
2548  } else {
2549  // only push rex_inputs listed in the window function down to a new project node
2550  RexInputSet inputs;
2551  RexInputCollector input_collector;
2552  for (size_t i = 0; i < window_func_project_node->size(); i++) {
2553  auto new_inputs =
2554  input_collector.visit(window_func_project_node->getProjectAt(i));
2555  inputs.insert(new_inputs.begin(), new_inputs.end());
2556  }
2557 
2558  // Note: Technically not required since we are mapping old inputs to new input
2559  // indices, but makes the re-mapping of inputs easier to follow.
2560  std::vector<RexInput> sorted_inputs(inputs.begin(), inputs.end());
2561  std::sort(sorted_inputs.begin(),
2562  sorted_inputs.end(),
2563  [](const auto& a, const auto& b) { return a.getIndex() < b.getIndex(); });
2564 
2565  std::vector<std::unique_ptr<const RexScalar>> scalar_exprs;
2566  std::vector<std::string> fields;
2567  std::unordered_map<unsigned, unsigned> old_index_to_new_index;
2568  for (auto& input : sorted_inputs) {
2569  CHECK_EQ(input.getSourceNode(), prev_node.get());
2570  CHECK(old_index_to_new_index
2571  .insert(std::make_pair(input.getIndex(), scalar_exprs.size()))
2572  .second);
2573  scalar_exprs.emplace_back(input.deepCopy());
2574  fields.emplace_back("");
2575  }
2576 
2577  auto new_project = std::make_shared<RelProject>(scalar_exprs, fields, prev_node);
2578  propagate_hints_to_new_project(window_func_project_node, new_project, query_hints);
2579  node_list.insert(node_itr, new_project);
2580  window_func_project_node->replaceInput(
2581  prev_node, new_project, old_index_to_new_index);
2582  }
2583  }
2584  nodes.assign(node_list.begin(), node_list.end());
2585 }
2586 
2587 int64_t get_int_literal_field(const rapidjson::Value& obj,
2588  const char field[],
2589  const int64_t default_val) noexcept {
2590  const auto it = obj.FindMember(field);
2591  if (it == obj.MemberEnd()) {
2592  return default_val;
2593  }
2594  std::unique_ptr<RexLiteral> lit(parse_literal(it->value));
2595  CHECK_EQ(kDECIMAL, lit->getType());
2596  CHECK_EQ(unsigned(0), lit->getScale());
2597  CHECK_EQ(unsigned(0), lit->getTargetScale());
2598  return lit->getVal<int64_t>();
2599 }
2600 
2601 void check_empty_inputs_field(const rapidjson::Value& node) noexcept {
2602  const auto& inputs_json = field(node, "inputs");
2603  CHECK(inputs_json.IsArray() && !inputs_json.Size());
2604 }
2605 
2607  const rapidjson::Value& scan_ra) {
2608  const auto& table_json = field(scan_ra, "table");
2609  CHECK(table_json.IsArray());
2610  CHECK_EQ(unsigned(2), table_json.Size());
2611  const auto td = cat.getMetadataForTable(table_json[1].GetString());
2612  CHECK(td);
2613  return td;
2614 }
2615 
2616 std::vector<std::string> getFieldNamesFromScanNode(const rapidjson::Value& scan_ra) {
2617  const auto& fields_json = field(scan_ra, "fieldNames");
2618  return strings_from_json_array(fields_json);
2619 }
2620 
2621 } // namespace
2622 
2624  for (const auto& expr : scalar_exprs_) {
2625  if (is_window_function_operator(expr.get())) {
2626  return true;
2627  }
2628  }
2629  return false;
2630 }
2631 namespace details {
2632 
2634  public:
2636 
2637  std::vector<std::shared_ptr<RelAlgNode>> run(const rapidjson::Value& rels,
2638  RelAlgDag& root_dag) {
2639  for (auto rels_it = rels.Begin(); rels_it != rels.End(); ++rels_it) {
2640  const auto& crt_node = *rels_it;
2641  const auto id = node_id(crt_node);
2642  CHECK_EQ(static_cast<size_t>(id), nodes_.size());
2643  CHECK(crt_node.IsObject());
2644  std::shared_ptr<RelAlgNode> ra_node = nullptr;
2645  const auto rel_op = json_str(field(crt_node, "relOp"));
2646  if (rel_op == std::string("EnumerableTableScan") ||
2647  rel_op == std::string("LogicalTableScan")) {
2648  ra_node = dispatchTableScan(crt_node);
2649  } else if (rel_op == std::string("LogicalProject")) {
2650  ra_node = dispatchProject(crt_node, root_dag);
2651  } else if (rel_op == std::string("LogicalFilter")) {
2652  ra_node = dispatchFilter(crt_node, root_dag);
2653  } else if (rel_op == std::string("LogicalAggregate")) {
2654  ra_node = dispatchAggregate(crt_node);
2655  } else if (rel_op == std::string("LogicalJoin")) {
2656  ra_node = dispatchJoin(crt_node, root_dag);
2657  } else if (rel_op == std::string("LogicalSort")) {
2658  ra_node = dispatchSort(crt_node);
2659  } else if (rel_op == std::string("LogicalValues")) {
2660  ra_node = dispatchLogicalValues(crt_node);
2661  } else if (rel_op == std::string("LogicalTableModify")) {
2662  ra_node = dispatchModify(crt_node);
2663  } else if (rel_op == std::string("LogicalTableFunctionScan")) {
2664  ra_node = dispatchTableFunction(crt_node, root_dag);
2665  } else if (rel_op == std::string("LogicalUnion")) {
2666  ra_node = dispatchUnion(crt_node);
2667  } else {
2668  throw QueryNotSupported(std::string("Node ") + rel_op + " not supported yet");
2669  }
2670  nodes_.push_back(ra_node);
2671  }
2672 
2673  return std::move(nodes_);
2674  }
2675 
2676  private:
2677  std::shared_ptr<RelScan> dispatchTableScan(const rapidjson::Value& scan_ra) {
2678  check_empty_inputs_field(scan_ra);
2679  CHECK(scan_ra.IsObject());
2680  const auto td = getTableFromScanNode(cat_, scan_ra);
2681  const auto field_names = getFieldNamesFromScanNode(scan_ra);
2682  if (scan_ra.HasMember("hints")) {
2683  auto scan_node = std::make_shared<RelScan>(td, field_names);
2684  getRelAlgHints(scan_ra, scan_node);
2685  return scan_node;
2686  }
2687  return std::make_shared<RelScan>(td, field_names);
2688  }
2689 
2690  std::shared_ptr<RelProject> dispatchProject(const rapidjson::Value& proj_ra,
2691  RelAlgDag& root_dag) {
2692  const auto inputs = getRelAlgInputs(proj_ra);
2693  CHECK_EQ(size_t(1), inputs.size());
2694  const auto& exprs_json = field(proj_ra, "exprs");
2695  CHECK(exprs_json.IsArray());
2696  std::vector<std::unique_ptr<const RexScalar>> exprs;
2697  for (auto exprs_json_it = exprs_json.Begin(); exprs_json_it != exprs_json.End();
2698  ++exprs_json_it) {
2699  exprs.emplace_back(parse_scalar_expr(*exprs_json_it, cat_, root_dag));
2700  }
2701  const auto& fields = field(proj_ra, "fields");
2702  if (proj_ra.HasMember("hints")) {
2703  auto project_node = std::make_shared<RelProject>(
2704  exprs, strings_from_json_array(fields), inputs.front());
2705  getRelAlgHints(proj_ra, project_node);
2706  return project_node;
2707  }
2708  return std::make_shared<RelProject>(
2709  exprs, strings_from_json_array(fields), inputs.front());
2710  }
2711 
2712  std::shared_ptr<RelFilter> dispatchFilter(const rapidjson::Value& filter_ra,
2713  RelAlgDag& root_dag) {
2714  const auto inputs = getRelAlgInputs(filter_ra);
2715  CHECK_EQ(size_t(1), inputs.size());
2716  const auto id = node_id(filter_ra);
2717  CHECK(id);
2718  auto condition = parse_scalar_expr(field(filter_ra, "condition"), cat_, root_dag);
2719  return std::make_shared<RelFilter>(condition, inputs.front());
2720  }
2721 
2722  std::shared_ptr<RelAggregate> dispatchAggregate(const rapidjson::Value& agg_ra) {
2723  const auto inputs = getRelAlgInputs(agg_ra);
2724  CHECK_EQ(size_t(1), inputs.size());
2725  const auto fields = strings_from_json_array(field(agg_ra, "fields"));
2726  const auto group = indices_from_json_array(field(agg_ra, "group"));
2727  for (size_t i = 0; i < group.size(); ++i) {
2728  CHECK_EQ(i, group[i]);
2729  }
2730  if (agg_ra.HasMember("groups") || agg_ra.HasMember("indicator")) {
2731  throw QueryNotSupported("GROUP BY extensions not supported");
2732  }
2733  const auto& aggs_json_arr = field(agg_ra, "aggs");
2734  CHECK(aggs_json_arr.IsArray());
2735  std::vector<std::unique_ptr<const RexAgg>> aggs;
2736  for (auto aggs_json_arr_it = aggs_json_arr.Begin();
2737  aggs_json_arr_it != aggs_json_arr.End();
2738  ++aggs_json_arr_it) {
2739  aggs.emplace_back(parse_aggregate_expr(*aggs_json_arr_it));
2740  }
2741  if (agg_ra.HasMember("hints")) {
2742  auto agg_node =
2743  std::make_shared<RelAggregate>(group.size(), aggs, fields, inputs.front());
2744  getRelAlgHints(agg_ra, agg_node);
2745  return agg_node;
2746  }
2747  return std::make_shared<RelAggregate>(group.size(), aggs, fields, inputs.front());
2748  }
2749 
2750  std::shared_ptr<RelJoin> dispatchJoin(const rapidjson::Value& join_ra,
2751  RelAlgDag& root_dag) {
2752  const auto inputs = getRelAlgInputs(join_ra);
2753  CHECK_EQ(size_t(2), inputs.size());
2754  const auto join_type = to_join_type(json_str(field(join_ra, "joinType")));
2755  auto filter_rex = parse_scalar_expr(field(join_ra, "condition"), cat_, root_dag);
2756  if (join_ra.HasMember("hints")) {
2757  auto join_node =
2758  std::make_shared<RelJoin>(inputs[0], inputs[1], filter_rex, join_type);
2759  getRelAlgHints(join_ra, join_node);
2760  return join_node;
2761  }
2762  return std::make_shared<RelJoin>(inputs[0], inputs[1], filter_rex, join_type);
2763  }
2764 
2765  std::shared_ptr<RelSort> dispatchSort(const rapidjson::Value& sort_ra) {
2766  const auto inputs = getRelAlgInputs(sort_ra);
2767  CHECK_EQ(size_t(1), inputs.size());
2768  std::vector<SortField> collation;
2769  const auto& collation_arr = field(sort_ra, "collation");
2770  CHECK(collation_arr.IsArray());
2771  for (auto collation_arr_it = collation_arr.Begin();
2772  collation_arr_it != collation_arr.End();
2773  ++collation_arr_it) {
2774  const size_t field_idx = json_i64(field(*collation_arr_it, "field"));
2775  const auto sort_dir = parse_sort_direction(*collation_arr_it);
2776  const auto null_pos = parse_nulls_position(*collation_arr_it);
2777  collation.emplace_back(field_idx, sort_dir, null_pos);
2778  }
2779  auto limit = get_int_literal_field(sort_ra, "fetch", -1);
2780  const auto offset = get_int_literal_field(sort_ra, "offset", 0);
2781  auto ret = std::make_shared<RelSort>(
2782  collation, limit > 0 ? limit : 0, offset, inputs.front(), limit > 0);
2783  ret->setEmptyResult(limit == 0);
2784  return ret;
2785  }
2786 
2787  std::shared_ptr<RelModify> dispatchModify(const rapidjson::Value& logical_modify_ra) {
2788  const auto inputs = getRelAlgInputs(logical_modify_ra);
2789  CHECK_EQ(size_t(1), inputs.size());
2790 
2791  const auto table_descriptor = getTableFromScanNode(cat_, logical_modify_ra);
2792  if (table_descriptor->isView) {
2793  throw std::runtime_error("UPDATE of a view is unsupported.");
2794  }
2795 
2796  bool flattened = json_bool(field(logical_modify_ra, "flattened"));
2797  std::string op = json_str(field(logical_modify_ra, "operation"));
2798  RelModify::TargetColumnList target_column_list;
2799 
2800  if (op == "UPDATE") {
2801  const auto& update_columns = field(logical_modify_ra, "updateColumnList");
2802  CHECK(update_columns.IsArray());
2803 
2804  for (auto column_arr_it = update_columns.Begin();
2805  column_arr_it != update_columns.End();
2806  ++column_arr_it) {
2807  target_column_list.push_back(column_arr_it->GetString());
2808  }
2809  }
2810 
2811  auto modify_node = std::make_shared<RelModify>(
2812  cat_, table_descriptor, flattened, op, target_column_list, inputs[0]);
2813  switch (modify_node->getOperation()) {
2815  modify_node->applyDeleteModificationsToInputNode();
2816  break;
2817  }
2819  modify_node->applyUpdateModificationsToInputNode();
2820  break;
2821  }
2822  default:
2823  throw std::runtime_error("Unsupported RelModify operation: " +
2824  json_node_to_string(logical_modify_ra));
2825  }
2826 
2827  return modify_node;
2828  }
2829 
2830  std::shared_ptr<RelTableFunction> dispatchTableFunction(
2831  const rapidjson::Value& table_func_ra,
2832  RelAlgDag& root_dag) {
2833  const auto inputs = getRelAlgInputs(table_func_ra);
2834  const auto& invocation = field(table_func_ra, "invocation");
2835  CHECK(invocation.IsObject());
2836 
2837  const auto& operands = field(invocation, "operands");
2838  CHECK(operands.IsArray());
2839  CHECK_GE(operands.Size(), unsigned(0));
2840 
2841  std::vector<const Rex*> col_inputs;
2842  std::vector<std::unique_ptr<const RexScalar>> table_func_inputs;
2843  std::vector<std::string> fields;
2844 
2845  for (auto exprs_json_it = operands.Begin(); exprs_json_it != operands.End();
2846  ++exprs_json_it) {
2847  const auto& expr_json = *exprs_json_it;
2848  CHECK(expr_json.IsObject());
2849  if (expr_json.HasMember("op")) {
2850  const auto op_str = json_str(field(expr_json, "op"));
2851  if (op_str == "CAST" && expr_json.HasMember("type")) {
2852  const auto& expr_type = field(expr_json, "type");
2853  CHECK(expr_type.IsObject());
2854  CHECK(expr_type.HasMember("type"));
2855  const auto& expr_type_name = json_str(field(expr_type, "type"));
2856  if (expr_type_name == "CURSOR") {
2857  CHECK(expr_json.HasMember("operands"));
2858  const auto& expr_operands = field(expr_json, "operands");
2859  CHECK(expr_operands.IsArray());
2860  if (expr_operands.Size() != 1) {
2861  throw std::runtime_error(
2862  "Table functions currently only support one ResultSet input");
2863  }
2864  auto pos = field(expr_operands[0], "input").GetInt();
2865  CHECK_LT(pos, inputs.size());
2866  for (size_t i = inputs[pos]->size(); i > 0; i--) {
2867  table_func_inputs.emplace_back(
2868  std::make_unique<RexAbstractInput>(col_inputs.size()));
2869  col_inputs.emplace_back(table_func_inputs.back().get());
2870  }
2871  continue;
2872  }
2873  }
2874  }
2875  table_func_inputs.emplace_back(parse_scalar_expr(*exprs_json_it, cat_, root_dag));
2876  }
2877 
2878  const auto& op_name = field(invocation, "op");
2879  CHECK(op_name.IsString());
2880 
2881  std::vector<std::unique_ptr<const RexScalar>> table_function_projected_outputs;
2882  const auto& row_types = field(table_func_ra, "rowType");
2883  CHECK(row_types.IsArray());
2884  CHECK_GE(row_types.Size(), unsigned(0));
2885  const auto& row_types_array = row_types.GetArray();
2886  for (size_t i = 0; i < row_types_array.Size(); i++) {
2887  // We don't care about the type information in rowType -- replace each output with
2888  // a reference to be resolved later in the translator
2889  table_function_projected_outputs.emplace_back(std::make_unique<RexRef>(i));
2890  fields.emplace_back("");
2891  }
2892  return std::make_shared<RelTableFunction>(op_name.GetString(),
2893  inputs,
2894  fields,
2895  col_inputs,
2896  table_func_inputs,
2897  table_function_projected_outputs);
2898  }
2899 
2900  std::shared_ptr<RelLogicalValues> dispatchLogicalValues(
2901  const rapidjson::Value& logical_values_ra) {
2902  const auto& tuple_type_arr = field(logical_values_ra, "type");
2903  CHECK(tuple_type_arr.IsArray());
2904  std::vector<TargetMetaInfo> tuple_type;
2905  for (auto tuple_type_arr_it = tuple_type_arr.Begin();
2906  tuple_type_arr_it != tuple_type_arr.End();
2907  ++tuple_type_arr_it) {
2908  const auto component_type = parse_type(*tuple_type_arr_it);
2909  const auto component_name = json_str(field(*tuple_type_arr_it, "name"));
2910  tuple_type.emplace_back(component_name, component_type);
2911  }
2912  const auto& inputs_arr = field(logical_values_ra, "inputs");
2913  CHECK(inputs_arr.IsArray());
2914  const auto& tuples_arr = field(logical_values_ra, "tuples");
2915  CHECK(tuples_arr.IsArray());
2916 
2917  if (inputs_arr.Size()) {
2918  throw QueryNotSupported("Inputs not supported in logical values yet.");
2919  }
2920 
2921  std::vector<RelLogicalValues::RowValues> values;
2922  if (tuples_arr.Size()) {
2923  for (const auto& row : tuples_arr.GetArray()) {
2924  CHECK(row.IsArray());
2925  const auto values_json = row.GetArray();
2926  if (!values.empty()) {
2927  CHECK_EQ(values[0].size(), values_json.Size());
2928  }
2929  values.emplace_back(RelLogicalValues::RowValues{});
2930  for (const auto& value : values_json) {
2931  CHECK(value.IsObject());
2932  CHECK(value.HasMember("literal"));
2933  values.back().emplace_back(parse_literal(value));
2934  }
2935  }
2936  }
2937 
2938  return std::make_shared<RelLogicalValues>(tuple_type, values);
2939  }
2940 
2941  std::shared_ptr<RelLogicalUnion> dispatchUnion(
2942  const rapidjson::Value& logical_union_ra) {
2943  auto inputs = getRelAlgInputs(logical_union_ra);
2944  auto const& all_type_bool = field(logical_union_ra, "all");
2945  CHECK(all_type_bool.IsBool());
2946  return std::make_shared<RelLogicalUnion>(std::move(inputs), all_type_bool.GetBool());
2947  }
2948 
2949  RelAlgInputs getRelAlgInputs(const rapidjson::Value& node) {
2950  if (node.HasMember("inputs")) {
2951  const auto str_input_ids = strings_from_json_array(field(node, "inputs"));
2952  RelAlgInputs ra_inputs;
2953  for (const auto& str_id : str_input_ids) {
2954  ra_inputs.push_back(nodes_[std::stoi(str_id)]);
2955  }
2956  return ra_inputs;
2957  }
2958  return {prev(node)};
2959  }
2960 
2961  std::pair<std::string, std::string> getKVOptionPair(std::string& str, size_t& pos) {
2962  auto option = str.substr(0, pos);
2963  std::string delim = "=";
2964  size_t delim_pos = option.find(delim);
2965  auto key = option.substr(0, delim_pos);
2966  auto val = option.substr(delim_pos + 1, option.length());
2967  str.erase(0, pos + delim.length() + 1);
2968  return {key, val};
2969  }
2970 
2971  ExplainedQueryHint parseHintString(std::string& hint_string) {
2972  std::string white_space_delim = " ";
2973  int l = hint_string.length();
2974  hint_string = hint_string.erase(0, 1).substr(0, l - 2);
2975  size_t pos = 0;
2976  auto global_hint_checker = [&](const std::string& input_hint_name) -> HintIdentifier {
2977  bool global_hint = false;
2978  std::string hint_name = input_hint_name;
2979  auto global_hint_identifier = hint_name.substr(0, 2);
2980  if (global_hint_identifier.compare("g_") == 0) {
2981  global_hint = true;
2982  hint_name = hint_name.substr(2, hint_string.length());
2983  }
2984  return {global_hint, hint_name};
2985  };
2986  auto parsed_hint =
2987  global_hint_checker(hint_string.substr(0, hint_string.find(white_space_delim)));
2988  auto hint_type = RegisteredQueryHint::translateQueryHint(parsed_hint.hint_name);
2989  if ((pos = hint_string.find("options:")) != std::string::npos) {
2990  // need to parse hint options
2991  std::vector<std::string> tokens;
2992  bool kv_list_op = false;
2993  std::string raw_options = hint_string.substr(pos + 8, hint_string.length() - 2);
2994  if (raw_options.find('{') != std::string::npos) {
2995  kv_list_op = true;
2996  } else {
2997  CHECK(raw_options.find('[') != std::string::npos);
2998  }
2999  auto t1 = raw_options.erase(0, 1);
3000  raw_options = t1.substr(0, t1.length() - 1);
3001  std::string op_delim = ", ";
3002  if (kv_list_op) {
3003  // kv options
3004  std::unordered_map<std::string, std::string> kv_options;
3005  while ((pos = raw_options.find(op_delim)) != std::string::npos) {
3006  auto kv_pair = getKVOptionPair(raw_options, pos);
3007  kv_options.emplace(kv_pair.first, kv_pair.second);
3008  }
3009  // handle the last kv pair
3010  auto kv_pair = getKVOptionPair(raw_options, pos);
3011  kv_options.emplace(kv_pair.first, kv_pair.second);
3012  return {hint_type, parsed_hint.global_hint, false, true, kv_options};
3013  } else {
3014  std::vector<std::string> list_options;
3015  while ((pos = raw_options.find(op_delim)) != std::string::npos) {
3016  list_options.emplace_back(raw_options.substr(0, pos));
3017  raw_options.erase(0, pos + white_space_delim.length() + 1);
3018  }
3019  // handle the last option
3020  list_options.emplace_back(raw_options.substr(0, pos));
3021  return {hint_type, parsed_hint.global_hint, false, false, list_options};
3022  }
3023  } else {
3024  // marker hint: no extra option for this hint
3025  return {hint_type, parsed_hint.global_hint, true, false};
3026  }
3027  }
3028 
3029  void getRelAlgHints(const rapidjson::Value& json_node,
3030  std::shared_ptr<RelAlgNode> node) {
3031  std::string hint_explained = json_str(field(json_node, "hints"));
3032  size_t pos = 0;
3033  std::string delim = "|";
3034  std::vector<std::string> hint_list;
3035  while ((pos = hint_explained.find(delim)) != std::string::npos) {
3036  hint_list.emplace_back(hint_explained.substr(0, pos));
3037  hint_explained.erase(0, pos + delim.length());
3038  }
3039  // handling the last one
3040  hint_list.emplace_back(hint_explained.substr(0, pos));
3041 
3042  const auto agg_node = std::dynamic_pointer_cast<RelAggregate>(node);
3043  if (agg_node) {
3044  for (std::string& hint : hint_list) {
3045  auto parsed_hint = parseHintString(hint);
3046  agg_node->addHint(parsed_hint);
3047  }
3048  }
3049  const auto project_node = std::dynamic_pointer_cast<RelProject>(node);
3050  if (project_node) {
3051  for (std::string& hint : hint_list) {
3052  auto parsed_hint = parseHintString(hint);
3053  project_node->addHint(parsed_hint);
3054  }
3055  }
3056  const auto scan_node = std::dynamic_pointer_cast<RelScan>(node);
3057  if (scan_node) {
3058  for (std::string& hint : hint_list) {
3059  auto parsed_hint = parseHintString(hint);
3060  scan_node->addHint(parsed_hint);
3061  }
3062  }
3063  const auto join_node = std::dynamic_pointer_cast<RelJoin>(node);
3064  if (join_node) {
3065  for (std::string& hint : hint_list) {
3066  auto parsed_hint = parseHintString(hint);
3067  join_node->addHint(parsed_hint);
3068  }
3069  }
3070 
3071  const auto compound_node = std::dynamic_pointer_cast<RelCompound>(node);
3072  if (compound_node) {
3073  for (std::string& hint : hint_list) {
3074  auto parsed_hint = parseHintString(hint);
3075  compound_node->addHint(parsed_hint);
3076  }
3077  }
3078  }
3079 
3080  std::shared_ptr<const RelAlgNode> prev(const rapidjson::Value& crt_node) {
3081  const auto id = node_id(crt_node);
3082  CHECK(id);
3083  CHECK_EQ(static_cast<size_t>(id), nodes_.size());
3084  return nodes_.back();
3085  }
3086 
3088  std::vector<std::shared_ptr<RelAlgNode>> nodes_;
3089 };
3090 
3091 } // namespace details
3092 
3093 std::unique_ptr<RelAlgDag> RelAlgDagBuilder::buildDag(
3094  const std::string& query_ra,
3096  const bool optimize_dag) {
3097  rapidjson::Document query_ast;
3098  query_ast.Parse(query_ra.c_str());
3099  VLOG(2) << "Parsing query RA JSON: " << query_ra;
3100  if (query_ast.HasParseError()) {
3101  query_ast.GetParseError();
3102  LOG(ERROR) << "Failed to parse RA tree from Calcite (offset "
3103  << query_ast.GetErrorOffset() << "):\n"
3104  << rapidjson::GetParseError_En(query_ast.GetParseError());
3105  VLOG(1) << "Failed to parse query RA: " << query_ra;
3106  throw std::runtime_error(
3107  "Failed to parse relational algebra tree. Possible query syntax error.");
3108  }
3109  CHECK(query_ast.IsObject());
3111 
3112  return build(query_ast, cat, nullptr, optimize_dag);
3113 }
3114 
3115 std::unique_ptr<RelAlgDag> RelAlgDagBuilder::buildDagForSubquery(
3116  RelAlgDag& root_dag,
3117  const rapidjson::Value& query_ast,
3119  return build(query_ast, cat, &root_dag, true);
3120 }
3121 
3122 std::unique_ptr<RelAlgDag> RelAlgDagBuilder::build(const rapidjson::Value& query_ast,
3124  RelAlgDag* root_dag,
3125  const bool optimize_dag) {
3126  const auto& rels = field(query_ast, "rels");
3127  CHECK(rels.IsArray());
3128 
3129  auto rel_alg_dag_ptr = std::make_unique<RelAlgDag>();
3130  auto& rel_alg_dag = *rel_alg_dag_ptr;
3131  auto& nodes = getNodes(rel_alg_dag);
3132 
3133  try {
3134  nodes = details::RelAlgDispatcher(cat).run(rels, root_dag ? *root_dag : rel_alg_dag);
3135  } catch (const QueryNotSupported&) {
3136  throw;
3137  }
3138  CHECK(!nodes.empty());
3139  bind_inputs(nodes);
3140 
3142 
3143  if (optimize_dag) {
3144  optimizeDag(rel_alg_dag);
3145  }
3146 
3147  return rel_alg_dag_ptr;
3148 }
3149 
3151  auto const build_state = rel_alg_dag.getBuildState();
3152  if (build_state == RelAlgDag::BuildState::kBuiltOptimized) {
3153  return;
3154  }
3155 
3157  << static_cast<int>(build_state);
3158 
3159  auto& nodes = getNodes(rel_alg_dag);
3160  auto& subqueries = getSubqueries(rel_alg_dag);
3161  auto& query_hints = getQueryHints(rel_alg_dag);
3162 
3163  compute_node_hash(nodes);
3164  handle_query_hint(nodes, rel_alg_dag);
3165  mark_nops(nodes);
3166  simplify_sort(nodes);
3168  eliminate_identical_copy(nodes);
3169  fold_filters(nodes);
3170  std::vector<const RelAlgNode*> filtered_left_deep_joins;
3171  std::vector<const RelAlgNode*> left_deep_joins;
3172  for (const auto& node : nodes) {
3173  const auto left_deep_join_root = get_left_deep_join_root(node);
3174  // The filter which starts a left-deep join pattern must not be coalesced
3175  // since it contains (part of) the join condition.
3176  if (left_deep_join_root) {
3177  left_deep_joins.push_back(left_deep_join_root.get());
3178  if (std::dynamic_pointer_cast<const RelFilter>(left_deep_join_root)) {
3179  filtered_left_deep_joins.push_back(left_deep_join_root.get());
3180  }
3181  }
3182  }
3183  if (filtered_left_deep_joins.empty()) {
3185  }
3186  eliminate_dead_columns(nodes);
3187  eliminate_dead_subqueries(subqueries, nodes.back().get());
3188  separate_window_function_expressions(nodes, query_hints);
3190  nodes,
3191  g_cluster /* always_add_project_if_first_project_is_window_expr */,
3192  query_hints);
3193  coalesce_nodes(nodes, left_deep_joins, query_hints);
3194  CHECK(nodes.back().use_count() == 1);
3195  create_left_deep_join(nodes);
3196 
3198 }
3199 
3200 void RelAlgDag::eachNode(std::function<void(RelAlgNode const*)> const& callback) const {
3201  for (auto const& node : nodes_) {
3202  if (node) {
3203  callback(node.get());
3204  }
3205  }
3206 }
3207 
3209  for (auto& node : nodes_) {
3210  if (node) {
3211  node->resetQueryExecutionState();
3212  }
3213  }
3214 }
3215 
3216 // Return tree with depth represented by indentations.
3217 std::string tree_string(const RelAlgNode* ra, const size_t depth) {
3218  std::string result = std::string(2 * depth, ' ') + ::toString(ra) + '\n';
3219  for (size_t i = 0; i < ra->inputCount(); ++i) {
3220  result += tree_string(ra->getInput(i), depth + 1);
3221  }
3222  return result;
3223 }
3224 
3225 std::string RexSubQuery::toString(RelRexToStringConfig config) const {
3226  return cat(::typeName(this), "(", ra_->toString(config), ")");
3227 }
3228 
3229 size_t RexSubQuery::toHash() const {
3230  if (!hash_) {
3231  hash_ = typeid(RexSubQuery).hash_code();
3232  boost::hash_combine(*hash_, ra_->toHash());
3233  }
3234  return *hash_;
3235 }
3236 
3237 std::string RexInput::toString(RelRexToStringConfig config) const {
3238  const auto scan_node = dynamic_cast<const RelScan*>(node_);
3239  if (scan_node) {
3240  auto field_name = scan_node->getFieldName(getIndex());
3241  auto table_name = scan_node->getTableDescriptor()->tableName;
3242  return ::typeName(this) + "(" + table_name + "." + field_name + ")";
3243  }
3244  auto node_id_in_plan = node_->getIdInPlanTree();
3245  auto node_id_str =
3246  node_id_in_plan ? std::to_string(*node_id_in_plan) : std::to_string(node_->getId());
3247  auto node_str = config.skip_input_nodes ? "(input_node_id=" + node_id_str
3248  : "(input_node=" + node_->toString(config);
3249  return cat(::typeName(this), node_str, ", in_index=", std::to_string(getIndex()), ")");
3250 }
3251 
3252 size_t RexInput::toHash() const {
3253  if (!hash_) {
3254  hash_ = typeid(RexInput).hash_code();
3255  boost::hash_combine(*hash_, node_->toHash());
3256  boost::hash_combine(*hash_, getIndex());
3257  }
3258  return *hash_;
3259 }
3260 
3261 std::string RelCompound::toString(RelRexToStringConfig config) const {
3262  auto ret = cat(::typeName(this),
3263  ", filter_expr=",
3264  (filter_expr_ ? filter_expr_->toString(config) : "null"),
3265  ", target_exprs=");
3266  for (auto& expr : target_exprs_) {
3267  ret += expr->toString(config) + " ";
3268  }
3269  ret += ", agg_exps=";
3270  for (auto& expr : agg_exprs_) {
3271  ret += expr->toString(config) + " ";
3272  }
3273  ret += ", scalar_sources=";
3274  for (auto& expr : scalar_sources_) {
3275  ret += expr->toString(config) + " ";
3276  }
3277  return cat(ret,
3278  ", ",
3280  ", ",
3281  ", fields=",
3282  ::toString(fields_),
3283  ", is_agg=",
3285 }
3286 
3287 size_t RelCompound::toHash() const {
3288  if (!hash_) {
3289  hash_ = typeid(RelCompound).hash_code();
3290  boost::hash_combine(*hash_, filter_expr_ ? filter_expr_->toHash() : HASH_N);
3291  boost::hash_combine(*hash_, is_agg_);
3292  for (auto& target_expr : target_exprs_) {
3293  if (auto rex_scalar = dynamic_cast<const RexScalar*>(target_expr)) {
3294  boost::hash_combine(*hash_, rex_scalar->toHash());
3295  }
3296  }
3297  for (auto& agg_expr : agg_exprs_) {
3298  boost::hash_combine(*hash_, agg_expr->toHash());
3299  }
3300  for (auto& scalar_source : scalar_sources_) {
3301  boost::hash_combine(*hash_, scalar_source->toHash());
3302  }
3303  boost::hash_combine(*hash_, groupby_count_);
3304  boost::hash_combine(*hash_, ::toString(fields_));
3305  }
3306  return *hash_;
3307 }
std::vector< std::shared_ptr< const RexScalar > > scalar_exprs_
Definition: RelAlgDag.h:2257
DEVICE auto upper_bound(ARGS &&...args)
Definition: gpu_enabled.h:123
const size_t getGroupByCount() const
Definition: RelAlgDag.h:1195
SQLTypes to_sql_type(const std::string &type_name)
std::optional< size_t > is_collected_window_function(size_t rex_hash) const
Definition: RelAlgDag.cpp:2189
NullSortedPosition parse_nulls_position(const rapidjson::Value &collation)
Definition: RelAlgDag.cpp:1106
bool is_agg(const Analyzer::Expr *expr)
std::unique_ptr< const RexScalar > condition_
Definition: RelAlgDag.h:1398
std::unique_ptr< const RexOperator > disambiguate_operator(const RexOperator *rex_operator, const RANodeOutput &ra_output) noexcept
Definition: RelAlgDag.cpp:1309
const RexScalar * getThen(const size_t idx) const
Definition: RelAlgDag.h:400
std::shared_ptr< RelAggregate > dispatchAggregate(const rapidjson::Value &agg_ra)
Definition: RelAlgDag.cpp:2722
#define CHECK_EQ(x, y)
Definition: Logger.h:230
std::shared_ptr< RelFilter > dispatchFilter(const rapidjson::Value &filter_ra, RelAlgDag &root_dag)
Definition: RelAlgDag.cpp:2712
void * visitInput(const RexInput *rex_input) const override
Definition: RelAlgDag.cpp:112
const Catalog_Namespace::Catalog & cat_
Definition: RelAlgDag.cpp:3087
RexRebindReindexInputsVisitor(const RelAlgNode *old_input, const RelAlgNode *new_input, std::unordered_map< unsigned, unsigned > old_to_new_index_map)
Definition: RelAlgDag.cpp:106
std::unique_ptr< RexOperator > parse_operator(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDag &root_dag)
Definition: RelAlgDag.cpp:1159
void mark_nops(const std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
Definition: RelAlgDag.cpp:1507
std::unique_ptr< RexSubQuery > deepCopy() const
Definition: RelAlgDag.cpp:59
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
Definition: RelAlgDag.h:1089
JoinType
Definition: sqldefs.h:151
static std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint > > & getQueryHints(RelAlgDag &rel_alg_dag)
Definition: RelAlgDag.h:2588
std::vector< std::unique_ptr< const RexScalar > > table_func_inputs_
Definition: RelAlgDag.h:2167
std::string cat(Ts &&...args)
std::optional< size_t > getOffsetForPushedDownExpr(WindowExprType type, size_t expr_offset) const
Definition: RelAlgDag.cpp:157
RexWindowFuncReplacementVisitor(std::unordered_set< size_t > &collected_window_func_hash, std::vector< std::unique_ptr< const RexScalar >> &new_rex_input_for_window_func, std::unordered_map< size_t, size_t > &window_func_to_new_rex_input_idx_map, RelProject *new_project, std::unordered_map< size_t, std::unique_ptr< const RexInput >> &new_rex_input_from_child_node)
Definition: RelAlgDag.cpp:2084
void hoist_filter_cond_to_cross_join(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
class for a per-database catalog. also includes metadata for the current database and the current use...
Definition: Catalog.h:132
Definition: sqltypes.h:49
std::vector< std::unique_ptr< const RexScalar > > & scalar_exprs_for_new_project_
Definition: RelAlgDag.cpp:332
void addHint(const ExplainedQueryHint &hint_explained)
Definition: RelAlgDag.h:1715
std::shared_ptr< const RelAlgNode > get_left_deep_join_root(const std::shared_ptr< RelAlgNode > &node)
void sink_projected_boolean_expr_to_join(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
bool input_can_be_coalesced(const RelAlgNode *parent_node, const size_t index, const bool first_rex_is_input)
Definition: RelAlgDag.cpp:1806
std::string toString(RelRexToStringConfig config=RelRexToStringConfig::defaults()) const override
Definition: RelAlgDag.cpp:3261
void eliminate_identical_copy(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
size_t toHash() const override
Definition: RelAlgDag.cpp:3252
RetType visitInput(const RexInput *rex_input) const final
Definition: RelAlgDag.cpp:2107
std::vector< RexInput > RANodeOutput
Definition: RelAlgDag.h:2638
std::unique_ptr< const RexCase > disambiguate_case(const RexCase *rex_case, const RANodeOutput &ra_output)
Definition: RelAlgDag.cpp:1344
const RexScalar * getElse() const
Definition: RelAlgDag.h:405
RelCompound(std::unique_ptr< const RexScalar > &filter_expr, const std::vector< const Rex * > &target_exprs, const size_t groupby_count, const std::vector< const RexAgg * > &agg_exprs, const std::vector< std::string > &fields, std::vector< std::unique_ptr< const RexScalar >> &scalar_sources, const bool is_agg, bool update_disguised_as_select=false, bool delete_disguised_as_select=false, bool varlen_update_required=false, TableDescriptor const *manipulation_target_table=nullptr, ColumnNameList target_columns=ColumnNameList())
Definition: RelAlgDag.h:1635
static thread_local unsigned crt_id_
Definition: RelAlgDag.h:895
std::unique_ptr< const RexScalar > visitOperator(const RexOperator *rex_operator) const override
Definition: RelAlgDag.cpp:297
SqlWindowFunctionKind parse_window_function_kind(const std::string &name)
Definition: RelAlgDag.cpp:1037
std::shared_ptr< RelScan > dispatchTableScan(const rapidjson::Value &scan_ra)
Definition: RelAlgDag.cpp:2677
std::pair< std::shared_ptr< RelLeftDeepInnerJoin >, std::shared_ptr< const RelAlgNode > > create_left_deep_join(const std::shared_ptr< RelAlgNode > &left_deep_join_root)
RexScalar const * copyAndRedirectSource(RexScalar const *, size_t input_idx) const
Definition: RelAlgDag.cpp:886
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
Definition: RelAlgDag.cpp:684
std::unique_ptr< const RexSubQuery > parse_subquery(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDag &root_dag)
Definition: RelAlgDag.cpp:1145
SQLAgg to_agg_kind(const std::string &agg_name)
std::shared_ptr< RelLogicalUnion > dispatchUnion(const rapidjson::Value &logical_union_ra)
Definition: RelAlgDag.cpp:2941
#define LOG(tag)
Definition: Logger.h:216
std::vector< std::string > TargetColumnList
Definition: RelAlgDag.h:1862
const SQLTypeInfo & getType() const
Definition: RelAlgDag.h:259
std::unique_ptr< const RexScalar > get_new_rex_input(size_t rex_idx) const
Definition: RelAlgDag.cpp:2197
size_t size() const
Definition: RelAlgDag.h:245
Hints * getDeliveredHints()
Definition: RelAlgDag.h:1151
std::shared_ptr< RelProject > dispatchProject(const rapidjson::Value &proj_ra, RelAlgDag &root_dag)
Definition: RelAlgDag.cpp:2690
const bool json_bool(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:49
const RexScalar * getOperand(const size_t idx) const
Definition: RelAlgDag.h:247
std::vector< std::unique_ptr< const RexScalar > > parse_window_order_exprs(const rapidjson::Value &arr, const Catalog_Namespace::Catalog &cat, RelAlgDag &root_dag)
Definition: RelAlgDag.cpp:1089
std::vector< const Rex * > col_inputs_
Definition: RelAlgDag.h:2165
std::string json_node_to_string(const rapidjson::Value &node) noexcept
Definition: RelAlgDag.cpp:904
bool hasEquivCollationOf(const RelSort &that) const
Definition: RelAlgDag.cpp:779
JoinType to_join_type(const std::string &join_type_name)
Definition: RelAlgDag.cpp:1291
void resetQueryExecutionState()
Definition: RelAlgDag.cpp:3208
std::vector< SortField > parse_window_order_collation(const rapidjson::Value &arr, const Catalog_Namespace::Catalog &cat, RelAlgDag &root_dag)
Definition: RelAlgDag.cpp:1112
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:44
std::vector< std::shared_ptr< RelAlgNode > > nodes_
Definition: RelAlgDag.h:2557
std::string join(T const &container, std::string const &delim)
#define UNREACHABLE()
Definition: Logger.h:266
void handle_query_hint(const std::vector< std::shared_ptr< RelAlgNode >> &nodes, RelAlgDag &rel_alg_dag) noexcept
Definition: RelAlgDag.cpp:1460
bool hint_applied_
Definition: RelAlgDag.h:1400
DEVICE void sort(ARGS &&...args)
Definition: gpu_enabled.h:105
#define CHECK_GE(x, y)
Definition: Logger.h:235
std::optional< size_t > getIdInPlanTree() const
Definition: RelAlgDag.h:818
SortDirection
Definition: RelAlgDag.h:480
std::vector< std::string > fields_
Definition: RelAlgDag.h:1171
RexInput(const RelAlgNode *node, const unsigned in_index)
Definition: RelAlgDag.h:348
void pushDownExpressionInWindowFunction(const RexWindowFunctionOperator *window_expr) const
Definition: RelAlgDag.cpp:178
void addHint(const ExplainedQueryHint &hint_explained)
Definition: RelAlgDag.h:1372
Definition: sqldefs.h:48
std::unique_ptr< const RexScalar > visitCase(const RexCase *rex_case) const override
Definition: RelAlgDag.cpp:278
const RexScalar * getWhen(const size_t idx) const
Definition: RelAlgDag.h:395
std::vector< size_t > indices_from_json_array(const rapidjson::Value &json_idx_arr) noexcept
Definition: RelAlgDag.cpp:1238
void appendInput(std::string new_field_name, std::unique_ptr< const RexScalar > new_input)
Definition: RelAlgDag.cpp:364
void propagate_hints_to_new_project(std::shared_ptr< RelProject > prev_node, std::shared_ptr< RelProject > new_node, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints)
Definition: RelAlgDag.cpp:2217
bool isRenamedInput(const RelAlgNode *node, const size_t index, const std::string &new_name)
Definition: RelAlgDag.cpp:469
std::unique_ptr< const RexScalar > defaultResult() const override
Definition: RelAlgDag.cpp:329
void addHint(const ExplainedQueryHint &hint_explained)
Definition: RelAlgDag.h:1269
std::unique_ptr< const RexAgg > parse_aggregate_expr(const rapidjson::Value &expr)
Definition: RelAlgDag.cpp:1251
std::unordered_map< size_t, const RexScalar * > & collected_window_func_
Definition: RelAlgDag.cpp:2079
void checkForMatchingMetaInfoTypes() const
Definition: RelAlgDag.cpp:858
std::unique_ptr< const RexScalar > parse_scalar_expr(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDag &root_dag)
Definition: RelAlgDag.cpp:1267
std::vector< std::unique_ptr< const RexScalar > > scalar_sources_
Definition: RelAlgDag.h:1747
void * visitInput(const RexInput *rex_input) const override
Definition: RelAlgDag.cpp:76
std::unique_ptr< RexAbstractInput > parse_abstract_input(const rapidjson::Value &expr) noexcept
Definition: RelAlgDag.cpp:915
static std::vector< std::shared_ptr< RexSubQuery > > & getSubqueries(RelAlgDag &rel_alg_dag)
Definition: RelAlgDag.h:2582
RexInputSet aggregateResult(const RexInputSet &aggregate, const RexInputSet &next_result) const override
Definition: RelAlgDag.cpp:2383
std::unique_ptr< const RexScalar > disambiguate_rex(const RexScalar *, const RANodeOutput &)
Definition: RelAlgDag.cpp:1365
std::unique_ptr< const RexScalar > visitLiteral(const RexLiteral *rex_literal) const override
Definition: RelAlgDag.cpp:264
bool hint_applied_
Definition: RelAlgDag.h:1750
std::string to_string(char const *&&v)
void add_window_function_pre_project(std::vector< std::shared_ptr< RelAlgNode >> &nodes, const bool always_add_project_if_first_project_is_window_expr, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints)
Definition: RelAlgDag.cpp:2403
const std::string getFieldName(const size_t i) const
Definition: RelAlgDag.h:919
std::unique_ptr< const RexScalar > visitSubQuery(const RexSubQuery *rex_subquery) const override
Definition: RelAlgDag.cpp:273
std::unique_ptr< RexCase > parse_case(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDag &root_dag)
Definition: RelAlgDag.cpp:1204
void simplify_sort(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
std::vector< SortField > collation_
Definition: RelAlgDag.h:1849
constexpr double a
Definition: Utm.h:32
std::shared_ptr< RelJoin > dispatchJoin(const rapidjson::Value &join_ra, RelAlgDag &root_dag)
Definition: RelAlgDag.cpp:2750
std::vector< std::unique_ptr< const RexScalar > > & new_rex_input_for_window_func_
Definition: RelAlgDag.cpp:2208
std::unordered_set< RexInput > RexInputSet
Definition: RelAlgDag.cpp:2374
This file contains the class specification and related data structures for Catalog.
virtual T visit(const RexScalar *rex_scalar) const
Definition: RexVisitor.h:27
std::string to_string() const
Definition: sqltypes.h:483
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
bool const is_all_
Definition: RelAlgDag.h:2260
virtual std::string toString(RelRexToStringConfig config) const =0
unsigned getIndex() const
Definition: RelAlgDag.h:72
void separate_window_function_expressions(std::vector< std::shared_ptr< RelAlgNode >> &nodes, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints)
Definition: RelAlgDag.cpp:2261
void markAsNop()
Definition: RelAlgDag.h:866
WindowFunctionCollector(std::unordered_map< size_t, const RexScalar * > &collected_window_func)
Definition: RelAlgDag.cpp:2017
bool aggregateResult(const bool &aggregate, const bool &next_result) const final
Definition: RelAlgDag.cpp:1843
static auto const HASH_N
Definition: RelAlgDag.h:44
SQLOps getOperator() const
Definition: RelAlgDag.h:257
std::shared_ptr< RelTableFunction > dispatchTableFunction(const rapidjson::Value &table_func_ra, RelAlgDag &root_dag)
Definition: RelAlgDag.cpp:2830
std::unordered_map< size_t, std::unique_ptr< const RexInput > > & new_rex_input_from_child_node_
Definition: RelAlgDag.cpp:2213
unsigned getId() const
Definition: RelAlgDag.h:814
std::set< std::pair< const RelAlgNode *, int > > get_equiv_cols(const RelAlgNode *node, const size_t which_col)
Definition: RelAlgDag.cpp:740
std::unique_ptr< const RexScalar > visitInput(const RexInput *rex_input) const override
Definition: RelAlgDag.cpp:252
static QueryHint translateQueryHint(const std::string &hint_name)
Definition: QueryHint.h:250
DEVICE auto copy(ARGS &&...args)
Definition: gpu_enabled.h:51
#define CHECK_NE(x, y)
Definition: Logger.h:231
bool isRenaming() const
Definition: RelAlgDag.cpp:512
void setIndex(const unsigned in_index) const
Definition: RelAlgDag.h:74
Hints * getDeliveredHints()
Definition: RelAlgDag.h:1292
size_t toHash() const override
Definition: RelAlgDag.cpp:831
void coalesce_nodes(std::vector< std::shared_ptr< RelAlgNode >> &nodes, const std::vector< const RelAlgNode * > &left_deep_joins, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints)
Definition: RelAlgDag.cpp:1906
std::vector< std::string > fields_
Definition: RelAlgDag.h:1744
static std::unique_ptr< RelAlgDag > build(const rapidjson::Value &query_ast, const Catalog_Namespace::Catalog &cat, RelAlgDag *root_dag, const bool optimize_dag)
Definition: RelAlgDag.cpp:3122
SQLOps to_sql_op(const std::string &op_str)
std::unique_ptr< Hints > hints_
Definition: RelAlgDag.h:1299
void set_scale(int s)
Definition: sqltypes.h:434
const int64_t json_i64(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:39
std::unique_ptr< Hints > hints_
Definition: RelAlgDag.h:1401
std::vector< std::unique_ptr< const RexScalar > > copyRexScalars(std::vector< std::unique_ptr< const RexScalar >> const &scalar_sources)
Definition: RelAlgDag.cpp:625
std::vector< std::shared_ptr< const RelAlgNode >> RelAlgInputs
Definition: RelAlgDag.h:289
std::vector< std::unique_ptr< const RexScalar > > scalar_exprs_
Definition: RelAlgDag.h:1170
RetType visitOperator(const RexOperator *rex_operator) const final
Definition: RelAlgDag.cpp:2131
const double json_double(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:54
void addHint(const ExplainedQueryHint &hint_explained)
Definition: RelAlgDag.h:1128
size_t branchCount() const
Definition: RelAlgDag.h:393
const RelAlgNode * getInput(const size_t idx) const
Definition: RelAlgDag.h:826
SQLTypeInfo parse_type(const rapidjson::Value &type_obj)
Definition: RelAlgDag.cpp:1007
Checked json field retrieval.
RelFilter(std::unique_ptr< const RexScalar > &filter, std::shared_ptr< const RelAlgNode > input)
Definition: RelAlgDag.h:1525
void * visitCase(const RexCase *rex_case) const final
Definition: RelAlgDag.cpp:2045
std::vector< std::shared_ptr< RelAlgNode > > nodes_
Definition: RelAlgDag.cpp:3088
RelAggregate(const size_t groupby_count, std::vector< std::unique_ptr< const RexAgg >> &agg_exprs, const std::vector< std::string > &fields, std::shared_ptr< const RelAlgNode > input)
Definition: RelAlgDag.h:1179
std::unique_ptr< const RexScalar > filter_
Definition: RelAlgDag.h:1588
bool isSimple() const
Definition: RelAlgDag.h:1055
std::vector< const Rex * > remapTargetPointers(std::vector< std::unique_ptr< const RexAgg >> const &agg_exprs_new, std::vector< std::unique_ptr< const RexScalar >> const &scalar_sources_new, std::vector< std::unique_ptr< const RexAgg >> const &agg_exprs_old, std::vector< std::unique_ptr< const RexScalar >> const &scalar_sources_old, std::vector< const Rex * > const &target_exprs_old)
Definition: RelAlgDag.cpp:636
std::optional< size_t > hash_
Definition: RelAlgDag.h:62
const size_t groupby_count_
Definition: RelAlgDag.h:1742
void bind_inputs(const std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
Definition: RelAlgDag.cpp:1419
std::optional< size_t > hash_
Definition: RelAlgDag.h:889
unsigned getId() const
Definition: RelAlgDag.cpp:63
const RelAlgNode * node_
Definition: RelAlgDag.h:372
std::string toString(const Executor::ExtModuleKinds &kind)
Definition: Execute.h:1448
virtual void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input)
Definition: RelAlgDag.h:849
void bind_project_to_input(RelProject *project_node, const RANodeOutput &input) noexcept
Definition: RelAlgDag.cpp:1391
RexInputSet visitInput(const RexInput *input) const override
Definition: RelAlgDag.cpp:2378
std::vector< std::unique_ptr< const RexScalar > > parse_expr_array(const rapidjson::Value &arr, const Catalog_Namespace::Catalog &cat, RelAlgDag &root_dag)
Definition: RelAlgDag.cpp:1026
std::string tree_string(const RelAlgNode *ra, const size_t depth)
Definition: RelAlgDag.cpp:3217
std::vector< std::unique_ptr< const RexAgg > > agg_exprs_
Definition: RelAlgDag.h:1296
void compute_node_hash(const std::vector< std::shared_ptr< RelAlgNode >> &nodes)
Definition: RelAlgDag.cpp:1493
Hints * getDeliveredHints()
Definition: RelAlgDag.h:1738
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
Definition: RelAlgDag.cpp:536
size_t toHash() const override
Definition: RelAlgDag.h:417
PushDownGenericExpressionInWindowFunction(std::shared_ptr< RelProject > new_project, std::vector< std::unique_ptr< const RexScalar >> &scalar_exprs_for_new_project, std::vector< std::string > &fields_for_new_project, std::unordered_map< size_t, size_t > &expr_offset_cache)
Definition: RelAlgDag.cpp:128
const RexScalar * getProjectAt(const size_t idx) const
Definition: RelAlgDag.h:1070
bool hint_applied_
Definition: RelAlgDag.h:1298
static std::unique_ptr< RelAlgDag > buildDag(const std::string &query_ra, const Catalog_Namespace::Catalog &cat, const bool optimize_dag)
Definition: RelAlgDag.cpp:3093
#define CHECK_LT(x, y)
Definition: Logger.h:232
Definition: sqltypes.h:52
Definition: sqltypes.h:53
static RegisteredQueryHint defaults()
Definition: QueryHint.h:247
int32_t countRexLiteralArgs() const
Definition: RelAlgDag.cpp:696
std::unique_ptr< Hints > hints_
Definition: RelAlgDag.h:1751
std::vector< const Rex * > reproject_targets(const RelProject *simple_project, const std::vector< const Rex * > &target_exprs) noexcept
Definition: RelAlgDag.cpp:1524
const ConstRexScalarPtrVector & getPartitionKeys() const
Definition: RelAlgDag.h:573
std::vector< std::shared_ptr< RelAlgNode > > run(const rapidjson::Value &rels, RelAlgDag &root_dag)
Definition: RelAlgDag.cpp:2637
DEVICE auto lower_bound(ARGS &&...args)
Definition: gpu_enabled.h:78
#define CHECK_LE(x, y)
Definition: Logger.h:233
const std::unordered_map< unsigned, unsigned > mapping_
Definition: RelAlgDag.cpp:121
std::unique_ptr< Hints > hints_
Definition: RelAlgDag.h:1173
int64_t get_int_literal_field(const rapidjson::Value &obj, const char field[], const int64_t default_val) noexcept
Definition: RelAlgDag.cpp:2587
const std::vector< const Rex * > target_exprs_
Definition: RelAlgDag.h:1749
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
Definition: RelAlgDag.cpp:543
RelLogicalUnion(RelAlgInputs, bool is_all)
Definition: RelAlgDag.cpp:810
void registerSubquery(std::shared_ptr< RexSubQuery > subquery)
Definition: RelAlgDag.h:2302
std::vector< std::unique_ptr< const RexAgg > > agg_exprs_
Definition: RelAlgDag.h:1743
std::unique_ptr< const RexScalar > filter_expr_
Definition: RelAlgDag.h:1741
static std::vector< std::shared_ptr< RelAlgNode > > & getNodes(RelAlgDag &rel_alg_dag)
Definition: RelAlgDag.h:2578
void setSourceNode(const RelAlgNode *node) const
Definition: RelAlgDag.h:356
bool hasWindowFunctionExpr() const
Definition: RelAlgDag.cpp:2623
std::shared_ptr< RelModify > dispatchModify(const rapidjson::Value &logical_modify_ra)
Definition: RelAlgDag.cpp:2787
std::vector< ElementType >::const_iterator Super
Definition: RelAlgDag.cpp:1725
std::vector< std::unique_ptr< const RexAgg > > copyAggExprs(std::vector< std::unique_ptr< const RexAgg >> const &agg_exprs)
Definition: RelAlgDag.cpp:615
std::unique_ptr< RexLiteral > parse_literal(const rapidjson::Value &expr)
Definition: RelAlgDag.cpp:921
std::vector< std::string > strings_from_json_array(const rapidjson::Value &json_str_arr) noexcept
Definition: RelAlgDag.cpp:1226
bool hint_applied_
Definition: RelAlgDag.h:1172
std::unordered_map< QueryHint, ExplainedQueryHint > Hints
Definition: QueryHint.h:273
virtual size_t size() const =0
const RelAlgNode * getSourceNode() const
Definition: RelAlgDag.h:351
void setExecutionResult(const std::shared_ptr< const ExecutionResult > result)
Definition: RelAlgDag.cpp:50
std::string toString(RelRexToStringConfig config=RelRexToStringConfig::defaults()) const override
Definition: RelAlgDag.cpp:827
RelLogicalValues(const std::vector< TargetMetaInfo > &tuple_type, std::vector< RowValues > &values)
Definition: RelAlgDag.h:2178
size_t toHash() const override
Definition: RelAlgDag.cpp:3229
std::string typeName(const T *v)
Definition: toString.h:102
ExplainedQueryHint parseHintString(std::string &hint_string)
Definition: RelAlgDag.cpp:2971
SqlWindowFunctionKind
Definition: sqldefs.h:110
void * visitOperator(const RexOperator *rex_operator) const final
Definition: RelAlgDag.cpp:2023
Definition: sqldefs.h:52
void eachNode(std::function< void(RelAlgNode const *)> const &) const
Definition: RelAlgDag.cpp:3200
std::string toString(RelRexToStringConfig config=RelRexToStringConfig::defaults()) const override
Definition: RelAlgDag.cpp:3225
std::shared_ptr< RelSort > dispatchSort(const rapidjson::Value &sort_ra)
Definition: RelAlgDag.cpp:2765
RexWindowFunctionOperator::RexWindowBound parse_window_bound(const rapidjson::Value &window_bound_obj, const Catalog_Namespace::Catalog &cat, RelAlgDag &root_dag)
Definition: RelAlgDag.cpp:1125
RelTableFunction(const std::string &function_name, RelAlgInputs inputs, std::vector< std::string > &fields, std::vector< const Rex * > col_inputs, std::vector< std::unique_ptr< const RexScalar >> &table_func_inputs, std::vector< std::unique_ptr< const RexScalar >> &target_exprs)
Definition: RelAlgDag.h:2049
const std::vector< std::string > & getFields() const
Definition: RelAlgDag.h:1084
std::unique_ptr< const RexScalar > visitRef(const RexRef *rex_ref) const override
Definition: RelAlgDag.cpp:269
std::string getFieldName(const size_t i) const
Definition: RelAlgDag.cpp:839
static void optimizeDag(RelAlgDag &rel_alg_dag)
Definition: RelAlgDag.cpp:3150
bool g_enable_watchdog false
Definition: Execute.cpp:79
#define CHECK(condition)
Definition: Logger.h:222
const ConstRexScalarPtrVector & getOrderKeys() const
Definition: RelAlgDag.h:583
RelProject(std::vector< std::unique_ptr< const RexScalar >> &scalar_exprs, const std::vector< std::string > &fields, std::shared_ptr< const RelAlgNode > input)
Definition: RelAlgDag.h:1035
RexInputReplacementVisitor(const RelAlgNode *node_to_keep, const std::vector< std::unique_ptr< const RexScalar >> &scalar_sources)
Definition: RelAlgDag.cpp:1544
bool g_enable_union
void create_compound(std::vector< std::shared_ptr< RelAlgNode >> &nodes, const std::vector< size_t > &pattern, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints) noexcept
Definition: RelAlgDag.cpp:1567
bool g_cluster
std::vector< RexInput > n_outputs(const RelAlgNode *node, const size_t n)
Definition: RelAlgDag.cpp:95
std::shared_ptr< const RelAlgNode > prev(const rapidjson::Value &crt_node)
Definition: RelAlgDag.cpp:3080
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
Definition: RelAlgDag.cpp:527
void getRelAlgHints(const rapidjson::Value &json_node, std::shared_ptr< RelAlgNode > node)
Definition: RelAlgDag.cpp:3029
virtual size_t toHash() const =0
SortDirection parse_sort_direction(const rapidjson::Value &collation)
Definition: RelAlgDag.cpp:1100
RelAlgDispatcher(const Catalog_Namespace::Catalog &cat)
Definition: RelAlgDag.cpp:2635
Common Enum definitions for SQL processing.
bool is_dict_encoded_string() const
Definition: sqltypes.h:548
Definition: sqltypes.h:45
std::string toString(RelRexToStringConfig config=RelRexToStringConfig::defaults()) const override
Definition: RelAlgDag.cpp:3237
void fold_filters(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
RexRebindInputsVisitor(const RelAlgNode *old_input, const RelAlgNode *new_input)
Definition: RelAlgDag.cpp:71
const TableDescriptor * getMetadataForTable(const std::string &tableName, const bool populateFragmenter=true) const
Returns a pointer to a const TableDescriptor struct matching the provided tableName.
std::vector< std::unique_ptr< const RexScalar >> RowValues
Definition: RelAlgDag.h:2176
void bind_table_func_to_input(RelTableFunction *table_func_node, const RANodeOutput &input) noexcept
Definition: RelAlgDag.cpp:1405
RetType visitCase(const RexCase *rex_case) const final
Definition: RelAlgDag.cpp:2158
const size_t inputCount() const
Definition: RelAlgDag.h:824
string name
Definition: setup.in.py:72
constexpr double n
Definition: Utm.h:38
void rebind_inputs_from_left_deep_join(const RexScalar *rex, const RelLeftDeepInnerJoin *left_deep_join)
void check_empty_inputs_field(const rapidjson::Value &node) noexcept
Definition: RelAlgDag.cpp:2601
unsigned node_id(const rapidjson::Value &ra_node) noexcept
Definition: RelAlgDag.cpp:899
const TableDescriptor * getTableFromScanNode(const Catalog_Namespace::Catalog &cat, const rapidjson::Value &scan_ra)
Definition: RelAlgDag.cpp:2606
void eliminate_dead_subqueries(std::vector< std::shared_ptr< RexSubQuery >> &subqueries, RelAlgNode const *root)
size_t size() const override
Definition: RelAlgDag.cpp:823
size_t operator()(const std::pair< const RelAlgNode *, int > &input_col) const
Definition: RelAlgDag.cpp:729
std::unordered_map< size_t, size_t > & window_func_to_new_rex_input_idx_map_
Definition: RelAlgDag.cpp:2210
RelAlgInputs getRelAlgInputs(const rapidjson::Value &node)
Definition: RelAlgDag.cpp:2949
std::vector< std::string > getFieldNamesFromScanNode(const rapidjson::Value &scan_ra)
Definition: RelAlgDag.cpp:2616
static std::unique_ptr< RelAlgDag > buildDagForSubquery(RelAlgDag &root_dag, const rapidjson::Value &query_ast, const Catalog_Namespace::Catalog &cat)
Definition: RelAlgDag.cpp:3115
std::shared_ptr< RelLogicalValues > dispatchLogicalValues(const rapidjson::Value &logical_values_ra)
Definition: RelAlgDag.cpp:2900
DEVICE void swap(ARGS &&...args)
Definition: gpu_enabled.h:114
std::unique_ptr< const RexScalar > RetType
Definition: RexVisitor.h:140
size_t toHash() const override
Definition: RelAlgDag.cpp:3287
NullSortedPosition
Definition: RelAlgDag.h:482
RANodeOutput get_node_output(const RelAlgNode *ra_node)
Definition: RelAlgDag.cpp:370
virtual size_t toHash() const =0
#define VLOG(n)
Definition: Logger.h:316
RelJoin(std::shared_ptr< const RelAlgNode > lhs, std::shared_ptr< const RelAlgNode > rhs, std::unique_ptr< const RexScalar > &condition, const JoinType join_type)
Definition: RelAlgDag.h:1304
BuildState getBuildState() const
Definition: RelAlgDag.h:2279
RelAlgInputs inputs_
Definition: RelAlgDag.h:886
void set_precision(int d)
Definition: sqltypes.h:432
std::pair< std::string, std::string > getKVOptionPair(std::string &str, size_t &pos)
Definition: RelAlgDag.cpp:2961
void eliminate_dead_columns(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
bool isIdentity() const
Definition: RelAlgDag.cpp:439
std::vector< std::unique_ptr< const RexScalar > > target_exprs_
Definition: RelAlgDag.h:2170
const bool is_agg_
Definition: RelAlgDag.h:1745
static void setBuildState(RelAlgDag &rel_alg_dag, const RelAlgDag::BuildState build_state)
Definition: RelAlgDag.h:2592
static void resetRelAlgFirstId() noexcept
Definition: RelAlgDag.cpp:46