OmniSciDB  d2f719934e
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
RelAlgDagBuilder.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RelAlgDagBuilder.h"
19 #include "Catalog/Catalog.h"
21 #include "JsonAccessors.h"
22 #include "RelAlgOptimizer.h"
23 #include "RelLeftDeepInnerJoin.h"
25 #include "RexVisitor.h"
26 #include "Shared/sqldefs.h"
27 
28 #include <rapidjson/error/en.h>
29 #include <rapidjson/error/error.h>
30 #include <rapidjson/stringbuffer.h>
31 #include <rapidjson/writer.h>
32 
33 #include <string>
34 #include <unordered_set>
35 
36 extern bool g_cluster;
37 extern bool g_enable_union;
38 
39 namespace {
40 
41 const unsigned FIRST_RA_NODE_ID = 1;
42 
43 } // namespace
44 
45 thread_local unsigned RelAlgNode::crt_id_ = FIRST_RA_NODE_ID;
46 
49 }
50 
52  const std::shared_ptr<const ExecutionResult> result) {
53  auto row_set = result->getRows();
54  CHECK(row_set);
55  CHECK_EQ(size_t(1), row_set->colCount());
56  *(type_.get()) = row_set->getColType(0);
57  (*(result_.get())) = result;
58 }
59 
60 std::unique_ptr<RexSubQuery> RexSubQuery::deepCopy() const {
61  return std::make_unique<RexSubQuery>(type_, result_, ra_->deepCopy());
62 }
63 
64 unsigned RexSubQuery::getId() const {
65  return ra_->getId();
66 }
67 
68 namespace {
69 
70 class RexRebindInputsVisitor : public RexVisitor<void*> {
71  public:
72  RexRebindInputsVisitor(const RelAlgNode* old_input, const RelAlgNode* new_input)
73  : old_input_(old_input), new_input_(new_input) {}
74 
75  virtual ~RexRebindInputsVisitor() = default;
76 
77  void* visitInput(const RexInput* rex_input) const override {
78  const auto old_source = rex_input->getSourceNode();
79  if (old_source == old_input_) {
80  const auto left_deep_join = dynamic_cast<const RelLeftDeepInnerJoin*>(new_input_);
81  if (left_deep_join) {
82  rebind_inputs_from_left_deep_join(rex_input, left_deep_join);
83  return nullptr;
84  }
85  rex_input->setSourceNode(new_input_);
86  }
87  return nullptr;
88  };
89 
90  private:
91  const RelAlgNode* old_input_;
93 };
94 
95 // Creates an output with n columns.
96 std::vector<RexInput> n_outputs(const RelAlgNode* node, const size_t n) {
97  std::vector<RexInput> outputs;
98  outputs.reserve(n);
99  for (size_t i = 0; i < n; ++i) {
100  outputs.emplace_back(node, i);
101  }
102  return outputs;
103 }
104 
106  public:
108  const RelAlgNode* old_input,
109  const RelAlgNode* new_input,
110  std::unordered_map<unsigned, unsigned> old_to_new_index_map)
111  : RexRebindInputsVisitor(old_input, new_input), mapping_(old_to_new_index_map) {}
112 
113  void* visitInput(const RexInput* rex_input) const override {
114  RexRebindInputsVisitor::visitInput(rex_input);
115  auto mapping_itr = mapping_.find(rex_input->getIndex());
116  CHECK(mapping_itr != mapping_.end());
117  rex_input->setIndex(mapping_itr->second);
118  return nullptr;
119  }
120 
121  private:
122  const std::unordered_map<unsigned, unsigned> mapping_;
123 };
124 
125 } // namespace
126 
128  std::shared_ptr<const RelAlgNode> old_input,
129  std::shared_ptr<const RelAlgNode> input,
130  std::optional<std::unordered_map<unsigned, unsigned>> old_to_new_index_map) {
131  RelAlgNode::replaceInput(old_input, input);
132  std::unique_ptr<RexRebindInputsVisitor> rebind_inputs;
133  if (old_to_new_index_map) {
134  rebind_inputs = std::make_unique<RexRebindReindexInputsVisitor>(
135  old_input.get(), input.get(), *old_to_new_index_map);
136  } else {
137  rebind_inputs =
138  std::make_unique<RexRebindInputsVisitor>(old_input.get(), input.get());
139  }
140  CHECK(rebind_inputs);
141  for (const auto& scalar_expr : scalar_exprs_) {
142  rebind_inputs->visit(scalar_expr.get());
143  }
144 }
145 
146 void RelProject::appendInput(std::string new_field_name,
147  std::unique_ptr<const RexScalar> new_input) {
148  fields_.emplace_back(std::move(new_field_name));
149  scalar_exprs_.emplace_back(std::move(new_input));
150 }
151 
153  const auto scan_node = dynamic_cast<const RelScan*>(ra_node);
154  if (scan_node) {
155  // Scan node has no inputs, output contains all columns in the table.
156  CHECK_EQ(size_t(0), scan_node->inputCount());
157  return n_outputs(scan_node, scan_node->size());
158  }
159  const auto project_node = dynamic_cast<const RelProject*>(ra_node);
160  if (project_node) {
161  // Project output count doesn't depend on the input
162  CHECK_EQ(size_t(1), project_node->inputCount());
163  return n_outputs(project_node, project_node->size());
164  }
165  const auto filter_node = dynamic_cast<const RelFilter*>(ra_node);
166  if (filter_node) {
167  // Filter preserves shape
168  CHECK_EQ(size_t(1), filter_node->inputCount());
169  const auto prev_out = get_node_output(filter_node->getInput(0));
170  return n_outputs(filter_node, prev_out.size());
171  }
172  const auto aggregate_node = dynamic_cast<const RelAggregate*>(ra_node);
173  if (aggregate_node) {
174  // Aggregate output count doesn't depend on the input
175  CHECK_EQ(size_t(1), aggregate_node->inputCount());
176  return n_outputs(aggregate_node, aggregate_node->size());
177  }
178  const auto compound_node = dynamic_cast<const RelCompound*>(ra_node);
179  if (compound_node) {
180  // Compound output count doesn't depend on the input
181  CHECK_EQ(size_t(1), compound_node->inputCount());
182  return n_outputs(compound_node, compound_node->size());
183  }
184  const auto join_node = dynamic_cast<const RelJoin*>(ra_node);
185  if (join_node) {
186  // Join concatenates the outputs from the inputs and the output
187  // directly references the nodes in the input.
188  CHECK_EQ(size_t(2), join_node->inputCount());
189  auto lhs_out =
190  n_outputs(join_node->getInput(0), get_node_output(join_node->getInput(0)).size());
191  const auto rhs_out =
192  n_outputs(join_node->getInput(1), get_node_output(join_node->getInput(1)).size());
193  lhs_out.insert(lhs_out.end(), rhs_out.begin(), rhs_out.end());
194  return lhs_out;
195  }
196  const auto table_func_node = dynamic_cast<const RelTableFunction*>(ra_node);
197  if (table_func_node) {
198  // Table Function output count doesn't depend on the input
199  return n_outputs(table_func_node, table_func_node->size());
200  }
201  const auto sort_node = dynamic_cast<const RelSort*>(ra_node);
202  if (sort_node) {
203  // Sort preserves shape
204  CHECK_EQ(size_t(1), sort_node->inputCount());
205  const auto prev_out = get_node_output(sort_node->getInput(0));
206  return n_outputs(sort_node, prev_out.size());
207  }
208  const auto logical_values_node = dynamic_cast<const RelLogicalValues*>(ra_node);
209  if (logical_values_node) {
210  CHECK_EQ(size_t(0), logical_values_node->inputCount());
211  return n_outputs(logical_values_node, logical_values_node->size());
212  }
213  const auto logical_union_node = dynamic_cast<const RelLogicalUnion*>(ra_node);
214  if (logical_union_node) {
215  return n_outputs(logical_union_node, logical_union_node->size());
216  }
217  LOG(FATAL) << "Unhandled ra_node type: " << ::toString(ra_node);
218  return {};
219 }
220 
222  if (!isSimple()) {
223  return false;
224  }
225  CHECK_EQ(size_t(1), inputCount());
226  const auto source = getInput(0);
227  if (dynamic_cast<const RelJoin*>(source)) {
228  return false;
229  }
230  const auto source_shape = get_node_output(source);
231  if (source_shape.size() != scalar_exprs_.size()) {
232  return false;
233  }
234  for (size_t i = 0; i < scalar_exprs_.size(); ++i) {
235  const auto& scalar_expr = scalar_exprs_[i];
236  const auto input = dynamic_cast<const RexInput*>(scalar_expr.get());
237  CHECK(input);
238  CHECK_EQ(source, input->getSourceNode());
239  // We should add the additional check that input->getIndex() !=
240  // source_shape[i].getIndex(), but Calcite doesn't generate the right
241  // Sort-Project-Sort sequence when joins are involved.
242  if (input->getSourceNode() != source_shape[i].getSourceNode()) {
243  return false;
244  }
245  }
246  return true;
247 }
248 
249 namespace {
250 
251 bool isRenamedInput(const RelAlgNode* node,
252  const size_t index,
253  const std::string& new_name) {
254  CHECK_LT(index, node->size());
255  if (auto join = dynamic_cast<const RelJoin*>(node)) {
256  CHECK_EQ(size_t(2), join->inputCount());
257  const auto lhs_size = join->getInput(0)->size();
258  if (index < lhs_size) {
259  return isRenamedInput(join->getInput(0), index, new_name);
260  }
261  CHECK_GE(index, lhs_size);
262  return isRenamedInput(join->getInput(1), index - lhs_size, new_name);
263  }
264 
265  if (auto scan = dynamic_cast<const RelScan*>(node)) {
266  return new_name != scan->getFieldName(index);
267  }
268 
269  if (auto aggregate = dynamic_cast<const RelAggregate*>(node)) {
270  return new_name != aggregate->getFieldName(index);
271  }
272 
273  if (auto project = dynamic_cast<const RelProject*>(node)) {
274  return new_name != project->getFieldName(index);
275  }
276 
277  if (auto table_func = dynamic_cast<const RelTableFunction*>(node)) {
278  return new_name != table_func->getFieldName(index);
279  }
280 
281  if (auto logical_values = dynamic_cast<const RelLogicalValues*>(node)) {
282  const auto& tuple_type = logical_values->getTupleType();
283  CHECK_LT(index, tuple_type.size());
284  return new_name != tuple_type[index].get_resname();
285  }
286 
287  CHECK(dynamic_cast<const RelSort*>(node) || dynamic_cast<const RelFilter*>(node) ||
288  dynamic_cast<const RelLogicalUnion*>(node));
289  return isRenamedInput(node->getInput(0), index, new_name);
290 }
291 
292 } // namespace
293 
295  if (!isSimple()) {
296  return false;
297  }
298  CHECK_EQ(scalar_exprs_.size(), fields_.size());
299  for (size_t i = 0; i < fields_.size(); ++i) {
300  auto rex_in = dynamic_cast<const RexInput*>(scalar_exprs_[i].get());
301  CHECK(rex_in);
302  if (isRenamedInput(rex_in->getSourceNode(), rex_in->getIndex(), fields_[i])) {
303  return true;
304  }
305  }
306  return false;
307 }
308 
309 void RelJoin::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
310  std::shared_ptr<const RelAlgNode> input) {
311  RelAlgNode::replaceInput(old_input, input);
312  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
313  if (condition_) {
314  rebind_inputs.visit(condition_.get());
315  }
316 }
317 
318 void RelFilter::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
319  std::shared_ptr<const RelAlgNode> input) {
320  RelAlgNode::replaceInput(old_input, input);
321  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
322  rebind_inputs.visit(filter_.get());
323 }
324 
325 void RelCompound::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
326  std::shared_ptr<const RelAlgNode> input) {
327  RelAlgNode::replaceInput(old_input, input);
328  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
329  for (const auto& scalar_source : scalar_sources_) {
330  rebind_inputs.visit(scalar_source.get());
331  }
332  if (filter_expr_) {
333  rebind_inputs.visit(filter_expr_.get());
334  }
335 }
336 
338  : RelAlgNode(rhs)
340  , fields_(rhs.fields_)
341  , hint_applied_(false)
342  , hints_(std::make_unique<Hints>()) {
343  RexDeepCopyVisitor copier;
344  for (auto const& expr : rhs.scalar_exprs_) {
345  scalar_exprs_.push_back(copier.visit(expr.get()));
346  }
347  if (rhs.hint_applied_) {
348  for (auto const& kv : *rhs.hints_) {
349  addHint(kv.second);
350  }
351  }
352 }
353 
355  : RelAlgNode(rhs)
356  , tuple_type_(rhs.tuple_type_)
357  , values_(RexDeepCopyVisitor::copy(rhs.values_)) {}
358 
360  RexDeepCopyVisitor copier;
361  filter_ = copier.visit(rhs.filter_.get());
362 }
363 
365  : RelAlgNode(rhs)
366  , groupby_count_(rhs.groupby_count_)
367  , fields_(rhs.fields_)
368  , hint_applied_(false)
369  , hints_(std::make_unique<Hints>()) {
370  agg_exprs_.reserve(rhs.agg_exprs_.size());
371  for (auto const& agg : rhs.agg_exprs_) {
372  agg_exprs_.push_back(agg->deepCopy());
373  }
374  if (rhs.hint_applied_) {
375  for (auto const& kv : *rhs.hints_) {
376  addHint(kv.second);
377  }
378  }
379 }
380 
382  : RelAlgNode(rhs)
383  , join_type_(rhs.join_type_)
384  , hint_applied_(false)
385  , hints_(std::make_unique<Hints>()) {
386  RexDeepCopyVisitor copier;
387  condition_ = copier.visit(rhs.condition_.get());
388  if (rhs.hint_applied_) {
389  for (auto const& kv : *rhs.hints_) {
390  addHint(kv.second);
391  }
392  }
393 }
394 
395 namespace {
396 
397 std::vector<std::unique_ptr<const RexAgg>> copyAggExprs(
398  std::vector<std::unique_ptr<const RexAgg>> const& agg_exprs) {
399  std::vector<std::unique_ptr<const RexAgg>> agg_exprs_copy;
400  agg_exprs_copy.reserve(agg_exprs.size());
401  for (auto const& agg_expr : agg_exprs) {
402  agg_exprs_copy.push_back(agg_expr->deepCopy());
403  }
404  return agg_exprs_copy;
405 }
406 
407 std::vector<std::unique_ptr<const RexScalar>> copyRexScalars(
408  std::vector<std::unique_ptr<const RexScalar>> const& scalar_sources) {
409  std::vector<std::unique_ptr<const RexScalar>> scalar_sources_copy;
410  scalar_sources_copy.reserve(scalar_sources.size());
411  RexDeepCopyVisitor copier;
412  for (auto const& scalar_source : scalar_sources) {
413  scalar_sources_copy.push_back(copier.visit(scalar_source.get()));
414  }
415  return scalar_sources_copy;
416 }
417 
418 std::vector<const Rex*> remapTargetPointers(
419  std::vector<std::unique_ptr<const RexAgg>> const& agg_exprs_new,
420  std::vector<std::unique_ptr<const RexScalar>> const& scalar_sources_new,
421  std::vector<std::unique_ptr<const RexAgg>> const& agg_exprs_old,
422  std::vector<std::unique_ptr<const RexScalar>> const& scalar_sources_old,
423  std::vector<const Rex*> const& target_exprs_old) {
424  std::vector<const Rex*> target_exprs(target_exprs_old);
425  std::unordered_map<const Rex*, const Rex*> old_to_new_target(target_exprs.size());
426  for (size_t i = 0; i < agg_exprs_new.size(); ++i) {
427  old_to_new_target.emplace(agg_exprs_old[i].get(), agg_exprs_new[i].get());
428  }
429  for (size_t i = 0; i < scalar_sources_new.size(); ++i) {
430  old_to_new_target.emplace(scalar_sources_old[i].get(), scalar_sources_new[i].get());
431  }
432  for (auto& target : target_exprs) {
433  auto target_it = old_to_new_target.find(target);
434  CHECK(target_it != old_to_new_target.end());
435  target = target_it->second;
436  }
437  return target_exprs;
438 }
439 
440 } // namespace
441 
443  : RelAlgNode(rhs)
445  , groupby_count_(rhs.groupby_count_)
446  , agg_exprs_(copyAggExprs(rhs.agg_exprs_))
447  , fields_(rhs.fields_)
448  , is_agg_(rhs.is_agg_)
449  , scalar_sources_(copyRexScalars(rhs.scalar_sources_))
450  , target_exprs_(remapTargetPointers(agg_exprs_,
451  scalar_sources_,
452  rhs.agg_exprs_,
453  rhs.scalar_sources_,
454  rhs.target_exprs_))
455  , hint_applied_(false)
456  , hints_(std::make_unique<Hints>()) {
457  RexDeepCopyVisitor copier;
458  filter_expr_ = rhs.filter_expr_ ? copier.visit(rhs.filter_expr_.get()) : nullptr;
459  if (rhs.hint_applied_) {
460  for (auto const& kv : *rhs.hints_) {
461  addHint(kv.second);
462  }
463  }
464 }
465 
466 void RelTableFunction::replaceInput(std::shared_ptr<const RelAlgNode> old_input,
467  std::shared_ptr<const RelAlgNode> input) {
468  RelAlgNode::replaceInput(old_input, input);
469  RexRebindInputsVisitor rebind_inputs(old_input.get(), input.get());
470  for (const auto& target_expr : target_exprs_) {
471  rebind_inputs.visit(target_expr.get());
472  }
473  for (const auto& func_input : table_func_inputs_) {
474  rebind_inputs.visit(func_input.get());
475  }
476 }
477 
479  int32_t literal_args = 0;
480  for (const auto& arg : table_func_inputs_) {
481  const auto rex_literal = dynamic_cast<const RexLiteral*>(arg.get());
482  if (rex_literal) {
483  literal_args += 1;
484  }
485  }
486  return literal_args;
487 }
488 
490  : RelAlgNode(rhs)
491  , function_name_(rhs.function_name_)
492  , fields_(rhs.fields_)
493  , col_inputs_(rhs.col_inputs_)
494  , table_func_inputs_(copyRexScalars(rhs.table_func_inputs_))
495  , target_exprs_(copyRexScalars(rhs.target_exprs_)) {
496  std::unordered_map<const Rex*, const Rex*> old_to_new_input;
497  for (size_t i = 0; i < table_func_inputs_.size(); ++i) {
498  old_to_new_input.emplace(rhs.table_func_inputs_[i].get(),
499  table_func_inputs_[i].get());
500  }
501  for (auto& target : col_inputs_) {
502  auto target_it = old_to_new_input.find(target);
503  CHECK(target_it != old_to_new_input.end());
504  target = target_it->second;
505  }
506 }
507 
508 namespace std {
509 template <>
510 struct hash<std::pair<const RelAlgNode*, int>> {
511  size_t operator()(const std::pair<const RelAlgNode*, int>& input_col) const {
512  auto ptr_val = reinterpret_cast<const int64_t*>(&input_col.first);
513  return static_cast<int64_t>(*ptr_val) ^ input_col.second;
514  }
515 };
516 } // namespace std
517 
518 namespace {
519 
520 std::set<std::pair<const RelAlgNode*, int>> get_equiv_cols(const RelAlgNode* node,
521  const size_t which_col) {
522  std::set<std::pair<const RelAlgNode*, int>> work_set;
523  auto walker = node;
524  auto curr_col = which_col;
525  while (true) {
526  work_set.insert(std::make_pair(walker, curr_col));
527  if (dynamic_cast<const RelScan*>(walker) || dynamic_cast<const RelJoin*>(walker)) {
528  break;
529  }
530  CHECK_EQ(size_t(1), walker->inputCount());
531  auto only_source = walker->getInput(0);
532  if (auto project = dynamic_cast<const RelProject*>(walker)) {
533  if (auto input = dynamic_cast<const RexInput*>(project->getProjectAt(curr_col))) {
534  const auto join_source = dynamic_cast<const RelJoin*>(only_source);
535  if (join_source) {
536  CHECK_EQ(size_t(2), join_source->inputCount());
537  auto lhs = join_source->getInput(0);
538  CHECK((input->getIndex() < lhs->size() && lhs == input->getSourceNode()) ||
539  join_source->getInput(1) == input->getSourceNode());
540  } else {
541  CHECK_EQ(input->getSourceNode(), only_source);
542  }
543  curr_col = input->getIndex();
544  } else {
545  break;
546  }
547  } else if (auto aggregate = dynamic_cast<const RelAggregate*>(walker)) {
548  if (curr_col >= aggregate->getGroupByCount()) {
549  break;
550  }
551  }
552  walker = only_source;
553  }
554  return work_set;
555 }
556 
557 } // namespace
558 
559 bool RelSort::hasEquivCollationOf(const RelSort& that) const {
560  if (collation_.size() != that.collation_.size()) {
561  return false;
562  }
563 
564  for (size_t i = 0, e = collation_.size(); i < e; ++i) {
565  auto this_sort_key = collation_[i];
566  auto that_sort_key = that.collation_[i];
567  if (this_sort_key.getSortDir() != that_sort_key.getSortDir()) {
568  return false;
569  }
570  if (this_sort_key.getNullsPosition() != that_sort_key.getNullsPosition()) {
571  return false;
572  }
573  auto this_equiv_keys = get_equiv_cols(this, this_sort_key.getField());
574  auto that_equiv_keys = get_equiv_cols(&that, that_sort_key.getField());
575  std::vector<std::pair<const RelAlgNode*, int>> intersect;
576  std::set_intersection(this_equiv_keys.begin(),
577  this_equiv_keys.end(),
578  that_equiv_keys.begin(),
579  that_equiv_keys.end(),
580  std::back_inserter(intersect));
581  if (intersect.empty()) {
582  return false;
583  }
584  }
585  return true;
586 }
587 
588 // class RelLogicalUnion methods
589 
591  : RelAlgNode(std::move(inputs)), is_all_(is_all) {
592  if (!g_enable_union) {
593  throw QueryNotSupported(
594  "UNION is not supported yet. There is an experimental enable-union option "
595  "available to enable UNION ALL queries.");
596  }
597  CHECK_EQ(2u, inputs_.size());
598  if (!is_all_) {
599  throw QueryNotSupported("UNION without ALL is not supported yet.");
600  }
601 }
602 
603 size_t RelLogicalUnion::size() const {
604  return inputs_.front()->size();
605 }
606 
607 std::string RelLogicalUnion::toString() const {
608  return cat(::typeName(this), "(is_all(", is_all_, "))");
609 }
610 
611 size_t RelLogicalUnion::toHash() const {
612  if (!hash_) {
613  hash_ = typeid(RelLogicalUnion).hash_code();
614  boost::hash_combine(*hash_, is_all_);
615  }
616  return *hash_;
617 }
618 
619 std::string RelLogicalUnion::getFieldName(const size_t i) const {
620  if (auto const* input = dynamic_cast<RelCompound const*>(inputs_[0].get())) {
621  return input->getFieldName(i);
622  } else if (auto const* input = dynamic_cast<RelProject const*>(inputs_[0].get())) {
623  return input->getFieldName(i);
624  } else if (auto const* input = dynamic_cast<RelLogicalUnion const*>(inputs_[0].get())) {
625  return input->getFieldName(i);
626  } else if (auto const* input = dynamic_cast<RelAggregate const*>(inputs_[0].get())) {
627  return input->getFieldName(i);
628  } else if (auto const* input = dynamic_cast<RelScan const*>(inputs_[0].get())) {
629  return input->getFieldName(i);
630  } else if (auto const* input =
631  dynamic_cast<RelTableFunction const*>(inputs_[0].get())) {
632  return input->getFieldName(i);
633  }
634  UNREACHABLE() << "Unhandled input type: " << ::toString(inputs_.front());
635  return {};
636 }
637 
639  std::vector<TargetMetaInfo> const& tmis0 = inputs_[0]->getOutputMetainfo();
640  std::vector<TargetMetaInfo> const& tmis1 = inputs_[1]->getOutputMetainfo();
641  if (tmis0.size() != tmis1.size()) {
642  VLOG(2) << "tmis0.size() = " << tmis0.size() << " != " << tmis1.size()
643  << " = tmis1.size()";
644  throw std::runtime_error("Subqueries of a UNION must have matching data types.");
645  }
646  for (size_t i = 0; i < tmis0.size(); ++i) {
647  if (tmis0[i].get_type_info() != tmis1[i].get_type_info()) {
648  SQLTypeInfo const& ti0 = tmis0[i].get_type_info();
649  SQLTypeInfo const& ti1 = tmis1[i].get_type_info();
650  VLOG(2) << "Types do not match for UNION:\n tmis0[" << i
651  << "].get_type_info().to_string() = " << ti0.to_string() << "\n tmis1["
652  << i << "].get_type_info().to_string() = " << ti1.to_string();
653  if (ti0.get_comp_param() != ti1.get_comp_param()) {
654  if (!ti0.is_dict_encoded_string() || !ti1.is_dict_encoded_string()) {
655  throw std::runtime_error(
656  "Subqueries of a UNION must have the exact same data types.");
657  }
658  }
659  }
660  }
661 }
662 
663 // Rest of code requires a raw pointer, but RexInput object needs to live somewhere.
665  size_t input_idx) const {
666  if (auto const* rex_input_ptr = dynamic_cast<RexInput const*>(rex_scalar)) {
667  RexInput rex_input(*rex_input_ptr);
668  rex_input.setSourceNode(getInput(input_idx));
669  scalar_exprs_.emplace_back(std::make_shared<RexInput const>(std::move(rex_input)));
670  return scalar_exprs_.back().get();
671  }
672  return rex_scalar;
673 }
674 
675 namespace {
676 
677 unsigned node_id(const rapidjson::Value& ra_node) noexcept {
678  const auto& id = field(ra_node, "id");
679  return std::stoi(json_str(id));
680 }
681 
682 std::string json_node_to_string(const rapidjson::Value& node) noexcept {
683  rapidjson::StringBuffer buffer;
684  rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
685  node.Accept(writer);
686  return buffer.GetString();
687 }
688 
689 // The parse_* functions below de-serialize expressions as they come from Calcite.
690 // RelAlgDagBuilder will take care of making the representation easy to
691 // navigate for lower layers, for example by replacing RexAbstractInput with RexInput.
692 
693 std::unique_ptr<RexAbstractInput> parse_abstract_input(
694  const rapidjson::Value& expr) noexcept {
695  const auto& input = field(expr, "input");
696  return std::unique_ptr<RexAbstractInput>(new RexAbstractInput(json_i64(input)));
697 }
698 
699 std::unique_ptr<RexLiteral> parse_literal(const rapidjson::Value& expr) {
700  CHECK(expr.IsObject());
701  const auto& literal = field(expr, "literal");
702  const auto type = to_sql_type(json_str(field(expr, "type")));
703  const auto target_type = to_sql_type(json_str(field(expr, "target_type")));
704  const auto scale = json_i64(field(expr, "scale"));
705  const auto precision = json_i64(field(expr, "precision"));
706  const auto type_scale = json_i64(field(expr, "type_scale"));
707  const auto type_precision = json_i64(field(expr, "type_precision"));
708  if (literal.IsNull()) {
709  return std::unique_ptr<RexLiteral>(new RexLiteral(target_type));
710  }
711  switch (type) {
712  case kINT:
713  case kBIGINT:
714  case kDECIMAL:
715  case kINTERVAL_DAY_TIME:
717  case kTIME:
718  case kTIMESTAMP:
719  case kDATE:
720  return std::unique_ptr<RexLiteral>(new RexLiteral(json_i64(literal),
721  type,
722  target_type,
723  scale,
724  precision,
725  type_scale,
726  type_precision));
727  case kDOUBLE: {
728  if (literal.IsDouble()) {
729  return std::unique_ptr<RexLiteral>(new RexLiteral(json_double(literal),
730  type,
731  target_type,
732  scale,
733  precision,
734  type_scale,
735  type_precision));
736  } else if (literal.IsInt64()) {
737  return std::make_unique<RexLiteral>(static_cast<double>(literal.GetInt64()),
738  type,
739  target_type,
740  scale,
741  precision,
742  type_scale,
743  type_precision);
744 
745  } else if (literal.IsUint64()) {
746  return std::make_unique<RexLiteral>(static_cast<double>(literal.GetUint64()),
747  type,
748  target_type,
749  scale,
750  precision,
751  type_scale,
752  type_precision);
753  }
754  UNREACHABLE() << "Unhandled type: " << literal.GetType();
755  }
756  case kTEXT:
757  return std::unique_ptr<RexLiteral>(new RexLiteral(json_str(literal),
758  type,
759  target_type,
760  scale,
761  precision,
762  type_scale,
763  type_precision));
764  case kBOOLEAN:
765  return std::unique_ptr<RexLiteral>(new RexLiteral(json_bool(literal),
766  type,
767  target_type,
768  scale,
769  precision,
770  type_scale,
771  type_precision));
772  case kNULLT:
773  return std::unique_ptr<RexLiteral>(new RexLiteral(target_type));
774  default:
775  CHECK(false);
776  }
777  CHECK(false);
778  return nullptr;
779 }
780 
781 std::unique_ptr<const RexScalar> parse_scalar_expr(const rapidjson::Value& expr,
783  RelAlgDagBuilder& root_dag_builder);
784 
785 SQLTypeInfo parse_type(const rapidjson::Value& type_obj) {
786  if (type_obj.IsArray()) {
787  throw QueryNotSupported("Composite types are not currently supported.");
788  }
789  CHECK(type_obj.IsObject() && type_obj.MemberCount() >= 2)
790  << json_node_to_string(type_obj);
791  const auto type = to_sql_type(json_str(field(type_obj, "type")));
792  const auto nullable = json_bool(field(type_obj, "nullable"));
793  const auto precision_it = type_obj.FindMember("precision");
794  const int precision =
795  precision_it != type_obj.MemberEnd() ? json_i64(precision_it->value) : 0;
796  const auto scale_it = type_obj.FindMember("scale");
797  const int scale = scale_it != type_obj.MemberEnd() ? json_i64(scale_it->value) : 0;
798  SQLTypeInfo ti(type, !nullable);
799  ti.set_precision(precision);
800  ti.set_scale(scale);
801  return ti;
802 }
803 
804 std::vector<std::unique_ptr<const RexScalar>> parse_expr_array(
805  const rapidjson::Value& arr,
807  RelAlgDagBuilder& root_dag_builder) {
808  std::vector<std::unique_ptr<const RexScalar>> exprs;
809  for (auto it = arr.Begin(); it != arr.End(); ++it) {
810  exprs.emplace_back(parse_scalar_expr(*it, cat, root_dag_builder));
811  }
812  return exprs;
813 }
814 
816  if (name == "ROW_NUMBER") {
818  }
819  if (name == "RANK") {
821  }
822  if (name == "DENSE_RANK") {
824  }
825  if (name == "PERCENT_RANK") {
827  }
828  if (name == "CUME_DIST") {
830  }
831  if (name == "NTILE") {
833  }
834  if (name == "LAG") {
836  }
837  if (name == "LEAD") {
839  }
840  if (name == "FIRST_VALUE") {
842  }
843  if (name == "LAST_VALUE") {
845  }
846  if (name == "AVG") {
848  }
849  if (name == "MIN") {
851  }
852  if (name == "MAX") {
854  }
855  if (name == "SUM") {
857  }
858  if (name == "COUNT") {
860  }
861  if (name == "$SUM0") {
863  }
864  throw std::runtime_error("Unsupported window function: " + name);
865 }
866 
867 std::vector<std::unique_ptr<const RexScalar>> parse_window_order_exprs(
868  const rapidjson::Value& arr,
870  RelAlgDagBuilder& root_dag_builder) {
871  std::vector<std::unique_ptr<const RexScalar>> exprs;
872  for (auto it = arr.Begin(); it != arr.End(); ++it) {
873  exprs.emplace_back(parse_scalar_expr(field(*it, "field"), cat, root_dag_builder));
874  }
875  return exprs;
876 }
877 
878 SortDirection parse_sort_direction(const rapidjson::Value& collation) {
879  return json_str(field(collation, "direction")) == std::string("DESCENDING")
882 }
883 
884 NullSortedPosition parse_nulls_position(const rapidjson::Value& collation) {
885  return json_str(field(collation, "nulls")) == std::string("FIRST")
888 }
889 
890 std::vector<SortField> parse_window_order_collation(const rapidjson::Value& arr,
892  RelAlgDagBuilder& root_dag_builder) {
893  std::vector<SortField> collation;
894  size_t field_idx = 0;
895  for (auto it = arr.Begin(); it != arr.End(); ++it, ++field_idx) {
896  const auto sort_dir = parse_sort_direction(*it);
897  const auto null_pos = parse_nulls_position(*it);
898  collation.emplace_back(field_idx, sort_dir, null_pos);
899  }
900  return collation;
901 }
902 
904  const rapidjson::Value& window_bound_obj,
906  RelAlgDagBuilder& root_dag_builder) {
907  CHECK(window_bound_obj.IsObject());
909  window_bound.unbounded = json_bool(field(window_bound_obj, "unbounded"));
910  window_bound.preceding = json_bool(field(window_bound_obj, "preceding"));
911  window_bound.following = json_bool(field(window_bound_obj, "following"));
912  window_bound.is_current_row = json_bool(field(window_bound_obj, "is_current_row"));
913  const auto& offset_field = field(window_bound_obj, "offset");
914  if (offset_field.IsObject()) {
915  window_bound.offset = parse_scalar_expr(offset_field, cat, root_dag_builder);
916  } else {
917  CHECK(offset_field.IsNull());
918  }
919  window_bound.order_key = json_i64(field(window_bound_obj, "order_key"));
920  return window_bound;
921 }
922 
923 std::unique_ptr<const RexSubQuery> parse_subquery(const rapidjson::Value& expr,
925  RelAlgDagBuilder& root_dag_builder) {
926  const auto& operands = field(expr, "operands");
927  CHECK(operands.IsArray());
928  CHECK_GE(operands.Size(), unsigned(0));
929  const auto& subquery_ast = field(expr, "subquery");
930 
931  RelAlgDagBuilder subquery_dag(root_dag_builder, subquery_ast, cat, nullptr);
932  auto subquery = std::make_shared<RexSubQuery>(subquery_dag.getRootNodeShPtr());
933  root_dag_builder.registerSubquery(subquery);
934  return subquery->deepCopy();
935 }
936 
937 std::unique_ptr<RexOperator> parse_operator(const rapidjson::Value& expr,
939  RelAlgDagBuilder& root_dag_builder) {
940  const auto op_name = json_str(field(expr, "op"));
941  const bool is_quantifier =
942  op_name == std::string("PG_ANY") || op_name == std::string("PG_ALL");
943  const auto op = is_quantifier ? kFUNCTION : to_sql_op(op_name);
944  const auto& operators_json_arr = field(expr, "operands");
945  CHECK(operators_json_arr.IsArray());
946  auto operands = parse_expr_array(operators_json_arr, cat, root_dag_builder);
947  const auto type_it = expr.FindMember("type");
948  CHECK(type_it != expr.MemberEnd());
949  auto ti = parse_type(type_it->value);
950  if (op == kIN && expr.HasMember("subquery")) {
951  auto subquery = parse_subquery(expr, cat, root_dag_builder);
952  operands.emplace_back(std::move(subquery));
953  }
954  if (expr.FindMember("partition_keys") != expr.MemberEnd()) {
955  const auto& partition_keys_arr = field(expr, "partition_keys");
956  auto partition_keys = parse_expr_array(partition_keys_arr, cat, root_dag_builder);
957  const auto& order_keys_arr = field(expr, "order_keys");
958  auto order_keys = parse_window_order_exprs(order_keys_arr, cat, root_dag_builder);
959  const auto collation =
960  parse_window_order_collation(order_keys_arr, cat, root_dag_builder);
961  const auto kind = parse_window_function_kind(op_name);
962  const auto lower_bound =
963  parse_window_bound(field(expr, "lower_bound"), cat, root_dag_builder);
964  const auto upper_bound =
965  parse_window_bound(field(expr, "upper_bound"), cat, root_dag_builder);
966  bool is_rows = json_bool(field(expr, "is_rows"));
967  ti.set_notnull(false);
968  return std::make_unique<RexWindowFunctionOperator>(kind,
969  operands,
970  partition_keys,
971  order_keys,
972  collation,
973  lower_bound,
974  upper_bound,
975  is_rows,
976  ti);
977  }
978  return std::unique_ptr<RexOperator>(op == kFUNCTION
979  ? new RexFunctionOperator(op_name, operands, ti)
980  : new RexOperator(op, operands, ti));
981 }
982 
983 std::unique_ptr<RexCase> parse_case(const rapidjson::Value& expr,
985  RelAlgDagBuilder& root_dag_builder) {
986  const auto& operands = field(expr, "operands");
987  CHECK(operands.IsArray());
988  CHECK_GE(operands.Size(), unsigned(2));
989  std::unique_ptr<const RexScalar> else_expr;
990  std::vector<
991  std::pair<std::unique_ptr<const RexScalar>, std::unique_ptr<const RexScalar>>>
992  expr_pair_list;
993  for (auto operands_it = operands.Begin(); operands_it != operands.End();) {
994  auto when_expr = parse_scalar_expr(*operands_it++, cat, root_dag_builder);
995  if (operands_it == operands.End()) {
996  else_expr = std::move(when_expr);
997  break;
998  }
999  auto then_expr = parse_scalar_expr(*operands_it++, cat, root_dag_builder);
1000  expr_pair_list.emplace_back(std::move(when_expr), std::move(then_expr));
1001  }
1002  return std::unique_ptr<RexCase>(new RexCase(expr_pair_list, else_expr));
1003 }
1004 
1005 std::vector<std::string> strings_from_json_array(
1006  const rapidjson::Value& json_str_arr) noexcept {
1007  CHECK(json_str_arr.IsArray());
1008  std::vector<std::string> fields;
1009  for (auto json_str_arr_it = json_str_arr.Begin(); json_str_arr_it != json_str_arr.End();
1010  ++json_str_arr_it) {
1011  CHECK(json_str_arr_it->IsString());
1012  fields.emplace_back(json_str_arr_it->GetString());
1013  }
1014  return fields;
1015 }
1016 
1017 std::vector<size_t> indices_from_json_array(
1018  const rapidjson::Value& json_idx_arr) noexcept {
1019  CHECK(json_idx_arr.IsArray());
1020  std::vector<size_t> indices;
1021  for (auto json_idx_arr_it = json_idx_arr.Begin(); json_idx_arr_it != json_idx_arr.End();
1022  ++json_idx_arr_it) {
1023  CHECK(json_idx_arr_it->IsInt());
1024  CHECK_GE(json_idx_arr_it->GetInt(), 0);
1025  indices.emplace_back(json_idx_arr_it->GetInt());
1026  }
1027  return indices;
1028 }
1029 
1030 std::unique_ptr<const RexAgg> parse_aggregate_expr(const rapidjson::Value& expr) {
1031  const auto agg_str = json_str(field(expr, "agg"));
1032  if (agg_str == "APPROX_QUANTILE") {
1033  LOG(INFO) << "APPROX_QUANTILE is deprecated. Please use APPROX_PERCENTILE instead.";
1034  }
1035  const auto agg = to_agg_kind(agg_str);
1036  const auto distinct = json_bool(field(expr, "distinct"));
1037  const auto agg_ti = parse_type(field(expr, "type"));
1038  const auto operands = indices_from_json_array(field(expr, "operands"));
1039  if (operands.size() > 1 && (operands.size() != 2 || (agg != kAPPROX_COUNT_DISTINCT &&
1040  agg != kAPPROX_QUANTILE))) {
1041  throw QueryNotSupported("Multiple arguments for aggregates aren't supported");
1042  }
1043  return std::unique_ptr<const RexAgg>(new RexAgg(agg, distinct, agg_ti, operands));
1044 }
1045 
1046 std::unique_ptr<const RexScalar> parse_scalar_expr(const rapidjson::Value& expr,
1048  RelAlgDagBuilder& root_dag_builder) {
1049  CHECK(expr.IsObject());
1050  if (expr.IsObject() && expr.HasMember("input")) {
1051  return std::unique_ptr<const RexScalar>(parse_abstract_input(expr));
1052  }
1053  if (expr.IsObject() && expr.HasMember("literal")) {
1054  return std::unique_ptr<const RexScalar>(parse_literal(expr));
1055  }
1056  if (expr.IsObject() && expr.HasMember("op")) {
1057  const auto op_str = json_str(field(expr, "op"));
1058  if (op_str == std::string("CASE")) {
1059  return std::unique_ptr<const RexScalar>(parse_case(expr, cat, root_dag_builder));
1060  }
1061  if (op_str == std::string("$SCALAR_QUERY")) {
1062  return std::unique_ptr<const RexScalar>(
1063  parse_subquery(expr, cat, root_dag_builder));
1064  }
1065  return std::unique_ptr<const RexScalar>(parse_operator(expr, cat, root_dag_builder));
1066  }
1067  throw QueryNotSupported("Expression node " + json_node_to_string(expr) +
1068  " not supported");
1069 }
1070 
1071 JoinType to_join_type(const std::string& join_type_name) {
1072  if (join_type_name == "inner") {
1073  return JoinType::INNER;
1074  }
1075  if (join_type_name == "left") {
1076  return JoinType::LEFT;
1077  }
1078  if (join_type_name == "semi") {
1079  return JoinType::SEMI;
1080  }
1081  if (join_type_name == "anti") {
1082  return JoinType::ANTI;
1083  }
1084  throw QueryNotSupported("Join type (" + join_type_name + ") not supported");
1085 }
1086 
1087 std::unique_ptr<const RexScalar> disambiguate_rex(const RexScalar*, const RANodeOutput&);
1088 
1089 std::unique_ptr<const RexOperator> disambiguate_operator(
1090  const RexOperator* rex_operator,
1091  const RANodeOutput& ra_output) noexcept {
1092  std::vector<std::unique_ptr<const RexScalar>> disambiguated_operands;
1093  for (size_t i = 0; i < rex_operator->size(); ++i) {
1094  auto operand = rex_operator->getOperand(i);
1095  if (dynamic_cast<const RexSubQuery*>(operand)) {
1096  disambiguated_operands.emplace_back(rex_operator->getOperandAndRelease(i));
1097  } else {
1098  disambiguated_operands.emplace_back(disambiguate_rex(operand, ra_output));
1099  }
1100  }
1101  const auto rex_window_function_operator =
1102  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
1103  if (rex_window_function_operator) {
1104  const auto& partition_keys = rex_window_function_operator->getPartitionKeys();
1105  std::vector<std::unique_ptr<const RexScalar>> disambiguated_partition_keys;
1106  for (const auto& partition_key : partition_keys) {
1107  disambiguated_partition_keys.emplace_back(
1108  disambiguate_rex(partition_key.get(), ra_output));
1109  }
1110  std::vector<std::unique_ptr<const RexScalar>> disambiguated_order_keys;
1111  const auto& order_keys = rex_window_function_operator->getOrderKeys();
1112  for (const auto& order_key : order_keys) {
1113  disambiguated_order_keys.emplace_back(disambiguate_rex(order_key.get(), ra_output));
1114  }
1115  return rex_window_function_operator->disambiguatedOperands(
1116  disambiguated_operands,
1117  disambiguated_partition_keys,
1118  disambiguated_order_keys,
1119  rex_window_function_operator->getCollation());
1120  }
1121  return rex_operator->getDisambiguated(disambiguated_operands);
1122 }
1123 
1124 std::unique_ptr<const RexCase> disambiguate_case(const RexCase* rex_case,
1125  const RANodeOutput& ra_output) {
1126  std::vector<
1127  std::pair<std::unique_ptr<const RexScalar>, std::unique_ptr<const RexScalar>>>
1128  disambiguated_expr_pair_list;
1129  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
1130  auto disambiguated_when = disambiguate_rex(rex_case->getWhen(i), ra_output);
1131  auto disambiguated_then = disambiguate_rex(rex_case->getThen(i), ra_output);
1132  disambiguated_expr_pair_list.emplace_back(std::move(disambiguated_when),
1133  std::move(disambiguated_then));
1134  }
1135  std::unique_ptr<const RexScalar> disambiguated_else{
1136  disambiguate_rex(rex_case->getElse(), ra_output)};
1137  return std::unique_ptr<const RexCase>(
1138  new RexCase(disambiguated_expr_pair_list, disambiguated_else));
1139 }
1140 
1141 // The inputs used by scalar expressions are given as indices in the serialized
1142 // representation of the query. This is hard to navigate; make the relationship
1143 // explicit by creating RexInput expressions which hold a pointer to the source
1144 // relational algebra node and the index relative to the output of that node.
1145 std::unique_ptr<const RexScalar> disambiguate_rex(const RexScalar* rex_scalar,
1146  const RANodeOutput& ra_output) {
1147  const auto rex_abstract_input = dynamic_cast<const RexAbstractInput*>(rex_scalar);
1148  if (rex_abstract_input) {
1149  CHECK_LT(static_cast<size_t>(rex_abstract_input->getIndex()), ra_output.size());
1150  return std::unique_ptr<const RexInput>(
1151  new RexInput(ra_output[rex_abstract_input->getIndex()]));
1152  }
1153  const auto rex_operator = dynamic_cast<const RexOperator*>(rex_scalar);
1154  if (rex_operator) {
1155  return disambiguate_operator(rex_operator, ra_output);
1156  }
1157  const auto rex_case = dynamic_cast<const RexCase*>(rex_scalar);
1158  if (rex_case) {
1159  return disambiguate_case(rex_case, ra_output);
1160  }
1161  const auto rex_literal = dynamic_cast<const RexLiteral*>(rex_scalar);
1162  CHECK(rex_literal);
1163  return std::unique_ptr<const RexLiteral>(new RexLiteral(*rex_literal));
1164 }
1165 
1166 void bind_project_to_input(RelProject* project_node, const RANodeOutput& input) noexcept {
1167  CHECK_EQ(size_t(1), project_node->inputCount());
1168  std::vector<std::unique_ptr<const RexScalar>> disambiguated_exprs;
1169  for (size_t i = 0; i < project_node->size(); ++i) {
1170  const auto projected_expr = project_node->getProjectAt(i);
1171  if (dynamic_cast<const RexSubQuery*>(projected_expr)) {
1172  disambiguated_exprs.emplace_back(project_node->getProjectAtAndRelease(i));
1173  } else {
1174  disambiguated_exprs.emplace_back(disambiguate_rex(projected_expr, input));
1175  }
1176  }
1177  project_node->setExpressions(disambiguated_exprs);
1178 }
1179 
1181  const RANodeOutput& input) noexcept {
1182  std::vector<std::unique_ptr<const RexScalar>> disambiguated_exprs;
1183  for (size_t i = 0; i < table_func_node->getTableFuncInputsSize(); ++i) {
1184  const auto target_expr = table_func_node->getTableFuncInputAt(i);
1185  if (dynamic_cast<const RexSubQuery*>(target_expr)) {
1186  disambiguated_exprs.emplace_back(table_func_node->getTableFuncInputAtAndRelease(i));
1187  } else {
1188  disambiguated_exprs.emplace_back(disambiguate_rex(target_expr, input));
1189  }
1190  }
1191  table_func_node->setTableFuncInputs(disambiguated_exprs);
1192 }
1193 
1194 void bind_inputs(const std::vector<std::shared_ptr<RelAlgNode>>& nodes) noexcept {
1195  for (auto ra_node : nodes) {
1196  const auto filter_node = std::dynamic_pointer_cast<RelFilter>(ra_node);
1197  if (filter_node) {
1198  CHECK_EQ(size_t(1), filter_node->inputCount());
1199  auto disambiguated_condition = disambiguate_rex(
1200  filter_node->getCondition(), get_node_output(filter_node->getInput(0)));
1201  filter_node->setCondition(disambiguated_condition);
1202  continue;
1203  }
1204  const auto join_node = std::dynamic_pointer_cast<RelJoin>(ra_node);
1205  if (join_node) {
1206  CHECK_EQ(size_t(2), join_node->inputCount());
1207  auto disambiguated_condition =
1208  disambiguate_rex(join_node->getCondition(), get_node_output(join_node.get()));
1209  join_node->setCondition(disambiguated_condition);
1210  continue;
1211  }
1212  const auto project_node = std::dynamic_pointer_cast<RelProject>(ra_node);
1213  if (project_node) {
1214  bind_project_to_input(project_node.get(),
1215  get_node_output(project_node->getInput(0)));
1216  continue;
1217  }
1218  const auto table_func_node = std::dynamic_pointer_cast<RelTableFunction>(ra_node);
1219  if (table_func_node) {
1220  /*
1221  Collect all inputs from table function input (non-literal)
1222  arguments.
1223  */
1224  RANodeOutput input;
1225  input.reserve(table_func_node->inputCount());
1226  for (size_t i = 0; i < table_func_node->inputCount(); i++) {
1227  auto node_output = get_node_output(table_func_node->getInput(i));
1228  input.insert(input.end(), node_output.begin(), node_output.end());
1229  }
1230  bind_table_func_to_input(table_func_node.get(), input);
1231  }
1232  }
1233 }
1234 
1235 void handleQueryHint(const std::vector<std::shared_ptr<RelAlgNode>>& nodes,
1236  RelAlgDagBuilder* dag_builder) noexcept {
1237  // query hint is delivered by the above three nodes
1238  // when a query block has top-sort node, a hint is registered to
1239  // one of the node which locates at the nearest from the sort node
1240  RegisteredQueryHint global_query_hint;
1241  for (auto node : nodes) {
1242  Hints* hint_delivered = nullptr;
1243  const auto agg_node = std::dynamic_pointer_cast<RelAggregate>(node);
1244  if (agg_node) {
1245  if (agg_node->hasDeliveredHint()) {
1246  hint_delivered = agg_node->getDeliveredHints();
1247  }
1248  }
1249  const auto project_node = std::dynamic_pointer_cast<RelProject>(node);
1250  if (project_node) {
1251  if (project_node->hasDeliveredHint()) {
1252  hint_delivered = project_node->getDeliveredHints();
1253  }
1254  }
1255  const auto compound_node = std::dynamic_pointer_cast<RelCompound>(node);
1256  if (compound_node) {
1257  if (compound_node->hasDeliveredHint()) {
1258  hint_delivered = compound_node->getDeliveredHints();
1259  }
1260  }
1261  if (hint_delivered && !hint_delivered->empty()) {
1262  dag_builder->registerQueryHints(node, hint_delivered, global_query_hint);
1263  }
1264  }
1265  dag_builder->setGlobalQueryHints(global_query_hint);
1266 }
1267 
1268 void mark_nops(const std::vector<std::shared_ptr<RelAlgNode>>& nodes) noexcept {
1269  for (auto node : nodes) {
1270  const auto agg_node = std::dynamic_pointer_cast<RelAggregate>(node);
1271  if (!agg_node || agg_node->getAggExprsCount()) {
1272  continue;
1273  }
1274  CHECK_EQ(size_t(1), node->inputCount());
1275  const auto agg_input_node = dynamic_cast<const RelAggregate*>(node->getInput(0));
1276  if (agg_input_node && !agg_input_node->getAggExprsCount() &&
1277  agg_node->getGroupByCount() == agg_input_node->getGroupByCount()) {
1278  agg_node->markAsNop();
1279  }
1280  }
1281 }
1282 
1283 namespace {
1284 
1285 std::vector<const Rex*> reproject_targets(
1286  const RelProject* simple_project,
1287  const std::vector<const Rex*>& target_exprs) noexcept {
1288  std::vector<const Rex*> result;
1289  for (size_t i = 0; i < simple_project->size(); ++i) {
1290  const auto input_rex = dynamic_cast<const RexInput*>(simple_project->getProjectAt(i));
1291  CHECK(input_rex);
1292  CHECK_LT(static_cast<size_t>(input_rex->getIndex()), target_exprs.size());
1293  result.push_back(target_exprs[input_rex->getIndex()]);
1294  }
1295  return result;
1296 }
1297 
1304  public:
1306  const RelAlgNode* node_to_keep,
1307  const std::vector<std::unique_ptr<const RexScalar>>& scalar_sources)
1308  : node_to_keep_(node_to_keep), scalar_sources_(scalar_sources) {}
1309 
1310  // Reproject the RexInput from its current RA Node to the RA Node we intend to keep
1311  RetType visitInput(const RexInput* input) const final {
1312  if (input->getSourceNode() == node_to_keep_) {
1313  const auto index = input->getIndex();
1314  CHECK_LT(index, scalar_sources_.size());
1315  return visit(scalar_sources_[index].get());
1316  } else {
1317  return input->deepCopy();
1318  }
1319  }
1320 
1321  private:
1323  const std::vector<std::unique_ptr<const RexScalar>>& scalar_sources_;
1324 };
1325 
1326 } // namespace
1327 
1329  std::vector<std::shared_ptr<RelAlgNode>>& nodes,
1330  const std::vector<size_t>& pattern,
1331  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
1332  query_hints) noexcept {
1333  CHECK_GE(pattern.size(), size_t(2));
1334  CHECK_LE(pattern.size(), size_t(4));
1335 
1336  std::unique_ptr<const RexScalar> filter_rex;
1337  std::vector<std::unique_ptr<const RexScalar>> scalar_sources;
1338  size_t groupby_count{0};
1339  std::vector<std::string> fields;
1340  std::vector<const RexAgg*> agg_exprs;
1341  std::vector<const Rex*> target_exprs;
1342  bool first_project{true};
1343  bool is_agg{false};
1344  RelAlgNode* last_node{nullptr};
1345 
1346  std::shared_ptr<ModifyManipulationTarget> manipulation_target;
1347  size_t node_hash{0};
1348  unsigned node_id{0};
1349  bool hint_registered{false};
1350  RegisteredQueryHint registered_query_hint = RegisteredQueryHint::defaults();
1351  for (const auto node_idx : pattern) {
1352  const auto ra_node = nodes[node_idx];
1353  auto registered_query_hint_map_it = query_hints.find(ra_node->toHash());
1354  if (registered_query_hint_map_it != query_hints.end()) {
1355  auto& registered_query_hint_map = registered_query_hint_map_it->second;
1356  auto registered_query_hint_it = registered_query_hint_map.find(ra_node->getId());
1357  if (registered_query_hint_it != registered_query_hint_map.end()) {
1358  hint_registered = true;
1359  node_hash = registered_query_hint_map_it->first;
1360  node_id = registered_query_hint_it->first;
1361  registered_query_hint = registered_query_hint_it->second;
1362  }
1363  }
1364  const auto ra_filter = std::dynamic_pointer_cast<RelFilter>(ra_node);
1365  if (ra_filter) {
1366  CHECK(!filter_rex);
1367  filter_rex.reset(ra_filter->getAndReleaseCondition());
1368  CHECK(filter_rex);
1369  last_node = ra_node.get();
1370  continue;
1371  }
1372  const auto ra_project = std::dynamic_pointer_cast<RelProject>(ra_node);
1373  if (ra_project) {
1374  fields = ra_project->getFields();
1375  manipulation_target = ra_project;
1376 
1377  if (first_project) {
1378  CHECK_EQ(size_t(1), ra_project->inputCount());
1379  // Rebind the input of the project to the input of the filter itself
1380  // since we know that we'll evaluate the filter on the fly, with no
1381  // intermediate buffer.
1382  const auto filter_input = dynamic_cast<const RelFilter*>(ra_project->getInput(0));
1383  if (filter_input) {
1384  CHECK_EQ(size_t(1), filter_input->inputCount());
1385  bind_project_to_input(ra_project.get(),
1386  get_node_output(filter_input->getInput(0)));
1387  }
1388  scalar_sources = ra_project->getExpressionsAndRelease();
1389  for (const auto& scalar_expr : scalar_sources) {
1390  target_exprs.push_back(scalar_expr.get());
1391  }
1392  first_project = false;
1393  } else {
1394  if (ra_project->isSimple()) {
1395  target_exprs = reproject_targets(ra_project.get(), target_exprs);
1396  } else {
1397  // TODO(adb): This is essentially a more general case of simple project, we
1398  // could likely merge the two
1399  std::vector<const Rex*> result;
1400  RexInputReplacementVisitor visitor(last_node, scalar_sources);
1401  for (size_t i = 0; i < ra_project->size(); ++i) {
1402  const auto rex = ra_project->getProjectAt(i);
1403  if (auto rex_input = dynamic_cast<const RexInput*>(rex)) {
1404  const auto index = rex_input->getIndex();
1405  CHECK_LT(index, target_exprs.size());
1406  result.push_back(target_exprs[index]);
1407  } else {
1408  scalar_sources.push_back(visitor.visit(rex));
1409  result.push_back(scalar_sources.back().get());
1410  }
1411  }
1412  target_exprs = result;
1413  }
1414  }
1415  last_node = ra_node.get();
1416  continue;
1417  }
1418  const auto ra_aggregate = std::dynamic_pointer_cast<RelAggregate>(ra_node);
1419  if (ra_aggregate) {
1420  is_agg = true;
1421  fields = ra_aggregate->getFields();
1422  agg_exprs = ra_aggregate->getAggregatesAndRelease();
1423  groupby_count = ra_aggregate->getGroupByCount();
1424  decltype(target_exprs){}.swap(target_exprs);
1425  CHECK_LE(groupby_count, scalar_sources.size());
1426  for (size_t group_idx = 0; group_idx < groupby_count; ++group_idx) {
1427  const auto rex_ref = new RexRef(group_idx + 1);
1428  target_exprs.push_back(rex_ref);
1429  scalar_sources.emplace_back(rex_ref);
1430  }
1431  for (const auto rex_agg : agg_exprs) {
1432  target_exprs.push_back(rex_agg);
1433  }
1434  last_node = ra_node.get();
1435  continue;
1436  }
1437  }
1438 
1439  auto compound_node =
1440  std::make_shared<RelCompound>(filter_rex,
1441  target_exprs,
1442  groupby_count,
1443  agg_exprs,
1444  fields,
1445  scalar_sources,
1446  is_agg,
1447  manipulation_target->isUpdateViaSelect(),
1448  manipulation_target->isDeleteViaSelect(),
1449  manipulation_target->isVarlenUpdateRequired(),
1450  manipulation_target->getModifiedTableDescriptor(),
1451  manipulation_target->getTargetColumns());
1452  auto old_node = nodes[pattern.back()];
1453  nodes[pattern.back()] = compound_node;
1454  auto first_node = nodes[pattern.front()];
1455  CHECK_EQ(size_t(1), first_node->inputCount());
1456  compound_node->addManagedInput(first_node->getAndOwnInput(0));
1457  if (hint_registered) {
1458  // pass the registered hint from the origin node to newly created compound node
1459  // where it is coalesced
1460  auto registered_query_hint_map_it = query_hints.find(node_hash);
1461  CHECK(registered_query_hint_map_it != query_hints.end());
1462  auto registered_query_hint_map = registered_query_hint_map_it->second;
1463  if (registered_query_hint_map.size() > 1) {
1464  registered_query_hint_map.erase(node_id);
1465  } else {
1466  CHECK_EQ(registered_query_hint_map.size(), static_cast<size_t>(1));
1467  query_hints.erase(node_hash);
1468  }
1469  std::unordered_map<unsigned, RegisteredQueryHint> hint_map;
1470  hint_map.emplace(compound_node->getId(), registered_query_hint);
1471  query_hints.emplace(compound_node->toHash(), hint_map);
1472  }
1473  for (size_t i = 0; i < pattern.size() - 1; ++i) {
1474  nodes[pattern[i]].reset();
1475  }
1476  for (auto node : nodes) {
1477  if (!node) {
1478  continue;
1479  }
1480  node->replaceInput(old_node, compound_node);
1481  }
1482 }
1483 
1484 class RANodeIterator : public std::vector<std::shared_ptr<RelAlgNode>>::const_iterator {
1485  using ElementType = std::shared_ptr<RelAlgNode>;
1486  using Super = std::vector<ElementType>::const_iterator;
1487  using Container = std::vector<ElementType>;
1488 
1489  public:
1490  enum class AdvancingMode { DUChain, InOrder };
1491 
1492  explicit RANodeIterator(const Container& nodes)
1493  : Super(nodes.begin()), owner_(nodes), nodeCount_([&nodes]() -> size_t {
1494  size_t non_zero_count = 0;
1495  for (const auto& node : nodes) {
1496  if (node) {
1497  ++non_zero_count;
1498  }
1499  }
1501  }()) {}
1502 
1503  explicit operator size_t() {
1504  return std::distance(owner_.begin(), *static_cast<Super*>(this));
1505  }
1506 
1507  RANodeIterator operator++() = delete;
1508 
1509  void advance(AdvancingMode mode) {
1510  Super& super = *this;
1511  switch (mode) {
1512  case AdvancingMode::DUChain: {
1513  size_t use_count = 0;
1514  Super only_use = owner_.end();
1515  for (Super nodeIt = std::next(super); nodeIt != owner_.end(); ++nodeIt) {
1516  if (!*nodeIt) {
1517  continue;
1518  }
1519  for (size_t i = 0; i < (*nodeIt)->inputCount(); ++i) {
1520  if ((*super) == (*nodeIt)->getAndOwnInput(i)) {
1521  ++use_count;
1522  if (1 == use_count) {
1523  only_use = nodeIt;
1524  } else {
1525  super = owner_.end();
1526  return;
1527  }
1528  }
1529  }
1530  }
1531  super = only_use;
1532  break;
1533  }
1534  case AdvancingMode::InOrder:
1535  for (size_t i = 0; i != owner_.size(); ++i) {
1536  if (!visited_.count(i)) {
1537  super = owner_.begin();
1538  std::advance(super, i);
1539  return;
1540  }
1541  }
1542  super = owner_.end();
1543  break;
1544  default:
1545  CHECK(false);
1546  }
1547  }
1548 
1549  bool allVisited() { return visited_.size() == nodeCount_; }
1550 
1552  visited_.insert(size_t(*this));
1553  Super& super = *this;
1554  return *super;
1555  }
1556 
1557  const ElementType* operator->() { return &(operator*()); }
1558 
1559  private:
1561  const size_t nodeCount_;
1562  std::unordered_set<size_t> visited_;
1563 };
1564 
1565 namespace {
1566 
1567 bool input_can_be_coalesced(const RelAlgNode* parent_node,
1568  const size_t index,
1569  const bool first_rex_is_input) {
1570  if (auto agg_node = dynamic_cast<const RelAggregate*>(parent_node)) {
1571  if (index == 0 && agg_node->getGroupByCount() > 0) {
1572  return true;
1573  } else {
1574  // Is an aggregated target, only allow the project to be elided if the aggregate
1575  // target is simply passed through (i.e. if the top level expression attached to
1576  // the project node is a RexInput expression)
1577  return first_rex_is_input;
1578  }
1579  }
1580  return first_rex_is_input;
1581 }
1582 
1589  public:
1590  bool visitInput(const RexInput* input) const final {
1591  // The top level expression node is checked before we apply the visitor. If we get
1592  // here, this input rex is a child of another rex node, and we handle the can be
1593  // coalesced check slightly differently
1594  return input_can_be_coalesced(input->getSourceNode(), input->getIndex(), false);
1595  }
1596 
1597  bool visitLiteral(const RexLiteral*) const final { return false; }
1598 
1599  bool visitSubQuery(const RexSubQuery*) const final { return false; }
1600 
1601  bool visitRef(const RexRef*) const final { return false; }
1602 
1603  protected:
1604  bool aggregateResult(const bool& aggregate, const bool& next_result) const final {
1605  return aggregate && next_result;
1606  }
1607 
1608  bool defaultResult() const final { return true; }
1609 };
1610 
1611 // Detect the window function SUM pattern: CASE WHEN COUNT() > 0 THEN SUM ELSE 0
1613  const auto case_operator = dynamic_cast<const RexCase*>(rex);
1614  if (case_operator && case_operator->branchCount() == 1) {
1615  const auto then_window =
1616  dynamic_cast<const RexWindowFunctionOperator*>(case_operator->getThen(0));
1617  if (then_window && then_window->getKind() == SqlWindowFunctionKind::SUM_INTERNAL) {
1618  return true;
1619  }
1620  }
1621  return false;
1622 }
1623 
1624 // Detect both window function operators and window function operators embedded in case
1625 // statements (for null handling)
1627  if (dynamic_cast<const RexWindowFunctionOperator*>(rex)) {
1628  return true;
1629  }
1630 
1631  // unwrap from casts, if they exist
1632  const auto rex_cast = dynamic_cast<const RexOperator*>(rex);
1633  if (rex_cast && rex_cast->getOperator() == kCAST) {
1634  CHECK_EQ(rex_cast->size(), size_t(1));
1635  return is_window_function_operator(rex_cast->getOperand(0));
1636  }
1637 
1638  if (is_window_function_sum(rex)) {
1639  return true;
1640  }
1641  // Check for Window Function AVG:
1642  // (CASE WHEN count > 0 THEN sum ELSE 0) / COUNT
1643  const RexOperator* divide_operator = dynamic_cast<const RexOperator*>(rex);
1644  if (divide_operator && divide_operator->getOperator() == kDIVIDE) {
1645  CHECK_EQ(divide_operator->size(), size_t(2));
1646  const auto case_operator =
1647  dynamic_cast<const RexCase*>(divide_operator->getOperand(0));
1648  const auto second_window =
1649  dynamic_cast<const RexWindowFunctionOperator*>(divide_operator->getOperand(1));
1650  if (case_operator && second_window &&
1651  second_window->getKind() == SqlWindowFunctionKind::COUNT) {
1652  if (is_window_function_sum(case_operator)) {
1653  return true;
1654  }
1655  }
1656  }
1657  return false;
1658 }
1659 
1660 } // namespace
1661 
1663  std::vector<std::shared_ptr<RelAlgNode>>& nodes,
1664  const std::vector<const RelAlgNode*>& left_deep_joins,
1665  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
1666  query_hints) {
1667  enum class CoalesceState { Initial, Filter, FirstProject, Aggregate };
1668  std::vector<size_t> crt_pattern;
1669  CoalesceState crt_state{CoalesceState::Initial};
1670 
1671  auto reset_state = [&crt_pattern, &crt_state]() {
1672  crt_state = CoalesceState::Initial;
1673  std::vector<size_t>().swap(crt_pattern);
1674  };
1675 
1676  for (RANodeIterator nodeIt(nodes); !nodeIt.allVisited();) {
1677  const auto ra_node = nodeIt != nodes.end() ? *nodeIt : nullptr;
1678  switch (crt_state) {
1679  case CoalesceState::Initial: {
1680  if (std::dynamic_pointer_cast<const RelFilter>(ra_node) &&
1681  std::find(left_deep_joins.begin(), left_deep_joins.end(), ra_node.get()) ==
1682  left_deep_joins.end()) {
1683  crt_pattern.push_back(size_t(nodeIt));
1684  crt_state = CoalesceState::Filter;
1685  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1686  } else if (auto project_node =
1687  std::dynamic_pointer_cast<const RelProject>(ra_node)) {
1688  if (project_node->hasWindowFunctionExpr()) {
1689  nodeIt.advance(RANodeIterator::AdvancingMode::InOrder);
1690  } else {
1691  crt_pattern.push_back(size_t(nodeIt));
1692  crt_state = CoalesceState::FirstProject;
1693  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1694  }
1695  } else {
1696  nodeIt.advance(RANodeIterator::AdvancingMode::InOrder);
1697  }
1698  break;
1699  }
1700  case CoalesceState::Filter: {
1701  if (auto project_node = std::dynamic_pointer_cast<const RelProject>(ra_node)) {
1702  // Given we now add preceding projects for all window functions following
1703  // RelFilter nodes, the following should never occur
1704  CHECK(!project_node->hasWindowFunctionExpr());
1705  crt_pattern.push_back(size_t(nodeIt));
1706  crt_state = CoalesceState::FirstProject;
1707  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1708  } else {
1709  reset_state();
1710  }
1711  break;
1712  }
1713  case CoalesceState::FirstProject: {
1714  if (std::dynamic_pointer_cast<const RelAggregate>(ra_node)) {
1715  crt_pattern.push_back(size_t(nodeIt));
1716  crt_state = CoalesceState::Aggregate;
1717  nodeIt.advance(RANodeIterator::AdvancingMode::DUChain);
1718  } else {
1719  if (crt_pattern.size() >= 2) {
1720  create_compound(nodes, crt_pattern, query_hints);
1721  }
1722  reset_state();
1723  }
1724  break;
1725  }
1726  case CoalesceState::Aggregate: {
1727  if (auto project_node = std::dynamic_pointer_cast<const RelProject>(ra_node)) {
1728  if (!project_node->hasWindowFunctionExpr()) {
1729  // TODO(adb): overloading the simple project terminology again here
1730  bool is_simple_project{true};
1731  for (size_t i = 0; i < project_node->size(); i++) {
1732  const auto scalar_rex = project_node->getProjectAt(i);
1733  // If the top level scalar rex is an input node, we can bypass the visitor
1734  if (auto input_rex = dynamic_cast<const RexInput*>(scalar_rex)) {
1736  input_rex->getSourceNode(), input_rex->getIndex(), true)) {
1737  is_simple_project = false;
1738  break;
1739  }
1740  continue;
1741  }
1742  CoalesceSecondaryProjectVisitor visitor;
1743  if (!visitor.visit(project_node->getProjectAt(i))) {
1744  is_simple_project = false;
1745  break;
1746  }
1747  }
1748  if (is_simple_project) {
1749  crt_pattern.push_back(size_t(nodeIt));
1750  nodeIt.advance(RANodeIterator::AdvancingMode::InOrder);
1751  }
1752  }
1753  }
1754  CHECK_GE(crt_pattern.size(), size_t(2));
1755  create_compound(nodes, crt_pattern, query_hints);
1756  reset_state();
1757  break;
1758  }
1759  default:
1760  CHECK(false);
1761  }
1762  }
1763  if (crt_state == CoalesceState::FirstProject || crt_state == CoalesceState::Aggregate) {
1764  if (crt_pattern.size() >= 2) {
1765  create_compound(nodes, crt_pattern, query_hints);
1766  }
1767  CHECK(!crt_pattern.empty());
1768  }
1769 }
1770 
1778 class WindowFunctionDetectionVisitor : public RexVisitor<const RexScalar*> {
1779  protected:
1780  // Detect embedded window function expressions in operators
1781  const RexScalar* visitOperator(const RexOperator* rex_operator) const final {
1782  if (is_window_function_operator(rex_operator)) {
1783  return rex_operator;
1784  }
1785 
1786  const size_t operand_count = rex_operator->size();
1787  for (size_t i = 0; i < operand_count; ++i) {
1788  const auto operand = rex_operator->getOperand(i);
1789  if (is_window_function_operator(operand)) {
1790  // Handle both RexWindowFunctionOperators and window functions built up from
1791  // multiple RexScalar objects (e.g. AVG)
1792  return operand;
1793  }
1794  const auto operandResult = visit(operand);
1795  if (operandResult) {
1796  return operandResult;
1797  }
1798  }
1799 
1800  return defaultResult();
1801  }
1802 
1803  // Detect embedded window function expressions in case statements. Note that this may
1804  // manifest as a nested case statement inside a top level case statement, as some
1805  // window functions (sum, avg) are represented as a case statement. Use the
1806  // is_window_function_operator helper to detect complete window function expressions.
1807  const RexScalar* visitCase(const RexCase* rex_case) const final {
1808  if (is_window_function_operator(rex_case)) {
1809  return rex_case;
1810  }
1811 
1812  auto result = defaultResult();
1813  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
1814  const auto when = rex_case->getWhen(i);
1815  result = is_window_function_operator(when) ? when : visit(when);
1816  if (result) {
1817  return result;
1818  }
1819  const auto then = rex_case->getThen(i);
1820  result = is_window_function_operator(then) ? then : visit(then);
1821  if (result) {
1822  return result;
1823  }
1824  }
1825  if (rex_case->getElse()) {
1826  auto else_expr = rex_case->getElse();
1827  result = is_window_function_operator(else_expr) ? else_expr : visit(else_expr);
1828  }
1829  return result;
1830  }
1831 
1832  const RexScalar* aggregateResult(const RexScalar* const& aggregate,
1833  const RexScalar* const& next_result) const final {
1834  // all methods calling aggregate result should be overriden
1835  UNREACHABLE();
1836  return nullptr;
1837  }
1838 
1839  const RexScalar* defaultResult() const final { return nullptr; }
1840 };
1841 
1851  public:
1852  RexWindowFuncReplacementVisitor(std::unique_ptr<const RexScalar> replacement_rex)
1853  : replacement_rex_(std::move(replacement_rex)) {}
1854 
1855  ~RexWindowFuncReplacementVisitor() { CHECK(replacement_rex_ == nullptr); }
1856 
1857  protected:
1858  RetType visitOperator(const RexOperator* rex_operator) const final {
1859  if (should_replace_operand(rex_operator)) {
1860  return std::move(replacement_rex_);
1861  }
1862 
1863  const auto rex_window_function_operator =
1864  dynamic_cast<const RexWindowFunctionOperator*>(rex_operator);
1865  if (rex_window_function_operator) {
1866  // Deep copy the embedded window function operator
1867  return visitWindowFunctionOperator(rex_window_function_operator);
1868  }
1869 
1870  const size_t operand_count = rex_operator->size();
1871  std::vector<RetType> new_opnds;
1872  for (size_t i = 0; i < operand_count; ++i) {
1873  const auto operand = rex_operator->getOperand(i);
1874  if (should_replace_operand(operand)) {
1875  new_opnds.push_back(std::move(replacement_rex_));
1876  } else {
1877  new_opnds.emplace_back(visit(rex_operator->getOperand(i)));
1878  }
1879  }
1880  return rex_operator->getDisambiguated(new_opnds);
1881  }
1882 
1883  RetType visitCase(const RexCase* rex_case) const final {
1884  if (should_replace_operand(rex_case)) {
1885  return std::move(replacement_rex_);
1886  }
1887 
1888  std::vector<std::pair<RetType, RetType>> new_pair_list;
1889  for (size_t i = 0; i < rex_case->branchCount(); ++i) {
1890  auto when_operand = rex_case->getWhen(i);
1891  auto then_operand = rex_case->getThen(i);
1892  new_pair_list.emplace_back(
1893  should_replace_operand(when_operand) ? std::move(replacement_rex_)
1894  : visit(when_operand),
1895  should_replace_operand(then_operand) ? std::move(replacement_rex_)
1896  : visit(then_operand));
1897  }
1898  auto new_else = should_replace_operand(rex_case->getElse())
1899  ? std::move(replacement_rex_)
1900  : visit(rex_case->getElse());
1901  return std::make_unique<RexCase>(new_pair_list, new_else);
1902  }
1903 
1904  private:
1905  bool should_replace_operand(const RexScalar* rex) const {
1906  return replacement_rex_ && is_window_function_operator(rex);
1907  }
1908 
1909  mutable std::unique_ptr<const RexScalar> replacement_rex_;
1910 };
1911 
1922  public:
1923  RexInputBackpropagationVisitor(RelProject* node) : node_(node) { CHECK(node_); }
1924 
1925  protected:
1926  RetType visitInput(const RexInput* rex_input) const final {
1927  if (rex_input->getSourceNode() != node_) {
1928  const auto cur_index = rex_input->getIndex();
1929  auto cur_source_node = rex_input->getSourceNode();
1930  std::string field_name = "";
1931  if (auto cur_project_node = dynamic_cast<const RelProject*>(cur_source_node)) {
1932  field_name = cur_project_node->getFieldName(cur_index);
1933  }
1934  node_->appendInput(field_name, rex_input->deepCopy());
1935  return std::make_unique<RexInput>(node_, node_->size() - 1);
1936  } else {
1937  return rex_input->deepCopy();
1938  }
1939  }
1940 
1941  private:
1942  mutable RelProject* node_;
1943 };
1944 
1946  std::shared_ptr<RelProject> prev_node,
1947  std::shared_ptr<RelProject> new_node,
1948  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
1949  query_hints) {
1950  auto delivered_hints = prev_node->getDeliveredHints();
1951  bool needs_propagate_hints = !delivered_hints->empty();
1952  if (needs_propagate_hints) {
1953  for (auto& kv : *delivered_hints) {
1954  new_node->addHint(kv.second);
1955  }
1956  auto prev_it = query_hints.find(prev_node->toHash());
1957  // query hint for the prev projection node should be registered
1958  CHECK(prev_it != query_hints.end());
1959  auto prev_hint_it = prev_it->second.find(prev_node->getId());
1960  CHECK(prev_hint_it != prev_it->second.end());
1961  std::unordered_map<unsigned, RegisteredQueryHint> hint_map;
1962  hint_map.emplace(new_node->getId(), prev_hint_it->second);
1963  query_hints.emplace(new_node->toHash(), hint_map);
1964  }
1965 }
1966 
1983  std::vector<std::shared_ptr<RelAlgNode>>& nodes,
1984  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
1985  query_hints) {
1986  std::list<std::shared_ptr<RelAlgNode>> node_list(nodes.begin(), nodes.end());
1987 
1989  for (auto node_itr = node_list.begin(); node_itr != node_list.end(); ++node_itr) {
1990  const auto node = *node_itr;
1991  auto window_func_project_node = std::dynamic_pointer_cast<RelProject>(node);
1992  if (!window_func_project_node) {
1993  continue;
1994  }
1995 
1996  // map scalar expression index in the project node to window function ptr
1997  std::unordered_map<size_t, const RexScalar*> embedded_window_function_expressions;
1998 
1999  // Iterate the target exprs of the project node and check for window function
2000  // expressions. If an embedded expression exists, save it in the
2001  // embedded_window_function_expressions map and split the expression into a window
2002  // function expression and a parent expression in a subsequent project node
2003  for (size_t i = 0; i < window_func_project_node->size(); i++) {
2004  const auto scalar_rex = window_func_project_node->getProjectAt(i);
2005  if (is_window_function_operator(scalar_rex)) {
2006  // top level window function exprs are fine
2007  continue;
2008  }
2009 
2010  if (const auto window_func_rex = visitor.visit(scalar_rex)) {
2011  const auto ret = embedded_window_function_expressions.insert(
2012  std::make_pair(i, window_func_rex));
2013  CHECK(ret.second);
2014  }
2015  }
2016 
2017  if (!embedded_window_function_expressions.empty()) {
2018  std::vector<std::unique_ptr<const RexScalar>> new_scalar_exprs;
2019 
2020  auto window_func_scalar_exprs =
2021  window_func_project_node->getExpressionsAndRelease();
2022  for (size_t rex_idx = 0; rex_idx < window_func_scalar_exprs.size(); ++rex_idx) {
2023  const auto embedded_window_func_expr_pair =
2024  embedded_window_function_expressions.find(rex_idx);
2025  if (embedded_window_func_expr_pair ==
2026  embedded_window_function_expressions.end()) {
2027  new_scalar_exprs.emplace_back(
2028  std::make_unique<const RexInput>(window_func_project_node.get(), rex_idx));
2029  } else {
2030  const auto window_func_rex_idx = embedded_window_func_expr_pair->first;
2031  CHECK_LT(window_func_rex_idx, window_func_scalar_exprs.size());
2032 
2033  const auto& window_func_rex = embedded_window_func_expr_pair->second;
2034 
2035  RexDeepCopyVisitor copier;
2036  auto window_func_rex_copy = copier.visit(window_func_rex);
2037 
2038  auto window_func_parent_expr =
2039  window_func_scalar_exprs[window_func_rex_idx].get();
2040 
2041  // Replace window func rex with an input rex
2042  auto window_func_result_input = std::make_unique<const RexInput>(
2043  window_func_project_node.get(), window_func_rex_idx);
2044  RexWindowFuncReplacementVisitor replacer(std::move(window_func_result_input));
2045  auto new_parent_rex = replacer.visit(window_func_parent_expr);
2046 
2047  // Put the parent expr in the new scalar exprs
2048  new_scalar_exprs.emplace_back(std::move(new_parent_rex));
2049 
2050  // Put the window func expr in cur scalar exprs
2051  window_func_scalar_exprs[window_func_rex_idx] = std::move(window_func_rex_copy);
2052  }
2053  }
2054 
2055  CHECK_EQ(window_func_scalar_exprs.size(), new_scalar_exprs.size());
2056  window_func_project_node->setExpressions(window_func_scalar_exprs);
2057 
2058  // Ensure any inputs from the node containing the expression (the "new" node)
2059  // exist on the window function project node, e.g. if we had a binary operation
2060  // involving an aggregate value or column not included in the top level
2061  // projection list.
2062  RexInputBackpropagationVisitor input_visitor(window_func_project_node.get());
2063  for (size_t i = 0; i < new_scalar_exprs.size(); i++) {
2064  if (dynamic_cast<const RexInput*>(new_scalar_exprs[i].get())) {
2065  // ignore top level inputs, these were copied directly from the previous
2066  // node
2067  continue;
2068  }
2069  new_scalar_exprs[i] = input_visitor.visit(new_scalar_exprs[i].get());
2070  }
2071 
2072  // Build the new project node and insert it into the list after the project node
2073  // containing the window function
2074  auto new_project =
2075  std::make_shared<RelProject>(new_scalar_exprs,
2076  window_func_project_node->getFields(),
2077  window_func_project_node);
2078  propagate_hints_to_new_project(window_func_project_node, new_project, query_hints);
2079  node_list.insert(std::next(node_itr), new_project);
2080 
2081  // Rebind all the following inputs
2082  for (auto rebind_itr = std::next(node_itr, 2); rebind_itr != node_list.end();
2083  rebind_itr++) {
2084  (*rebind_itr)->replaceInput(window_func_project_node, new_project);
2085  }
2086  }
2087  }
2088  nodes.assign(node_list.begin(), node_list.end());
2089 }
2090 
2091 using RexInputSet = std::unordered_set<RexInput>;
2092 
2093 class RexInputCollector : public RexVisitor<RexInputSet> {
2094  public:
2095  RexInputSet visitInput(const RexInput* input) const override {
2096  return RexInputSet{*input};
2097  }
2098 
2099  protected:
2101  const RexInputSet& next_result) const override {
2102  auto result = aggregate;
2103  result.insert(next_result.begin(), next_result.end());
2104  return result;
2105  }
2106 };
2107 
2121  std::vector<std::shared_ptr<RelAlgNode>>& nodes,
2122  const bool always_add_project_if_first_project_is_window_expr,
2123  std::unordered_map<size_t, std::unordered_map<unsigned, RegisteredQueryHint>>&
2124  query_hints) {
2125  std::list<std::shared_ptr<RelAlgNode>> node_list(nodes.begin(), nodes.end());
2126  size_t project_node_counter{0};
2127  for (auto node_itr = node_list.begin(); node_itr != node_list.end(); ++node_itr) {
2128  const auto node = *node_itr;
2129 
2130  auto window_func_project_node = std::dynamic_pointer_cast<RelProject>(node);
2131  if (!window_func_project_node) {
2132  continue;
2133  }
2134  project_node_counter++;
2135  if (!window_func_project_node->hasWindowFunctionExpr()) {
2136  // this projection node does not have a window function
2137  // expression -- skip to the next node in the DAG.
2138  continue;
2139  }
2140 
2141  const auto prev_node_itr = std::prev(node_itr);
2142  const auto prev_node = *prev_node_itr;
2143  CHECK(prev_node);
2144 
2145  auto filter_node = std::dynamic_pointer_cast<RelFilter>(prev_node);
2146 
2147  auto scan_node = std::dynamic_pointer_cast<RelScan>(prev_node);
2148  const bool has_multi_fragment_scan_input =
2149  (scan_node && (scan_node->getNumShards() > 0 || scan_node->getNumFragments() > 1))
2150  ? true
2151  : false;
2152 
2153  // We currently add a preceding project node in one of two conditions:
2154  // 1. always_add_project_if_first_project_is_window_expr = true, which
2155  // we currently only set for distributed, but could also be set to support
2156  // multi-frag window function inputs, either if we can detect that an input table
2157  // is multi-frag up front, or using a retry mechanism like we do for join filter
2158  // push down.
2159  // TODO(todd): Investigate a viable approach for the above.
2160  // 2. Regardless of #1, if the window function project node is preceded by a
2161  // filter node. This is required both for correctness and to avoid pulling
2162  // all source input columns into memory since non-coalesced filter node
2163  // inputs are currently not pruned or eliminated via dead column elimination.
2164  // Note that we expect any filter node followed by a project node to be coalesced
2165  // into a single compound node in RelAlgDagBuilder::coalesce_nodes, and that action
2166  // prunes unused inputs.
2167  // TODO(todd): Investigate whether the shotgun filter node issue affects other
2168  // query plans, i.e. filters before joins, and whether there is a more general
2169  // approach to solving this (will still need the preceding project node for
2170  // window functions preceded by filter nodes for correctness though)
2171 
2172  if (!((always_add_project_if_first_project_is_window_expr &&
2173  project_node_counter == 1) ||
2174  filter_node || has_multi_fragment_scan_input)) {
2175  continue;
2176  }
2177 
2178  RexInputSet inputs;
2179  RexInputCollector input_collector;
2180  for (size_t i = 0; i < window_func_project_node->size(); i++) {
2181  auto new_inputs = input_collector.visit(window_func_project_node->getProjectAt(i));
2182  inputs.insert(new_inputs.begin(), new_inputs.end());
2183  }
2184 
2185  // Note: Technically not required since we are mapping old inputs to new input
2186  // indices, but makes the re-mapping of inputs easier to follow.
2187  std::vector<RexInput> sorted_inputs(inputs.begin(), inputs.end());
2188  std::sort(sorted_inputs.begin(),
2189  sorted_inputs.end(),
2190  [](const auto& a, const auto& b) { return a.getIndex() < b.getIndex(); });
2191 
2192  std::vector<std::unique_ptr<const RexScalar>> scalar_exprs;
2193  std::vector<std::string> fields;
2194  std::unordered_map<unsigned, unsigned> old_index_to_new_index;
2195  for (auto& input : sorted_inputs) {
2196  CHECK_EQ(input.getSourceNode(), prev_node.get());
2197  CHECK(old_index_to_new_index
2198  .insert(std::make_pair(input.getIndex(), scalar_exprs.size()))
2199  .second);
2200  scalar_exprs.emplace_back(input.deepCopy());
2201  fields.emplace_back("");
2202  }
2203 
2204  auto new_project = std::make_shared<RelProject>(scalar_exprs, fields, prev_node);
2205  propagate_hints_to_new_project(window_func_project_node, new_project, query_hints);
2206  node_list.insert(node_itr, new_project);
2207  window_func_project_node->replaceInput(
2208  prev_node, new_project, old_index_to_new_index);
2209  }
2210 
2211  nodes.assign(node_list.begin(), node_list.end());
2212 }
2213 
2214 int64_t get_int_literal_field(const rapidjson::Value& obj,
2215  const char field[],
2216  const int64_t default_val) noexcept {
2217  const auto it = obj.FindMember(field);
2218  if (it == obj.MemberEnd()) {
2219  return default_val;
2220  }
2221  std::unique_ptr<RexLiteral> lit(parse_literal(it->value));
2222  CHECK_EQ(kDECIMAL, lit->getType());
2223  CHECK_EQ(unsigned(0), lit->getScale());
2224  CHECK_EQ(unsigned(0), lit->getTargetScale());
2225  return lit->getVal<int64_t>();
2226 }
2227 
2228 void check_empty_inputs_field(const rapidjson::Value& node) noexcept {
2229  const auto& inputs_json = field(node, "inputs");
2230  CHECK(inputs_json.IsArray() && !inputs_json.Size());
2231 }
2232 
2234  const rapidjson::Value& scan_ra) {
2235  const auto& table_json = field(scan_ra, "table");
2236  CHECK(table_json.IsArray());
2237  CHECK_EQ(unsigned(2), table_json.Size());
2238  const auto td = cat.getMetadataForTable(table_json[1].GetString());
2239  CHECK(td);
2240  return td;
2241 }
2242 
2243 std::vector<std::string> getFieldNamesFromScanNode(const rapidjson::Value& scan_ra) {
2244  const auto& fields_json = field(scan_ra, "fieldNames");
2245  return strings_from_json_array(fields_json);
2246 }
2247 
2248 } // namespace
2249 
2251  for (const auto& expr : scalar_exprs_) {
2252  if (is_window_function_operator(expr.get())) {
2253  return true;
2254  }
2255  }
2256  return false;
2257 }
2258 namespace details {
2259 
2261  public:
2263 
2264  std::vector<std::shared_ptr<RelAlgNode>> run(const rapidjson::Value& rels,
2265  RelAlgDagBuilder& root_dag_builder) {
2266  for (auto rels_it = rels.Begin(); rels_it != rels.End(); ++rels_it) {
2267  const auto& crt_node = *rels_it;
2268  const auto id = node_id(crt_node);
2269  CHECK_EQ(static_cast<size_t>(id), nodes_.size());
2270  CHECK(crt_node.IsObject());
2271  std::shared_ptr<RelAlgNode> ra_node = nullptr;
2272  const auto rel_op = json_str(field(crt_node, "relOp"));
2273  if (rel_op == std::string("EnumerableTableScan") ||
2274  rel_op == std::string("LogicalTableScan")) {
2275  ra_node = dispatchTableScan(crt_node);
2276  } else if (rel_op == std::string("LogicalProject")) {
2277  ra_node = dispatchProject(crt_node, root_dag_builder);
2278  } else if (rel_op == std::string("LogicalFilter")) {
2279  ra_node = dispatchFilter(crt_node, root_dag_builder);
2280  } else if (rel_op == std::string("LogicalAggregate")) {
2281  ra_node = dispatchAggregate(crt_node);
2282  } else if (rel_op == std::string("LogicalJoin")) {
2283  ra_node = dispatchJoin(crt_node, root_dag_builder);
2284  } else if (rel_op == std::string("LogicalSort")) {
2285  ra_node = dispatchSort(crt_node);
2286  } else if (rel_op == std::string("LogicalValues")) {
2287  ra_node = dispatchLogicalValues(crt_node);
2288  } else if (rel_op == std::string("LogicalTableModify")) {
2289  ra_node = dispatchModify(crt_node);
2290  } else if (rel_op == std::string("LogicalTableFunctionScan")) {
2291  ra_node = dispatchTableFunction(crt_node, root_dag_builder);
2292  } else if (rel_op == std::string("LogicalUnion")) {
2293  ra_node = dispatchUnion(crt_node);
2294  } else {
2295  throw QueryNotSupported(std::string("Node ") + rel_op + " not supported yet");
2296  }
2297  nodes_.push_back(ra_node);
2298  }
2299 
2300  return std::move(nodes_);
2301  }
2302 
2303  private:
2304  std::shared_ptr<RelScan> dispatchTableScan(const rapidjson::Value& scan_ra) {
2305  check_empty_inputs_field(scan_ra);
2306  CHECK(scan_ra.IsObject());
2307  const auto td = getTableFromScanNode(cat_, scan_ra);
2308  const auto field_names = getFieldNamesFromScanNode(scan_ra);
2309  if (scan_ra.HasMember("hints")) {
2310  auto scan_node = std::make_shared<RelScan>(td, field_names);
2311  getRelAlgHints(scan_ra, scan_node);
2312  return scan_node;
2313  }
2314  return std::make_shared<RelScan>(td, field_names);
2315  }
2316 
2317  std::shared_ptr<RelProject> dispatchProject(const rapidjson::Value& proj_ra,
2318  RelAlgDagBuilder& root_dag_builder) {
2319  const auto inputs = getRelAlgInputs(proj_ra);
2320  CHECK_EQ(size_t(1), inputs.size());
2321  const auto& exprs_json = field(proj_ra, "exprs");
2322  CHECK(exprs_json.IsArray());
2323  std::vector<std::unique_ptr<const RexScalar>> exprs;
2324  for (auto exprs_json_it = exprs_json.Begin(); exprs_json_it != exprs_json.End();
2325  ++exprs_json_it) {
2326  exprs.emplace_back(parse_scalar_expr(*exprs_json_it, cat_, root_dag_builder));
2327  }
2328  const auto& fields = field(proj_ra, "fields");
2329  if (proj_ra.HasMember("hints")) {
2330  auto project_node = std::make_shared<RelProject>(
2331  exprs, strings_from_json_array(fields), inputs.front());
2332  getRelAlgHints(proj_ra, project_node);
2333  return project_node;
2334  }
2335  return std::make_shared<RelProject>(
2336  exprs, strings_from_json_array(fields), inputs.front());
2337  }
2338 
2339  std::shared_ptr<RelFilter> dispatchFilter(const rapidjson::Value& filter_ra,
2340  RelAlgDagBuilder& root_dag_builder) {
2341  const auto inputs = getRelAlgInputs(filter_ra);
2342  CHECK_EQ(size_t(1), inputs.size());
2343  const auto id = node_id(filter_ra);
2344  CHECK(id);
2345  auto condition =
2346  parse_scalar_expr(field(filter_ra, "condition"), cat_, root_dag_builder);
2347  return std::make_shared<RelFilter>(condition, inputs.front());
2348  }
2349 
2350  std::shared_ptr<RelAggregate> dispatchAggregate(const rapidjson::Value& agg_ra) {
2351  const auto inputs = getRelAlgInputs(agg_ra);
2352  CHECK_EQ(size_t(1), inputs.size());
2353  const auto fields = strings_from_json_array(field(agg_ra, "fields"));
2354  const auto group = indices_from_json_array(field(agg_ra, "group"));
2355  for (size_t i = 0; i < group.size(); ++i) {
2356  CHECK_EQ(i, group[i]);
2357  }
2358  if (agg_ra.HasMember("groups") || agg_ra.HasMember("indicator")) {
2359  throw QueryNotSupported("GROUP BY extensions not supported");
2360  }
2361  const auto& aggs_json_arr = field(agg_ra, "aggs");
2362  CHECK(aggs_json_arr.IsArray());
2363  std::vector<std::unique_ptr<const RexAgg>> aggs;
2364  for (auto aggs_json_arr_it = aggs_json_arr.Begin();
2365  aggs_json_arr_it != aggs_json_arr.End();
2366  ++aggs_json_arr_it) {
2367  aggs.emplace_back(parse_aggregate_expr(*aggs_json_arr_it));
2368  }
2369  if (agg_ra.HasMember("hints")) {
2370  auto agg_node =
2371  std::make_shared<RelAggregate>(group.size(), aggs, fields, inputs.front());
2372  getRelAlgHints(agg_ra, agg_node);
2373  return agg_node;
2374  }
2375  return std::make_shared<RelAggregate>(group.size(), aggs, fields, inputs.front());
2376  }
2377 
2378  std::shared_ptr<RelJoin> dispatchJoin(const rapidjson::Value& join_ra,
2379  RelAlgDagBuilder& root_dag_builder) {
2380  const auto inputs = getRelAlgInputs(join_ra);
2381  CHECK_EQ(size_t(2), inputs.size());
2382  const auto join_type = to_join_type(json_str(field(join_ra, "joinType")));
2383  auto filter_rex =
2384  parse_scalar_expr(field(join_ra, "condition"), cat_, root_dag_builder);
2385  if (join_ra.HasMember("hints")) {
2386  auto join_node =
2387  std::make_shared<RelJoin>(inputs[0], inputs[1], filter_rex, join_type);
2388  getRelAlgHints(join_ra, join_node);
2389  return join_node;
2390  }
2391  return std::make_shared<RelJoin>(inputs[0], inputs[1], filter_rex, join_type);
2392  }
2393 
2394  std::shared_ptr<RelSort> dispatchSort(const rapidjson::Value& sort_ra) {
2395  const auto inputs = getRelAlgInputs(sort_ra);
2396  CHECK_EQ(size_t(1), inputs.size());
2397  std::vector<SortField> collation;
2398  const auto& collation_arr = field(sort_ra, "collation");
2399  CHECK(collation_arr.IsArray());
2400  for (auto collation_arr_it = collation_arr.Begin();
2401  collation_arr_it != collation_arr.End();
2402  ++collation_arr_it) {
2403  const size_t field_idx = json_i64(field(*collation_arr_it, "field"));
2404  const auto sort_dir = parse_sort_direction(*collation_arr_it);
2405  const auto null_pos = parse_nulls_position(*collation_arr_it);
2406  collation.emplace_back(field_idx, sort_dir, null_pos);
2407  }
2408  auto limit = get_int_literal_field(sort_ra, "fetch", -1);
2409  const auto offset = get_int_literal_field(sort_ra, "offset", 0);
2410  auto ret = std::make_shared<RelSort>(
2411  collation, limit > 0 ? limit : 0, offset, inputs.front());
2412  ret->setEmptyResult(limit == 0);
2413  return ret;
2414  }
2415 
2416  std::shared_ptr<RelModify> dispatchModify(const rapidjson::Value& logical_modify_ra) {
2417  const auto inputs = getRelAlgInputs(logical_modify_ra);
2418  CHECK_EQ(size_t(1), inputs.size());
2419 
2420  const auto table_descriptor = getTableFromScanNode(cat_, logical_modify_ra);
2421  if (table_descriptor->isView) {
2422  throw std::runtime_error("UPDATE of a view is unsupported.");
2423  }
2424 
2425  bool flattened = json_bool(field(logical_modify_ra, "flattened"));
2426  std::string op = json_str(field(logical_modify_ra, "operation"));
2427  RelModify::TargetColumnList target_column_list;
2428 
2429  if (op == "UPDATE") {
2430  const auto& update_columns = field(logical_modify_ra, "updateColumnList");
2431  CHECK(update_columns.IsArray());
2432 
2433  for (auto column_arr_it = update_columns.Begin();
2434  column_arr_it != update_columns.End();
2435  ++column_arr_it) {
2436  target_column_list.push_back(column_arr_it->GetString());
2437  }
2438  }
2439 
2440  auto modify_node = std::make_shared<RelModify>(
2441  cat_, table_descriptor, flattened, op, target_column_list, inputs[0]);
2442  switch (modify_node->getOperation()) {
2444  modify_node->applyDeleteModificationsToInputNode();
2445  break;
2446  }
2448  modify_node->applyUpdateModificationsToInputNode();
2449  break;
2450  }
2451  default:
2452  throw std::runtime_error("Unsupported RelModify operation: " +
2453  json_node_to_string(logical_modify_ra));
2454  }
2455 
2456  return modify_node;
2457  }
2458 
2459  std::shared_ptr<RelTableFunction> dispatchTableFunction(
2460  const rapidjson::Value& table_func_ra,
2461  RelAlgDagBuilder& root_dag_builder) {
2462  const auto inputs = getRelAlgInputs(table_func_ra);
2463  const auto& invocation = field(table_func_ra, "invocation");
2464  CHECK(invocation.IsObject());
2465 
2466  const auto& operands = field(invocation, "operands");
2467  CHECK(operands.IsArray());
2468  CHECK_GE(operands.Size(), unsigned(0));
2469 
2470  std::vector<const Rex*> col_inputs;
2471  std::vector<std::unique_ptr<const RexScalar>> table_func_inputs;
2472  std::vector<std::string> fields;
2473 
2474  for (auto exprs_json_it = operands.Begin(); exprs_json_it != operands.End();
2475  ++exprs_json_it) {
2476  const auto& expr_json = *exprs_json_it;
2477  CHECK(expr_json.IsObject());
2478  if (expr_json.HasMember("op")) {
2479  const auto op_str = json_str(field(expr_json, "op"));
2480  if (op_str == "CAST" && expr_json.HasMember("type")) {
2481  const auto& expr_type = field(expr_json, "type");
2482  CHECK(expr_type.IsObject());
2483  CHECK(expr_type.HasMember("type"));
2484  const auto& expr_type_name = json_str(field(expr_type, "type"));
2485  if (expr_type_name == "CURSOR") {
2486  CHECK(expr_json.HasMember("operands"));
2487  const auto& expr_operands = field(expr_json, "operands");
2488  CHECK(expr_operands.IsArray());
2489  if (expr_operands.Size() != 1) {
2490  throw std::runtime_error(
2491  "Table functions currently only support one ResultSet input");
2492  }
2493  auto pos = field(expr_operands[0], "input").GetInt();
2494  CHECK_LT(pos, inputs.size());
2495  for (size_t i = inputs[pos]->size(); i > 0; i--) {
2496  table_func_inputs.emplace_back(
2497  std::make_unique<RexAbstractInput>(col_inputs.size()));
2498  col_inputs.emplace_back(table_func_inputs.back().get());
2499  }
2500  continue;
2501  }
2502  }
2503  }
2504  table_func_inputs.emplace_back(
2505  parse_scalar_expr(*exprs_json_it, cat_, root_dag_builder));
2506  }
2507 
2508  const auto& op_name = field(invocation, "op");
2509  CHECK(op_name.IsString());
2510 
2511  std::vector<std::unique_ptr<const RexScalar>> table_function_projected_outputs;
2512  const auto& row_types = field(table_func_ra, "rowType");
2513  CHECK(row_types.IsArray());
2514  CHECK_GE(row_types.Size(), unsigned(0));
2515  const auto& row_types_array = row_types.GetArray();
2516  for (size_t i = 0; i < row_types_array.Size(); i++) {
2517  // We don't care about the type information in rowType -- replace each output with
2518  // a reference to be resolved later in the translator
2519  table_function_projected_outputs.emplace_back(std::make_unique<RexRef>(i));
2520  fields.emplace_back("");
2521  }
2522  return std::make_shared<RelTableFunction>(op_name.GetString(),
2523  inputs,
2524  fields,
2525  col_inputs,
2526  table_func_inputs,
2527  table_function_projected_outputs);
2528  }
2529 
2530  std::shared_ptr<RelLogicalValues> dispatchLogicalValues(
2531  const rapidjson::Value& logical_values_ra) {
2532  const auto& tuple_type_arr = field(logical_values_ra, "type");
2533  CHECK(tuple_type_arr.IsArray());
2534  std::vector<TargetMetaInfo> tuple_type;
2535  for (auto tuple_type_arr_it = tuple_type_arr.Begin();
2536  tuple_type_arr_it != tuple_type_arr.End();
2537  ++tuple_type_arr_it) {
2538  const auto component_type = parse_type(*tuple_type_arr_it);
2539  const auto component_name = json_str(field(*tuple_type_arr_it, "name"));
2540  tuple_type.emplace_back(component_name, component_type);
2541  }
2542  const auto& inputs_arr = field(logical_values_ra, "inputs");
2543  CHECK(inputs_arr.IsArray());
2544  const auto& tuples_arr = field(logical_values_ra, "tuples");
2545  CHECK(tuples_arr.IsArray());
2546 
2547  if (inputs_arr.Size()) {
2548  throw QueryNotSupported("Inputs not supported in logical values yet.");
2549  }
2550 
2551  std::vector<RelLogicalValues::RowValues> values;
2552  if (tuples_arr.Size()) {
2553  for (const auto& row : tuples_arr.GetArray()) {
2554  CHECK(row.IsArray());
2555  const auto values_json = row.GetArray();
2556  if (!values.empty()) {
2557  CHECK_EQ(values[0].size(), values_json.Size());
2558  }
2559  values.emplace_back(RelLogicalValues::RowValues{});
2560  for (const auto& value : values_json) {
2561  CHECK(value.IsObject());
2562  CHECK(value.HasMember("literal"));
2563  values.back().emplace_back(parse_literal(value));
2564  }
2565  }
2566  }
2567 
2568  return std::make_shared<RelLogicalValues>(tuple_type, values);
2569  }
2570 
2571  std::shared_ptr<RelLogicalUnion> dispatchUnion(
2572  const rapidjson::Value& logical_union_ra) {
2573  auto inputs = getRelAlgInputs(logical_union_ra);
2574  auto const& all_type_bool = field(logical_union_ra, "all");
2575  CHECK(all_type_bool.IsBool());
2576  return std::make_shared<RelLogicalUnion>(std::move(inputs), all_type_bool.GetBool());
2577  }
2578 
2579  RelAlgInputs getRelAlgInputs(const rapidjson::Value& node) {
2580  if (node.HasMember("inputs")) {
2581  const auto str_input_ids = strings_from_json_array(field(node, "inputs"));
2582  RelAlgInputs ra_inputs;
2583  for (const auto& str_id : str_input_ids) {
2584  ra_inputs.push_back(nodes_[std::stoi(str_id)]);
2585  }
2586  return ra_inputs;
2587  }
2588  return {prev(node)};
2589  }
2590 
2591  std::pair<std::string, std::string> getKVOptionPair(std::string& str, size_t& pos) {
2592  auto option = str.substr(0, pos);
2593  std::string delim = "=";
2594  size_t delim_pos = option.find(delim);
2595  auto key = option.substr(0, delim_pos);
2596  auto val = option.substr(delim_pos + 1, option.length());
2597  str.erase(0, pos + delim.length() + 1);
2598  return {key, val};
2599  }
2600 
2601  ExplainedQueryHint parseHintString(std::string& hint_string) {
2602  std::string white_space_delim = " ";
2603  int l = hint_string.length();
2604  hint_string = hint_string.erase(0, 1).substr(0, l - 2);
2605  size_t pos = 0;
2606  auto global_hint_checker = [&](const std::string& input_hint_name) -> HintIdentifier {
2607  bool global_hint = false;
2608  std::string hint_name = input_hint_name;
2609  auto global_hint_identifier = hint_name.substr(0, 2);
2610  if (global_hint_identifier.compare("g_") == 0) {
2611  global_hint = true;
2612  hint_name = hint_name.substr(2, hint_string.length());
2613  }
2614  return {global_hint, hint_name};
2615  };
2616  auto parsed_hint =
2617  global_hint_checker(hint_string.substr(0, hint_string.find(white_space_delim)));
2618  auto hint_type = RegisteredQueryHint::translateQueryHint(parsed_hint.hint_name);
2619  if ((pos = hint_string.find("options:")) != std::string::npos) {
2620  // need to parse hint options
2621  std::vector<std::string> tokens;
2622  bool kv_list_op = false;
2623  std::string raw_options = hint_string.substr(pos + 8, hint_string.length() - 2);
2624  if (raw_options.find('{') != std::string::npos) {
2625  kv_list_op = true;
2626  } else {
2627  CHECK(raw_options.find('[') != std::string::npos);
2628  }
2629  auto t1 = raw_options.erase(0, 1);
2630  raw_options = t1.substr(0, t1.length() - 1);
2631  std::string op_delim = ", ";
2632  if (kv_list_op) {
2633  // kv options
2634  std::unordered_map<std::string, std::string> kv_options;
2635  while ((pos = raw_options.find(op_delim)) != std::string::npos) {
2636  auto kv_pair = getKVOptionPair(raw_options, pos);
2637  kv_options.emplace(kv_pair.first, kv_pair.second);
2638  }
2639  // handle the last kv pair
2640  auto kv_pair = getKVOptionPair(raw_options, pos);
2641  kv_options.emplace(kv_pair.first, kv_pair.second);
2642  return {hint_type, parsed_hint.global_hint, false, true, kv_options};
2643  } else {
2644  std::vector<std::string> list_options;
2645  while ((pos = raw_options.find(op_delim)) != std::string::npos) {
2646  list_options.emplace_back(raw_options.substr(0, pos));
2647  raw_options.erase(0, pos + white_space_delim.length() + 1);
2648  }
2649  // handle the last option
2650  list_options.emplace_back(raw_options.substr(0, pos));
2651  return {hint_type, parsed_hint.global_hint, false, false, list_options};
2652  }
2653  } else {
2654  // marker hint: no extra option for this hint
2655  return {hint_type, parsed_hint.global_hint, true, false};
2656  }
2657  }
2658 
2659  void getRelAlgHints(const rapidjson::Value& json_node,
2660  std::shared_ptr<RelAlgNode> node) {
2661  std::string hint_explained = json_str(field(json_node, "hints"));
2662  size_t pos = 0;
2663  std::string delim = "|";
2664  std::vector<std::string> hint_list;
2665  while ((pos = hint_explained.find(delim)) != std::string::npos) {
2666  hint_list.emplace_back(hint_explained.substr(0, pos));
2667  hint_explained.erase(0, pos + delim.length());
2668  }
2669  // handling the last one
2670  hint_list.emplace_back(hint_explained.substr(0, pos));
2671 
2672  const auto agg_node = std::dynamic_pointer_cast<RelAggregate>(node);
2673  if (agg_node) {
2674  for (std::string& hint : hint_list) {
2675  auto parsed_hint = parseHintString(hint);
2676  agg_node->addHint(parsed_hint);
2677  }
2678  }
2679  const auto project_node = std::dynamic_pointer_cast<RelProject>(node);
2680  if (project_node) {
2681  for (std::string& hint : hint_list) {
2682  auto parsed_hint = parseHintString(hint);
2683  project_node->addHint(parsed_hint);
2684  }
2685  }
2686  const auto scan_node = std::dynamic_pointer_cast<RelScan>(node);
2687  if (scan_node) {
2688  for (std::string& hint : hint_list) {
2689  auto parsed_hint = parseHintString(hint);
2690  scan_node->addHint(parsed_hint);
2691  }
2692  }
2693  const auto join_node = std::dynamic_pointer_cast<RelJoin>(node);
2694  if (join_node) {
2695  for (std::string& hint : hint_list) {
2696  auto parsed_hint = parseHintString(hint);
2697  join_node->addHint(parsed_hint);
2698  }
2699  }
2700 
2701  const auto compound_node = std::dynamic_pointer_cast<RelCompound>(node);
2702  if (compound_node) {
2703  for (std::string& hint : hint_list) {
2704  auto parsed_hint = parseHintString(hint);
2705  compound_node->addHint(parsed_hint);
2706  }
2707  }
2708  }
2709 
2710  std::shared_ptr<const RelAlgNode> prev(const rapidjson::Value& crt_node) {
2711  const auto id = node_id(crt_node);
2712  CHECK(id);
2713  CHECK_EQ(static_cast<size_t>(id), nodes_.size());
2714  return nodes_.back();
2715  }
2716 
2718  std::vector<std::shared_ptr<RelAlgNode>> nodes_;
2719 };
2720 
2721 } // namespace details
2722 
2723 RelAlgDagBuilder::RelAlgDagBuilder(const std::string& query_ra,
2725  const RenderInfo* render_info)
2726  : cat_(cat), render_info_(render_info) {
2727  rapidjson::Document query_ast;
2728  query_ast.Parse(query_ra.c_str());
2729  VLOG(2) << "Parsing query RA JSON: " << query_ra;
2730  if (query_ast.HasParseError()) {
2731  query_ast.GetParseError();
2732  LOG(ERROR) << "Failed to parse RA tree from Calcite (offset "
2733  << query_ast.GetErrorOffset() << "):\n"
2734  << rapidjson::GetParseError_En(query_ast.GetParseError());
2735  VLOG(1) << "Failed to parse query RA: " << query_ra;
2736  throw std::runtime_error(
2737  "Failed to parse relational algebra tree. Possible query syntax error.");
2738  }
2739  CHECK(query_ast.IsObject());
2741  build(query_ast, *this);
2742 }
2743 
2745  const rapidjson::Value& query_ast,
2747  const RenderInfo* render_info)
2748  : cat_(cat), render_info_(render_info) {
2749  build(query_ast, root_dag_builder);
2750 }
2751 
2752 void RelAlgDagBuilder::build(const rapidjson::Value& query_ast,
2753  RelAlgDagBuilder& lead_dag_builder) {
2754  const auto& rels = field(query_ast, "rels");
2755  CHECK(rels.IsArray());
2756  try {
2757  nodes_ = details::RelAlgDispatcher(cat_).run(rels, lead_dag_builder);
2758  } catch (const QueryNotSupported&) {
2759  throw;
2760  }
2761  CHECK(!nodes_.empty());
2763 
2764  if (render_info_) {
2765  // Alter the RA for render. Do this before any flattening/optimizations are done to
2766  // the tree.
2768  }
2769 
2770  handleQueryHint(nodes_, this);
2771  mark_nops(nodes_);
2776  std::vector<const RelAlgNode*> filtered_left_deep_joins;
2777  std::vector<const RelAlgNode*> left_deep_joins;
2778  for (const auto& node : nodes_) {
2779  const auto left_deep_join_root = get_left_deep_join_root(node);
2780  // The filter which starts a left-deep join pattern must not be coalesced
2781  // since it contains (part of) the join condition.
2782  if (left_deep_join_root) {
2783  left_deep_joins.push_back(left_deep_join_root.get());
2784  if (std::dynamic_pointer_cast<const RelFilter>(left_deep_join_root)) {
2785  filtered_left_deep_joins.push_back(left_deep_join_root.get());
2786  }
2787  }
2788  }
2789  if (filtered_left_deep_joins.empty()) {
2791  }
2792  eliminate_dead_columns(nodes_);
2793  eliminate_dead_subqueries(subqueries_, nodes_.back().get());
2796  nodes_,
2797  g_cluster /* always_add_project_if_first_project_is_window_expr */,
2798  query_hint_);
2799  coalesce_nodes(nodes_, left_deep_joins, query_hint_);
2800  CHECK(nodes_.back().use_count() == 1);
2801  create_left_deep_join(nodes_);
2802 }
2803 
2805  std::function<void(RelAlgNode const*)> const& callback) const {
2806  for (auto const& node : nodes_) {
2807  if (node) {
2808  callback(node.get());
2809  }
2810  }
2811 }
2812 
2814  for (auto& node : nodes_) {
2815  if (node) {
2816  node->resetQueryExecutionState();
2817  }
2818  }
2819 }
2820 
2821 // Return tree with depth represented by indentations.
2822 std::string tree_string(const RelAlgNode* ra, const size_t depth) {
2823  std::string result = std::string(2 * depth, ' ') + ::toString(ra) + '\n';
2824  for (size_t i = 0; i < ra->inputCount(); ++i) {
2825  result += tree_string(ra->getInput(i), depth + 1);
2826  }
2827  return result;
2828 }
2829 
2830 std::string RexSubQuery::toString() const {
2831  return cat(::typeName(this), "(", ::toString(ra_.get()), ")");
2832 }
2833 
2834 size_t RexSubQuery::toHash() const {
2835  if (!hash_) {
2836  hash_ = typeid(RexSubQuery).hash_code();
2837  boost::hash_combine(*hash_, ra_->toHash());
2838  }
2839  return *hash_;
2840 }
2841 
2842 std::string RexInput::toString() const {
2843  const auto scan_node = dynamic_cast<const RelScan*>(node_);
2844  if (scan_node) {
2845  auto field_name = scan_node->getFieldName(getIndex());
2846  auto table_name = scan_node->getTableDescriptor()->tableName;
2847  return ::typeName(this) + "(" + table_name + "." + field_name + ")";
2848  }
2849  return cat(::typeName(this),
2850  "(node=",
2851  ::toString(node_),
2852  ", in_index=",
2854  ")");
2855 }
2856 
2857 size_t RexInput::toHash() const {
2858  if (!hash_) {
2859  hash_ = typeid(RexInput).hash_code();
2860  boost::hash_combine(*hash_, node_->toHash());
2861  boost::hash_combine(*hash_, getIndex());
2862  }
2863  return *hash_;
2864 }
2865 
2866 std::string RelCompound::toString() const {
2867  return cat(::typeName(this),
2868  "(",
2869  (filter_expr_ ? filter_expr_->toString() : "null"),
2870  ", target_exprs=",
2872  ", ",
2874  ", agg_exps=",
2876  ", fields=",
2878  ", scalar_sources=",
2880  ", is_agg=",
2882  ")");
2883 }
2884 
2885 size_t RelCompound::toHash() const {
2886  if (!hash_) {
2887  hash_ = typeid(RelCompound).hash_code();
2888  boost::hash_combine(*hash_,
2889  filter_expr_ ? filter_expr_->toHash() : boost::hash_value("n"));
2890  boost::hash_combine(*hash_, is_agg_);
2891  for (auto& target_expr : target_exprs_) {
2892  if (auto rex_scalar = dynamic_cast<const RexScalar*>(target_expr)) {
2893  boost::hash_combine(*hash_, rex_scalar->toHash());
2894  }
2895  }
2896  for (auto& agg_expr : agg_exprs_) {
2897  boost::hash_combine(*hash_, agg_expr->toHash());
2898  }
2899  for (auto& scalar_source : scalar_sources_) {
2900  boost::hash_combine(*hash_, scalar_source->toHash());
2901  }
2902  boost::hash_combine(*hash_, groupby_count_);
2903  boost::hash_combine(*hash_, ::toString(fields_));
2904  }
2905  return *hash_;
2906 }
std::vector< std::shared_ptr< const RexScalar > > scalar_exprs_
DEVICE auto upper_bound(ARGS &&...args)
Definition: gpu_enabled.h:123
SQLTypes to_sql_type(const std::string &type_name)
void handleQueryHint(const std::vector< std::shared_ptr< RelAlgNode >> &nodes, RelAlgDagBuilder *dag_builder) noexcept
std::shared_ptr< const RelAlgNode > getRootNodeShPtr() const
void coalesce_nodes(std::vector< std::shared_ptr< RelAlgNode >> &nodes, const std::vector< const RelAlgNode * > &left_deep_joins, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints)
bool is_agg(const Analyzer::Expr *expr)
std::unique_ptr< const RexScalar > condition_
SQLTypeInfo parse_type(const rapidjson::Value &type_obj)
const RexScalar * getThen(const size_t idx) const
std::vector< std::string > strings_from_json_array(const rapidjson::Value &json_str_arr) noexcept
std::unique_ptr< RexCase > parse_case(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
std::shared_ptr< RelAggregate > dispatchAggregate(const rapidjson::Value &agg_ra)
#define CHECK_EQ(x, y)
Definition: Logger.h:219
JoinType to_join_type(const std::string &join_type_name)
const Catalog_Namespace::Catalog & cat_
std::shared_ptr< RelProject > dispatchProject(const rapidjson::Value &proj_ra, RelAlgDagBuilder &root_dag_builder)
std::unique_ptr< RexSubQuery > deepCopy() const
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
JoinType
Definition: sqldefs.h:108
std::vector< const Rex * > remapTargetPointers(std::vector< std::unique_ptr< const RexAgg >> const &agg_exprs_new, std::vector< std::unique_ptr< const RexScalar >> const &scalar_sources_new, std::vector< std::unique_ptr< const RexAgg >> const &agg_exprs_old, std::vector< std::unique_ptr< const RexScalar >> const &scalar_sources_old, std::vector< const Rex * > const &target_exprs_old)
std::vector< std::unique_ptr< const RexScalar > > table_func_inputs_
std::string cat(Ts &&...args)
std::string toString(const ExtArgumentType &sig_type)
void bind_inputs(const std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
void hoist_filter_cond_to_cross_join(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
class for a per-database catalog. also includes metadata for the current database and the current use...
Definition: Catalog.h:114
Definition: sqltypes.h:49
void addHint(const ExplainedQueryHint &hint_explained)
std::shared_ptr< const RelAlgNode > get_left_deep_join_root(const std::shared_ptr< RelAlgNode > &node)
void sink_projected_boolean_expr_to_join(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
std::unique_ptr< const RexAgg > parse_aggregate_expr(const rapidjson::Value &expr)
void eliminate_identical_copy(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
size_t toHash() const override
const RexScalar * getElse() const
RelCompound(std::unique_ptr< const RexScalar > &filter_expr, const std::vector< const Rex * > &target_exprs, const size_t groupby_count, const std::vector< const RexAgg * > &agg_exprs, const std::vector< std::string > &fields, std::vector< std::unique_ptr< const RexScalar >> &scalar_sources, const bool is_agg, bool update_disguised_as_select=false, bool delete_disguised_as_select=false, bool varlen_update_required=false, TableDescriptor const *manipulation_target_table=nullptr, ColumnNameList target_columns=ColumnNameList())
static thread_local unsigned crt_id_
std::shared_ptr< RelScan > dispatchTableScan(const rapidjson::Value &scan_ra)
std::pair< std::shared_ptr< RelLeftDeepInnerJoin >, std::shared_ptr< const RelAlgNode > > create_left_deep_join(const std::shared_ptr< RelAlgNode > &left_deep_join_root)
RexScalar const * copyAndRedirectSource(RexScalar const *, size_t input_idx) const
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
#define const
SQLAgg to_agg_kind(const std::string &agg_name)
std::shared_ptr< RelLogicalUnion > dispatchUnion(const rapidjson::Value &logical_union_ra)
#define LOG(tag)
Definition: Logger.h:205
NullSortedPosition
std::vector< std::string > TargetColumnList
size_t size() const
Hints * getDeliveredHints()
const bool json_bool(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:49
const RexScalar * getOperand(const size_t idx) const
std::vector< const Rex * > col_inputs_
bool hasEquivCollationOf(const RelSort &that) const
const std::vector< std::string > fields_
std::string toString() const override
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:44
void build(const rapidjson::Value &query_ast, RelAlgDagBuilder &root_dag_builder)
string name
Definition: setup.in.py:72
RexWindowFunctionOperator::RexWindowBound parse_window_bound(const rapidjson::Value &window_bound_obj, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
std::string join(T const &container, std::string const &delim)
void bind_table_func_to_input(RelTableFunction *table_func_node, const RANodeOutput &input) noexcept
std::unique_ptr< RexOperator > parse_operator(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
#define UNREACHABLE()
Definition: Logger.h:255
bool hint_applied_
DEVICE void sort(ARGS &&...args)
Definition: gpu_enabled.h:105
bool isRenamedInput(const RelAlgNode *node, const size_t index, const std::string &new_name)
#define CHECK_GE(x, y)
Definition: Logger.h:224
std::vector< std::string > fields_
RexInput(const RelAlgNode *node, const unsigned in_index)
void addHint(const ExplainedQueryHint &hint_explained)
Definition: sqldefs.h:49
const RexScalar * getWhen(const size_t idx) const
void appendInput(std::string new_field_name, std::unique_ptr< const RexScalar > new_input)
RetType visitOperator(const RexOperator *rex_operator) const final
void addHint(const ExplainedQueryHint &hint_explained)
void bind_project_to_input(RelProject *project_node, const RANodeOutput &input) noexcept
const Catalog_Namespace::Catalog & cat_
std::vector< std::unique_ptr< const RexScalar > > parse_window_order_exprs(const rapidjson::Value &arr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
void create_compound(std::vector< std::shared_ptr< RelAlgNode >> &nodes, const std::vector< size_t > &pattern, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints) noexcept
void checkForMatchingMetaInfoTypes() const
std::vector< std::unique_ptr< const RexScalar > > scalar_sources_
std::string to_string(char const *&&v)
SqlWindowFunctionKind parse_window_function_kind(const std::string &name)
const std::string getFieldName(const size_t i) const
std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint > > query_hint_
void simplify_sort(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
std::vector< SortField > collation_
RexRebindInputsVisitor(const RelAlgNode *old_input, const RelAlgNode *new_input)
int64_t get_int_literal_field(const rapidjson::Value &obj, const char field[], const int64_t default_val) noexcept
constexpr double a
Definition: Utm.h:33
std::vector< std::unique_ptr< const RexScalar > > parse_expr_array(const rapidjson::Value &arr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
This file contains the class specification and related data structures for Catalog.
virtual T visit(const RexScalar *rex_scalar) const
Definition: RexVisitor.h:27
RexInputSet visitInput(const RexInput *input) const override
std::vector< std::shared_ptr< RexSubQuery > > subqueries_
RexWindowFuncReplacementVisitor(std::unique_ptr< const RexScalar > replacement_rex)
const RenderInfo * render_info_
std::string to_string() const
Definition: sqltypes.h:482
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
const RexScalar * visitOperator(const RexOperator *rex_operator) const final
unsigned getIndex() const
SQLOps getOperator() const
std::vector< std::shared_ptr< RelAlgNode > > nodes_
const TableDescriptor * getTableFromScanNode(const Catalog_Namespace::Catalog &cat, const rapidjson::Value &scan_ra)
SortDirection parse_sort_direction(const rapidjson::Value &collation)
std::unique_ptr< const RexScalar > disambiguate_rex(const RexScalar *, const RANodeOutput &)
static QueryHint translateQueryHint(const std::string &hint_name)
Definition: QueryHint.h:225
DEVICE auto copy(ARGS &&...args)
Definition: gpu_enabled.h:51
bool isRenaming() const
std::vector< SortField > parse_window_order_collation(const rapidjson::Value &arr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
void setIndex(const unsigned in_index) const
Hints * getDeliveredHints()
size_t toHash() const override
SQLOps to_sql_op(const std::string &op_str)
std::unique_ptr< Hints > hints_
void set_scale(int s)
Definition: sqltypes.h:434
const int64_t json_i64(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:39
void * visitInput(const RexInput *rex_input) const override
std::unique_ptr< Hints > hints_
std::vector< std::unique_ptr< const RexScalar > > scalar_exprs_
const double json_double(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:54
void addHint(const ExplainedQueryHint &hint_explained)
size_t branchCount() const
std::unordered_set< RexInput > RexInputSet
void propagate_hints_to_new_project(std::shared_ptr< RelProject > prev_node, std::shared_ptr< RelProject > new_node, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints)
const RelAlgNode * getInput(const size_t idx) const
RelFilter(std::unique_ptr< const RexScalar > &filter, std::shared_ptr< const RelAlgNode > input)
std::vector< std::shared_ptr< RelAlgNode > > nodes_
std::string toString() const override
RelAggregate(const size_t groupby_count, std::vector< std::unique_ptr< const RexAgg >> &agg_exprs, const std::vector< std::string > &fields, std::shared_ptr< const RelAlgNode > input)
std::unique_ptr< const RexScalar > filter_
bool isSimple() const
RexInputReplacementVisitor(const RelAlgNode *node_to_keep, const std::vector< std::unique_ptr< const RexScalar >> &scalar_sources)
std::optional< size_t > hash_
const size_t groupby_count_
RexRebindReindexInputsVisitor(const RelAlgNode *old_input, const RelAlgNode *new_input, std::unordered_map< unsigned, unsigned > old_to_new_index_map)
std::optional< size_t > hash_
unsigned getId() const
const RelAlgNode * node_
virtual void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input)
void eachNode(std::function< void(RelAlgNode const *)> const &) const
NullSortedPosition parse_nulls_position(const rapidjson::Value &collation)
std::shared_ptr< RelFilter > dispatchFilter(const rapidjson::Value &filter_ra, RelAlgDagBuilder &root_dag_builder)
std::unique_ptr< RexAbstractInput > parse_abstract_input(const rapidjson::Value &expr) noexcept
std::vector< std::unique_ptr< const RexAgg > > agg_exprs_
std::set< std::pair< const RelAlgNode *, int > > get_equiv_cols(const RelAlgNode *node, const size_t which_col)
Hints * getDeliveredHints()
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
std::unique_ptr< RexLiteral > parse_literal(const rapidjson::Value &expr)
SortDirection
const RexScalar * getProjectAt(const size_t idx) const
std::vector< std::unique_ptr< const RexAgg > > copyAggExprs(std::vector< std::unique_ptr< const RexAgg >> const &agg_exprs)
std::vector< std::shared_ptr< const RelAlgNode >> RelAlgInputs
#define CHECK_LT(x, y)
Definition: Logger.h:221
Definition: sqltypes.h:52
Definition: sqltypes.h:53
static RegisteredQueryHint defaults()
Definition: QueryHint.h:222
int32_t countRexLiteralArgs() const
std::unique_ptr< Hints > hints_
std::unique_ptr< const RexOperator > disambiguate_operator(const RexOperator *rex_operator, const RANodeOutput &ra_output) noexcept
const ConstRexScalarPtrVector & getPartitionKeys() const
std::string tree_string(const RelAlgNode *ra, const size_t depth)
DEVICE auto lower_bound(ARGS &&...args)
Definition: gpu_enabled.h:78
#define CHECK_LE(x, y)
Definition: Logger.h:222
std::unique_ptr< Hints > hints_
const std::vector< const Rex * > target_exprs_
std::vector< ElementType >::const_iterator Super
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
RelLogicalUnion(RelAlgInputs, bool is_all)
std::vector< std::unique_ptr< const RexAgg > > agg_exprs_
std::unique_ptr< const RexCase > disambiguate_case(const RexCase *rex_case, const RANodeOutput &ra_output)
std::shared_ptr< RelJoin > dispatchJoin(const rapidjson::Value &join_ra, RelAlgDagBuilder &root_dag_builder)
std::unique_ptr< const RexScalar > filter_expr_
void setSourceNode(const RelAlgNode *node) const
const RexScalar * aggregateResult(const RexScalar *const &aggregate, const RexScalar *const &next_result) const final
bool hasWindowFunctionExpr() const
std::shared_ptr< RelModify > dispatchModify(const rapidjson::Value &logical_modify_ra)
std::unordered_map< QueryHint, ExplainedQueryHint > Hints
Definition: QueryHint.h:248
virtual size_t size() const =0
const RelAlgNode * getSourceNode() const
void registerSubquery(std::shared_ptr< RexSubQuery > subquery)
void setExecutionResult(const std::shared_ptr< const ExecutionResult > result)
RelLogicalValues(const std::vector< TargetMetaInfo > &tuple_type, std::vector< RowValues > &values)
size_t toHash() const override
std::string typeName(const T *v)
Definition: toString.h:93
ExplainedQueryHint parseHintString(std::string &hint_string)
RelAlgDagBuilder()=delete
SqlWindowFunctionKind
Definition: sqldefs.h:83
HOST DEVICE int get_comp_param() const
Definition: sqltypes.h:338
void mark_nops(const std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
RexInputSet aggregateResult(const RexInputSet &aggregate, const RexInputSet &next_result) const override
Definition: sqldefs.h:53
std::shared_ptr< RelSort > dispatchSort(const rapidjson::Value &sort_ra)
void separate_window_function_expressions(std::vector< std::shared_ptr< RelAlgNode >> &nodes, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints)
RelTableFunction(const std::string &function_name, RelAlgInputs inputs, std::vector< std::string > &fields, std::vector< const Rex * > col_inputs, std::vector< std::unique_ptr< const RexScalar >> &table_func_inputs, std::vector< std::unique_ptr< const RexScalar >> &target_exprs)
const std::vector< std::string > & getFields() const
std::string getFieldName(const size_t i) const
std::vector< size_t > indices_from_json_array(const rapidjson::Value &json_idx_arr) noexcept
bool g_enable_watchdog false
Definition: Execute.cpp:76
#define CHECK(condition)
Definition: Logger.h:211
std::unique_ptr< const RexScalar > parse_scalar_expr(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
RelProject(std::vector< std::unique_ptr< const RexScalar >> &scalar_exprs, const std::vector< std::string > &fields, std::shared_ptr< const RelAlgNode > input)
unsigned node_id(const rapidjson::Value &ra_node) noexcept
RANodeOutput get_node_output(const RelAlgNode *ra_node)
bool g_enable_union
std::vector< std::string > getFieldNamesFromScanNode(const rapidjson::Value &scan_ra)
bool g_cluster
void alterRAForRender(std::vector< std::shared_ptr< RelAlgNode >> &nodes, const RenderInfo &render_info)
std::shared_ptr< const RelAlgNode > prev(const rapidjson::Value &crt_node)
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
std::vector< std::shared_ptr< RelAlgNode > > run(const rapidjson::Value &rels, RelAlgDagBuilder &root_dag_builder)
void getRelAlgHints(const rapidjson::Value &json_node, std::shared_ptr< RelAlgNode > node)
virtual size_t toHash() const =0
RelAlgDispatcher(const Catalog_Namespace::Catalog &cat)
void add_window_function_pre_project(std::vector< std::shared_ptr< RelAlgNode >> &nodes, const bool always_add_project_if_first_project_is_window_expr, std::unordered_map< size_t, std::unordered_map< unsigned, RegisteredQueryHint >> &query_hints)
Common Enum definitions for SQL processing.
bool is_dict_encoded_string() const
Definition: sqltypes.h:557
Definition: sqltypes.h:45
void fold_filters(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
const TableDescriptor * getMetadataForTable(const std::string &tableName, const bool populateFragmenter=true) const
Returns a pointer to a const TableDescriptor struct matching the provided tableName.
std::vector< RexInput > RANodeOutput
std::vector< std::unique_ptr< const RexScalar >> RowValues
const size_t inputCount() const
constexpr double n
Definition: Utm.h:39
void rebind_inputs_from_left_deep_join(const RexScalar *rex, const RelLeftDeepInnerJoin *left_deep_join)
std::unique_ptr< const RexSubQuery > parse_subquery(const rapidjson::Value &expr, const Catalog_Namespace::Catalog &cat, RelAlgDagBuilder &root_dag_builder)
void eliminate_dead_subqueries(std::vector< std::shared_ptr< RexSubQuery >> &subqueries, RelAlgNode const *root)
size_t size() const override
size_t operator()(const std::pair< const RelAlgNode *, int > &input_col) const
bool input_can_be_coalesced(const RelAlgNode *parent_node, const size_t index, const bool first_rex_is_input)
RelAlgInputs getRelAlgInputs(const rapidjson::Value &node)
std::shared_ptr< RelLogicalValues > dispatchLogicalValues(const rapidjson::Value &logical_values_ra)
std::shared_ptr< RelTableFunction > dispatchTableFunction(const rapidjson::Value &table_func_ra, RelAlgDagBuilder &root_dag_builder)
DEVICE void swap(ARGS &&...args)
Definition: gpu_enabled.h:114
std::unique_ptr< const RexScalar > RetType
Definition: RexVisitor.h:139
std::vector< std::unique_ptr< const RexScalar > > copyRexScalars(std::vector< std::unique_ptr< const RexScalar >> const &scalar_sources)
size_t toHash() const override
#define VLOG(n)
Definition: Logger.h:305
RelJoin(std::shared_ptr< const RelAlgNode > lhs, std::shared_ptr< const RelAlgNode > rhs, std::unique_ptr< const RexScalar > &condition, const JoinType join_type)
RelAlgInputs inputs_
void set_precision(int d)
Definition: sqltypes.h:432
std::string toString() const override
std::pair< std::string, std::string > getKVOptionPair(std::string &str, size_t &pos)
void eliminate_dead_columns(std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
void check_empty_inputs_field(const rapidjson::Value &node) noexcept
std::vector< const Rex * > reproject_targets(const RelProject *simple_project, const std::vector< const Rex * > &target_exprs) noexcept
bool isIdentity() const
std::vector< std::unique_ptr< const RexScalar > > target_exprs_
std::vector< RexInput > n_outputs(const RelAlgNode *node, const size_t n)
const bool is_agg_
const RexScalar * visitCase(const RexCase *rex_case) const final
std::string toString() const override
static void resetRelAlgFirstId() noexcept
std::string json_node_to_string(const rapidjson::Value &node) noexcept