OmniSciDB  6686921089
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
RelLeftDeepInnerJoin.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RelLeftDeepInnerJoin.h"
18 #include "Logger/Logger.h"
19 #include "RelAlgDagBuilder.h"
20 #include "RexVisitor.h"
21 
22 #include <numeric>
23 
25  const std::shared_ptr<RelFilter>& filter,
26  std::vector<std::shared_ptr<const RelAlgNode>> inputs,
27  std::vector<std::shared_ptr<const RelJoin>>& original_joins)
28  : condition_(filter ? filter->getAndReleaseCondition() : nullptr)
29  , original_filter_(filter)
30  , original_joins_(original_joins) {
31  std::vector<std::unique_ptr<const RexScalar>> operands;
32  bool is_notnull = true;
33  // Accumulate join conditions from the (explicit) joins themselves and
34  // from the filter node at the root of the left-deep tree pattern.
35  outer_conditions_per_level_.resize(original_joins.size());
36  for (size_t nesting_level = 0; nesting_level < original_joins.size(); ++nesting_level) {
37  const auto& original_join = original_joins[nesting_level];
38  const auto condition_true =
39  dynamic_cast<const RexLiteral*>(original_join->getCondition());
40  if (!condition_true || !condition_true->getVal<bool>()) {
41  if (dynamic_cast<const RexOperator*>(original_join->getCondition())) {
42  is_notnull =
43  is_notnull && dynamic_cast<const RexOperator*>(original_join->getCondition())
44  ->getType()
45  .get_notnull();
46  }
47  switch (original_join->getJoinType()) {
48  case JoinType::INNER:
49  case JoinType::SEMI:
50  case JoinType::ANTI: {
51  if (original_join->getCondition()) {
52  operands.emplace_back(original_join->getAndReleaseCondition());
53  }
54  break;
55  }
56  case JoinType::LEFT: {
57  if (original_join->getCondition()) {
58  outer_conditions_per_level_[nesting_level].reset(
59  original_join->getAndReleaseCondition());
60  }
61  break;
62  }
63  default:
64  CHECK(false);
65  }
66  }
67  }
68  if (!operands.empty()) {
69  if (condition_) {
70  CHECK(dynamic_cast<const RexOperator*>(condition_.get()));
71  is_notnull =
72  is_notnull &&
73  static_cast<const RexOperator*>(condition_.get())->getType().get_notnull();
74  operands.emplace_back(std::move(condition_));
75  }
76  if (operands.size() > 1) {
77  condition_.reset(
78  new RexOperator(kAND, operands, SQLTypeInfo(kBOOLEAN, is_notnull)));
79  } else {
80  condition_ = std::move(operands.front());
81  }
82  }
83  if (!condition_) {
84  condition_.reset(new RexLiteral(true, kBOOLEAN, kBOOLEAN, 0, 0, 0, 0));
85  }
86  for (const auto& input : inputs) {
87  addManagedInput(input);
88  }
89 }
90 
92  return condition_.get();
93 }
94 
96  const size_t nesting_level) const {
97  CHECK_GE(nesting_level, size_t(1));
98  CHECK_LE(nesting_level, outer_conditions_per_level_.size());
99  // Outer join conditions are collected depth-first while the returned condition
100  // must be consistent with the order of the loops (which is reverse depth-first).
101  return outer_conditions_per_level_[outer_conditions_per_level_.size() - nesting_level]
102  .get();
103 }
104 
105 const JoinType RelLeftDeepInnerJoin::getJoinType(const size_t nesting_level) const {
106  CHECK_LE(nesting_level, original_joins_.size());
107  return original_joins_[original_joins_.size() - nesting_level]->getJoinType();
108 }
109 
110 std::string RelLeftDeepInnerJoin::toString() const {
111  std::string ret = ::typeName(this) + "(";
112  ret += ::toString(condition_);
113  for (const auto& input : inputs_) {
114  ret += " " + ::toString(input);
115  }
116  ret += ")";
117  return ret;
118 }
119 
121  if (!hash_) {
122  hash_ = typeid(RelLeftDeepInnerJoin).hash_code();
123  boost::hash_combine(*hash_,
124  condition_ ? condition_->toHash() : boost::hash_value("n"));
125  for (auto& node : inputs_) {
126  boost::hash_combine(*hash_, node->toHash());
127  }
128  }
129  return *hash_;
130 }
131 
133  size_t total_size = 0;
134  for (const auto& input : inputs_) {
135  total_size += input->size();
136  }
137  return total_size;
138 }
139 
140 std::shared_ptr<RelAlgNode> RelLeftDeepInnerJoin::deepCopy() const {
141  CHECK(false);
142  return nullptr;
143 }
144 
146  if (node == original_filter_.get()) {
147  return true;
148  }
149  for (const auto& original_join : original_joins_) {
150  if (original_join.get() == node) {
151  return true;
152  }
153  }
154  return false;
155 }
156 
158  return original_filter_.get();
159 }
160 
161 std::vector<std::shared_ptr<const RelJoin>> RelLeftDeepInnerJoin::getOriginalJoins()
162  const {
163  std::vector<std::shared_ptr<const RelJoin>> original_joins;
164  original_joins.assign(original_joins_.begin(), original_joins_.end());
165  return original_joins;
166 }
167 
168 namespace {
169 
171  std::deque<std::shared_ptr<const RelAlgNode>>& inputs,
172  std::vector<std::shared_ptr<const RelJoin>>& original_joins,
173  const std::shared_ptr<const RelJoin>& join) {
174  original_joins.push_back(join);
175  CHECK_EQ(size_t(2), join->inputCount());
176  const auto left_input_join =
177  std::dynamic_pointer_cast<const RelJoin>(join->getAndOwnInput(0));
178  if (left_input_join) {
179  inputs.push_front(join->getAndOwnInput(1));
180  collect_left_deep_join_inputs(inputs, original_joins, left_input_join);
181  } else {
182  inputs.push_front(join->getAndOwnInput(1));
183  inputs.push_front(join->getAndOwnInput(0));
184  }
185 }
186 
187 std::pair<std::shared_ptr<RelLeftDeepInnerJoin>, std::shared_ptr<const RelAlgNode>>
188 create_left_deep_join(const std::shared_ptr<RelAlgNode>& left_deep_join_root) {
189  const auto old_root = get_left_deep_join_root(left_deep_join_root);
190  if (!old_root) {
191  return {nullptr, nullptr};
192  }
193  std::deque<std::shared_ptr<const RelAlgNode>> inputs_deque;
194  const auto left_deep_join_filter =
195  std::dynamic_pointer_cast<RelFilter>(left_deep_join_root);
196  const auto join =
197  std::dynamic_pointer_cast<const RelJoin>(left_deep_join_root->getAndOwnInput(0));
198  CHECK(join);
199  std::vector<std::shared_ptr<const RelJoin>> original_joins;
200  collect_left_deep_join_inputs(inputs_deque, original_joins, join);
201  std::vector<std::shared_ptr<const RelAlgNode>> inputs(inputs_deque.begin(),
202  inputs_deque.end());
203  return {std::make_shared<RelLeftDeepInnerJoin>(
204  left_deep_join_filter, inputs, original_joins),
205  old_root};
206 }
207 
209  public:
211  : left_deep_join_(left_deep_join) {
212  std::vector<size_t> input_sizes;
213  CHECK_GT(left_deep_join->inputCount(), size_t(1));
214  for (size_t i = 0; i < left_deep_join->inputCount(); ++i) {
215  input_sizes.push_back(left_deep_join->getInput(i)->size());
216  }
217  input_size_prefix_sums_.resize(input_sizes.size());
219  input_sizes.begin(), input_sizes.end(), input_size_prefix_sums_.begin());
220  }
221 
222  void* visitInput(const RexInput* rex_input) const override {
223  const auto source_node = rex_input->getSourceNode();
224  if (left_deep_join_->coversOriginalNode(source_node)) {
225  const auto it = std::lower_bound(input_size_prefix_sums_.begin(),
226  input_size_prefix_sums_.end(),
227  rex_input->getIndex(),
228  std::less_equal<size_t>());
229  CHECK(it != input_size_prefix_sums_.end());
230  const auto input_node =
231  left_deep_join_->getInput(std::distance(input_size_prefix_sums_.begin(), it));
232  if (it != input_size_prefix_sums_.begin()) {
233  const auto prev_input_count = *(it - 1);
234  CHECK_LE(prev_input_count, rex_input->getIndex());
235  const auto input_index = rex_input->getIndex() - prev_input_count;
236  rex_input->setIndex(input_index);
237  }
238  rex_input->setSourceNode(input_node);
239  }
240  return nullptr;
241  };
242 
243  private:
244  std::vector<size_t> input_size_prefix_sums_;
246 };
247 
248 } // namespace
249 
250 // Recognize the left-deep join tree pattern with an optional filter as root
251 // with `node` as the parent of the join sub-tree. On match, return the root
252 // of the recognized tree (either the filter node or the outermost join).
253 std::shared_ptr<const RelAlgNode> get_left_deep_join_root(
254  const std::shared_ptr<RelAlgNode>& node) {
255  const auto left_deep_join_filter = dynamic_cast<const RelFilter*>(node.get());
256  if (left_deep_join_filter) {
257  const auto join = dynamic_cast<const RelJoin*>(left_deep_join_filter->getInput(0));
258  if (!join) {
259  return nullptr;
260  }
261  if (join->getJoinType() == JoinType::INNER || join->getJoinType() == JoinType::SEMI ||
262  join->getJoinType() == JoinType::ANTI) {
263  return node;
264  }
265  }
266  if (!node || node->inputCount() != 1) {
267  return nullptr;
268  }
269  const auto join = dynamic_cast<const RelJoin*>(node->getInput(0));
270  if (!join) {
271  return nullptr;
272  }
273  return node->getAndOwnInput(0);
274 }
275 
277  const RelLeftDeepInnerJoin* left_deep_join) {
278  RebindRexInputsFromLeftDeepJoin rebind_rex_inputs_from_left_deep_join(left_deep_join);
279  rebind_rex_inputs_from_left_deep_join.visit(rex);
280 }
281 
282 void create_left_deep_join(std::vector<std::shared_ptr<RelAlgNode>>& nodes) {
283  std::list<std::shared_ptr<RelAlgNode>> new_nodes;
284  for (auto& left_deep_join_candidate : nodes) {
285  std::shared_ptr<RelLeftDeepInnerJoin> left_deep_join;
286  std::shared_ptr<const RelAlgNode> old_root;
287  std::tie(left_deep_join, old_root) = create_left_deep_join(left_deep_join_candidate);
288  if (!left_deep_join) {
289  continue;
290  }
291  CHECK_GE(left_deep_join->inputCount(), size_t(2));
292  for (size_t nesting_level = 1; nesting_level <= left_deep_join->inputCount() - 1;
293  ++nesting_level) {
294  const auto outer_condition = left_deep_join->getOuterCondition(nesting_level);
295  if (outer_condition) {
296  rebind_inputs_from_left_deep_join(outer_condition, left_deep_join.get());
297  }
298  }
299  rebind_inputs_from_left_deep_join(left_deep_join->getInnerCondition(),
300  left_deep_join.get());
301  for (auto& node : nodes) {
302  if (node && node->hasInput(old_root.get())) {
303  node->replaceInput(left_deep_join_candidate, left_deep_join);
304  std::shared_ptr<const RelJoin> old_join;
305  if (std::dynamic_pointer_cast<const RelJoin>(left_deep_join_candidate)) {
306  old_join = std::static_pointer_cast<const RelJoin>(left_deep_join_candidate);
307  } else {
308  CHECK_EQ(size_t(1), left_deep_join_candidate->inputCount());
309  old_join = std::dynamic_pointer_cast<const RelJoin>(
310  left_deep_join_candidate->getAndOwnInput(0));
311  }
312  while (old_join) {
313  node->replaceInput(old_join, left_deep_join);
314  old_join =
315  std::dynamic_pointer_cast<const RelJoin>(old_join->getAndOwnInput(0));
316  }
317  }
318  }
319 
320  new_nodes.emplace_back(std::move(left_deep_join));
321  }
322 
323  // insert the new left join nodes to the front of the owned RelAlgNode list.
324  // This is done to ensure all created RelAlgNodes exist in this list for later
325  // visitation, such as RelAlgDagBuilder::resetQueryExecutionState.
326  nodes.insert(nodes.begin(), new_nodes.begin(), new_nodes.end());
327 }
#define CHECK_EQ(x, y)
Definition: Logger.h:217
std::vector< std::unique_ptr< const RexScalar > > outer_conditions_per_level_
JoinType
Definition: sqldefs.h:108
size_t size() const override
std::shared_ptr< const RelAlgNode > get_left_deep_join_root(const std::shared_ptr< RelAlgNode > &node)
bool coversOriginalNode(const RelAlgNode *node) const
std::pair< std::shared_ptr< RelLeftDeepInnerJoin >, std::shared_ptr< const RelAlgNode > > create_left_deep_join(const std::shared_ptr< RelAlgNode > &left_deep_join_root)
const RexScalar * getOuterCondition(const size_t nesting_level) const
std::string join(T const &container, std::string const &delim)
void addManagedInput(std::shared_ptr< const RelAlgNode > input)
#define CHECK_GE(x, y)
Definition: Logger.h:222
#define CHECK_GT(x, y)
Definition: Logger.h:221
std::shared_ptr< const RelAlgNode > getAndOwnInput(const size_t idx) const
std::string toString() const override
unsigned getIndex() const
void setIndex(const unsigned in_index) const
const std::vector< std::shared_ptr< const RelJoin > > original_joins_
std::shared_ptr< RelAlgNode > deepCopy() const override
DEVICE void partial_sum(ARGS &&...args)
Definition: gpu_enabled.h:87
const RelAlgNode * getInput(const size_t idx) const
Definition: sqldefs.h:37
std::vector< std::shared_ptr< const RelJoin > > getOriginalJoins() const
std::optional< size_t > hash_
DEVICE auto lower_bound(ARGS &&...args)
Definition: gpu_enabled.h:78
#define CHECK_LE(x, y)
Definition: Logger.h:220
void setSourceNode(const RelAlgNode *node) const
virtual size_t size() const =0
const RelAlgNode * getSourceNode() const
const RelFilter * getOriginalFilter() const
std::string typeName(const T *v)
Definition: toString.h:88
std::unique_ptr< const RexScalar > condition_
RelLeftDeepInnerJoin(const std::shared_ptr< RelFilter > &filter, RelAlgInputs inputs, std::vector< std::shared_ptr< const RelJoin >> &original_joins)
const JoinType getJoinType(const size_t nesting_level) const
const RexScalar * getInnerCondition() const
#define CHECK(condition)
Definition: Logger.h:209
void collect_left_deep_join_inputs(std::deque< std::shared_ptr< const RelAlgNode >> &inputs, std::vector< std::shared_ptr< const RelJoin >> &original_joins, const std::shared_ptr< const RelJoin > &join)
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
const std::shared_ptr< RelFilter > original_filter_
const size_t inputCount() const
void rebind_inputs_from_left_deep_join(const RexScalar *rex, const RelLeftDeepInnerJoin *left_deep_join)
size_t toHash() const override
RelAlgInputs inputs_