OmniSciDB  085a039ca4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
RelAlgDagBuilder.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RelAlgDagBuilder.h"
19 #include "Catalog/Catalog.h"
21 #include "JsonAccessors.h"
22 #include "RelAlgOptimizer.h"
23 #include "RelLeftDeepInnerJoin.h"
25 #include "RexVisitor.h"
26 #include "Shared/sqldefs.h"
27 
28 #include <rapidjson/error/en.h>
29 #include <rapidjson/error/error.h>
30 #include <rapidjson/stringbuffer.h>
31 #include <rapidjson/writer.h>
32 
33 #include <string>
34 #include <unordered_set>
35 
36 extern bool g_cluster;
37 extern bool g_enable_union;
38 
39 namespace {
40 
41 const unsigned FIRST_RA_NODE_ID = 1;
42 
43 } // namespace
44 
45 thread_local unsigned RelAlgNode::crt_id_ = FIRST_RA_NODE_ID;
46 
49 }
50 
52  const std::shared_ptr<const ExecutionResult> result) {
53  auto row_set = result->getRows();
54  CHECK(row_set);
55  CHECK_EQ(size_t(1), row_set->colCount());
56  *(type_.get()) = row_set->getColType(0);
57  (*(result_.get())) = result;
58 }
59 
60 std::unique_ptr<RexSubQuery> RexSubQuery::deepCopy() const {
61  return std::make_unique<RexSubQuery>(type_, result_, ra_->deepCopy());
62 }
63 
64 unsigned RexSubQuery::getId() const {
65  return ra_->getId();
66 }
67 
68 namespace {
69 
70 class RexRebindInputsVisitor : public RexVisitor<void*> {
71  public:
72  RexRebindInputsVisitor(const RelAlgNode* old_input, const RelAlgNode* new_input)
73  : old_input_(old_input), new_input_(new_input) {}
74 
75  virtual ~RexRebindInputsVisitor() = default;
76 
77  void* visitInput(const RexInput* rex_input) const override {
78  const auto old_source = rex_input->getSourceNode();
79  if (old_source == old_input_) {
80  const auto left_deep_join = dynamic_cast<const RelLeftDeepInnerJoin*>(new_input_);
81  if (left_deep_join) {
82  rebind_inputs_from_left_deep_join(rex_input, left_deep_join);
83  return nullptr;
84  }
85  rex_input->setSourceNode(new_input_);
86  }
87  return nullptr;
88  };
89 
90  private:
91  const RelAlgNode* old_input_;
93 };
94 
95 // Creates an output with n columns.
96 std::vector<RexInput> n_outputs(const RelAlgNode* node, const size_t n) {
97  std::vector<RexInput> outputs;
98  outputs.reserve(n);
99  for (size_t i = 0; i < n; ++i) {
100  outputs.emplace_back(node, i);
101  }
102  return outputs;
103 }
104 
106  public:
108  const RelAlgNode* old_input,
109  const RelAlgNode* new_input,
110  std::unordered_map<unsigned, unsigned> old_to_new_index_map)
111  : RexRebindInputsVisitor(old_input, new_input), mapping_(old_to_new_index_map) {}
112 
113  void* visitInput(const RexInput* rex_input) const override {
114  RexRebindInputsVisitor::visitInput(rex_input);
115  auto mapping_itr = mapping_.find(rex_input->getIndex());
116  CHECK(mapping_itr != mapping_.end());
117  rex_input->setIndex(mapping_itr->second);
118  return nullptr;
119  }
120 
121  private:
122  const std::unordered_map<unsigned, unsigned> mapping_;
123 };
124 
126  : public RexVisitorBase<std::unique_ptr<const RexScalar>> {
127  public:
128  enum class WindowExprType { PARTITION_KEY, ORDER_KEY };
130  std::shared_ptr<RelProject> new_project,
131  std::vector<std::unique_ptr<const RexScalar>>& scalar_exprs_for_new_project,
132  std::vector<std::string>& fields_for_new_project,
133  std::unordered_map<size_t, size_t>& expr_offset_cache)
134  : new_project_(new_project)
135  , scalar_exprs_for_new_project_(scalar_exprs_for_new_project)
136  , fields_for_new_project_(fields_for_new_project)
137  , expr_offset_cache_(expr_offset_cache)
138  , found_case_expr_window_operand_(false)
139  , has_partition_expr_(false) {}
140 
141  size_t pushDownExpressionImpl(const RexScalar* expr) const {
142  auto hash = expr->toHash();
143  auto it = expr_offset_cache_.find(hash);
144  auto new_offset = -1;
145  if (it == expr_offset_cache_.end()) {
146  CHECK(
147  expr_offset_cache_.emplace(hash, scalar_exprs_for_new_project_.size()).second);
148  new_offset = scalar_exprs_for_new_project_.size();
149  fields_for_new_project_.emplace_back("");
150  scalar_exprs_for_new_project_.emplace_back(deep_copier_.visit(expr));
151  } else {
152  // we already pushed down the same expression, so reuse it
153  new_offset = it->second;
154  }
155  return new_offset;
156  }
157 
159  size_t expr_offset) const {
160  // given window expr offset and inner expr's offset,
161  // return a (push-downed) expression's offset from the new projection node
162  switch (type) {
163  case WindowExprType::PARTITION_KEY: {
164  auto it = pushed_down_partition_key_offset_.find(expr_offset);
165  CHECK(it != pushed_down_partition_key_offset_.end());
166  return it->second;
167  }
168  case WindowExprType::ORDER_KEY: {
169  auto it = pushed_down_order_key_offset_.find(expr_offset);
170  CHECK(it != pushed_down_order_key_offset_.end());
171  return it->second;
172  }
173  default:
174  UNREACHABLE();
175  return std::nullopt;
176  }
177  }
178 
180  const RexWindowFunctionOperator* window_expr) const {
181  // step 1. push "all" target expressions of the window_func_project_node down to the
182  // new child projection
183  // each window expr is a separate target expression of the projection node
184  // and they have their own inner expression related to partition / order clauses
185  // so we capture their offsets to correctly rebind their input
186  pushed_down_window_operands_offset_.clear();
187  pushed_down_partition_key_offset_.clear();
188  pushed_down_order_key_offset_.clear();
189  for (size_t offset = 0; offset < window_expr->size(); ++offset) {
190  auto expr = window_expr->getOperand(offset);
191  auto literal_expr = dynamic_cast<const RexLiteral*>(expr);
192  auto case_expr = dynamic_cast<const RexCase*>(expr);
193  if (case_expr) {
194  // when columnar output is enabled, pushdown case expr can incur an issue
195  // during columnarization, so we record this and try to force rowwise-output
196  // until we fix the issue
197  // todo (yoonmin) : relax this
198  found_case_expr_window_operand_ = true;
199  }
200  if (!literal_expr) {
201  auto new_offset = pushDownExpressionImpl(expr);
202  pushed_down_window_operands_offset_.emplace(offset, new_offset);
203  }
204  }
205  size_t offset = 0;
206  for (const auto& partition_key : window_expr->getPartitionKeys()) {
207  auto new_offset = pushDownExpressionImpl(partition_key.get());
208  pushed_down_partition_key_offset_.emplace(offset, new_offset);
209  ++offset;
210  }
211  has_partition_expr_ = !window_expr->getPartitionKeys().empty();
212  offset = 0;
213  for (const auto& order_key : window_expr->getOrderKeys()) {
214  auto new_offset = pushDownExpressionImpl(order_key.get());
215  pushed_down_order_key_offset_.emplace(offset, new_offset);
216  ++offset;
217  }
218 
219  // step 2. rebind projected targets of the window_func_project_node with the new
220  // project node
221  std::vector<std::unique_ptr<const RexScalar>> window_operands;
222  auto deconst_window_expr = const_cast<RexWindowFunctionOperator*>(window_expr);
223  for (size_t idx = 0; idx < window_expr->size(); ++idx) {
224  auto it = pushed_down_window_operands_offset_.find(idx);
225  if (it != pushed_down_window_operands_offset_.end()) {
226  auto new_input = std::make_unique<const RexInput>(new_project_.get(), it->second);
227  CHECK(new_input);
228  window_operands.emplace_back(std::move(new_input));
229  } else {
230  auto copied_expr = deep_copier_.visit(window_expr->getOperand(idx));
231  window_operands.emplace_back(std::move(copied_expr));
232  }
233  }
234  deconst_window_expr->replaceOperands(std::move(window_operands));
235 
236  for (size_t idx = 0; idx < window_expr->getPartitionKeys().size(); ++idx) {
237  auto new_offset = getOffsetForPushedDownExpr(WindowExprType::PARTITION_KEY, idx);
238  CHECK(new_offset);
239  auto new_input = std::make_unique<const RexInput>(new_project_.get(), *new_offset);
240  CHECK(new_input);
241  deconst_window_expr->replacePartitionKey(idx, std::move(new_input));
242  }
243 
244  for (size_t idx = 0; idx < window_expr->getOrderKeys().size(); ++idx) {
245  auto new_offset = getOffsetForPushedDownExpr(WindowExprType::ORDER_KEY, idx);
246  CHECK(new_offset);
247  auto new_input = std::make_unique<const RexInput>(new_project_.get(), *new_offset);
248  CHECK(new_input);
249  deconst_window_expr->replaceOrderKey(idx, std::move(new_input));
250  }
251  }
252 
253  std::unique_ptr<const RexScalar> visitInput(const RexInput* rex_input) const override {
254  auto new_offset = pushDownExpressionImpl(rex_input);
255  CHECK_LT(new_offset, scalar_exprs_for_new_project_.size());
256  auto hash = rex_input->toHash();
257  auto it = expr_offset_cache_.find(hash);
258  CHECK(it != expr_offset_cache_.end());
259  CHECK_EQ(new_offset, it->second);
260  auto new_input = std::make_unique<const RexInput>(new_project_.get(), new_offset);
261  CHECK(new_input);
262  return new_input;
263  }
264 
265  std::unique_ptr<const RexScalar> visitLiteral(
266  const RexLiteral* rex_literal) const override {
267  return deep_copier_.visit(rex_literal);
268  }
269 
270  std::unique_ptr<const RexScalar> visitRef(const RexRef* rex_ref) const override {
271  return deep_copier_.visit(rex_ref);
272  }
273 
274  std::unique_ptr<const RexScalar> visitSubQuery(
275  const RexSubQuery* rex_subquery) const override {
276  return deep_copier_.visit(rex_subquery);
277  }
278 
279  std::unique_ptr<const RexScalar> visitCase(const RexCase* rex_case) const override {
280  std::vector<
281  std::pair<std::unique_ptr<const RexScalar>, std::unique_ptr<const RexScalar>>>
282  new_expr_pair_list;
283  std::unique_ptr<const RexScalar> new_else_expr;
284  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
285  const auto when = rex_case->getWhen(i);
286  auto new_when = PushDownGenericExpressionInWindowFunction::visit(when);
287  const auto then = rex_case->getThen(i);
288  auto new_then = PushDownGenericExpressionInWindowFunction::visit(then);
289  new_expr_pair_list.emplace_back(std::move(new_when), std::move(new_then));
290  }
291  if (rex_case->getElse()) {
292  new_else_expr = deep_copier_.visit(rex_case->getElse());
293  }
294  auto new_case = std::make_unique<const RexCase>(new_expr_pair_list, new_else_expr);
295  return new_case;
296  }
297 
298  std::unique_ptr<const RexScalar> visitOperator(
299  const RexOperator* rex_operator) const override {
300  const auto rex_window_func_operator =
301  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
302  if (rex_window_func_operator) {
303  pushDownExpressionInWindowFunction(rex_window_func_operator);
304  return deep_copier_.visit(rex_operator);
305  } else {
306  std::unique_ptr<const RexOperator> new_operator{nullptr};
307  std::vector<std::unique_ptr<const RexScalar>> new_operands;
308  for (size_t i = 0; i < rex_operator->size(); ++i) {
309  const auto operand = rex_operator->getOperand(i);
310  auto new_operand = PushDownGenericExpressionInWindowFunction::visit(operand);
311  new_operands.emplace_back(std::move(new_operand));
312  }
313  if (auto function_op = dynamic_cast<const RexFunctionOperator*>(rex_operator)) {
314  new_operator = std::make_unique<const RexFunctionOperator>(
315  function_op->getName(), new_operands, rex_operator->getType());
316  } else {
317  new_operator = std::make_unique<const RexOperator>(
318  rex_operator->getOperator(), new_operands, rex_operator->getType());
319  }
320  CHECK(new_operator);
321  return new_operator;
322  }
323  }
324 
325  bool hasCaseExprAsWindowOperand() { return found_case_expr_window_operand_; }
326 
327  bool hasPartitionExpression() { return has_partition_expr_; }
328 
329  private:
330  std::unique_ptr<const RexScalar> defaultResult() const override { return nullptr; }
331 
332  std::shared_ptr<RelProject> new_project_;
333  std::vector<std::unique_ptr<const RexScalar>>& scalar_exprs_for_new_project_;
334  std::vector<std::string>& fields_for_new_project_;
335  std::unordered_map<size_t, size_t>& expr_offset_cache_;
337  mutable bool has_partition_expr_;
338  mutable std::unordered_map<size_t, size_t> pushed_down_window_operands_offset_;
339  mutable std::unordered_map<size_t, size_t> pushed_down_partition_key_offset_;
340  mutable std::unordered_map<size_t, size_t> pushed_down_order_key_offset_;
342 };
343 
344 } // namespace
345 
347  std::shared_ptr<const RelAlgNode> old_input,
348  std::shared_ptr<const RelAlgNode> input,
349  std::optional<std::unordered_map<unsigned, unsigned>> old_to_new_index_map) {
350  RelAlgNode::replaceInput(old_input, input);
351  std::unique_ptr<RexRebindInputsVisitor> rebind_inputs;
352  if (old_to_new_index_map) {
353  rebind_inputs = std::make_unique<RexRebindReindexInputsVisitor>(
354  old_input.get(), input.get(), *old_to_new_index_map);
355  } else {
356  rebind_inputs =
357  std::make_unique<RexRebindInputsVisitor>(old_input.get(), input.get());
358  }
359  CHECK(rebind_inputs);
360  for (const auto& scalar_expr : scalar_exprs_) {
361  rebind_inputs->visit(scalar_expr.get());
362  }
363 }
364 
365 void RelProject::appendInput(std::string new_field_name,
366  std::unique_ptr<const RexScalar> new_input) {
367  fields_.emplace_back(std::move(new_field_name));
368  scalar_exprs_.emplace_back(std::move(new_input));
369 }
370 
372  const auto scan_node = dynamic_cast<const RelScan*>(ra_node);
373  if (scan_node) {
374  // Scan node has no inputs, output contains all columns in the table.
375  CHECK_EQ(size_t(0), scan_node->inputCount());
376  return n_outputs(scan_node, scan_node->size());
377  }
378  const auto project_node = dynamic_cast<const RelProject*>(ra_node);
379  if (project_node) {
380  // Project output count doesn't depend on the input
381  CHECK_EQ(size_t(1), project_node->inputCount());
382  return n_outputs(project_node, project_node->size());
383  }
384  const auto filter_node = dynamic_cast<const RelFilter*>(ra_node);
385  if (filter_node) {
386  // Filter preserves shape
387  CHECK_EQ(size_t(1), filter_node->inputCount());
388  const auto prev_out = get_node_output(filter_node->getInput(0));
389  return n_outputs(filter_node, prev_out.size());
390  }
391  const auto aggregate_node = dynamic_cast<const RelAggregate*>(ra_node);
392  if (aggregate_node) {
393  // Aggregate output count doesn't depend on the input
394  CHECK_EQ(size_t(1), aggregate_node->inputCount());
395  return n_outputs(aggregate_node, aggregate_node->size());
396  }
397  const auto compound_node = dynamic_cast<const RelCompound*>(ra_node);
398  if (compound_node) {
399  // Compound output count doesn't depend on the input
400  CHECK_EQ(size_t(1), compound_node->inputCount());
401  return n_outputs(compound_node, compound_node->size());
402  }
403  const auto join_node = dynamic_cast<const RelJoin*>(ra_node);
404  if (join_node) {
405  // Join concatenates the outputs from the inputs and the output
406  // directly references the nodes in the input.
407  CHECK_EQ(size_t(2), join_node->inputCount());
408  auto lhs_out =
409  n_outputs(join_node->getInput(0), get_node_output(join_node->getInput(0)).size());
410  const auto rhs_out =
411  n_outputs(join_node->getInput(1), get_node_output(join_node->getInput(1)).size());
412  lhs_out.insert(lhs_out.end(), rhs_out.begin(), rhs_out.end());
413  return lhs_out;
414  }
415  const auto table_func_node = dynamic_cast<const RelTableFunction*>(ra_node);
416  if (table_func_node) {
417  // Table Function output count doesn't depend on the input
418  return n_outputs(table_func_node, table_func_node->size());
419  }
420  const auto sort_node = dynamic_cast<const RelSort*>(ra_node);
421  if (sort_node) {
422  // Sort preserves shape
423  CHECK_EQ(size_t(1), sort_node->inputCount());
424  const auto prev_out = get_node_output(sort_node->getInput(0));
425  return n_outputs(sort_node, prev_out.size());
426  }
427  const auto logical_values_node = dynamic_cast<const RelLogicalValues*>(ra_node);
428  if (logical_values_node) {
429  CHECK_EQ(size_t(0), logical_values_node->inputCount());
430  return n_outputs(logical_values_node, logical_values_node->size());
431  }
432  const auto logical_union_node = dynamic_cast<const RelLogicalUnion*>(ra_node);
433  if (logical_union_node) {
434  return n_outputs(logical_union_node, logical_union_node->size());
435  }
436  LOG(FATAL) << "Unhandled ra_node type: " << ::toString(ra_node);
437  return {};
438 }
439 
441  if (!isSimple()) {
442  return false;
443  }
444  CHECK_EQ(size_t(1), inputCount());
445  const auto source = getInput(0);
446  if (dynamic_cast<const RelJoin*>(source)) {
447  return false;
448  }
449  const auto source_shape = get_node_output(source);
450  if (source_shape.size() != scalar_exprs_.size()) {
451  return false;
452  }
453  for (size_t i = 0; i < scalar_exprs_.size(); ++i) {
454  const auto& scalar_expr = scalar_exprs_[i];
455  const auto input = dynamic_cast<const RexInput*>(scalar_expr.get());
456  CHECK(input);
457  CHECK_EQ(source, input->getSourceNode());
458  // We should add the additional check that input->getIndex() !=
459  // source_shape[i].getIndex(), but Calcite doesn't generate the right
460  // Sort-Project-Sort sequence when joins are involved.
461  if (input->getSourceNode() != source_shape[i].getSourceNode()) {
462  return false;
463  }
464  }
465  return true;
466 }
467 
468 namespace {
469 
470 bool isRenamedInput(const RelAlgNode* node,
471  const size_t index,
472  const std::string& new_name) {
473  CHECK_LT(index, node->size());
474  if (auto join = dynamic_cast<const RelJoin*>(node)) {
475  CHECK_EQ(size_t(2), join->inputCount());
476  const auto lhs_size = join->getInput(0)->size();
477  if (index < lhs_size) {
478  return isRenamedInput(join->getInput(0), index, new_name);
479  }
480  CHECK_GE(index, lhs_size);
481  return isRenamedInput(join->getInput(1), index - lhs_size, new_name);
482  }
483 
484  if (auto scan = dynamic_cast<const RelScan*>(node)) {
485  return new_name != scan->getFieldName(index);
486  }
487 
488  if (auto aggregate = dynamic_cast<const RelAggregate*>(node)) {
489  return new_name != aggregate->getFieldName(index);
490  }
491 
492  if (auto project = dynamic_cast<const RelProject*>(node)) {
493  return new_name != project->getFieldName(index);
494  }
495 
496  if (auto table_func = dynamic_cast<const RelTableFunction*>(node)) {
497  return new_name != table_func->getFieldName(index);
498  }
499 
500  if (auto logical_values = dynamic_cast<const RelLogicalValues*>(node)) {
501  const auto& tuple_type = logical_values->getTupleType();
502  CHECK_LT(index, tuple_type.size());
503  return new_name != tuple_type[index].get_resname();
504  }
505 
506  CHECK(dynamic_cast<const RelSort*>(node) || dynamic_cast<const RelFilter*>(node) ||
507  dynamic_cast<const RelLogicalUnion*>(node));
508  return isRenamedInput(node->getInput(0), index, new_name);
509 }
510 
511 } // namespace
512 
514  if (!isSimple()) {
515  return false;
516  }
517  CHECK_EQ(scalar_exprs_.size(), fields_.size());
518  for (size_t i = 0; i < fields_.size(); ++i) {
519  auto rex_in = dynamic_cast<const RexInput*>(scalar_exprs_[i].get());
520  CHECK(rex_in);
521  if (isRenamedInput(rex_in->getSourceNode(), rex_in->getIndex(), fields_[i])) {
522  return true;
523  }
524  }
525  return false;
526 }
527 
528 void RelJoin::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
529  std::shared_ptr<const RelAlgNode> input) {
530  RelAlgNode::replaceInput(old_input, input);
531  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
532  if (condition_) {
533  rebind_inputs.visit(condition_.get());
534  }
535 }
536 
537 void RelFilter::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
538  std::shared_ptr<const RelAlgNode> input) {
539  RelAlgNode::replaceInput(old_input, input);
540  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
541  rebind_inputs.visit(filter_.get());
542 }
543 
544 void RelCompound::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
545  std::shared_ptr<const RelAlgNode> input) {
546  RelAlgNode::replaceInput(old_input, input);
547  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
548  for (const auto& scalar_source : scalar_sources_) {
549  rebind_inputs.visit(scalar_source.get());
550  }
551  if (filter_expr_) {
552  rebind_inputs.visit(filter_expr_.get());
553  }
554 }
555 
557  : RelAlgNode(rhs)
559  , fields_(rhs.fields_)
560  , hint_applied_(false)
561  , hints_(std::make_unique<Hints>()) {
562  RexDeepCopyVisitor copier;
563  for (auto const& expr : rhs.scalar_exprs_) {
564  scalar_exprs_.push_back(copier.visit(expr.get()));
565  }
566  if (rhs.hint_applied_) {
567  for (auto const& kv : *rhs.hints_) {
568  addHint(kv.second);
569  }
570  }
571 }
572 
574  : RelAlgNode(rhs)
575  , tuple_type_(rhs.tuple_type_)
576  , values_(RexDeepCopyVisitor::copy(rhs.values_)) {}
577 
579  RexDeepCopyVisitor copier;
580  filter_ = copier.visit(rhs.filter_.get());
581 }
582 
584  : RelAlgNode(rhs)
585  , groupby_count_(rhs.groupby_count_)
586  , fields_(rhs.fields_)
587  , hint_applied_(false)
588  , hints_(std::make_unique<Hints>()) {
589  agg_exprs_.reserve(rhs.agg_exprs_.size());
590  for (auto const& agg : rhs.agg_exprs_) {
591  agg_exprs_.push_back(agg->deepCopy());
592  }
593  if (rhs.hint_applied_) {
594  for (auto const& kv : *rhs.hints_) {
595  addHint(kv.second);
596  }
597  }
598 }
599 
601  : RelAlgNode(rhs)
602  , join_type_(rhs.join_type_)
603  , hint_applied_(false)
604  , hints_(std::make_unique<Hints>()) {
605  RexDeepCopyVisitor copier;
606  condition_ = copier.visit(rhs.condition_.get());
607  if (rhs.hint_applied_) {
608  for (auto const& kv : *rhs.hints_) {
609  addHint(kv.second);
610  }
611  }
612 }
613 
614 namespace {
615 
616 std::vector<std::unique_ptr<const RexAgg>> copyAggExprs(
617  std::vector<std::unique_ptr<const RexAgg>> const& agg_exprs) {
618  std::vector<std::unique_ptr<const RexAgg>> agg_exprs_copy;
619  agg_exprs_copy.reserve(agg_exprs.size());
620  for (auto const& agg_expr : agg_exprs) {
621  agg_exprs_copy.push_back(agg_expr->deepCopy());
622  }
623  return agg_exprs_copy;
624 }
625 
626 std::vector<std::unique_ptr<const RexScalar>> copyRexScalars(
627  std::vector<std::unique_ptr<const RexScalar>> const& scalar_sources) {
628  std::vector<std::unique_ptr<const RexScalar>> scalar_sources_copy;
629  scalar_sources_copy.reserve(scalar_sources.size());
630  RexDeepCopyVisitor copier;
631  for (auto const& scalar_source : scalar_sources) {
632  scalar_sources_copy.push_back(copier.visit(scalar_source.get()));
633  }
634  return scalar_sources_copy;
635 }
636 
637 std::vector<const Rex*> remapTargetPointers(
638  std::vector<std::unique_ptr<const RexAgg>> const& agg_exprs_new,
639  std::vector<std::unique_ptr<const RexScalar>> const& scalar_sources_new,
640  std::vector<std::unique_ptr<const RexAgg>> const& agg_exprs_old,
641  std::vector<std::unique_ptr<const RexScalar>> const& scalar_sources_old,
642  std::vector<const Rex*> const& target_exprs_old) {
643  std::vector<const Rex*> target_exprs(target_exprs_old);
644  std::unordered_map<const Rex*, const Rex*> old_to_new_target(target_exprs.size());
645  for (size_t i = 0; i < agg_exprs_new.size(); ++i) {
646  old_to_new_target.emplace(agg_exprs_old[i].get(), agg_exprs_new[i].get());
647  }
648  for (size_t i = 0; i < scalar_sources_new.size(); ++i) {
649  old_to_new_target.emplace(scalar_sources_old[i].get(), scalar_sources_new[i].get());
650  }
651  for (auto& target : target_exprs) {
652  auto target_it = old_to_new_target.find(target);
653  CHECK(target_it != old_to_new_target.end());
654  target = target_it->second;
655  }
656  return target_exprs;
657 }
658 
659 } // namespace
660 
662  : RelAlgNode(rhs)
664  , groupby_count_(rhs.groupby_count_)
665  , agg_exprs_(copyAggExprs(rhs.agg_exprs_))
666  , fields_(rhs.fields_)
667  , is_agg_(rhs.is_agg_)
668  , scalar_sources_(copyRexScalars(rhs.scalar_sources_))
669  , target_exprs_(remapTargetPointers(agg_exprs_,
670  scalar_sources_,
671  rhs.agg_exprs_,
672  rhs.scalar_sources_,
673  rhs.target_exprs_))
674  , hint_applied_(false)
675  , hints_(std::make_unique<Hints>()) {
676  RexDeepCopyVisitor copier;
677  filter_expr_ = rhs.filter_expr_ ? copier.visit(rhs.filter_expr_.get()) : nullptr;
678  if (rhs.hint_applied_) {
679  for (auto const& kv : *rhs.hints_) {
680  addHint(kv.second);
681  }
682  }
683 }
684 
685 void RelTableFunction::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
686  std::shared_ptr<const RelAlgNode> input) {
687  RelAlgNode::replaceInput(old_input, input);
688  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
689  for (const auto& target_expr : target_exprs_) {
690  rebind_inputs.visit(target_expr.get());
691  }
692  for (const auto& func_input : table_func_inputs_) {
693  rebind_inputs.visit(func_input.get());
694  }
695 }
696 
698  int32_t literal_args = 0;
699  for (const auto& arg : table_func_inputs_) {
700  const auto rex_literal = dynamic_cast<const RexLiteral*>(arg.get());
701  if (rex_literal) {
702  literal_args += 1;
703  }
704  }
705  return literal_args;
706 }
707 
709  : RelAlgNode(rhs)
710  , function_name_(rhs.function_name_)
711  , fields_(rhs.fields_)
712  , col_inputs_(rhs.col_inputs_)
713  , table_func_inputs_(copyRexScalars(rhs.table_func_inputs_))
714  , target_exprs_(copyRexScalars(rhs.target_exprs_)) {
715  std::unordered_map<const Rex*, const Rex*> old_to_new_input;
716  for (size_t i = 0; i < table_func_inputs_.size(); ++i) {
717  old_to_new_input.emplace(rhs.table_func_inputs_[i].get(),
718  table_func_inputs_[i].get());
719  }
720  for (auto& target : col_inputs_) {
721  auto target_it = old_to_new_input.find(target);
722  CHECK(target_it != old_to_new_input.end());
723  target = target_it->second;
724  }
725 }
726 
727 namespace std {
728 template <>
729 struct hash<std::pair<const RelAlgNode*, int>> {
730  size_t operator()(const std::pair<const RelAlgNode*, int>& input_col) const {
731  auto ptr_val = reinterpret_cast<const int64_t*>(&input_col.first);
732  return static_cast<int64_t>(*ptr_val) ^ input_col.second;
733  }
734 };
735 } // namespace std
736 
737 namespace {
738 
739 std::set<std::pair<const RelAlgNode*, int>> get_equiv_cols(const RelAlgNode* node,
740  const size_t which_col) {
741  std::set<std::pair<const RelAlgNode*, int>> work_set;
742  auto walker = node;
743  auto curr_col = which_col;
744  while (true) {
745  work_set.insert(std::make_pair(walker, curr_col));
746  if (dynamic_cast<const RelScan*>(walker) || dynamic_cast<const RelJoin*>(walker)) {
747  break;
748  }
749  CHECK_EQ(size_t(1), walker->inputCount());
750  auto only_source = walker->getInput(0);
751  if (auto project = dynamic_cast<const RelProject*>(walker)) {
752  if (auto input = dynamic_cast<const RexInput*>(project->getProjectAt(curr_col))) {
753  const auto join_source = dynamic_cast<const RelJoin*>(only_source);
754  if (join_source) {
755  CHECK_EQ(size_t(2), join_source->inputCount());
756  auto lhs = join_source->getInput(0);
757  CHECK((input->getIndex() < lhs->size() && lhs == input->getSourceNode()) ||
758  join_source->getInput(1) == input->getSourceNode());
759  } else {
760  CHECK_EQ(input->getSourceNode(), only_source);
761  }
762  curr_col = input->getIndex();
763  } else {
764  break;
765  }
766  } else if (auto aggregate = dynamic_cast<const RelAggregate*>(walker)) {
767  if (curr_col >= aggregate->getGroupByCount()) {
768  break;
769  }
770  }
771  walker = only_source;
772  }
773  return work_set;
774 }
775 
776 } // namespace
777 
778 bool RelSort::hasEquivCollationOf(const RelSort& that) const {
779  if (collation_.size() != that.collation_.size()) {
780  return false;
781  }
782 
783  for (size_t i = 0, e = collation_.size(); i < e; ++i) {
784  auto this_sort_key = collation_[i];
785  auto that_sort_key = that.collation_[i];
786  if (this_sort_key.getSortDir() != that_sort_key.getSortDir()) {
787  return false;
788  }
789  if (this_sort_key.getNullsPosition() != that_sort_key.getNullsPosition()) {
790  return false;
791  }
792  auto this_equiv_keys = get_equiv_cols(this, this_sort_key.getField());
793  auto that_equiv_keys = get_equiv_cols(&that, that_sort_key.getField());
794  std::vector<std::pair<const RelAlgNode*, int>> intersect;
795  std::set_intersection(this_equiv_keys.begin(),
796  this_equiv_keys.end(),
797  that_equiv_keys.begin(),
798  that_equiv_keys.end(),
799  std::back_inserter(intersect));
800  if (intersect.empty()) {
801  return false;
802  }
803  }
804  return true;
805 }
806 
807 // class RelLogicalUnion methods
808 
810  : RelAlgNode(std::move(inputs)), is_all_(is_all) {
811  if (!g_enable_union) {
812  throw QueryNotSupported(
813  "The DEPRECATED enable-union option is set to off. Please remove this option as "
814  "it may be disabled in the future.");
815  }
816  CHECK_EQ(2u, inputs_.size());
817  if (!is_all_) {
818  throw QueryNotSupported("UNION without ALL is not supported yet.");
819  }
820 }
821 
822 size_t RelLogicalUnion::size() const {
823  return inputs_.front()->size();
824 }
825 
827  return cat(::typeName(this), "(is_all(", is_all_, "))");
828 }
829 
830 size_t RelLogicalUnion::toHash() const {
831  if (!hash_) {
832  hash_ = typeid(RelLogicalUnion).hash_code();
833  boost::hash_combine(*hash_, is_all_);
834  }
835  return *hash_;
836 }
837 
838 std::string RelLogicalUnion::getFieldName(const size_t i) const {
839  if (auto const* input = dynamic_cast<RelCompound const*>(inputs_[0].get())) {
840  return input->getFieldName(i);
841  } else if (auto const* input = dynamic_cast<RelProject const*>(inputs_[0].get())) {
842  return input->getFieldName(i);
843  } else if (auto const* input = dynamic_cast<RelLogicalUnion const*>(inputs_[0].get())) {
844  return input->getFieldName(i);
845  } else if (auto const* input = dynamic_cast<RelAggregate const*>(inputs_[0].get())) {
846  return input->getFieldName(i);
847  } else if (auto const* input = dynamic_cast<RelScan const*>(inputs_[0].get())) {
848  return input->getFieldName(i);
849  } else if (auto const* input =
850  dynamic_cast<RelTableFunction const*>(inputs_[0].get())) {
851  return input->getFieldName(i);
852  }
853  UNREACHABLE() << "Unhandled input type: " << ::toString(inputs_.front());
854  return {};
855 }
856 
858  std::vector<TargetMetaInfo> const& tmis0 = inputs_[0]->getOutputMetainfo();
859  std::vector<TargetMetaInfo> const& tmis1 = inputs_[1]->getOutputMetainfo();
860  if (tmis0.size() != tmis1.size()) {
861  VLOG(2) << "tmis0.size() = " << tmis0.size() << " != " << tmis1.size()
862  << " = tmis1.size()";
863  throw std::runtime_error("Subqueries of a UNION must have matching data types.");
864  }
865  for (size_t i = 0; i < tmis0.size(); ++i) {
866  if (tmis0[i].get_type_info() != tmis1[i].get_type_info()) {
867  SQLTypeInfo const& ti0 = tmis0[i].get_type_info();
868  SQLTypeInfo const& ti1 = tmis1[i].get_type_info();
869  VLOG(2) << "Types do not match for UNION:\n tmis0[" << i
870  << "].get_type_info().to_string() = " << ti0.to_string() << "\n tmis1["
871  << i << "].get_type_info().to_string() = " << ti1.to_string();
872  // The only permitted difference is when both columns are dictionary-encoded.
873  if (!(ti0.is_dict_encoded_string() && ti1.is_dict_encoded_string())) {
874  throw std::runtime_error(
875  "Subqueries of a UNION must have the exact same data types.");
876  }
877  }
878  }
879 }
880 
881 // Rest of code requires a raw pointer, but RexInput object needs to live somewhere.
883  size_t input_idx) const {
884  if (auto const* rex_input_ptr = dynamic_cast<RexInput const*>(rex_scalar)) {
885  RexInput rex_input(*rex_input_ptr);
886  rex_input.setSourceNode(getInput(input_idx));
887  scalar_exprs_.emplace_back(std::make_shared<RexInput const>(std::move(rex_input)));
888  return scalar_exprs_.back().get();
889  }
890  return rex_scalar;
891 }
892 
893 namespace {
894 
895 unsigned node_id(const rapidjson::Value& ra_node) noexcept {
896  const auto& id = field(ra_node, "id");
897  return std::stoi(json_str(id));
898 }
899 
900 std::string json_node_to_string(const rapidjson::Value& node) noexcept {
901  rapidjson::StringBuffer buffer;
902  rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
903  node.Accept(writer);
904  return buffer.GetString();
905 }
906 
907 // The parse_* functions below de-serialize expressions as they come from Calcite.
908 // RelAlgDagBuilder will take care of making the representation easy to
909 // navigate for lower layers, for example by replacing RexAbstractInput with RexInput.
910 
911 std::unique_ptr<RexAbstractInput> parse_abstract_input(
912  const rapidjson::Value& expr) noexcept {
913  const auto& input = field(expr, "input");
914  return std::unique_ptr<RexAbstractInput>(new RexAbstractInput(json_i64(input)));
915 }
916 
917 std::unique_ptr<RexLiteral> parse_literal(const rapidjson::Value& expr) {
918  CHECK(expr.IsObject());
919  const auto& literal = field(expr, "literal");
920  const auto type = to_sql_type(json_str(field(expr, "type")));
921  const auto target_type = to_sql_type(json_str(field(expr, "target_type")));
922  const auto scale = json_i64(field(expr, "scale"));
923  const auto precision = json_i64(field(expr, "precision"));
924  const auto type_scale = json_i64(field(expr, "type_scale"));
925  const auto type_precision = json_i64(field(expr, "type_precision"));
926  if (literal.IsNull()) {
927  return std::unique_ptr<RexLiteral>(new RexLiteral(target_type));
928  }
929  switch (type) {
930  case kINT:
931  case kBIGINT:
932  case kDECIMAL:
933  case kINTERVAL_DAY_TIME:
935  case kTIME:
936  case kTIMESTAMP:
937  case kDATE:
938  return std::unique_ptr<RexLiteral>(new RexLiteral(json_i64(literal),
939  type,
940  target_type,
941  scale,
942  precision,
943  type_scale,
944  type_precision));
945  case kDOUBLE: {
946  if (literal.IsDouble()) {
947  return std::unique_ptr<RexLiteral>(new RexLiteral(json_double(literal),
948  type,
949  target_type,
950  scale,
951  precision,
952  type_scale,
953  type_precision));
954  } else if (literal.IsInt64()) {
955  return std::make_unique<RexLiteral>(static_cast<double>(literal.GetInt64()),
956  type,
957  target_type,
958  scale,
959  precision,
960  type_scale,
961  type_precision);
962 
963  } else if (literal.IsUint64()) {
964  return std::make_unique<RexLiteral>(static_cast<double>(literal.GetUint64()),
965  type,
966  target_type,
967  scale,
968  precision,
969  type_scale,
970  type_precision);
971  }
972  UNREACHABLE() << "Unhandled type: " << literal.GetType();
973  }
974  case kTEXT:
975  return std::unique_ptr<RexLiteral>(new RexLiteral(json_str(literal),
976  type,
977  target_type,
978  scale,
979  precision,
980  type_scale,
981  type_precision));
982  case kBOOLEAN:
983  return std::unique_ptr<RexLiteral>(new RexLiteral(json_bool(literal),
984  type,
985  target_type,
986  scale,
987  precision,
988  type_scale,
989  type_precision));
990  case kNULLT:
991  return std::unique_ptr<RexLiteral>(new RexLiteral(target_type));
992  default:
993  CHECK(false);
994  }
995  CHECK(false);
996  return nullptr;
997 }
998 
999 std::unique_ptr<const RexScalar> parse_scalar_expr(const rapidjson::Value& expr,
1001  RelAlgDagBuilder& root_dag_builder);
1002 
1003 SQLTypeInfo parse_type(const rapidjson::Value& type_obj) {
1004  if (type_obj.IsArray()) {
1005  throw QueryNotSupported("Composite types are not currently supported.");
1006  }
1007  CHECK(type_obj.IsObject() && type_obj.MemberCount() >= 2)
1008  << json_node_to_string(type_obj);
1009  const auto type = to_sql_type(json_str(field(type_obj, "type")));
1010  const auto nullable = json_bool(field(type_obj, "nullable"));
1011  const auto precision_it = type_obj.FindMember("precision");
1012  const int precision =
1013  precision_it != type_obj.MemberEnd() ? json_i64(precision_it->value) : 0;
1014  const auto scale_it = type_obj.FindMember("scale");
1015  const int scale = scale_it != type_obj.MemberEnd() ? json_i64(scale_it->value) : 0;
1016  SQLTypeInfo ti(type, !nullable);
1017  ti.set_precision(precision);
1018  ti.set_scale(scale);
1019  return ti;
1020 }
1021 
1022 std::vector<std::unique_ptr<const RexScalar>> parse_expr_array(
1023  const rapidjson::Value& arr,
1025  RelAlgDagBuilder& root_dag_builder) {
1026  std::vector<std::unique_ptr<const RexScalar>> exprs;
1027  for (auto it = arr.Begin(); it != arr.End(); ++it) {
1028  exprs.emplace_back(parse_scalar_expr(*it, cat, root_dag_builder));
1029  }
1030  return exprs;
1031 }
1032 
1034  if (name == "ROW_NUMBER") {
1036  }
1037  if (name == "RANK") {
1039  }
1040  if (name == "DENSE_RANK") {
1042  }
1043  if (name == "PERCENT_RANK") {
1045  }
1046  if (name == "CUME_DIST") {
1048  }
1049  if (name == "NTILE") {
1051  }
1052  if (name == "LAG") {
1054  }
1055  if (name == "LEAD") {
1057  }
1058  if (name == "FIRST_VALUE") {
1060  }
1061  if (name == "LAST_VALUE") {
1063  }
1064  if (name == "AVG") {
1066  }
1067  if (name == "MIN") {
1069  }
1070  if (name == "MAX") {
1072  }
1073  if (name == "SUM") {
1075  }
1076  if (name == "COUNT") {
1078  }
1079  if (name == "$SUM0") {
1081  }
1082  throw std::runtime_error("Unsupported window function: " + name);
1083 }
1084 
1085 std::vector<std::unique_ptr<const RexScalar>> parse_window_order_exprs(
1086  const rapidjson::Value& arr,
1088  RelAlgDagBuilder& root_dag_builder) {
1089  std::vector<std::unique_ptr<const RexScalar>> exprs;
1090  for (auto it = arr.Begin(); it != arr.End(); ++it) {
1091  exprs.emplace_back(parse_scalar_expr(field(*it, "field"), cat, root_dag_builder));
1092  }
1093  return exprs;
1094 }
1095 
1096 SortDirection parse_sort_direction(const rapidjson::Value& collation) {
1097  return json_str(field(collation, "direction")) == std::string("DESCENDING")
1100 }
1101 
1102 NullSortedPosition parse_nulls_position(const rapidjson::Value& collation) {
1103  return json_str(field(collation, "nulls")) == std::string("FIRST")
1106 }
1107 
1108 std::vector<SortField> parse_window_order_collation(const rapidjson::Value& arr,
1110  RelAlgDagBuilder& root_dag_builder) {
1111  std::vector<SortField> collation;
1112  size_t field_idx = 0;
1113  for (auto it = arr.Begin(); it != arr.End(); ++it, ++field_idx) {
1114  const auto sort_dir = parse_sort_direction(*it);
1115  const auto null_pos = parse_nulls_position(*it);
1116  collation.emplace_back(field_idx, sort_dir, null_pos);
1117  }
1118  return collation;
1119 }
1120 
1122  const rapidjson::Value& window_bound_obj,
1124  RelAlgDagBuilder& root_dag_builder) {
1125  CHECK(window_bound_obj.IsObject());
1127  window_bound.unbounded = json_bool(field(window_bound_obj, "unbounded"));
1128  window_bound.preceding = json_bool(field(window_bound_obj, "preceding"));
1129  window_bound.following = json_bool(field(window_bound_obj, "following"));
1130  window_bound.is_current_row = json_bool(field(window_bound_obj, "is_current_row"));
1131  const auto& offset_field = field(window_bound_obj, "offset");
1132  if (offset_field.IsObject()) {
1133  window_bound.offset = parse_scalar_expr(offset_field, cat, root_dag_builder);
1134  } else {
1135  CHECK(offset_field.IsNull());
1136  }
1137  window_bound.order_key = json_i64(field(window_bound_obj, "order_key"));
1138  return window_bound;
1139 }
1140 
1141 std::unique_ptr<const RexSubQuery> parse_subquery(const rapidjson::Value& expr,
1143  RelAlgDagBuilder& root_dag_builder) {
1144  const auto& operands = field(expr, "operands");
1145  CHECK(operands.IsArray());
1146  CHECK_GE(operands.Size(), unsigned(0));
1147  const auto& subquery_ast = field(expr, "subquery");
1148 
1149  RelAlgDagBuilder subquery_dag(root_dag_builder, subquery_ast, cat, nullptr);
1150  auto subquery = std::make_shared<RexSubQuery>(subquery_dag.getRootNodeShPtr());
1151  root_dag_builder.registerSubquery(subquery);
1152  return subquery->deepCopy();
1153 }
1154 
1155 std::unique_ptr<RexOperator> parse_operator(const rapidjson::Value& expr,
1157  RelAlgDagBuilder& root_dag_builder) {
1158  const auto op_name = json_str(field(expr, "op"));
1159  const bool is_quantifier =
1160  op_name == std::string("PG_ANY") || op_name == std::string("PG_ALL");
1161  const auto op = is_quantifier ? kFUNCTION : to_sql_op(op_name);
1162  const auto& operators_json_arr = field(expr, "operands");
1163  CHECK(operators_json_arr.IsArray());
1164  auto operands = parse_expr_array(operators_json_arr, cat, root_dag_builder);
1165  const auto type_it = expr.FindMember("type");
1166  CHECK(type_it != expr.MemberEnd());
1167  auto ti = parse_type(type_it->value);
1168  if (op == kIN && expr.HasMember("subquery")) {
1169  auto subquery = parse_subquery(expr, cat, root_dag_builder);
1170  operands.emplace_back(std::move(subquery));
1171  }
1172  if (expr.FindMember("partition_keys") != expr.MemberEnd()) {
1173  const auto& partition_keys_arr = field(expr, "partition_keys");
1174  auto partition_keys = parse_expr_array(partition_keys_arr, cat, root_dag_builder);
1175  const auto& order_keys_arr = field(expr, "order_keys");
1176  auto order_keys = parse_window_order_exprs(order_keys_arr, cat, root_dag_builder);
1177  const auto collation =
1178  parse_window_order_collation(order_keys_arr, cat, root_dag_builder);
1179  const auto kind = parse_window_function_kind(op_name);
1180  const auto lower_bound =
1181  parse_window_bound(field(expr, "lower_bound"), cat, root_dag_builder);
1182  const auto upper_bound =
1183  parse_window_bound(field(expr, "upper_bound"), cat, root_dag_builder);
1184  bool is_rows = json_bool(field(expr, "is_rows"));
1185  ti.set_notnull(false);
1186  return std::make_unique<RexWindowFunctionOperator>(kind,
1187  operands,
1188  partition_keys,
1189  order_keys,
1190  collation,
1191  lower_bound,
1192  upper_bound,
1193  is_rows,
1194  ti);
1195  }
1196  return std::unique_ptr<RexOperator>(op == kFUNCTION
1197  ? new RexFunctionOperator(op_name, operands, ti)
1198  : new RexOperator(op, operands, ti));
1199 }
1200 
1201 std::unique_ptr<RexCase> parse_case(const rapidjson::Value& expr,
1203  RelAlgDagBuilder& root_dag_builder) {
1204  const auto& operands = field(expr, "operands");
1205  CHECK(operands.IsArray());
1206  CHECK_GE(operands.Size(), unsigned(2));
1207  std::unique_ptr<const RexScalar> else_expr;
1208  std::vector<
1209  std::pair<std::unique_ptr<const RexScalar>, std::unique_ptr<const RexScalar>>>
1210  expr_pair_list;
1211  for (auto operands_it = operands.Begin(); operands_it != operands.End();) {
1212  auto when_expr = parse_scalar_expr(*operands_it++, cat, root_dag_builder);
1213  if (operands_it == operands.End()) {
1214  else_expr = std::move(when_expr);
1215  break;
1216  }
1217  auto then_expr = parse_scalar_expr(*operands_it++, cat, root_dag_builder);
1218  expr_pair_list.emplace_back(std::move(when_expr), std::move(then_expr));
1219  }
1220  return std::unique_ptr<RexCase>(new RexCase(expr_pair_list, else_expr));
1221 }
1222 
1223 std::vector<std::string> strings_from_json_array(
1224  const rapidjson::Value& json_str_arr) noexcept {
1225  CHECK(json_str_arr.IsArray());
1226  std::vector<std::string> fields;
1227  for (auto json_str_arr_it = json_str_arr.Begin(); json_str_arr_it != json_str_arr.End();
1228  ++json_str_arr_it) {
1229  CHECK(json_str_arr_it->IsString());
1230  fields.emplace_back(json_str_arr_it->GetString());
1231  }
1232  return fields;
1233 }
1234 
1235 std::vector<size_t> indices_from_json_array(
1236  const rapidjson::Value& json_idx_arr) noexcept {
1237  CHECK(json_idx_arr.IsArray());
1238  std::vector<size_t> indices;
1239  for (auto json_idx_arr_it = json_idx_arr.Begin(); json_idx_arr_it != json_idx_arr.End();
1240  ++json_idx_arr_it) {
1241  CHECK(json_idx_arr_it->IsInt());
1242  CHECK_GE(json_idx_arr_it->GetInt(), 0);
1243  indices.emplace_back(json_idx_arr_it->GetInt());
1244  }
1245  return indices;
1246 }
1247 
1248 std::unique_ptr<const RexAgg> parse_aggregate_expr(const rapidjson::Value& expr) {
1249  const auto agg_str = json_str(field(expr, "agg"));
1250  if (agg_str == "APPROX_QUANTILE") {
1251  LOG(INFO) << "APPROX_QUANTILE is deprecated. Please use APPROX_PERCENTILE instead.";
1252  }
1253  const auto agg = to_agg_kind(agg_str);
1254  const auto distinct = json_bool(field(expr, "distinct"));
1255  const auto agg_ti = parse_type(field(expr, "type"));
1256  const auto operands = indices_from_json_array(field(expr, "operands"));
1257  if (operands.size() > 1 && (operands.size() != 2 || (agg != kAPPROX_COUNT_DISTINCT &&
1258  agg != kAPPROX_QUANTILE))) {
1259  throw QueryNotSupported("Multiple arguments for aggregates aren't supported");
1260  }
1261  return std::unique_ptr<const RexAgg>(new RexAgg(agg, distinct, agg_ti, operands));
1262 }
1263 
1264 std::unique_ptr<const RexScalar> parse_scalar_expr(const rapidjson::Value& expr,
1266  RelAlgDagBuilder& root_dag_builder) {
1267  CHECK(expr.IsObject());
1268  if (expr.IsObject() && expr.HasMember("input")) {
1269  return std::unique_ptr<const RexScalar>(parse_abstract_input(expr));
1270  }
1271  if (expr.IsObject() && expr.HasMember("literal")) {
1272  return std::unique_ptr<const RexScalar>(parse_literal(expr));
1273  }
1274  if (expr.IsObject() && expr.HasMember("op")) {
1275  const auto op_str = json_str(field(expr, "op"));
1276  if (op_str == std::string("CASE")) {
1277  return std::unique_ptr<const RexScalar>(parse_case(expr, cat, root_dag_builder));
1278  }
1279  if (op_str == std::string("$SCALAR_QUERY")) {
1280  return std::unique_ptr<const RexScalar>(
1281  parse_subquery(expr, cat, root_dag_builder));
1282  }
1283  return std::unique_ptr<const RexScalar>(parse_operator(expr, cat, root_dag_builder));
1284  }
1285  throw QueryNotSupported("Expression node " + json_node_to_string(expr) +
1286  " not supported");
1287 }
1288 
1289 JoinType to_join_type(const std::string& join_type_name) {
1290  if (join_type_name == "inner") {
1291  return JoinType::INNER;
1292  }
1293  if (join_type_name == "left") {
1294  return JoinType::LEFT;
1295  }
1296  if (join_type_name == "semi") {
1297  return JoinType::SEMI;
1298  }
1299  if (join_type_name == "anti") {
1300  return JoinType::ANTI;
1301  }
1302  throw QueryNotSupported("Join type (" + join_type_name + ") not supported");
1303 }
1304 
1305 std::unique_ptr<const RexScalar> disambiguate_rex(const RexScalar*, const RANodeOutput&);
1306 
1307 std::unique_ptr<const RexOperator> disambiguate_operator(
1308  const RexOperator* rex_operator,
1309  const RANodeOutput& ra_output) noexcept {
1310  std::vector<std::unique_ptr<const RexScalar>> disambiguated_operands;
1311  for (size_t i = 0; i < rex_operator->size(); ++i) {
1312  auto operand = rex_operator->getOperand(i);
1313  if (dynamic_cast<const RexSubQuery*>(operand)) {
1314  disambiguated_operands.emplace_back(rex_operator->getOperandAndRelease(i));
1315  } else {
1316  disambiguated_operands.emplace_back(disambiguate_rex(operand, ra_output));
1317  }
1318  }
1319  const auto rex_window_function_operator =
1320  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
1321  if (rex_window_function_operator) {
1322  const auto& partition_keys = rex_window_function_operator->getPartitionKeys();
1323  std::vector<std::unique_ptr<const RexScalar>> disambiguated_partition_keys;
1324  for (const auto& partition_key : partition_keys) {
1325  disambiguated_partition_keys.emplace_back(
1326  disambiguate_rex(partition_key.get(), ra_output));
1327  }
1328  std::vector<std::unique_ptr<const RexScalar>> disambiguated_order_keys;
1329  const auto& order_keys = rex_window_function_operator->getOrderKeys();
1330  for (const auto& order_key : order_keys) {
1331  disambiguated_order_keys.emplace_back(disambiguate_rex(order_key.get(), ra_output));
1332  }
1333  return rex_window_function_operator->disambiguatedOperands(
1334  disambiguated_operands,
1335  disambiguated_partition_keys,
1336  disambiguated_order_keys,
1337  rex_window_function_operator->getCollation());
1338  }
1339  return rex_operator->getDisambiguated(disambiguated_operands);
1340 }
1341 
1342 std::unique_ptr<const RexCase> disambiguate_case(const RexCase* rex_case,
1343  const RANodeOutput& ra_output) {
1344  std::vector<
1345  std::pair<std::unique_ptr<const RexScalar>, std::unique_ptr<const RexScalar>>>
1346  disambiguated_expr_pair_list;
1347  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
1348  auto disambiguated_when = disambiguate_rex(rex_case->getWhen(i), ra_output);
1349  auto disambiguated_then = disambiguate_rex(rex_case->getThen(i), ra_output);
1350  disambiguated_expr_pair_list.emplace_back(std::move(disambiguated_when),
1351  std::move(disambiguated_then));
1352  }
1353  std::unique_ptr<const RexScalar> disambiguated_else{
1354  disambiguate_rex(rex_case->getElse(), ra_output)};
1355  return std::unique_ptr<const RexCase>(
1356  new RexCase(disambiguated_expr_pair_list, disambiguated_else));
1357 }
1358 
1359 // The inputs used by scalar expressions are given as indices in the serialized
1360 // representation of the query. This is hard to navigate; make the relationship
1361 // explicit by creating RexInput expressions which hold a pointer to the source
1362 // relational algebra node and the index relative to the output of that node.
1363 std::unique_ptr<const RexScalar> disambiguate_rex(const RexScalar* rex_scalar,
1364  const RANodeOutput& ra_output) {
1365  const auto rex_abstract_input = dynamic_cast<const RexAbstractInput*>(rex_scalar);
1366  if (rex_abstract_input) {
1367  CHECK_LT(static_cast<size_t>(rex_abstract_input->getIndex()), ra_output.size());
1368  return std::unique_ptr<const RexInput>(
1369  new RexInput(ra_output[rex_abstract_input->getIndex()]));
1370  }
1371  const auto rex_operator = dynamic_cast<const RexOperator*>(rex_scalar);
1372  if (rex_operator) {
1373  return disambiguate_operator(rex_operator, ra_output);
1374  }
1375  const auto rex_case = dynamic_cast<const RexCase*>(rex_scalar);
1376  if (rex_case) {
1377  return disambiguate_case(rex_case, ra_output);
1378  }
1379  if (auto const rex_literal = dynamic_cast<const RexLiteral*>(rex_scalar)) {
1380  return rex_literal->deepCopy();
1381  } else if (auto const rex_subquery = dynamic_cast<const RexSubQuery*>(rex_scalar)) {
1382  return rex_subquery->deepCopy();
1383  } else {
1384  throw QueryNotSupported("Unable to disambiguate expression of type " +
1385  std::string(typeid(*rex_scalar).name()));
1386  }
1387 }
1388 
1389 void bind_project_to_input(RelProject* project_node, const RANodeOutput& input) noexcept {
1390  CHECK_EQ(size_t(1), project_node->inputCount());
1391  std::vector<std::unique_ptr<const RexScalar>> disambiguated_exprs;
1392  for (size_t i = 0; i < project_node->size(); ++i) {
1393  const auto projected_expr = project_node->getProjectAt(i);
1394  if (dynamic_cast<const RexSubQuery*>(projected_expr)) {
1395  disambiguated_exprs.emplace_back(project_node->getProjectAtAndRelease(i));
1396  } else {
1397  disambiguated_exprs.emplace_back(disambiguate_rex(projected_expr, input));
1398  }
1399  }
1400  project_node->setExpressions(disambiguated_exprs);
1401 }
1402 
1404  const RANodeOutput& input) noexcept {
1405  std::vector<std::unique_ptr<const RexScalar>> disambiguated_exprs;
1406  for (size_t i = 0; i < table_func_node->getTableFuncInputsSize(); ++i) {
1407  const auto target_expr = table_func_node->getTableFuncInputAt(i);
1408  if (dynamic_cast<const RexSubQuery*>(target_expr)) {
1409  disambiguated_exprs.emplace_back(table_func_node->getTableFuncInputAtAndRelease(i));
1410  } else {
1411  disambiguated_exprs.emplace_back(disambiguate_rex(target_expr, input));
1412  }
1413  }
1414  table_func_node->setTableFuncInputs(disambiguated_exprs);
1415 }
1416 
1417 void bind_inputs(const std::vector<std::shared_ptr<RelAlgNode>>& nodes) noexcept {
1418  for (auto ra_node : nodes) {
1419  const auto filter_node = std::dynamic_pointer_cast<RelFilter>(ra_node);
1420  if (filter_node) {
1421  CHECK_EQ(size_t(1), filter_node->inputCount());
1422  auto disambiguated_condition = disambiguate_rex(
1423  filter_node->getCondition(), get_node_output(filter_node->getInput(0)));
1424  filter_node->setCondition(disambiguated_condition);
1425  continue;
1426  }
1427  const auto join_node = std::dynamic_pointer_cast<RelJoin>(ra_node);
1428  if (join_node) {
1429  CHECK_EQ(size_t(2), join_node->inputCount());
1430  auto disambiguated_condition =
1431  disambiguate_rex(join_node->getCondition(), get_node_output(join_node.get()));
1432  join_node->setCondition(disambiguated_condition);
1433  continue;
1434  }
1435  const auto project_node = std::dynamic_pointer_cast<RelProject>(ra_node);
1436  if (project_node) {
1437  bind_project_to_input(project_node.get(),
1438  get_node_output(project_node->getInput(0)));
1439  continue;
1440  }
1441  const auto table_func_node = std::dynamic_pointer_cast<RelTableFunction>(ra_node);
1442  if (table_func_node) {
1443  /*
1444  Collect all inputs from table function input (non-literal)
1445  arguments.
1446  */
1447  RANodeOutput input;
1448  input.reserve(table_func_node->inputCount());
1449  for (size_t i = 0; i < table_func_node->inputCount(); i++) {
1450  auto node_output = get_node_output(table_func_node->getInput(i));
1451  input.insert(input.end(), node_output.begin(), node_output.end());
1452  }
1453  bind_table_func_to_input(table_func_node.get(), input);
1454  }
1455  }
1456 }
1457 
1458 void handle_query_hint(const std::vector<std::shared_ptr<RelAlgNode>>& nodes,
1459  RelAlgDagBuilder* dag_builder) noexcept {
1460  // query hint is delivered by the above three nodes
1461  // when a query block has top-sort node, a hint is registered to
1462  // one of the node which locates at the nearest from the sort node
1463  RegisteredQueryHint global_query_hint;
1464  for (auto node : nodes) {
1465  Hints* hint_delivered = nullptr;
1466  const auto agg_node = std::dynamic_pointer_cast<RelAggregate>(node);
1467  if (agg_node) {
1468  if (agg_node->hasDeliveredHint()) {
1469  hint_delivered = agg_node->getDeliveredHints();
1470  }
1471  }
1472  const auto project_node = std::dynamic_pointer_cast<RelProject>(node);
1473  if (project_node) {
1474  if (project_node->hasDeliveredHint()) {
1475  hint_delivered = project_node->getDeliveredHints();
1476  }
1477  }
1478  const auto compound_node = std::dynamic_pointer_cast<RelCompound>(node);
1479  if (compound_node) {
1480  if (compound_node->hasDeliveredHint()) {
1481  hint_delivered = compound_node->getDeliveredHints();
1482  }
1483  }
1484  if (hint_delivered && !hint_delivered->empty()) {
1485  dag_builder->registerQueryHints(node, hint_delivered, global_query_hint);
1486  }
1487  }
1488  dag_builder->setGlobalQueryHints(global_query_hint);
1489 }
1490 
1491 void compute_node_hash(const std::vector<std::shared_ptr<RelAlgNode>>& nodes) {
1492  // compute each rel node's hash value in advance to avoid inconsistency of their hash
1493  // values depending on the toHash's caller
1494  // specifically, we manipulate our logical query plan before retrieving query step
1495  // sequence but once we compute a hash value we cached it so there is no way to update
1496  // it after the plan has been changed starting from the top node, we compute the hash
1497  // value (top-down manner)
1498  std::for_each(
1499  nodes.rbegin(), nodes.rend(), [](const std::shared_ptr<RelAlgNode>& node) {
1500  auto node_hash = node->toHash();
1501  CHECK_NE(node_hash, static_cast<size_t>(0));
1502  });
1503 }
1504 
1505 void mark_nops(const std::vector<std::shared_ptr<RelAlgNode>>& nodes) noexcept {
1506  for (auto node : nodes) {
1507  const auto agg_node = std::dynamic_pointer_cast<RelAggregate>(node);
1508  if (!agg_node || agg_node->getAggExprsCount()) {
1509  continue;
1510  }
1511  CHECK_EQ(size_t(1), node->inputCount());
1512  const auto agg_input_node = dynamic_cast<const RelAggregate*>(node->getInput(0));
1513  if (agg_input_node && !agg_input_node->getAggExprsCount() &&
1514  agg_node->getGroupByCount() == agg_input_node->getGroupByCount()) {
1515  agg_node->markAsNop();
1516  }
1517  }
1518 }
1519 
1520 namespace {
1521 
1522 std::vector<const Rex*> reproject_targets(
1523  const RelProject* simple_project,
1524  const std::vector<const Rex*>& target_exprs) noexcept {
1525  std::vector<const Rex*> result;
1526  for (size_t i = 0; i < simple_project->size(); ++i) {
1527  const auto input_rex = dynamic_cast<const RexInput*>(simple_project->getProjectAt(i));
1528  CHECK(input_rex);
1529  CHECK_LT(static_cast<size_t>(input_rex->getIndex()), target_exprs.size());
1530  result.push_back(target_exprs[input_rex->getIndex()]);
1531  }
1532  return result;
1533 }
1534 
1541  public:
1543  const RelAlgNode* node_to_keep,
1544  const std::vector<std::unique_ptr<const RexScalar>>& scalar_sources)
1545  : node_to_keep_(node_to_keep), scalar_sources_(scalar_sources) {}
1546 
1547  // Reproject the RexInput from its current RA Node to the RA Node we intend to keep
1548  RetType visitInput(const RexInput* input) const final {
1549  if (input->getSourceNode() == node_to_keep_) {
1550  const auto index = input->getIndex();
1551  CHECK_LT(index, scalar_sources_.size());
1552  return visit(scalar_sources_[index].get());
1553  } else {
1554  return input->deepCopy();
1555  }
1556  }
1557 
1558  private:
1560  const std::vector<std::unique_ptr<const RexScalar>>& scalar_sources_;
1561 };
1562 
1563 } // namespace
1564 
1566  std::vector<std::shared_ptr<RelAlgNode>>& nodes,
1567  const std::vector<size_t>& pattern,
1568  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
1569  query_hints) noexcept {
1570  CHECK_GE(pattern.size(), size_t(2));
1571  CHECK_LE(pattern.size(), size_t(4));
1572 
1573  std::unique_ptr<const RexScalar> filter_rex;
1574  std::vector<std::unique_ptr<const RexScalar>> scalar_sources;
1575  size_t groupby_count{0};
1576  std::vector<std::string> fields;
1577  std::vector<const RexAgg*> agg_exprs;
1578  std::vector<const Rex*> target_exprs;
1579  bool first_project{true};
1580  bool is_agg{false};
1581  RelAlgNode* last_node{nullptr};
1582 
1583  std::shared_ptr<ModifyManipulationTarget> manipulation_target;
1584  size_t node_hash{0};
1585  unsigned node_id{0};
1586  bool hint_registered{false};
1587  RegisteredQueryHint registered_query_hint = RegisteredQueryHint::defaults();
1588  for (const auto node_idx : pattern) {
1589  const auto ra_node = nodes[node_idx];
1590  auto registered_query_hint_map_it = query_hints.find(ra_node->toHash());
1591  if (registered_query_hint_map_it != query_hints.end()) {
1592  auto& registered_query_hint_map = registered_query_hint_map_it->second;
1593  auto registered_query_hint_it = registered_query_hint_map.find(ra_node->getId());
1594  if (registered_query_hint_it != registered_query_hint_map.end()) {
1595  hint_registered = true;
1596  node_hash = registered_query_hint_map_it->first;
1597  node_id = registered_query_hint_it->first;
1598  registered_query_hint = registered_query_hint_it->second;
1599  }
1600  }
1601  const auto ra_filter = std::dynamic_pointer_cast<RelFilter>(ra_node);
1602  if (ra_filter) {
1603  CHECK(!filter_rex);
1604  filter_rex.reset(ra_filter->getAndReleaseCondition());
1605  CHECK(filter_rex);
1606  last_node = ra_node.get();
1607  continue;
1608  }
1609  const auto ra_project = std::dynamic_pointer_cast<RelProject>(ra_node);
1610  if (ra_project) {
1611  fields = ra_project->getFields();
1612  manipulation_target = ra_project;
1613 
1614  if (first_project) {
1615  CHECK_EQ(size_t(1), ra_project->inputCount());
1616  // Rebind the input of the project to the input of the filter itself
1617  // since we know that we'll evaluate the filter on the fly, with no
1618  // intermediate buffer.
1619  const auto filter_input = dynamic_cast<const RelFilter*>(ra_project->getInput(0));
1620  if (filter_input) {
1621  CHECK_EQ(size_t(1), filter_input->inputCount());
1622  bind_project_to_input(ra_project.get(),
1623  get_node_output(filter_input->getInput(0)));
1624  }
1625  scalar_sources = ra_project->getExpressionsAndRelease();
1626  for (const auto& scalar_expr : scalar_sources) {
1627  target_exprs.push_back(scalar_expr.get());
1628  }
1629  first_project = false;
1630  } else {
1631  if (ra_project->isSimple()) {
1632  target_exprs = reproject_targets(ra_project.get(), target_exprs);
1633  } else {
1634  // TODO(adb): This is essentially a more general case of simple project, we
1635  // could likely merge the two
1636  std::vector<const Rex*> result;
1637  RexInputReplacementVisitor visitor(last_node, scalar_sources);
1638  for (size_t i = 0; i < ra_project->size(); ++i) {
1639  const auto rex = ra_project->getProjectAt(i);
1640  if (auto rex_input = dynamic_cast<const RexInput*>(rex)) {
1641  const auto index = rex_input->getIndex();
1642  CHECK_LT(index, target_exprs.size());
1643  result.push_back(target_exprs[index]);
1644  } else {
1645  scalar_sources.push_back(visitor.visit(rex));
1646  result.push_back(scalar_sources.back().get());
1647  }
1648  }
1649  target_exprs = result;
1650  }
1651  }
1652  last_node = ra_node.get();
1653  continue;
1654  }
1655  const auto ra_aggregate = std::dynamic_pointer_cast<RelAggregate>(ra_node);
1656  if (ra_aggregate) {
1657  is_agg = true;
1658  fields = ra_aggregate->getFields();
1659  agg_exprs = ra_aggregate->getAggregatesAndRelease();
1660  groupby_count = ra_aggregate->getGroupByCount();
1661  decltype(target_exprs){}.swap(target_exprs);
1662  CHECK_LE(groupby_count, scalar_sources.size());
1663  for (size_t group_idx = 0; group_idx < groupby_count; ++group_idx) {
1664  const auto rex_ref = new RexRef(group_idx + 1);
1665  target_exprs.push_back(rex_ref);
1666  scalar_sources.emplace_back(rex_ref);
1667  }
1668  for (const auto rex_agg : agg_exprs) {
1669  target_exprs.push_back(rex_agg);
1670  }
1671  last_node = ra_node.get();
1672  continue;
1673  }
1674  }
1675 
1676  auto compound_node =
1677  std::make_shared<RelCompound>(filter_rex,
1678  target_exprs,
1679  groupby_count,
1680  agg_exprs,
1681  fields,
1682  scalar_sources,
1683  is_agg,
1684  manipulation_target->isUpdateViaSelect(),
1685  manipulation_target->isDeleteViaSelect(),
1686  manipulation_target->isVarlenUpdateRequired(),
1687  manipulation_target->getModifiedTableDescriptor(),
1688  manipulation_target->getTargetColumns());
1689  auto old_node = nodes[pattern.back()];
1690  nodes[pattern.back()] = compound_node;
1691  auto first_node = nodes[pattern.front()];
1692  CHECK_EQ(size_t(1), first_node->inputCount());
1693  compound_node->addManagedInput(first_node->getAndOwnInput(0));
1694  if (hint_registered) {
1695  // pass the registered hint from the origin node to newly created compound node
1696  // where it is coalesced
1697  auto registered_query_hint_map_it = query_hints.find(node_hash);
1698  CHECK(registered_query_hint_map_it != query_hints.end());
1699  auto registered_query_hint_map = registered_query_hint_map_it->second;
1700  if (registered_query_hint_map.size() > 1) {
1701  registered_query_hint_map.erase(node_id);
1702  } else {
1703  CHECK_EQ(registered_query_hint_map.size(), static_cast<size_t>(1));
1704  query_hints.erase(node_hash);
1705  }
1706  std::unordered_map<unsigned, RegisteredQueryHint> hint_map;
1707  hint_map.emplace(compound_node->getId(), registered_query_hint);
1708  query_hints.emplace(compound_node->toHash(), hint_map);
1709  }
1710  for (size_t i = 0; i < pattern.size() - 1; ++i) {
1711  nodes[pattern[i]].reset();
1712  }
1713  for (auto node : nodes) {
1714  if (!node) {
1715  continue;
1716  }
1717  node->replaceInput(old_node, compound_node);
1718  }
1719 }
1720 
1721 class RANodeIterator : public std::vector<std::shared_ptr<RelAlgNode>>::const_iterator {
1722  using ElementType = std::shared_ptr<RelAlgNode>;
1723  using Super = std::vector<ElementType>::const_iterator;
1724  using Container = std::vector<ElementType>;
1725 
1726  public:
1727  enum class AdvancingMode { DUChain, InOrder };
1728 
1729  explicit RANodeIterator(const Container& nodes)
1730  : Super(nodes.begin()), owner_(nodes), nodeCount_([&nodes]() -> size_t {
1731  size_t non_zero_count = 0;
1732  for (const auto& node : nodes) {
1733  if (node) {
1734  ++non_zero_count;
1735  }
1736  }
1738  }()) {}
1739 
1740  explicit operator size_t() {
1741  return std::distance(owner_.begin(), *static_cast<Super*>(this));
1742  }
1743 
1744  RANodeIterator operator++() = delete;
1745 
1746  void advance(AdvancingMode mode) {
1747  Super& super = *this;
1748  switch (mode) {
1749  case AdvancingMode::DUChain: {
1750  size_t use_count = 0;
1751  Super only_use = owner_.end();
1752  for (Super nodeIt = std::next(super); nodeIt != owner_.end(); ++nodeIt) {
1753  if (!*nodeIt) {
1754  continue;
1755  }
1756  for (size_t i = 0; i < (*nodeIt)->inputCount(); ++i) {
1757  if ((*super) == (*nodeIt)->getAndOwnInput(i)) {
1758  ++use_count;
1759  if (1 == use_count) {
1760  only_use = nodeIt;
1761  } else {
1762  super = owner_.end();
1763  return;
1764  }
1765  }
1766  }
1767  }
1768  super = only_use;
1769  break;
1770  }
1771  case AdvancingMode::InOrder:
1772  for (size_t i = 0; i != owner_.size(); ++i) {
1773  if (!visited_.count(i)) {
1774  super = owner_.begin();
1775  std::advance(super, i);
1776  return;
1777  }
1778  }
1779  super = owner_.end();
1780  break;
1781  default:
1782  CHECK(false);
1783  }
1784  }
1785 
1786  bool allVisited() { return visited_.size() == nodeCount_; }
1787 
1789  visited_.insert(size_t(*this));
1790  Super& super = *this;
1791  return *super;
1792  }
1793 
1794  const ElementType* operator->() { return &(operator*()); }
1795 
1796  private:
1798  const size_t nodeCount_;
1799  std::unordered_set<size_t> visited_;
1800 };
1801 
1802 namespace {
1803 
1804 bool input_can_be_coalesced(const RelAlgNode* parent_node,
1805  const size_t index,
1806  const bool first_rex_is_input) {
1807  if (auto agg_node = dynamic_cast<const RelAggregate*>(parent_node)) {
1808  if (index == 0 && agg_node->getGroupByCount() > 0) {
1809  return true;
1810  } else {
1811  // Is an aggregated target, only allow the project to be elided if the aggregate
1812  // target is simply passed through (i.e. if the top level expression attached to
1813  // the project node is a RexInput expression)
1814  return first_rex_is_input;
1815  }
1816  }
1817  return first_rex_is_input;
1818 }
1819 
1826  public:
1827  bool visitInput(const RexInput* input) const final {
1828  // The top level expression node is checked before we apply the visitor. If we get
1829  // here, this input rex is a child of another rex node, and we handle the can be
1830  // coalesced check slightly differently
1831  return input_can_be_coalesced(input->getSourceNode(), input->getIndex(), false);
1832  }
1833 
1834  bool visitLiteral(const RexLiteral*) const final { return false; }
1835 
1836  bool visitSubQuery(const RexSubQuery*) const final { return false; }
1837 
1838  bool visitRef(const RexRef*) const final { return false; }
1839 
1840  protected:
1841  bool aggregateResult(const bool& aggregate, const bool& next_result) const final {
1842  return aggregate && next_result;
1843  }
1844 
1845  bool defaultResult() const final { return true; }
1846 };
1847 
1848 // Detect the window function SUM pattern: CASE WHEN COUNT() > 0 THEN SUM ELSE 0
1850  const auto case_operator = dynamic_cast<const RexCase*>(rex);
1851  if (case_operator && case_operator->branchCount() == 1) {
1852  const auto then_window =
1853  dynamic_cast<const RexWindowFunctionOperator*>(case_operator->getThen(0));
1854  if (then_window && then_window->getKind() == SqlWindowFunctionKind::SUM_INTERNAL) {
1855  return true;
1856  }
1857  }
1858  return false;
1859 }
1860 
1861 // Check for Window Function AVG:
1862 // (CASE WHEN count > 0 THEN sum ELSE 0) / COUNT
1864  const RexOperator* divide_operator = dynamic_cast<const RexOperator*>(rex);
1865  if (divide_operator && divide_operator->getOperator() == kDIVIDE) {
1866  CHECK_EQ(divide_operator->size(), size_t(2));
1867  const auto case_operator =
1868  dynamic_cast<const RexCase*>(divide_operator->getOperand(0));
1869  const auto second_window =
1870  dynamic_cast<const RexWindowFunctionOperator*>(divide_operator->getOperand(1));
1871  if (case_operator && second_window &&
1872  second_window->getKind() == SqlWindowFunctionKind::COUNT) {
1873  if (is_window_function_sum(case_operator)) {
1874  return true;
1875  }
1876  }
1877  }
1878  return false;
1879 }
1880 
1881 // Detect both window function operators and window function operators embedded in case
1882 // statements (for null handling)
1884  if (dynamic_cast<const RexWindowFunctionOperator*>(rex)) {
1885  return true;
1886  }
1887 
1888  // unwrap from casts, if they exist
1889  const auto rex_cast = dynamic_cast<const RexOperator*>(rex);
1890  if (rex_cast && rex_cast->getOperator() == kCAST) {
1891  CHECK_EQ(rex_cast->size(), size_t(1));
1892  return is_window_function_operator(rex_cast->getOperand(0));
1893  }
1894 
1896  return true;
1897  }
1898 
1899  return false;
1900 }
1901 
1902 } // namespace
1903 
1905  std::vector<std::shared_ptr<RelAlgNode>>& nodes,
1906  const std::vector<const RelAlgNode*>& left_deep_joins,
1907  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
1908  query_hints) {
1909  enum class CoalesceState { Initial, Filter, FirstProject, Aggregate };
1910  std::vector<size_t> crt_pattern;
1911  CoalesceState crt_state{CoalesceState::Initial};
1912 
1913  auto reset_state = [&crt_pattern, &crt_state]() {
1914  crt_state = CoalesceState::Initial;
1915  std::vector<size_t>().swap(crt_pattern);
1916  };
1917 
1918  for (RANodeIterator nodeIt(nodes); !nodeIt.allVisited();) {
1919  const auto ra_node = nodeIt != nodes.end() ? *nodeIt : nullptr;
1920  switch (crt_state) {
1921  case CoalesceState::Initial: {
1922  if (std::dynamic_pointer_cast<const RelFilter>(ra_node) &&
1923  std::find(left_deep_joins.begin(), left_deep_joins.end(), ra_node.get()) ==
1924  left_deep_joins.end()) {
1925  crt_pattern.push_back(size_t(nodeIt));
1926  crt_state = CoalesceState::Filter;
1927  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1928  } else if (auto project_node =
1929  std::dynamic_pointer_cast<const RelProject>(ra_node)) {
1930  if (project_node->hasWindowFunctionExpr()) {
1931  nodeIt.advance(RANodeIterator::AdvancingMode::InOrder);
1932  } else {
1933  crt_pattern.push_back(size_t(nodeIt));
1934  crt_state = CoalesceState::FirstProject;
1935  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1936  }
1937  } else {
1938  nodeIt.advance(RANodeIterator::AdvancingMode::InOrder);
1939  }
1940  break;
1941  }
1942  case CoalesceState::Filter: {
1943  if (auto project_node = std::dynamic_pointer_cast<const RelProject>(ra_node)) {
1944  // Given we now add preceding projects for all window functions following
1945  // RelFilter nodes, the following should never occur
1946  CHECK(!project_node->hasWindowFunctionExpr());
1947  crt_pattern.push_back(size_t(nodeIt));
1948  crt_state = CoalesceState::FirstProject;
1949  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1950  } else {
1951  reset_state();
1952  }
1953  break;
1954  }
1955  case CoalesceState::FirstProject: {
1956  if (std::dynamic_pointer_cast<const RelAggregate>(ra_node)) {
1957  crt_pattern.push_back(size_t(nodeIt));
1958  crt_state = CoalesceState::Aggregate;
1959  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1960  } else {
1961  if (crt_pattern.size() >= 2) {
1962  create_compound(nodes, crt_pattern, query_hints);
1963  }
1964  reset_state();
1965  }
1966  break;
1967  }
1968  case CoalesceState::Aggregate: {
1969  if (auto project_node = std::dynamic_pointer_cast<const RelProject>(ra_node)) {
1970  if (!project_node->hasWindowFunctionExpr()) {
1971  // TODO(adb): overloading the simple project terminology again here
1972  bool is_simple_project{true};
1973  for (size_t i = 0; i < project_node->size(); i++) {
1974  const auto scalar_rex = project_node->getProjectAt(i);
1975  // If the top level scalar rex is an input node, we can bypass the visitor
1976  if (auto input_rex = dynamic_cast<const RexInput*>(scalar_rex)) {
1978  input_rex->getSourceNode(), input_rex->getIndex(), true)) {
1979  is_simple_project = false;
1980  break;
1981  }
1982  continue;
1983  }
1984  CoalesceSecondaryProjectVisitor visitor;
1985  if (!visitor.visit(project_node->getProjectAt(i))) {
1986  is_simple_project = false;
1987  break;
1988  }
1989  }
1990  if (is_simple_project) {
1991  crt_pattern.push_back(size_t(nodeIt));
1992  nodeIt.advance(RANodeIterator::AdvancingMode::InOrder);
1993  }
1994  }
1995  }
1996  CHECK_GE(crt_pattern.size(), size_t(2));
1997  create_compound(nodes, crt_pattern, query_hints);
1998  reset_state();
1999  break;
2000  }
2001  default:
2002  CHECK(false);
2003  }
2004  }
2005  if (crt_state == CoalesceState::FirstProject || crt_state == CoalesceState::Aggregate) {
2006  if (crt_pattern.size() >= 2) {
2007  create_compound(nodes, crt_pattern, query_hints);
2008  }
2009  CHECK(!crt_pattern.empty());
2010  }
2011 }
2012 
2013 class WindowFunctionCollector : public RexVisitor<void*> {
2014  public:
2016  std::unordered_map<size_t, const RexScalar*>& collected_window_func)
2017  : collected_window_func_(collected_window_func) {}
2018 
2019  protected:
2020  // Detect embedded window function expressions in operators
2021  void* visitOperator(const RexOperator* rex_operator) const final {
2022  if (is_window_function_operator(rex_operator)) {
2023  collected_window_func_.emplace(rex_operator->toHash(), rex_operator);
2024  }
2025  const size_t operand_count = rex_operator->size();
2026  for (size_t i = 0; i < operand_count; ++i) {
2027  const auto operand = rex_operator->getOperand(i);
2028  if (is_window_function_operator(operand)) {
2029  // Handle both RexWindowFunctionOperators and window functions built up from
2030  // multiple RexScalar objects (e.g. AVG)
2031  collected_window_func_.emplace(operand->toHash(), operand);
2032  } else {
2033  visit(operand);
2034  }
2035  }
2036  return defaultResult();
2037  }
2038 
2039  // Detect embedded window function expressions in case statements. Note that this may
2040  // manifest as a nested case statement inside a top level case statement, as some
2041  // window functions (sum, avg) are represented as a case statement. Use the
2042  // is_window_function_operator helper to detect complete window function expressions.
2043  void* visitCase(const RexCase* rex_case) const final {
2044  if (is_window_function_operator(rex_case)) {
2045  collected_window_func_.emplace(rex_case->toHash(), rex_case);
2046  return nullptr;
2047  }
2048 
2049  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
2050  const auto when = rex_case->getWhen(i);
2051  if (is_window_function_operator(when)) {
2052  collected_window_func_.emplace(when->toHash(), when);
2053  } else {
2054  visit(when);
2055  }
2056  const auto then = rex_case->getThen(i);
2057  if (is_window_function_operator(then)) {
2058  collected_window_func_.emplace(then->toHash(), then);
2059  } else {
2060  visit(then);
2061  }
2062  }
2063  if (rex_case->getElse()) {
2064  auto else_expr = rex_case->getElse();
2065  if (is_window_function_operator(else_expr)) {
2066  collected_window_func_.emplace(else_expr->toHash(), else_expr);
2067  } else {
2068  visit(else_expr);
2069  }
2070  }
2071  return defaultResult();
2072  }
2073 
2074  void* defaultResult() const final { return nullptr; }
2075 
2076  private:
2077  std::unordered_map<size_t, const RexScalar*>& collected_window_func_;
2078 };
2079 
2081  public:
2083  std::unordered_set<size_t>& collected_window_func_hash,
2084  std::vector<std::unique_ptr<const RexScalar>>& new_rex_input_for_window_func,
2085  std::unordered_map<size_t, size_t>& window_func_to_new_rex_input_idx_map,
2086  RelProject* new_project,
2087  std::unordered_map<size_t, std::unique_ptr<const RexInput>>&
2088  new_rex_input_from_child_node)
2089  : collected_window_func_hash_(collected_window_func_hash)
2090  , new_rex_input_for_window_func_(new_rex_input_for_window_func)
2091  , window_func_to_new_rex_input_idx_map_(window_func_to_new_rex_input_idx_map)
2092  , new_project_(new_project)
2093  , new_rex_input_from_child_node_(new_rex_input_from_child_node) {
2094  CHECK_EQ(collected_window_func_hash_.size(),
2095  window_func_to_new_rex_input_idx_map_.size());
2096  for (auto hash : collected_window_func_hash_) {
2097  auto rex_it = window_func_to_new_rex_input_idx_map_.find(hash);
2098  CHECK(rex_it != window_func_to_new_rex_input_idx_map_.end());
2099  CHECK_LT(rex_it->second, new_rex_input_for_window_func_.size());
2100  }
2101  CHECK(new_project_);
2102  }
2103 
2104  protected:
2105  RetType visitInput(const RexInput* rex_input) const final {
2106  if (rex_input->getSourceNode() != new_project_) {
2107  const auto cur_index = rex_input->getIndex();
2108  auto cur_source_node = rex_input->getSourceNode();
2109  std::string field_name = "";
2110  if (auto cur_project_node = dynamic_cast<const RelProject*>(cur_source_node)) {
2111  field_name = cur_project_node->getFieldName(cur_index);
2112  }
2113  auto rex_input_hash = rex_input->toHash();
2114  auto rex_input_it = new_rex_input_from_child_node_.find(rex_input_hash);
2115  if (rex_input_it == new_rex_input_from_child_node_.end()) {
2116  auto new_rex_input =
2117  std::make_unique<RexInput>(new_project_, new_project_->size());
2118  new_project_->appendInput(field_name, rex_input->deepCopy());
2119  new_rex_input_from_child_node_.emplace(rex_input_hash, new_rex_input->deepCopy());
2120  return new_rex_input;
2121  } else {
2122  return rex_input_it->second->deepCopy();
2123  }
2124  } else {
2125  return rex_input->deepCopy();
2126  }
2127  }
2128 
2129  RetType visitOperator(const RexOperator* rex_operator) const final {
2130  auto new_rex_idx = is_collected_window_function(rex_operator->toHash());
2131  if (new_rex_idx) {
2132  return get_new_rex_input(*new_rex_idx);
2133  }
2134 
2135  const auto rex_window_function_operator =
2136  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
2137  if (rex_window_function_operator) {
2138  // Deep copy the embedded window function operator
2139  return visitWindowFunctionOperator(rex_window_function_operator);
2140  }
2141 
2142  const size_t operand_count = rex_operator->size();
2143  std::vector<RetType> new_opnds;
2144  for (size_t i = 0; i < operand_count; ++i) {
2145  const auto operand = rex_operator->getOperand(i);
2146  auto new_rex_idx_for_operand = is_collected_window_function(operand->toHash());
2147  if (new_rex_idx_for_operand) {
2148  new_opnds.push_back(get_new_rex_input(*new_rex_idx_for_operand));
2149  } else {
2150  new_opnds.emplace_back(visit(rex_operator->getOperand(i)));
2151  }
2152  }
2153  return rex_operator->getDisambiguated(new_opnds);
2154  }
2155 
2156  RetType visitCase(const RexCase* rex_case) const final {
2157  auto new_rex_idx = is_collected_window_function(rex_case->toHash());
2158  if (new_rex_idx) {
2159  return get_new_rex_input(*new_rex_idx);
2160  }
2161 
2162  std::vector<std::pair<RetType, RetType>> new_pair_list;
2163  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
2164  auto when_operand = rex_case->getWhen(i);
2165  auto new_rex_idx_for_when_operand =
2166  is_collected_window_function(when_operand->toHash());
2167 
2168  auto then_operand = rex_case->getThen(i);
2169  auto new_rex_idx_for_then_operand =
2170  is_collected_window_function(then_operand->toHash());
2171 
2172  new_pair_list.emplace_back(
2173  new_rex_idx_for_when_operand ? get_new_rex_input(*new_rex_idx_for_when_operand)
2174  : visit(when_operand),
2175  new_rex_idx_for_then_operand ? get_new_rex_input(*new_rex_idx_for_then_operand)
2176  : visit(then_operand));
2177  }
2178  auto new_rex_idx_for_else_operand =
2179  is_collected_window_function(rex_case->getElse()->toHash());
2180  auto new_else = new_rex_idx_for_else_operand
2181  ? get_new_rex_input(*new_rex_idx_for_else_operand)
2182  : visit(rex_case->getElse());
2183  return std::make_unique<RexCase>(new_pair_list, new_else);
2184  }
2185 
2186  private:
2187  std::optional<size_t> is_collected_window_function(size_t rex_hash) const {
2188  auto rex_it = window_func_to_new_rex_input_idx_map_.find(rex_hash);
2189  if (rex_it != window_func_to_new_rex_input_idx_map_.end()) {
2190  return rex_it->second;
2191  }
2192  return std::nullopt;
2193  }
2194 
2195  std::unique_ptr<const RexScalar> get_new_rex_input(size_t rex_idx) const {
2196  CHECK_GE(rex_idx, 0UL);
2197  CHECK_LT(rex_idx, new_rex_input_for_window_func_.size());
2198  auto& new_rex_input = new_rex_input_for_window_func_.at(rex_idx);
2199  CHECK(new_rex_input);
2200  auto copied_rex_input = copier_.visit(new_rex_input.get());
2201  return copied_rex_input;
2202  }
2203 
2204  std::unordered_set<size_t>& collected_window_func_hash_;
2205  // we should have new rex_input for each window function collected
2206  std::vector<std::unique_ptr<const RexScalar>>& new_rex_input_for_window_func_;
2207  // an index to get a new rex_input for the collected window function
2208  std::unordered_map<size_t, size_t>& window_func_to_new_rex_input_idx_map_;
2210  std::unordered_map<size_t, std::unique_ptr<const RexInput>>&
2213 };
2214 
2216  std::shared_ptr<RelProject> prev_node,
2217  std::shared_ptr<RelProject> new_node,
2218  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
2219  query_hints) {
2220  auto delivered_hints = prev_node->getDeliveredHints();
2221  bool needs_propagate_hints = !delivered_hints->empty();
2222  if (needs_propagate_hints) {
2223  for (auto& kv : *delivered_hints) {
2224  new_node->addHint(kv.second);
2225  }
2226  auto prev_it = query_hints.find(prev_node->toHash());
2227  // query hint for the prev projection node should be registered
2228  CHECK(prev_it != query_hints.end());
2229  auto prev_hint_it = prev_it->second.find(prev_node->getId());
2230  CHECK(prev_hint_it != prev_it->second.end());
2231  std::unordered_map<unsigned, RegisteredQueryHint> hint_map;
2232  hint_map.emplace(new_node->getId(), prev_hint_it->second);
2233  query_hints.emplace(new_node->toHash(), hint_map);
2234  }
2235 }
2236 
2260  std::vector<std::shared_ptr<RelAlgNode>>& nodes,
2261  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
2262  query_hints) {
2263  std::list<std::shared_ptr<RelAlgNode>> node_list(nodes.begin(), nodes.end());
2264  for (auto node_itr = node_list.begin(); node_itr != node_list.end(); ++node_itr) {
2265  const auto node = *node_itr;
2266  auto window_func_project_node = std::dynamic_pointer_cast<RelProject>(node);
2267  if (!window_func_project_node) {
2268  continue;
2269  }
2270 
2271  const auto prev_node_itr = std::prev(node_itr);
2272  const auto prev_node = *prev_node_itr;
2273  CHECK(prev_node);
2274 
2275  // map scalar expression index in the project node to window function ptr
2276  std::unordered_map<size_t, const RexScalar*> collected_window_func;
2277  WindowFunctionCollector collector(collected_window_func);
2278  // Iterate the target exprs of the project node and check for window function
2279  // expressions. If an embedded expression exists, collect it
2280  for (size_t i = 0; i < window_func_project_node->size(); i++) {
2281  const auto scalar_rex = window_func_project_node->getProjectAt(i);
2282  if (is_window_function_operator(scalar_rex)) {
2283  // top level window function exprs are fine
2284  continue;
2285  }
2286  collector.visit(scalar_rex);
2287  }
2288 
2289  if (!collected_window_func.empty()) {
2290  // we have a nested window function expression
2291  std::unordered_set<size_t> collected_window_func_hash;
2292  // the current window function needs a set of new rex input which references
2293  // expressions in the newly introduced projection node
2294  std::vector<std::unique_ptr<const RexScalar>> new_rex_input_for_window_func;
2295  // a target projection expression of the newly created projection node
2296  std::vector<std::unique_ptr<const RexScalar>> new_scalar_expr_for_window_project;
2297  // a map between nested window function (hash val) and
2298  // its rex index stored in the `new_rex_input_for_window_func`
2299  std::unordered_map<size_t, size_t> window_func_to_new_rex_input_idx_map;
2300  // a map between RexInput of the current window function projection node (hash val)
2301  // and its corresponding new RexInput which is pushed down to the new projection
2302  // node
2303  std::unordered_map<size_t, std::unique_ptr<const RexInput>>
2304  new_rex_input_from_child_node;
2305  RexDeepCopyVisitor copier;
2306 
2307  std::vector<std::unique_ptr<const RexScalar>> dummy_scalar_exprs;
2308  std::vector<std::string> dummy_fields;
2309  std::vector<std::string> new_project_field_names;
2310  // create a new project node, it will contain window function expressions
2311  auto new_project =
2312  std::make_shared<RelProject>(dummy_scalar_exprs, dummy_fields, prev_node);
2313  // insert this new project node between the current window project node and its
2314  // child node
2315  node_list.insert(node_itr, new_project);
2316 
2317  // retrieve various information to replace expressions in the current window
2318  // function project node w/ considering scalar expressions in the new project node
2319  std::for_each(collected_window_func.begin(),
2320  collected_window_func.end(),
2321  [&new_project_field_names,
2322  &collected_window_func_hash,
2323  &new_rex_input_for_window_func,
2324  &new_scalar_expr_for_window_project,
2325  &copier,
2326  &new_project,
2327  &window_func_to_new_rex_input_idx_map](const auto& kv) {
2328  // compute window function expr's hash, and create a new rex_input
2329  // for it
2330  collected_window_func_hash.insert(kv.first);
2331 
2332  // map an old expression in the window function project node
2333  // to an index of the corresponding new RexInput
2334  const auto rex_idx = new_rex_input_for_window_func.size();
2335  window_func_to_new_rex_input_idx_map.emplace(kv.first, rex_idx);
2336 
2337  // create a new RexInput and make it as one of new expression of the
2338  // newly created project node
2339  new_rex_input_for_window_func.emplace_back(
2340  std::make_unique<const RexInput>(new_project.get(), rex_idx));
2341  new_scalar_expr_for_window_project.push_back(
2342  std::move(copier.visit(kv.second)));
2343  new_project_field_names.emplace_back("");
2344  });
2345  new_project->setExpressions(new_scalar_expr_for_window_project);
2346  new_project->setFields(std::move(new_project_field_names));
2347 
2348  auto window_func_scalar_exprs =
2349  window_func_project_node->getExpressionsAndRelease();
2350  RexWindowFuncReplacementVisitor replacer(collected_window_func_hash,
2351  new_rex_input_for_window_func,
2352  window_func_to_new_rex_input_idx_map,
2353  new_project.get(),
2354  new_rex_input_from_child_node);
2355  size_t rex_idx = 0;
2356  for (auto& scalar_expr : window_func_scalar_exprs) {
2357  // try to replace the old expressions in the window function project node
2358  // with expressions of the newly created project node
2359  auto new_parent_rex = replacer.visit(scalar_expr.get());
2360  window_func_scalar_exprs[rex_idx] = std::move(new_parent_rex);
2361  rex_idx++;
2362  }
2363  // Update the previous window project node
2364  window_func_project_node->setExpressions(window_func_scalar_exprs);
2365  window_func_project_node->replaceInput(prev_node, new_project);
2366  propagate_hints_to_new_project(window_func_project_node, new_project, query_hints);
2367  }
2368  }
2369  nodes.assign(node_list.begin(), node_list.end());
2370 }
2371 
2372 using RexInputSet = std::unordered_set<RexInput>;
2373 
2374 class RexInputCollector : public RexVisitor<RexInputSet> {
2375  public:
2376  RexInputSet visitInput(const RexInput* input) const override {
2377  return RexInputSet{*input};
2378  }
2379 
2380  protected:
2382  const RexInputSet& next_result) const override {
2383  auto result = aggregate;
2384  result.insert(next_result.begin(), next_result.end());
2385  return result;
2386  }
2387 };
2388 
2402  std::vector<std::shared_ptr<RelAlgNode>>& nodes,
2403  const bool always_add_project_if_first_project_is_window_expr,
2404  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
2405  query_hints) {
2406  std::list<std::shared_ptr<RelAlgNode>> node_list(nodes.begin(), nodes.end());
2407  size_t project_node_counter{0};
2408  for (auto node_itr = node_list.begin(); node_itr != node_list.end(); ++node_itr) {
2409  const auto node = *node_itr;
2410 
2411  auto window_func_project_node = std::dynamic_pointer_cast<RelProject>(node);
2412  if (!window_func_project_node) {
2413  continue;
2414  }
2415  project_node_counter++;
2416  if (!window_func_project_node->hasWindowFunctionExpr()) {
2417  // this projection node does not have a window function
2418  // expression -- skip to the next node in the DAG.
2419  continue;
2420  }
2421 
2422  auto need_pushdown_generic_expr = [&window_func_project_node]() {
2423  for (size_t i = 0; i < window_func_project_node->size(); ++i) {
2424  const auto projected_target = window_func_project_node->getProjectAt(i);
2425  if (auto window_expr =
2426  dynamic_cast<const RexWindowFunctionOperator*>(projected_target)) {
2427  for (const auto& partition_key : window_expr->getPartitionKeys()) {
2428  auto partition_input = dynamic_cast<const RexInput*>(partition_key.get());
2429  if (!partition_input) {
2430  return true;
2431  }
2432  }
2433  for (const auto& order_key : window_expr->getOrderKeys()) {
2434  auto order_input = dynamic_cast<const RexInput*>(order_key.get());
2435  if (!order_input) {
2436  return true;
2437  }
2438  }
2439  }
2440  }
2441  return false;
2442  };
2443 
2444  const auto prev_node_itr = std::prev(node_itr);
2445  const auto prev_node = *prev_node_itr;
2446  CHECK(prev_node);
2447 
2448  auto filter_node = std::dynamic_pointer_cast<RelFilter>(prev_node);
2449  auto join_node = std::dynamic_pointer_cast<RelJoin>(prev_node);
2450 
2451  auto scan_node = std::dynamic_pointer_cast<RelScan>(prev_node);
2452  const bool has_multi_fragment_scan_input =
2453  (scan_node &&
2454  (scan_node->getNumShards() > 0 || scan_node->getNumFragments() > 1));
2455  const bool needs_expr_pushdown = need_pushdown_generic_expr();
2456 
2457  // We currently add a preceding project node in one of two conditions:
2458  // 1. always_add_project_if_first_project_is_window_expr = true, which
2459  // we currently only set for distributed, but could also be set to support
2460  // multi-frag window function inputs, either if we can detect that an input table
2461  // is multi-frag up front, or using a retry mechanism like we do for join filter
2462  // push down.
2463  // TODO(todd): Investigate a viable approach for the above.
2464  // 2. Regardless of #1, if the window function project node is preceded by a
2465  // filter node. This is required both for correctness and to avoid pulling
2466  // all source input columns into memory since non-coalesced filter node
2467  // inputs are currently not pruned or eliminated via dead column elimination.
2468  // Note that we expect any filter node followed by a project node to be coalesced
2469  // into a single compound node in RelAlgDagBuilder::coalesce_nodes, and that action
2470  // prunes unused inputs.
2471  // TODO(todd): Investigate whether the shotgun filter node issue affects other
2472  // query plans, i.e. filters before joins, and whether there is a more general
2473  // approach to solving this (will still need the preceding project node for
2474  // window functions preceded by filter nodes for correctness though)
2475  // 3. Similar to the above, when the window function project node is preceded
2476  // by a join node.
2477  // 4. when partition by / order by clauses have a general expression instead of
2478  // referencing column
2479 
2480  if (!((always_add_project_if_first_project_is_window_expr &&
2481  project_node_counter == 1) ||
2482  filter_node || join_node || has_multi_fragment_scan_input ||
2483  needs_expr_pushdown)) {
2484  continue;
2485  }
2486 
2487  if (needs_expr_pushdown || join_node) {
2488  // previous logic cannot cover join_node case well, so use the newly introduced
2489  // push-down expression logic to safely add pre_project node before processing
2490  // window function
2491  std::unordered_map<size_t, size_t> expr_offset_cache;
2492  std::vector<std::unique_ptr<const RexScalar>> scalar_exprs_for_new_project;
2493  std::vector<std::unique_ptr<const RexScalar>> scalar_exprs_for_window_project;
2494  std::vector<std::string> fields_for_window_project;
2495  std::vector<std::string> fields_for_new_project;
2496 
2497  // step 0. create new project node with an empty scalar expr to rebind target exprs
2498  std::vector<std::unique_ptr<const RexScalar>> dummy_scalar_exprs;
2499  std::vector<std::string> dummy_fields;
2500  auto new_project =
2501  std::make_shared<RelProject>(dummy_scalar_exprs, dummy_fields, prev_node);
2502 
2503  // step 1 - 2
2504  PushDownGenericExpressionInWindowFunction visitor(new_project,
2505  scalar_exprs_for_new_project,
2506  fields_for_new_project,
2507  expr_offset_cache);
2508  for (size_t i = 0; i < window_func_project_node->size(); ++i) {
2509  auto projected_target = window_func_project_node->getProjectAt(i);
2510  auto new_projection_target = visitor.visit(projected_target);
2511  scalar_exprs_for_window_project.emplace_back(
2512  std::move(new_projection_target.release()));
2513  }
2514  new_project->setExpressions(scalar_exprs_for_new_project);
2515  new_project->setFields(std::move(fields_for_new_project));
2516  bool has_groupby = false;
2517  auto aggregate = std::dynamic_pointer_cast<RelAggregate>(prev_node);
2518  if (aggregate) {
2519  has_groupby = aggregate->getGroupByCount() > 0;
2520  }
2521  if (has_groupby && visitor.hasPartitionExpression()) {
2522  // we currently may compute incorrect result with columnar output when
2523  // 1) the window function has partition expression, and
2524  // 2) a parent node of the window function projection node has group by expression
2525  // so we force rowwise output (only) for the newly injected projection node
2526  // to prevent computing incorrect query result
2527  // todo (yoonmin) : relax this
2528  VLOG(1)
2529  << "Query output overridden to row-wise format due to presence of a window "
2530  "function with partition expression and group-by expression.";
2531  new_project->forceRowwiseOutput();
2532  }
2533  if (visitor.hasCaseExprAsWindowOperand()) {
2534  // force rowwise output
2535  VLOG(1)
2536  << "Query output overridden to row-wise format due to presence of a window "
2537  "function with a case statement as its operand.";
2538  new_project->forceRowwiseOutput();
2539  }
2540 
2541  // step 3. finalize
2542  propagate_hints_to_new_project(window_func_project_node, new_project, query_hints);
2543  node_list.insert(node_itr, new_project);
2544  window_func_project_node->replaceInput(prev_node, new_project);
2545  window_func_project_node->setExpressions(scalar_exprs_for_window_project);
2546  } else {
2547  // only push rex_inputs listed in the window function down to a new project node
2548  RexInputSet inputs;
2549  RexInputCollector input_collector;
2550  for (size_t i = 0; i < window_func_project_node->size(); i++) {
2551  auto new_inputs =
2552  input_collector.visit(window_func_project_node->getProjectAt(i));
2553  inputs.insert(new_inputs.begin(), new_inputs.end());
2554  }
2555 
2556  // Note: Technically not required since we are mapping old inputs to new input
2557  // indices, but makes the re-mapping of inputs easier to follow.
2558  std::vector<RexInput> sorted_inputs(inputs.begin(), inputs.end());
2559  std::sort(sorted_inputs.begin(),
2560  sorted_inputs.end(),
2561  [](const auto& a, const auto& b) { return a.getIndex() < b.getIndex(); });
2562 
2563  std::vector<std::unique_ptr<const RexScalar>> scalar_exprs;
2564  std::vector<std::string> fields;
2565  std::unordered_map<unsigned, unsigned> old_index_to_new_index;
2566  for (auto& input : sorted_inputs) {
2567  CHECK_EQ(input.getSourceNode(), prev_node.get());
2568  CHECK(old_index_to_new_index
2569  .insert(std::make_pair(input.getIndex(), scalar_exprs.size()))
2570  .second);
2571  scalar_exprs.emplace_back(input.deepCopy());
2572  fields.emplace_back("");
2573  }
2574 
2575  auto new_project = std::make_shared<RelProject>(scalar_exprs, fields, prev_node);
2576  propagate_hints_to_new_project(window_func_project_node, new_project, query_hints);
2577  node_list.insert(node_itr, new_project);
2578  window_func_project_node->replaceInput(
2579  prev_node, new_project, old_index_to_new_index);
2580  }
2581  }
2582  nodes.assign(node_list.begin(), node_list.end());
2583 }
2584 
2585 int64_t get_int_literal_field(const rapidjson::Value& obj,
2586  const char field[],
2587  const int64_t default_val) noexcept {
2588  const auto it = obj.FindMember(field);
2589  if (it == obj.MemberEnd()) {
2590  return default_val;
2591  }
2592  std::unique_ptr<RexLiteral> lit(parse_literal(it->value));
2593  CHECK_EQ(kDECIMAL, lit->getType());
2594  CHECK_EQ(unsigned(0), lit->getScale());
2595  CHECK_EQ(unsigned(0), lit->getTargetScale());
2596  return lit->getVal<int64_t>();
2597 }
2598 
2599 void check_empty_inputs_field(const rapidjson::Value& node) noexcept {
2600  const auto& inputs_json = field(node, "inputs");
2601  CHECK(inputs_json.IsArray() && !inputs_json.Size());
2602 }
2603 
2605  const rapidjson::Value& scan_ra) {
2606  const auto& table_json = field(scan_ra, "table");
2607  CHECK(table_json.IsArray());
2608  CHECK_EQ(unsigned(2), table_json.Size());
2609  const auto td = cat.getMetadataForTable(table_json[1].GetString());
2610  CHECK(td);
2611  return td;
2612 }
2613 
2614 std::vector<std::string> getFieldNamesFromScanNode(const rapidjson::Value& scan_ra) {
2615  const auto& fields_json = field(scan_ra, "fieldNames");
2616  return strings_from_json_array(fields_json);
2617 }
2618 
2619 } // namespace
2620 
2622  for (const auto& expr : scalar_exprs_) {
2623  if (is_window_function_operator(expr.get())) {
2624  return true;
2625  }
2626  }
2627  return false;
2628 }
2629 namespace details {
2630 
2632  public:
2634 
2635  std::vector<std::shared_ptr<RelAlgNode>> run(const rapidjson::Value& rels,
2636  RelAlgDagBuilder& root_dag_builder) {
2637  for (auto rels_it = rels.Begin(); rels_it != rels.End(); ++rels_it) {
2638  const auto& crt_node = *rels_it;
2639  const auto id = node_id(crt_node);
2640  CHECK_EQ(static_cast<size_t>(id), nodes_.size());
2641  CHECK(crt_node.IsObject());
2642  std::shared_ptr<RelAlgNode> ra_node = nullptr;
2643  const auto rel_op = json_str(field(crt_node, "relOp"));
2644  if (rel_op == std::string("EnumerableTableScan") ||
2645  rel_op == std::string("LogicalTableScan")) {
2646  ra_node = dispatchTableScan(crt_node);
2647  } else if (rel_op == std::string("LogicalProject")) {
2648  ra_node = dispatchProject(crt_node, root_dag_builder);
2649  } else if (rel_op == std::string("LogicalFilter")) {
2650  ra_node = dispatchFilter(crt_node, root_dag_builder);
2651  } else if (rel_op == std::string("LogicalAggregate")) {
2652  ra_node = dispatchAggregate(crt_node);
2653  } else if (rel_op == std::string("LogicalJoin")) {
2654  ra_node = dispatchJoin(crt_node, root_dag_builder);
2655  } else if (rel_op == std::string("LogicalSort")) {
2656  ra_node = dispatchSort(crt_node);
2657  } else if (rel_op == std::string("LogicalValues")) {
2658  ra_node = dispatchLogicalValues(crt_node);
2659  } else if (rel_op == std::string("LogicalTableModify")) {
2660  ra_node = dispatchModify(crt_node);
2661  } else if (rel_op == std::string("LogicalTableFunctionScan")) {
2662  ra_node = dispatchTableFunction(crt_node, root_dag_builder);
2663  } else if (rel_op == std::string("LogicalUnion")) {
2664  ra_node = dispatchUnion(crt_node);
2665  } else {
2666  throw QueryNotSupported(std::string("Node ") + rel_op + " not supported yet");
2667  }
2668  nodes_.push_back(ra_node);
2669  }
2670 
2671  return std::move(nodes_);
2672  }
2673 
2674  private:
2675  std::shared_ptr<RelScan> dispatchTableScan(const rapidjson::Value& scan_ra) {
2676  check_empty_inputs_field(scan_ra);
2677  CHECK(scan_ra.IsObject());
2678  const auto td = getTableFromScanNode(cat_, scan_ra);
2679  const auto field_names = getFieldNamesFromScanNode(scan_ra);
2680  if (scan_ra.HasMember("hints")) {
2681  auto scan_node = std::make_shared<RelScan>(td, field_names);
2682  getRelAlgHints(scan_ra, scan_node);
2683  return scan_node;
2684  }
2685  return std::make_shared<RelScan>(td, field_names);
2686  }
2687 
2688  std::shared_ptr<RelProject> dispatchProject(const rapidjson::Value& proj_ra,
2689  RelAlgDagBuilder& root_dag_builder) {
2690  const auto inputs = getRelAlgInputs(proj_ra);
2691  CHECK_EQ(size_t(1), inputs.size());
2692  const auto& exprs_json = field(proj_ra, "exprs");
2693  CHECK(exprs_json.IsArray());
2694  std::vector<std::unique_ptr<const RexScalar>> exprs;
2695  for (auto exprs_json_it = exprs_json.Begin(); exprs_json_it != exprs_json.End();
2696  ++exprs_json_it) {
2697  exprs.emplace_back(parse_scalar_expr(*exprs_json_it, cat_, root_dag_builder));
2698  }
2699  const auto& fields = field(proj_ra, "fields");
2700  if (proj_ra.HasMember("hints")) {
2701  auto project_node = std::make_shared<RelProject>(
2702  exprs, strings_from_json_array(fields), inputs.front());
2703  getRelAlgHints(proj_ra, project_node);
2704  return project_node;
2705  }
2706  return std::make_shared<RelProject>(
2707  exprs, strings_from_json_array(fields), inputs.front());
2708  }
2709 
2710  std::shared_ptr<RelFilter> dispatchFilter(const rapidjson::Value& filter_ra,
2711  RelAlgDagBuilder& root_dag_builder) {
2712  const auto inputs = getRelAlgInputs(filter_ra);
2713  CHECK_EQ(size_t(1), inputs.size());
2714  const auto id = node_id(filter_ra);
2715  CHECK(id);
2716  auto condition =
2717  parse_scalar_expr(field(filter_ra, "condition"), cat_, root_dag_builder);
2718  return std::make_shared<RelFilter>(condition, inputs.front());
2719  }
2720 
2721  std::shared_ptr<RelAggregate> dispatchAggregate(const rapidjson::Value& agg_ra) {
2722  const auto inputs = getRelAlgInputs(agg_ra);
2723  CHECK_EQ(size_t(1), inputs.size());
2724  const auto fields = strings_from_json_array(field(agg_ra, "fields"));
2725  const auto group = indices_from_json_array(field(agg_ra, "group"));
2726  for (size_t i = 0; i < group.size(); ++i) {
2727  CHECK_EQ(i, group[i]);
2728  }
2729  if (agg_ra.HasMember("groups") || agg_ra.HasMember("indicator")) {
2730  throw QueryNotSupported("GROUP BY extensions not supported");
2731  }
2732  const auto& aggs_json_arr = field(agg_ra, "aggs");
2733  CHECK(aggs_json_arr.IsArray());
2734  std::vector<std::unique_ptr<const RexAgg>> aggs;
2735  for (auto aggs_json_arr_it = aggs_json_arr.Begin();
2736  aggs_json_arr_it != aggs_json_arr.End();
2737  ++aggs_json_arr_it) {
2738  aggs.emplace_back(parse_aggregate_expr(*aggs_json_arr_it));
2739  }
2740  if (agg_ra.HasMember("hints")) {
2741  auto agg_node =
2742  std::make_shared<RelAggregate>(group.size(), aggs, fields, inputs.front());
2743  getRelAlgHints(agg_ra, agg_node);
2744  return agg_node;
2745  }
2746  return std::make_shared<RelAggregate>(group.size(), aggs, fields, inputs.front());
2747  }
2748 
2749  std::shared_ptr<RelJoin> dispatchJoin(const rapidjson::Value& join_ra,
2750  RelAlgDagBuilder& root_dag_builder) {
2751  const auto inputs = getRelAlgInputs(join_ra);
2752  CHECK_EQ(size_t(2), inputs.size());
2753  const auto join_type = to_join_type(json_str(field(join_ra, "joinType")));
2754  auto filter_rex =
2755  parse_scalar_expr(field(join_ra, "condition"), cat_, root_dag_builder);
2756  if (join_ra.HasMember("hints")) {
2757  auto join_node =
2758  std::make_shared<RelJoin>(inputs[0], inputs[1], filter_rex, join_type);
2759  getRelAlgHints(join_ra, join_node);
2760  return join_node;
2761  }
2762  return std::make_shared<RelJoin>(inputs[0], inputs[1], filter_rex, join_type);
2763  }
2764 
2765  std::shared_ptr<RelSort> dispatchSort(const rapidjson::Value& sort_ra) {
2766  const auto inputs = getRelAlgInputs(sort_ra);
2767  CHECK_EQ(size_t(1), inputs.size());
2768  std::vector<SortField> collation;
2769  const auto& collation_arr = field(sort_ra, "collation");
2770  CHECK(collation_arr.IsArray());
2771  for (auto collation_arr_it = collation_arr.Begin();
2772  collation_arr_it != collation_arr.End();
2773  ++collation_arr_it) {
2774  const size_t field_idx = json_i64(field(*collation_arr_it, "field"));
2775  const auto sort_dir = parse_sort_direction(*collation_arr_it);
2776  const auto null_pos = parse_nulls_position(*collation_arr_it);
2777  collation.emplace_back(field_idx, sort_dir, null_pos);
2778  }
2779  auto limit = get_int_literal_field(sort_ra, "fetch", -1);
2780  const auto offset = get_int_literal_field(sort_ra, "offset", 0);
2781  auto ret = std::make_shared<RelSort>(
2782  collation, limit > 0 ? limit : 0, offset, inputs.front(), limit > 0);
2783  ret->setEmptyResult(limit == 0);
2784  return ret;
2785  }
2786 
2787  std::shared_ptr<RelModify> dispatchModify(const rapidjson::Value& logical_modify_ra) {
2788  const auto inputs = getRelAlgInputs(logical_modify_ra);
2789  CHECK_EQ(size_t(1), inputs.size());
2790 
2791  const auto table_descriptor = getTableFromScanNode(cat_, logical_modify_ra);
2792  if (table_descriptor->isView) {
2793  throw std::runtime_error("UPDATE of a view is unsupported.");
2794  }
2795 
2796  bool flattened = json_bool(field(logical_modify_ra, "flattened"));
2797  std::string op = json_str(field(logical_modify_ra, "operation"));
2798  RelModify::TargetColumnList target_column_list;
2799 
2800  if (op == "UPDATE") {
2801  const auto& update_columns = field(logical_modify_ra, "updateColumnList");
2802  CHECK(update_columns.IsArray());
2803 
2804  for (auto column_arr_it = update_columns.Begin();
2805  column_arr_it != update_columns.End();
2806  ++column_arr_it) {
2807  target_column_list.push_back(column_arr_it->GetString());
2808  }
2809  }
2810 
2811  auto modify_node = std::make_shared<RelModify>(
2812  cat_, table_descriptor, flattened, op, target_column_list, inputs[0]);
2813  switch (modify_node->getOperation()) {
2815  modify_node->applyDeleteModificationsToInputNode();
2816  break;
2817  }
2819  modify_node->applyUpdateModificationsToInputNode();
2820  break;
2821  }
2822  default:
2823  throw std::runtime_error("Unsupported RelModify operation: " +
2824  json_node_to_string(logical_modify_ra));
2825  }
2826 
2827  return modify_node;
2828  }
2829 
2830  std::shared_ptr<RelTableFunction> dispatchTableFunction(
2831  const rapidjson::Value& table_func_ra,
2832  RelAlgDagBuilder& root_dag_builder) {
2833  const auto inputs = getRelAlgInputs(table_func_ra);
2834  const auto& invocation = field(table_func_ra, "invocation");
2835  CHECK(invocation.IsObject());
2836 
2837  const auto& operands = field(invocation, "operands");
2838  CHECK(operands.IsArray());
2839  CHECK_GE(operands.Size(), unsigned(0));
2840 
2841  std::vector<const Rex*> col_inputs;
2842  std::vector<std::unique_ptr<const RexScalar>> table_func_inputs;
2843  std::vector<std::string> fields;
2844 
2845  for (auto exprs_json_it = operands.Begin(); exprs_json_it != operands.End();
2846  ++exprs_json_it) {
2847  const auto& expr_json = *exprs_json_it;
2848  CHECK(expr_json.IsObject());
2849  if (expr_json.HasMember("op")) {
2850  const auto op_str = json_str(field(expr_json, "op"));
2851  if (op_str == "CAST" && expr_json.HasMember("type")) {
2852  const auto& expr_type = field(expr_json, "type");
2853  CHECK(expr_type.IsObject());
2854  CHECK(expr_type.HasMember("type"));
2855  const auto& expr_type_name = json_str(field(expr_type, "type"));
2856  if (expr_type_name == "CURSOR") {
2857  CHECK(expr_json.HasMember("operands"));
2858  const auto& expr_operands = field(expr_json, "operands");
2859  CHECK(expr_operands.IsArray());
2860  if (expr_operands.Size() != 1) {
2861  throw std::runtime_error(
2862  "Table functions currently only support one ResultSet input");
2863  }
2864  auto pos = field(expr_operands[0], "input").GetInt();
2865  CHECK_LT(pos, inputs.size());
2866  for (size_t i = inputs[pos]->size(); i > 0; i--) {
2867  table_func_inputs.emplace_back(
2868  std::make_unique<RexAbstractInput>(col_inputs.size()));
2869  col_inputs.emplace_back(table_func_inputs.back().get());
2870  }
2871  continue;
2872  }
2873  }
2874  }
2875  table_func_inputs.emplace_back(
2876  parse_scalar_expr(*exprs_json_it, cat_, root_dag_builder));
2877  }
2878 
2879  const auto& op_name = field(invocation, "op");
2880  CHECK(op_name.IsString());
2881 
2882  std::vector<std::unique_ptr<const RexScalar>> table_function_projected_outputs;
2883  const auto& row_types = field(table_func_ra, "rowType");
2884  CHECK(row_types.IsArray());
2885  CHECK_GE(row_types.Size(), unsigned(0));
2886  const auto& row_types_array = row_types.GetArray();
2887  for (size_t i = 0; i < row_types_array.Size(); i++) {
2888  // We don't care about the type information in rowType -- replace each output with
2889  // a reference to be resolved later in the translator
2890  table_function_projected_outputs.emplace_back(std::make_unique<RexRef>(i));
2891  fields.emplace_back("");
2892  }
2893  return std::make_shared<RelTableFunction>(op_name.GetString(),
2894  inputs,
2895  fields,
2896  col_inputs,
2897  table_func_inputs,
2898  table_function_projected_outputs);
2899  }
2900 
2901  std::shared_ptr<RelLogicalValues> dispatchLogicalValues(
2902  const rapidjson::Value& logical_values_ra) {
2903  const auto& tuple_type_arr = field(logical_values_ra, "type");
2904  CHECK(tuple_type_arr.IsArray());
2905  std::vector<TargetMetaInfo> tuple_type;
2906  for (auto tuple_type_arr_it = tuple_type_arr.Begin();
2907  tuple_type_arr_it != tuple_type_arr.End();
2908  ++tuple_type_arr_it) {
2909  const auto component_type = parse_type(*tuple_type_arr_it);
2910  const auto component_name = json_str(field(*tuple_type_arr_it, "name"));
2911  tuple_type.emplace_back(component_name, component_type);
2912  }
2913  const auto& inputs_arr = field(logical_values_ra, "inputs");
2914  CHECK(inputs_arr.IsArray());
2915  const auto& tuples_arr = field(logical_values_ra, "tuples");
2916  CHECK(tuples_arr.IsArray());
2917 
2918  if (inputs_arr.Size()) {
2919  throw QueryNotSupported("Inputs not supported in logical values yet.");
2920  }
2921 
2922  std::vector<RelLogicalValues::RowValues> values;
2923  if (tuples_arr.Size()) {
2924  for (const auto& row : tuples_arr.GetArray()) {
2925  CHECK(row.IsArray());
2926  const auto values_json = row.GetArray();
2927  if (!values.empty()) {
2928  CHECK_EQ(values[0].size(), values_json.Size());
2929  }
2930  values.emplace_back(RelLogicalValues::RowValues{});
2931  for (const auto& value : values_json) {
2932  CHECK(value.IsObject());
2933  CHECK(value.HasMember("literal"));
2934  values.back().emplace_back(parse_literal(value));
2935  }
2936  }
2937  }
2938 
2939  return std::make_shared<RelLogicalValues>(tuple_type, values);
2940  }
2941 
2942  std::shared_ptr<RelLogicalUnion> dispatchUnion(
2943  const rapidjson::Value& logical_union_ra) {
2944  auto inputs = getRelAlgInputs(logical_union_ra);
2945  auto const& all_type_bool = field(logical_union_ra, "all");
2946  CHECK(all_type_bool.IsBool());
2947  return std::make_shared<RelLogicalUnion>(std::move(inputs), all_type_bool.GetBool());
2948  }
2949 
2950  RelAlgInputs getRelAlgInputs(const rapidjson::Value& node) {
2951  if (node.HasMember("inputs")) {
2952  const auto str_input_ids = strings_from_json_array(field(node, "inputs"));
2953  RelAlgInputs ra_inputs;
2954  for (const auto& str_id : str_input_ids) {
2955  ra_inputs.push_back(nodes_[std::stoi(str_id)]);
2956  }
2957  return ra_inputs;
2958  }
2959  return {prev(node)};
2960  }
2961 
2962  std::pair<std::string, std::string> getKVOptionPair(std::string& str, size_t& pos) {
2963  auto option = str.substr(0, pos);
2964  std::string delim = "=";
2965  size_t delim_pos = option.find(delim);
2966  auto key = option.substr(0, delim_pos);
2967  auto val = option.substr(delim_pos + 1, option.length());
2968  str.erase(0, pos + delim.length() + 1);
2969  return {key, val};
2970  }
2971 
2972  ExplainedQueryHint parseHintString(std::string& hint_string) {
2973  std::string white_space_delim = " ";
2974  int l = hint_string.length();
2975  hint_string = hint_string.erase(0, 1).substr(0, l - 2);
2976  size_t pos = 0;
2977  auto global_hint_checker = [&](const std::string& input_hint_name) -> HintIdentifier {
2978  bool global_hint = false;
2979  std::string hint_name = input_hint_name;
2980  auto global_hint_identifier = hint_name.substr(0, 2);
2981  if (global_hint_identifier.compare("g_") == 0) {
2982  global_hint = true;
2983  hint_name = hint_name.substr(2, hint_string.length());
2984  }
2985  return {global_hint, hint_name};
2986  };
2987  auto parsed_hint =
2988  global_hint_checker(hint_string.substr(0, hint_string.find(white_space_delim)));
2989  auto hint_type = RegisteredQueryHint::translateQueryHint(parsed_hint.hint_name);
2990  if ((pos = hint_string.find("options:")) != std::string::npos) {
2991  // need to parse hint options
2992  std::vector<std::string> tokens;
2993  bool kv_list_op = false;
2994  std::string raw_options = hint_string.substr(pos + 8, hint_string.length() - 2);
2995  if (raw_options.find('{') != std::string::npos) {
2996  kv_list_op = true;
2997  } else {
2998  CHECK(raw_options.find('[') != std::string::npos);
2999  }
3000  auto t1 = raw_options.erase(0, 1);
3001  raw_options = t1.substr(0, t1.length() - 1);
3002  std::string op_delim = ", ";
3003  if (kv_list_op) {
3004  // kv options
3005  std::unordered_map<std::string, std::string> kv_options;
3006  while ((pos = raw_options.find(op_delim)) != std::string::npos) {
3007  auto kv_pair = getKVOptionPair(raw_options, pos);
3008  kv_options.emplace(kv_pair.first, kv_pair.second);
3009  }
3010  // handle the last kv pair
3011  auto kv_pair = getKVOptionPair(raw_options, pos);
3012  kv_options.emplace(kv_pair.first, kv_pair.second);
3013  return {hint_type, parsed_hint.global_hint, false, true, kv_options};
3014  } else {
3015  std::vector<std::string> list_options;
3016  while ((pos = raw_options.find(op_delim)) != std::string::npos) {
3017  list_options.emplace_back(raw_options.substr(0, pos));
3018  raw_options.erase(0, pos + white_space_delim.length() + 1);
3019  }
3020  // handle the last option
3021  list_options.emplace_back(raw_options.substr(0, pos));
3022  return {hint_type, parsed_hint.global_hint, false, false, list_options};
3023  }
3024  } else {
3025  // marker hint: no extra option for this hint
3026  return {hint_type, parsed_hint.global_hint, true, false};
3027  }
3028  }
3029 
3030  void getRelAlgHints(const rapidjson::Value& json_node,
3031  std::shared_ptr<RelAlgNode> node) {
3032  std::string hint_explained = json_str(field(json_node, "hints"));
3033  size_t pos = 0;
3034  std::string delim = "|";
3035  std::vector<std::string> hint_list;
3036  while ((pos = hint_explained.find(delim)) != std::string::npos) {
3037  hint_list.emplace_back(hint_explained.substr(0, pos));
3038  hint_explained.erase(0, pos + delim.length());
3039  }
3040  // handling the last one
3041  hint_list.emplace_back(hint_explained.substr(0, pos));
3042 
3043  const auto agg_node = std::dynamic_pointer_cast<RelAggregate>(node);
3044  if (agg_node) {
3045  for (std::string& hint : hint_list) {
3046  auto parsed_hint = parseHintString(hint);
3047  agg_node->addHint(parsed_hint);
3048  }
3049  }
3050  const auto project_node = std::dynamic_pointer_cast<RelProject>(node);
3051  if (project_node) {
3052  for (std::string& hint : hint_list) {
3053  auto parsed_hint = parseHintString(hint);
3054  project_node->addHint(parsed_hint);
3055  }
3056  }
3057  const auto scan_node = std::dynamic_pointer_cast<RelScan>(node);
3058  if (scan_node) {
3059  for (std::string& hint : hint_list) {
3060  auto parsed_hint = parseHintString(hint);
3061  scan_node->addHint(parsed_hint);
3062  }
3063  }
3064  const auto join_node = std::dynamic_pointer_cast<RelJoin>(node);
3065  if (join_node) {
3066  for (std::string& hint : hint_list) {
3067  auto parsed_hint = parseHintString(hint);
3068  join_node->addHint(parsed_hint);
3069  }
3070  }
3071 
3072  const auto compound_node = std::dynamic_pointer_cast<RelCompound>(node);
3073  if (compound_node) {
3074  for (std::string& hint : hint_list) {
3075  auto parsed_hint = parseHintString(hint);
3076  compound_node->addHint(parsed_hint);
3077  }
3078  }
3079  }
3080 
3081  std::shared_ptr<const RelAlgNode> prev(const rapidjson::Value& crt_node) {
3082  const auto id = node_id(crt_node);
3083  CHECK(id);
3084  CHECK_EQ(static_cast<size_t>(id), nodes_.size());
3085  return nodes_.back();
3086  }
3087 
3089  std::vector<std::shared_ptr<RelAlgNode>> nodes_;
3090 };
3091 
3092 } // namespace details
3093 
3094 RelAlgDagBuilder::RelAlgDagBuilder(const std::string& query_ra,
3096  const RenderInfo* render_info)
3097  : cat_(cat), render_info_(render_info) {
3098  rapidjson::Document query_ast;
3099  query_ast.Parse(query_ra.c_str());
3100  VLOG(2) << "Parsing query RA JSON: " << query_ra;
3101  if (query_ast.HasParseError()) {
3102  query_ast.GetParseError();
3103  LOG(ERROR) << "Failed to parse RA tree from Calcite (offset "
3104  << query_ast.GetErrorOffset() << "):\n"
3105  << rapidjson::GetParseError_En(query_ast.GetParseError());
3106  VLOG(1) << "Failed to parse query RA: " << query_ra;
3107  throw std::runtime_error(
3108  "Failed to parse relational algebra tree. Possible query syntax error.");
3109  }
3110  CHECK(query_ast.IsObject());
3112  build(query_ast, *this);
3113 }
3114 
3116  const rapidjson::Value& query_ast,
3118  const RenderInfo* render_info)
3119  : cat_(cat), render_info_(render_info) {
3120  build(query_ast, root_dag_builder);
3121 }
3122 
3123 void RelAlgDagBuilder::build(const rapidjson::Value& query_ast,
3124  RelAlgDagBuilder& lead_dag_builder) {
3125  const auto& rels = field(query_ast, "rels");
3126  CHECK(rels.IsArray());
3127  try {
3128  nodes_ = details::RelAlgDispatcher(cat_).run(rels, lead_dag_builder);
3129  } catch (const QueryNotSupported&) {
3130  throw;
3131  }
3132  CHECK(!nodes_.empty());
3134 
3135  if (render_info_) {
3136  // Alter the RA for render. Do this before any flattening/optimizations are done to
3137  // the tree.
3139  }
3140 
3142  handle_query_hint(nodes_, this);
3143  mark_nops(nodes_);
3148  std::vector<const RelAlgNode*> filtered_left_deep_joins;
3149  std::vector<const RelAlgNode*> left_deep_joins;
3150  for (const auto& node : nodes_) {
3151  const auto left_deep_join_root = get_left_deep_join_root(node);
3152  // The filter which starts a left-deep join pattern must not be coalesced
3153  // since it contains (part of) the join condition.
3154  if (left_deep_join_root) {
3155  left_deep_joins.push_back(left_deep_join_root.get());
3156  if (std::dynamic_pointer_cast<const RelFilter>(left_deep_join_root)) {
3157  filtered_left_deep_joins.push_back(left_deep_join_root.get());
3158  }
3159  }
3160  }
3161  if (filtered_left_deep_joins.empty()) {
3163  }
3164  eliminate_dead_columns(nodes_);
3165  eliminate_dead_subqueries(subqueries_, nodes_.back().get());
3168  nodes_,
3169  g_cluster /* always_add_project_if_first_project_is_window_expr */,
3170  query_hint_);
3171  coalesce_nodes(nodes_, left_deep_joins, query_hint_);
3172  CHECK(nodes_.back().use_count() == 1);
3173  create_left_deep_join(nodes_);
3174 }
3175 
3177  std::function<void(RelAlgNode const*)> const& callback) const {
3178  for (auto const& node : nodes_) {
3179  if (node) {
3180  callback(node.get());
3181  }
3182  }
3183 }
3184 
3186  for (auto& node : nodes_) {
3187  if (node) {
3188  node->resetQueryExecutionState();
3189  }
3190  }
3191 }
3192 
3193 // Return tree with depth represented by indentations.
3194 std::string tree_string(const RelAlgNode* ra, const size_t depth) {
3195  std::string result = std::string(2 * depth, ' ') + ::toString(ra) + '\n';
3196  for (size_t i = 0; i < ra->inputCount(); ++i) {
3197  result += tree_string(ra->getInput(i), depth + 1);
3198  }
3199  return result;
3200 }
3201 
3202 std::string RexSubQuery::toString(RelRexToStringConfig config) const {
3203  return cat(::typeName(this), "(", ra_->toString(config), ")");
3204 }
3205 
3206 size_t RexSubQuery::toHash() const {
3207  if (!hash_) {
3208  hash_ = typeid(RexSubQuery).hash_code();
3209  boost::hash_combine(*hash_, ra_->toHash());
3210  }
3211  return *hash_;
3212 }
3213 
3214 std::string RexInput::toString(RelRexToStringConfig config) const {
3215  const auto scan_node = dynamic_cast<const RelScan*>(node_);
3216  if (scan_node) {
3217  auto field_name = scan_node->getFieldName(getIndex());
3218  auto table_name = scan_node->getTableDescriptor()->tableName;
3219  return ::typeName(this) + "(" + table_name + "." + field_name + ")";
3220  }
3221  auto node_id_in_plan = node_->getIdInPlanTree();
3222  auto node_id_str =
3223  node_id_in_plan ? std::to_string(*node_id_in_plan) : std::to_string(node_->getId());
3224  auto node_str = config.skip_input_nodes ? "(input_node_id=" + node_id_str
3225  : "(input_node=" + node_->toString(config);
3226  return cat(::typeName(this), node_str, ", in_index=", std::to_string(getIndex()), ")");
3227 }
3228 
3229 size_t RexInput::toHash() const {
3230  if (!hash_) {
3231  hash_ = typeid(RexInput).hash_code();
3232  boost::hash_combine(*hash_, node_->toHash());
3233  boost::hash_combine(*hash_, getIndex());
3234  }
3235  return *hash_;
3236 }
3237 
3238 std::string RelCompound::toString(RelRexToStringConfig config) const {
3239  auto ret = cat(::typeName(this),
3240  ", filter_expr=",
3241  (filter_expr_ ? filter_expr_->toString(config) : "null"),
3242  ", target_exprs=");
3243  for (auto& expr : target_exprs_) {
3244  ret += expr->toString(config) + " ";
3245  }
3246  ret += ", agg_exps=";
3247  for (auto& expr : agg_exprs_) {
3248  ret += expr->toString(config) + " ";
3249  }
3250  ret += ", scalar_sources=";
3251  for (auto& expr : scalar_sources_) {
3252  ret += expr->toString(config) + " ";
3253  }
3254  return cat(ret,
3255  ", ",
3257  ", ",
3258  ", fields=",
3259  ::toString(fields_),
3260  ", is_agg=",
3262 }
3263 
3264 size_t RelCompound::toHash() const {
3265  if (!hash_) {
3266  hash_ = typeid(RelCompound).hash_code();
3267  boost::hash_combine(*hash_, filter_expr_ ? filter_expr_->toHash() : HASH_N);
3268  boost::hash_combine(*hash_, is_agg_);
3269  for (auto& target_expr : target_exprs_) {
3270  if (auto rex_scalar = dynamic_cast<const RexScalar*>(target_expr)) {
3271  boost::hash_combine(*hash_, rex_scalar->toHash());
3272  }
3273  }
3274  for (auto& agg_expr : agg_exprs_) {
3275  boost::hash_combine(*hash_, agg_expr->toHash());
3276  }
3277  for (auto& scalar_source : scalar_sources_) {
3278  boost::hash_combine(*hash_, scalar_source->toHash());
3279  }
3280  boost::hash_combine(*hash_, groupby_count_);
3281  boost::hash_combine(*hash_, ::toString(fields_));
3282  }
3283  return *hash_;
3284 }
std::vector< std::shared_ptr< const RexScalar > > scalar_exprs_
DEVICE auto upper_bound(ARGS &&...args)
Definition: gpu_enabled.h:123
const size_t getGroupByCount() const
SQLTypes to_sql_type(const std::string &type_name)
std::shared_ptr< const RelAlgNode > getRootNodeShPtr() const
void coalesce_nodes(std::vector< std::shared_ptr< RelAlgNode >> &nodes, const std::vector< const RelAlgNode * > &left_deep_joins, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints)
bool is_agg(const Analyzer::Expr *expr)
std::unique_ptr< const RexScalar > condition_
SQLTypeInfo parse_type(const rapidjson::Value &type_obj)
const RexScalar * getThen(const size_t idx) const
std::vector< std::string > strings_from_json_array(const rapidjson::Value &json_str_arr) noexcept
std::unique_ptr< RexCase > parse_case(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
std::shared_ptr< RelAggregate > dispatchAggregate(const rapidjson::Value &agg_ra)
std::vector< std::unique_ptr< const RexScalar > > & scalar_exprs_for_new_project_
#define CHECK_EQ(x, y)
Definition: Logger.h:231
JoinType to_join_type(const std::string &join_type_name)
const Catalog_Namespace::Catalog & cat_
std::shared_ptr< RelProject > dispatchProject(const rapidjson::Value &proj_ra, RelAlgDagBuilder &root_dag_builder)
std::unique_ptr< const RexScalar > visitOperator(const RexOperator *rex_operator) const override
std::unique_ptr< RexSubQuery > deepCopy() const
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
JoinType
Definition: sqldefs.h:136
std::vector< const Rex * > remapTargetPointers(std::vector< std::unique_ptr< const RexAgg >> const &agg_exprs_new, std::vector< std::unique_ptr< const RexScalar >> const &scalar_sources_new, std::vector< std::unique_ptr< const RexAgg >> const &agg_exprs_old, std::vector< std::unique_ptr< const RexScalar >> const &scalar_sources_old, std::vector< const Rex * > const &target_exprs_old)
std::vector< std::unique_ptr< const RexScalar > > table_func_inputs_
std::string cat(Ts &&...args)
void bind_inputs(const std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
void hoist_filter_cond_to_cross_join(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
class for a per-database catalog. also includes metadata for the current database and the current use...
Definition: Catalog.h:114
Definition: sqltypes.h:49
std::unique_ptr< const RexScalar > visitRef(const RexRef *rex_ref) const override
void addHint(const ExplainedQueryHint &hint_explained)
std::shared_ptr< const RelAlgNode > get_left_deep_join_root(const std::shared_ptr< RelAlgNode > &node)
void sink_projected_boolean_expr_to_join(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
std::unique_ptr< const RexAgg > parse_aggregate_expr(const rapidjson::Value &expr)
std::string toString(RelRexToStringConfig config=RelRexToStringConfig::defaults()) const override
void eliminate_identical_copy(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
size_t toHash() const override
const RexScalar * getElse() const
RelCompound(std::unique_ptr< const RexScalar > &filter_expr, const std::vector< const Rex * > &target_exprs, const size_t groupby_count, const std::vector< const RexAgg * > &agg_exprs, const std::vector< std::string > &fields, std::vector< std::unique_ptr< const RexScalar >> &scalar_sources, const bool is_agg, bool update_disguised_as_select=false, bool delete_disguised_as_select=false, bool varlen_update_required=false, TableDescriptor const *manipulation_target_table=nullptr, ColumnNameList target_columns=ColumnNameList())
static thread_local unsigned crt_id_
void compute_node_hash(const std::vector< std::shared_ptr< RelAlgNode >> &nodes)
std::shared_ptr< RelScan > dispatchTableScan(const rapidjson::Value &scan_ra)
std::pair< std::shared_ptr< RelLeftDeepInnerJoin >, std::shared_ptr< const RelAlgNode > > create_left_deep_join(const std::shared_ptr< RelAlgNode > &left_deep_join_root)
RexScalar const * copyAndRedirectSource(RexScalar const *, size_t input_idx) const
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
SQLAgg to_agg_kind(const std::string &agg_name)
std::shared_ptr< RelLogicalUnion > dispatchUnion(const rapidjson::Value &logical_union_ra)
#define LOG(tag)
Definition: Logger.h:217
NullSortedPosition
std::vector< std::string > TargetColumnList
const SQLTypeInfo & getType() const
void * visitOperator(const RexOperator *rex_operator) const final
size_t size() const
Hints * getDeliveredHints()
const bool json_bool(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:49
const RexScalar * getOperand(const size_t idx) const
std::vector< const Rex * > col_inputs_
bool hasEquivCollationOf(const RelSort &that) const
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:44
void build(const rapidjson::Value &query_ast, RelAlgDagBuilder &root_dag_builder)
RexWindowFunctionOperator::RexWindowBound parse_window_bound(const rapidjson::Value &window_bound_obj, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
std::string join(T const &container, std::string const &delim)
void bind_table_func_to_input(RelTableFunction *table_func_node, const RANodeOutput &input) noexcept
std::unique_ptr< RexOperator > parse_operator(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
#define UNREACHABLE()
Definition: Logger.h:267
bool hint_applied_
DEVICE void sort(ARGS &&...args)
Definition: gpu_enabled.h:105
bool isRenamedInput(const RelAlgNode *node, const size_t index, const std::string &new_name)
#define CHECK_GE(x, y)
Definition: Logger.h:236
std::optional< size_t > getIdInPlanTree() const
std::vector< std::string > fields_
RexInput(const RelAlgNode *node, const unsigned in_index)
void addHint(const ExplainedQueryHint &hint_explained)
Definition: sqldefs.h:49
const RexScalar * getWhen(const size_t idx) const
void appendInput(std::string new_field_name, std::unique_ptr< const RexScalar > new_input)
RetType visitOperator(const RexOperator *rex_operator) const final
void addHint(const ExplainedQueryHint &hint_explained)
void bind_project_to_input(RelProject *project_node, const RANodeOutput &input) noexcept
const Catalog_Namespace::Catalog & cat_
std::vector< std::unique_ptr< const RexScalar > > parse_window_order_exprs(const rapidjson::Value &arr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
void handle_query_hint(const std::vector< std::shared_ptr< RelAlgNode >> &nodes, RelAlgDagBuilder *dag_builder) noexcept
void create_compound(std::vector< std::shared_ptr< RelAlgNode >> &nodes, const std::vector< size_t > &pattern, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints) noexcept
void checkForMatchingMetaInfoTypes() const
std::vector< std::unique_ptr< const RexScalar > > scalar_sources_
std::string to_string(char const *&&v)
RexWindowFuncReplacementVisitor(std::unordered_set< size_t > &collected_window_func_hash, std::vector< std::unique_ptr< const RexScalar >> &new_rex_input_for_window_func, std::unordered_map< size_t, size_t > &window_func_to_new_rex_input_idx_map, RelProject *new_project, std::unordered_map< size_t, std::unique_ptr< const RexInput >> &new_rex_input_from_child_node)
SqlWindowFunctionKind parse_window_function_kind(const std::string &name)
const std::string getFieldName(const size_t i) const
std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint > > query_hint_
PushDownGenericExpressionInWindowFunction(std::shared_ptr< RelProject > new_project, std::vector< std::unique_ptr< const RexScalar >> &scalar_exprs_for_new_project, std::vector< std::string > &fields_for_new_project, std::unordered_map< size_t, size_t > &expr_offset_cache)
void simplify_sort(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
std::vector< SortField > collation_
RexRebindInputsVisitor(const RelAlgNode *old_input, const RelAlgNode *new_input)
int64_t get_int_literal_field(const rapidjson::Value &obj, const char field[], const int64_t default_val) noexcept
constexpr double a
Definition: Utm.h:32
std::unique_ptr< const RexScalar > visitSubQuery(const RexSubQuery *rex_subquery) const override
std::vector< std::unique_ptr< const RexScalar > > parse_expr_array(const rapidjson::Value &arr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
This file contains the class specification and related data structures for Catalog.
virtual T visit(const RexScalar *rex_scalar) const
Definition: RexVisitor.h:27
RexInputSet visitInput(const RexInput *input) const override
std::vector< std::shared_ptr< RexSubQuery > > subqueries_
const RenderInfo * render_info_
std::string to_string() const
Definition: sqltypes.h:483
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
virtual std::string toString(RelRexToStringConfig config) const =0
unsigned getIndex() const
SQLOps getOperator() const
std::vector< std::shared_ptr< RelAlgNode > > nodes_
unsigned getId() const
const TableDescriptor * getTableFromScanNode(const Catalog_Namespace::Catalog &cat, const rapidjson::Value &scan_ra)
SortDirection parse_sort_direction(const rapidjson::Value &collation)
std::unique_ptr< const RexScalar > get_new_rex_input(size_t rex_idx) const
std::optional< size_t > getOffsetForPushedDownExpr(WindowExprType type, size_t expr_offset) const
std::unique_ptr< const RexScalar > disambiguate_rex(const RexScalar *, const RANodeOutput &)
static QueryHint translateQueryHint(const std::string &hint_name)
Definition: QueryHint.h:240
DEVICE auto copy(ARGS &&...args)
Definition: gpu_enabled.h:51
#define CHECK_NE(x, y)
Definition: Logger.h:232
bool isRenaming() const
std::vector< SortField > parse_window_order_collation(const rapidjson::Value &arr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
void setIndex(const unsigned in_index) const
Hints * getDeliveredHints()
size_t toHash() const override
std::vector< std::string > fields_
SQLOps to_sql_op(const std::string &op_str)
std::unique_ptr< Hints > hints_
void set_scale(int s)
Definition: sqltypes.h:434
const int64_t json_i64(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:39
void * visitInput(const RexInput *rex_input) const override
std::unique_ptr< Hints > hints_
std::vector< std::unique_ptr< const RexScalar > > scalar_exprs_
const double json_double(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:54
void addHint(const ExplainedQueryHint &hint_explained)
size_t branchCount() const
std::unordered_set< RexInput > RexInputSet
void propagate_hints_to_new_project(std::shared_ptr< RelProject > prev_node, std::shared_ptr< RelProject > new_node, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints)
const RelAlgNode * getInput(const size_t idx) const
RelFilter(std::unique_ptr< const RexScalar > &filter, std::shared_ptr< const RelAlgNode > input)
std::vector< std::shared_ptr< RelAlgNode > > nodes_
RelAggregate(const size_t groupby_count, std::vector< std::unique_ptr< const RexAgg >> &agg_exprs, const std::vector< std::string > &fields, std::shared_ptr< const RelAlgNode > input)
std::unique_ptr< const RexScalar > filter_
bool isSimple() const
RexInputReplacementVisitor(const RelAlgNode *node_to_keep, const std::vector< std::unique_ptr< const RexScalar >> &scalar_sources)
std::optional< size_t > hash_
const size_t groupby_count_
RexRebindReindexInputsVisitor(const RelAlgNode *old_input, const RelAlgNode *new_input, std::unordered_map< unsigned, unsigned > old_to_new_index_map)
std::optional< size_t > hash_
unsigned getId() const
const RelAlgNode * node_
std::unordered_map< size_t, const RexScalar * > & collected_window_func_
std::string toString(const Executor::ExtModuleKinds &kind)
Definition: Execute.h:1453
virtual void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input)
void eachNode(std::function< void(RelAlgNode const *)> const &) const
std::unordered_map< size_t, std::unique_ptr< const RexInput > > & new_rex_input_from_child_node_
NullSortedPosition parse_nulls_position(const rapidjson::Value &collation)
std::shared_ptr< RelFilter > dispatchFilter(const rapidjson::Value &filter_ra, RelAlgDagBuilder &root_dag_builder)
std::unique_ptr< RexAbstractInput > parse_abstract_input(const rapidjson::Value &expr) noexcept
std::unique_ptr< const RexScalar > visitCase(const RexCase *rex_case) const override
std::vector< std::unique_ptr< const RexAgg > > agg_exprs_
std::set< std::pair< const RelAlgNode *, int > > get_equiv_cols(const RelAlgNode *node, const size_t which_col)
Hints * getDeliveredHints()
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
std::unique_ptr< RexLiteral > parse_literal(const rapidjson::Value &expr)
size_t toHash() const override
SortDirection
const RexScalar * getProjectAt(const size_t idx) const
std::vector< std::unique_ptr< const RexAgg > > copyAggExprs(std::vector< std::unique_ptr< const RexAgg >> const &agg_exprs)
std::vector< std::shared_ptr< const RelAlgNode >> RelAlgInputs
#define CHECK_LT(x, y)
Definition: Logger.h:233
Definition: sqltypes.h:52
Definition: sqltypes.h:53
static RegisteredQueryHint defaults()
Definition: QueryHint.h:237
int32_t countRexLiteralArgs() const
std::unique_ptr< Hints > hints_
std::unique_ptr< const RexOperator > disambiguate_operator(const RexOperator *rex_operator, const RANodeOutput &ra_output) noexcept
const ConstRexScalarPtrVector & getPartitionKeys() const
std::string tree_string(const RelAlgNode *ra, const size_t depth)
DEVICE auto lower_bound(ARGS &&...args)
Definition: gpu_enabled.h:78
#define CHECK_LE(x, y)
Definition: Logger.h:234
std::unique_ptr< Hints > hints_
const std::vector< const Rex * > target_exprs_
std::vector< ElementType >::const_iterator Super
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
RelLogicalUnion(RelAlgInputs, bool is_all)
std::vector< std::unique_ptr< const RexAgg > > agg_exprs_
std::unique_ptr< const RexCase > disambiguate_case(const RexCase *rex_case, const RANodeOutput &ra_output)
std::shared_ptr< RelJoin > dispatchJoin(const rapidjson::Value &join_ra, RelAlgDagBuilder &root_dag_builder)
std::unique_ptr< const RexScalar > filter_expr_
void setSourceNode(const RelAlgNode *node) const
bool hasWindowFunctionExpr() const
std::shared_ptr< RelModify > dispatchModify(const rapidjson::Value &logical_modify_ra)
std::unordered_map< QueryHint, ExplainedQueryHint > Hints
Definition: QueryHint.h:263
virtual size_t size() const =0
const RelAlgNode * getSourceNode() const
void registerSubquery(std::shared_ptr< RexSubQuery > subquery)
void setExecutionResult(const std::shared_ptr< const ExecutionResult > result)
std::string toString(RelRexToStringConfig config=RelRexToStringConfig::defaults()) const override
RelLogicalValues(const std::vector< TargetMetaInfo > &tuple_type, std::vector< RowValues > &values)
size_t toHash() const override
std::unique_ptr< const RexScalar > visitLiteral(const RexLiteral *rex_literal) const override
std::string typeName(const T *v)
Definition: toString.h:102
ExplainedQueryHint parseHintString(std::string &hint_string)
RelAlgDagBuilder()=delete
SqlWindowFunctionKind
Definition: sqldefs.h:111
WindowFunctionCollector(std::unordered_map< size_t, const RexScalar * > &collected_window_func)
void mark_nops(const std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
RexInputSet aggregateResult(const RexInputSet &aggregate, const RexInputSet &next_result) const override
Definition: sqldefs.h:53
std::string toString(RelRexToStringConfig config=RelRexToStringConfig::defaults()) const override
static auto const HASH_N
std::shared_ptr< RelSort > dispatchSort(const rapidjson::Value &sort_ra)
void separate_window_function_expressions(std::vector< std::shared_ptr< RelAlgNode >> &nodes, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints)
RelTableFunction(const std::string &function_name, RelAlgInputs inputs, std::vector< std::string > &fields, std::vector< const Rex * > col_inputs, std::vector< std::unique_ptr< const RexScalar >> &table_func_inputs, std::vector< std::unique_ptr< const RexScalar >> &target_exprs)
const std::vector< std::string > & getFields() const
std::string getFieldName(const size_t i) const
std::vector< size_t > indices_from_json_array(const rapidjson::Value &json_idx_arr) noexcept
bool g_enable_watchdog false
Definition: Execute.cpp:79
#define CHECK(condition)
Definition: Logger.h:223
std::unique_ptr< const RexScalar > parse_scalar_expr(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
const ConstRexScalarPtrVector & getOrderKeys() const
RelProject(std::vector< std::unique_ptr< const RexScalar >> &scalar_exprs, const std::vector< std::string > &fields, std::shared_ptr< const RelAlgNode > input)
unsigned node_id(const rapidjson::Value &ra_node) noexcept
RANodeOutput get_node_output(const RelAlgNode *ra_node)
bool g_enable_union
std::vector< std::string > getFieldNamesFromScanNode(const rapidjson::Value &scan_ra)
bool g_cluster
void alterRAForRender(std::vector< std::shared_ptr< RelAlgNode >> &nodes, const RenderInfo &render_info)
std::shared_ptr< const RelAlgNode > prev(const rapidjson::Value &crt_node)
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
std::vector< std::shared_ptr< RelAlgNode > > run(const rapidjson::Value &rels, RelAlgDagBuilder &root_dag_builder)
void getRelAlgHints(const rapidjson::Value &json_node, std::shared_ptr< RelAlgNode > node)
virtual size_t toHash() const =0
RelAlgDispatcher(const Catalog_Namespace::Catalog &cat)
void add_window_function_pre_project(std::vector< std::shared_ptr< RelAlgNode >> &nodes, const bool always_add_project_if_first_project_is_window_expr, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints)
Common Enum definitions for SQL processing.
bool is_dict_encoded_string() const
Definition: sqltypes.h:548
Definition: sqltypes.h:45
std::string toString(RelRexToStringConfig config=RelRexToStringConfig::defaults()) const override
void fold_filters(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
const TableDescriptor * getMetadataForTable(const std::string &tableName, const bool populateFragmenter=true) const
Returns a pointer to a const TableDescriptor struct matching the provided tableName.
std::vector< RexInput > RANodeOutput
void pushDownExpressionInWindowFunction(const RexWindowFunctionOperator *window_expr) const
std::vector< std::unique_ptr< const RexScalar >> RowValues
const size_t inputCount() const
string name
Definition: setup.in.py:72
constexpr double n
Definition: Utm.h:38
void rebind_inputs_from_left_deep_join(const RexScalar *rex, const RelLeftDeepInnerJoin *left_deep_join)
std::unique_ptr< const RexSubQuery > parse_subquery(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
void eliminate_dead_subqueries(std::vector< std::shared_ptr< RexSubQuery >> &subqueries, RelAlgNode const *root)
size_t size() const override
size_t operator()(const std::pair< const RelAlgNode *, int > &input_col) const
bool input_can_be_coalesced(const RelAlgNode *parent_node, const size_t index, const bool first_rex_is_input)
std::vector< std::unique_ptr< const RexScalar > > & new_rex_input_for_window_func_
RelAlgInputs getRelAlgInputs(const rapidjson::Value &node)
std::shared_ptr< RelLogicalValues > dispatchLogicalValues(const rapidjson::Value &logical_values_ra)
std::shared_ptr< RelTableFunction > dispatchTableFunction(const rapidjson::Value &table_func_ra, RelAlgDagBuilder &root_dag_builder)
DEVICE void swap(ARGS &&...args)
Definition: gpu_enabled.h:114
std::unique_ptr< const RexScalar > RetType
Definition: RexVisitor.h:140
std::vector< std::unique_ptr< const RexScalar > > copyRexScalars(std::vector< std::unique_ptr< const RexScalar >> const &scalar_sources)
size_t toHash() const override
virtual size_t toHash() const =0
#define VLOG(n)
Definition: Logger.h:317
std::optional< size_t > is_collected_window_function(size_t rex_hash) const
RelJoin(std::shared_ptr< const RelAlgNode > lhs, std::shared_ptr< const RelAlgNode > rhs, std::unique_ptr< const RexScalar > &condition, const JoinType join_type)
RelAlgInputs inputs_
void set_precision(int d)
Definition: sqltypes.h:432
std::pair< std::string, std::string > getKVOptionPair(std::string &str, size_t &pos)
void eliminate_dead_columns(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
void check_empty_inputs_field(const rapidjson::Value &node) noexcept
std::vector< const Rex * > reproject_targets(const RelProject *simple_project, const std::vector< const Rex * > &target_exprs) noexcept
bool isIdentity() const
std::vector< std::unique_ptr< const RexScalar > > target_exprs_
std::vector< RexInput > n_outputs(const RelAlgNode *node, const size_t n)
const bool is_agg_
std::unique_ptr< const RexScalar > visitInput(const RexInput *rex_input) const override
static void resetRelAlgFirstId() noexcept
std::string json_node_to_string(const rapidjson::Value &node) noexcept