OmniSciDB  c07336695a
RelLeftDeepInnerJoin.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RelLeftDeepInnerJoin.h"
19 #include "RexVisitor.h"
20 #include "Shared/Logger.h"
21 
22 #include <numeric>
23 
25  const std::shared_ptr<RelFilter>& filter,
26  std::vector<std::shared_ptr<const RelAlgNode>> inputs,
27  std::vector<std::shared_ptr<const RelJoin>>& original_joins)
28  : condition_(filter ? filter->getAndReleaseCondition() : nullptr)
29  , original_filter_(filter)
30  , original_joins_(original_joins) {
31  std::vector<std::unique_ptr<const RexScalar>> operands;
32  bool is_notnull = true;
33  // Accumulate join conditions from the (explicit) joins themselves and
34  // from the filter node at the root of the left-deep tree pattern.
35  outer_conditions_per_level_.resize(original_joins.size());
36  for (size_t nesting_level = 0; nesting_level < original_joins.size(); ++nesting_level) {
37  const auto& original_join = original_joins[nesting_level];
38  const auto condition_true =
39  dynamic_cast<const RexLiteral*>(original_join->getCondition());
40  if (!condition_true || !condition_true->getVal<bool>()) {
41  if (dynamic_cast<const RexOperator*>(original_join->getCondition())) {
42  is_notnull =
43  is_notnull && dynamic_cast<const RexOperator*>(original_join->getCondition())
44  ->getType()
45  .get_notnull();
46  }
47  switch (original_join->getJoinType()) {
48  case JoinType::INNER: {
49  if (original_join->getCondition()) {
50  operands.emplace_back(original_join->getAndReleaseCondition());
51  }
52  break;
53  }
54  case JoinType::LEFT: {
55  if (original_join->getCondition()) {
56  outer_conditions_per_level_[nesting_level].reset(
57  original_join->getAndReleaseCondition());
58  }
59  break;
60  }
61  default:
62  CHECK(false);
63  }
64  }
65  }
66  if (!operands.empty()) {
67  if (condition_) {
68  CHECK(dynamic_cast<const RexOperator*>(condition_.get()));
69  is_notnull =
70  is_notnull &&
71  static_cast<const RexOperator*>(condition_.get())->getType().get_notnull();
72  operands.emplace_back(std::move(condition_));
73  }
74  if (operands.size() > 1) {
75  condition_.reset(
76  new RexOperator(kAND, operands, SQLTypeInfo(kBOOLEAN, is_notnull)));
77  } else {
78  condition_ = std::move(operands.front());
79  }
80  }
81  if (!condition_) {
82  condition_.reset(new RexLiteral(true, kBOOLEAN, kBOOLEAN, 0, 0, 0, 0));
83  }
84  for (const auto& input : inputs) {
85  addManagedInput(input);
86  }
87 }
88 
90  return condition_.get();
91 }
92 
94  const size_t nesting_level) const {
95  CHECK_GE(nesting_level, size_t(1));
96  CHECK_LE(nesting_level, outer_conditions_per_level_.size());
97  // Outer join conditions are collected depth-first while the returned condition
98  // must be consistent with the order of the loops (which is reverse depth-first).
99  return outer_conditions_per_level_[outer_conditions_per_level_.size() - nesting_level]
100  .get();
101 }
102 
103 std::string RelLeftDeepInnerJoin::toString() const {
104  std::string result =
105  "(RelLeftDeepInnerJoin<" + std::to_string(reinterpret_cast<uint64_t>(this)) + ">(";
106  result += condition_->toString();
107  for (const auto& input : inputs_) {
108  result += " " + input->toString();
109  }
110  result += ")";
111  return result;
112 }
113 
115  size_t total_size = 0;
116  for (const auto& input : inputs_) {
117  total_size += input->size();
118  }
119  return total_size;
120 }
121 
122 std::shared_ptr<RelAlgNode> RelLeftDeepInnerJoin::deepCopy() const {
123  CHECK(false);
124  return nullptr;
125 }
126 
128  if (node == original_filter_.get()) {
129  return true;
130  }
131  for (const auto& original_join : original_joins_) {
132  if (original_join.get() == node) {
133  return true;
134  }
135  }
136  return false;
137 }
138 
139 namespace {
140 
142  std::deque<std::shared_ptr<const RelAlgNode>>& inputs,
143  std::vector<std::shared_ptr<const RelJoin>>& original_joins,
144  const std::shared_ptr<const RelJoin>& join) {
145  original_joins.push_back(join);
146  CHECK_EQ(size_t(2), join->inputCount());
147  const auto left_input_join =
148  std::dynamic_pointer_cast<const RelJoin>(join->getAndOwnInput(0));
149  if (left_input_join) {
150  inputs.push_front(join->getAndOwnInput(1));
151  collect_left_deep_join_inputs(inputs, original_joins, left_input_join);
152  } else {
153  inputs.push_front(join->getAndOwnInput(1));
154  inputs.push_front(join->getAndOwnInput(0));
155  }
156 }
157 
158 std::pair<std::shared_ptr<RelLeftDeepInnerJoin>, std::shared_ptr<const RelAlgNode>>
159 create_left_deep_join(const std::shared_ptr<RelAlgNode>& left_deep_join_root) {
160  const auto old_root = get_left_deep_join_root(left_deep_join_root);
161  if (!old_root) {
162  return {nullptr, nullptr};
163  }
164  std::deque<std::shared_ptr<const RelAlgNode>> inputs_deque;
165  const auto left_deep_join_filter =
166  std::dynamic_pointer_cast<RelFilter>(left_deep_join_root);
167  const auto join =
168  std::dynamic_pointer_cast<const RelJoin>(left_deep_join_root->getAndOwnInput(0));
169  CHECK(join);
170  std::vector<std::shared_ptr<const RelJoin>> original_joins;
171  collect_left_deep_join_inputs(inputs_deque, original_joins, join);
172  std::vector<std::shared_ptr<const RelAlgNode>> inputs(inputs_deque.begin(),
173  inputs_deque.end());
174  return {std::make_shared<RelLeftDeepInnerJoin>(
175  left_deep_join_filter, inputs, original_joins),
176  old_root};
177 }
178 
180  public:
182  : left_deep_join_(left_deep_join) {
183  std::vector<size_t> input_sizes;
184  CHECK_GT(left_deep_join->inputCount(), size_t(1));
185  for (size_t i = 0; i < left_deep_join->inputCount(); ++i) {
186  input_sizes.push_back(left_deep_join->getInput(i)->size());
187  }
188  input_size_prefix_sums_.resize(input_sizes.size());
189  std::partial_sum(
190  input_sizes.begin(), input_sizes.end(), input_size_prefix_sums_.begin());
191  }
192 
193  void* visitInput(const RexInput* rex_input) const override {
194  const auto source_node = rex_input->getSourceNode();
195  if (left_deep_join_->coversOriginalNode(source_node)) {
196  const auto it = std::lower_bound(input_size_prefix_sums_.begin(),
197  input_size_prefix_sums_.end(),
198  rex_input->getIndex(),
199  std::less_equal<size_t>());
200  CHECK(it != input_size_prefix_sums_.end());
201  const auto input_node =
202  left_deep_join_->getInput(std::distance(input_size_prefix_sums_.begin(), it));
203  if (it != input_size_prefix_sums_.begin()) {
204  const auto prev_input_count = *(it - 1);
205  CHECK_LE(prev_input_count, rex_input->getIndex());
206  const auto input_index = rex_input->getIndex() - prev_input_count;
207  rex_input->setIndex(input_index);
208  }
209  rex_input->setSourceNode(input_node);
210  }
211  return nullptr;
212  };
213 
214  private:
215  std::vector<size_t> input_size_prefix_sums_;
217 };
218 
219 } // namespace
220 
221 // Recognize the left-deep join tree pattern with an optional filter as root
222 // with `node` as the parent of the join sub-tree. On match, return the root
223 // of the recognized tree (either the filter node or the outermost join).
224 std::shared_ptr<const RelAlgNode> get_left_deep_join_root(
225  const std::shared_ptr<RelAlgNode>& node) {
226  const auto left_deep_join_filter = dynamic_cast<const RelFilter*>(node.get());
227  if (left_deep_join_filter) {
228  const auto join = dynamic_cast<const RelJoin*>(left_deep_join_filter->getInput(0));
229  if (!join) {
230  return nullptr;
231  }
232  if (join->getJoinType() == JoinType::INNER) {
233  return node;
234  }
235  }
236  if (!node || node->inputCount() != 1) {
237  return nullptr;
238  }
239  const auto join = dynamic_cast<const RelJoin*>(node->getInput(0));
240  if (!join) {
241  return nullptr;
242  }
243  return node->getAndOwnInput(0);
244 }
245 
247  const RelLeftDeepInnerJoin* left_deep_join) {
248  RebindRexInputsFromLeftDeepJoin rebind_rex_inputs_from_left_deep_join(left_deep_join);
249  rebind_rex_inputs_from_left_deep_join.visit(rex);
250 }
251 
252 void create_left_deep_join(std::vector<std::shared_ptr<RelAlgNode>>& nodes) {
253  for (const auto& left_deep_join_candidate : nodes) {
254  std::shared_ptr<RelLeftDeepInnerJoin> left_deep_join;
255  std::shared_ptr<const RelAlgNode> old_root;
256  std::tie(left_deep_join, old_root) = create_left_deep_join(left_deep_join_candidate);
257  if (!left_deep_join) {
258  continue;
259  }
260  CHECK_GE(left_deep_join->inputCount(), size_t(2));
261  for (size_t nesting_level = 1; nesting_level <= left_deep_join->inputCount() - 1;
262  ++nesting_level) {
263  const auto outer_condition = left_deep_join->getOuterCondition(nesting_level);
264  if (outer_condition) {
265  rebind_inputs_from_left_deep_join(outer_condition, left_deep_join.get());
266  }
267  }
268  rebind_inputs_from_left_deep_join(left_deep_join->getInnerCondition(),
269  left_deep_join.get());
270  for (auto& node : nodes) {
271  if (node && node->hasInput(old_root.get())) {
272  node->replaceInput(left_deep_join_candidate, left_deep_join);
273  std::shared_ptr<const RelJoin> old_join;
274  if (std::dynamic_pointer_cast<const RelJoin>(left_deep_join_candidate)) {
275  old_join = std::static_pointer_cast<const RelJoin>(left_deep_join_candidate);
276  } else {
277  CHECK_EQ(size_t(1), left_deep_join_candidate->inputCount());
278  old_join = std::dynamic_pointer_cast<const RelJoin>(
279  left_deep_join_candidate->getAndOwnInput(0));
280  }
281  while (old_join) {
282  node->replaceInput(old_join, left_deep_join);
283  old_join =
284  std::dynamic_pointer_cast<const RelJoin>(old_join->getAndOwnInput(0));
285  }
286  }
287  }
288  }
289 }
std::vector< std::shared_ptr< const RelAlgNode > > inputs_
#define CHECK_EQ(x, y)
Definition: Logger.h:195
std::vector< std::unique_ptr< const RexScalar > > outer_conditions_per_level_
void setSourceNode(const RelAlgNode *node) const
void setIndex(const unsigned in_index) const
size_t size() const override
std::shared_ptr< const RelAlgNode > get_left_deep_join_root(const std::shared_ptr< RelAlgNode > &node)
std::shared_ptr< const RelAlgNode > getAndOwnInput(const size_t idx) const
std::string join(T const &container, std::string const &delim)
void addManagedInput(std::shared_ptr< const RelAlgNode > input)
#define CHECK_GE(x, y)
Definition: Logger.h:200
#define CHECK_GT(x, y)
Definition: Logger.h:199
std::string to_string(char const *&&v)
std::string toString() const override
const RelAlgNode * getSourceNode() const
const std::vector< std::shared_ptr< const RelJoin > > original_joins_
RelLeftDeepInnerJoin(const std::shared_ptr< RelFilter > &filter, std::vector< std::shared_ptr< const RelAlgNode >> inputs, std::vector< std::shared_ptr< const RelJoin >> &original_joins)
const RexScalar * getInnerCondition() const
std::shared_ptr< RelAlgNode > deepCopy() const override
Definition: sqldefs.h:37
SQLTypeInfoCore< ArrayContextTypeSizer, ExecutorTypePackaging, DateTimeFacilities > SQLTypeInfo
Definition: sqltypes.h:819
const RexScalar * getOuterCondition(const size_t nesting_level) const
void create_left_deep_join(std::vector< std::shared_ptr< RelAlgNode >> &nodes)
#define CHECK_LE(x, y)
Definition: Logger.h:198
const size_t inputCount() const
virtual size_t size() const =0
std::unique_ptr< const RexScalar > condition_
#define CHECK(condition)
Definition: Logger.h:187
void collect_left_deep_join_inputs(std::deque< std::shared_ptr< const RelAlgNode >> &inputs, std::vector< std::shared_ptr< const RelJoin >> &original_joins, const std::shared_ptr< const RelJoin > &join)
const RelAlgNode * getInput(const size_t idx) const
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
const std::shared_ptr< RelFilter > original_filter_
void rebind_inputs_from_left_deep_join(const RexScalar *rex, const RelLeftDeepInnerJoin *left_deep_join)
bool coversOriginalNode(const RelAlgNode *node) const