OmniSciDB  c1a53651b2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
RelLeftDeepInnerJoin.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RelLeftDeepInnerJoin.h"
18 #include "Logger/Logger.h"
19 #include "RelAlgDag.h"
20 #include "RexVisitor.h"
21 
22 #include <numeric>
23 
25  const std::shared_ptr<RelFilter>& filter,
26  std::vector<std::shared_ptr<const RelAlgNode>> inputs,
27  std::vector<std::shared_ptr<const RelJoin>>& original_joins)
28  : condition_(filter ? filter->getAndReleaseCondition() : nullptr)
29  , original_filter_(filter)
30  , original_joins_(original_joins) {
31  std::vector<std::unique_ptr<const RexScalar>> operands;
32  bool is_notnull = true;
33  // Accumulate join conditions from the (explicit) joins themselves and
34  // from the filter node at the root of the left-deep tree pattern.
35  outer_conditions_per_level_.resize(original_joins.size());
36  for (size_t nesting_level = 0; nesting_level < original_joins.size(); ++nesting_level) {
37  const auto& original_join = original_joins[nesting_level];
38  const auto condition_true =
39  dynamic_cast<const RexLiteral*>(original_join->getCondition());
40  if (!condition_true || !condition_true->getVal<bool>()) {
41  if (dynamic_cast<const RexOperator*>(original_join->getCondition())) {
42  is_notnull =
43  is_notnull && dynamic_cast<const RexOperator*>(original_join->getCondition())
44  ->getType()
45  .get_notnull();
46  }
47  switch (original_join->getJoinType()) {
48  case JoinType::INNER:
49  case JoinType::SEMI:
50  case JoinType::ANTI: {
51  if (original_join->getCondition()) {
52  operands.emplace_back(original_join->getAndReleaseCondition());
53  }
54  break;
55  }
56  case JoinType::LEFT: {
57  if (original_join->getCondition()) {
58  outer_conditions_per_level_[nesting_level].reset(
59  original_join->getAndReleaseCondition());
60  }
61  break;
62  }
63  default:
64  CHECK(false);
65  }
66  }
67  }
68  if (!operands.empty()) {
69  if (condition_) {
70  CHECK(dynamic_cast<const RexOperator*>(condition_.get()));
71  is_notnull =
72  is_notnull &&
73  static_cast<const RexOperator*>(condition_.get())->getType().get_notnull();
74  operands.emplace_back(std::move(condition_));
75  }
76  if (operands.size() > 1) {
77  condition_.reset(
78  new RexOperator(kAND, operands, SQLTypeInfo(kBOOLEAN, is_notnull)));
79  } else {
80  condition_ = std::move(operands.front());
81  }
82  }
83  if (!condition_) {
84  condition_.reset(new RexLiteral(true, kBOOLEAN, kBOOLEAN, 0, 0, 0, 0));
85  }
86  for (const auto& input : inputs) {
87  addManagedInput(input);
88  }
89 }
90 
92  return condition_.get();
93 }
94 
96  const size_t nesting_level) const {
97  CHECK_GE(nesting_level, size_t(1));
98  CHECK_LE(nesting_level, outer_conditions_per_level_.size());
99  // Outer join conditions are collected depth-first while the returned condition
100  // must be consistent with the order of the loops (which is reverse depth-first).
101  return outer_conditions_per_level_[outer_conditions_per_level_.size() - nesting_level]
102  .get();
103 }
104 
105 const JoinType RelLeftDeepInnerJoin::getJoinType(const size_t nesting_level) const {
106  CHECK_LE(nesting_level, original_joins_.size());
107  return original_joins_[original_joins_.size() - nesting_level]->getJoinType();
108 }
109 
111  std::string ret = ::typeName(this) + "(";
112  ret += condition_->toString(config);
113  if (!config.skip_input_nodes) {
114  for (const auto& input : inputs_) {
115  ret += " " + input->toString(config);
116  }
117  } else {
118  ret += ", input node id={";
119  for (auto& input : inputs_) {
120  ret += std::to_string(input->getId()) + " ";
121  }
122  ret += "}";
123  }
124  ret += ")";
125  return ret;
126 }
127 
129  if (!hash_) {
130  hash_ = typeid(RelLeftDeepInnerJoin).hash_code();
131  boost::hash_combine(*hash_, condition_ ? condition_->toHash() : HASH_N);
132  for (auto& expr : outer_conditions_per_level_) {
133  boost::hash_combine(*hash_, expr ? expr->toHash() : HASH_N);
134  }
135  boost::hash_combine(*hash_, original_filter_ ? original_filter_->toHash() : HASH_N);
136  for (auto& node : inputs_) {
137  boost::hash_combine(*hash_, node->toHash());
138  }
139  }
140  return *hash_;
141 }
142 
144  size_t total_size = 0;
145  for (const auto& input : inputs_) {
146  total_size += input->size();
147  }
148  return total_size;
149 }
150 
151 std::shared_ptr<RelAlgNode> RelLeftDeepInnerJoin::deepCopy() const {
152  CHECK(false);
153  return nullptr;
154 }
155 
157  if (node == original_filter_.get()) {
158  return true;
159  }
160  for (const auto& original_join : original_joins_) {
161  if (original_join.get() == node) {
162  return true;
163  }
164  }
165  return false;
166 }
167 
169  return original_filter_.get();
170 }
171 
172 std::vector<std::shared_ptr<const RelJoin>> RelLeftDeepInnerJoin::getOriginalJoins()
173  const {
174  std::vector<std::shared_ptr<const RelJoin>> original_joins;
175  original_joins.assign(original_joins_.begin(), original_joins_.end());
176  return original_joins;
177 }
178 
179 namespace {
180 
182  std::deque<std::shared_ptr<const RelAlgNode>>& inputs,
183  std::vector<std::shared_ptr<const RelJoin>>& original_joins,
184  const std::shared_ptr<const RelJoin>& join) {
185  original_joins.push_back(join);
186  CHECK_EQ(size_t(2), join->inputCount());
187  const auto left_input_join =
188  std::dynamic_pointer_cast<const RelJoin>(join->getAndOwnInput(0));
189  if (left_input_join) {
190  inputs.push_front(join->getAndOwnInput(1));
191  collect_left_deep_join_inputs(inputs, original_joins, left_input_join);
192  } else {
193  inputs.push_front(join->getAndOwnInput(1));
194  inputs.push_front(join->getAndOwnInput(0));
195  }
196 }
197 
198 std::pair<std::shared_ptr<RelLeftDeepInnerJoin>, std::shared_ptr<const RelAlgNode>>
199 create_left_deep_join(const std::shared_ptr<RelAlgNode>& left_deep_join_root) {
200  const auto old_root = get_left_deep_join_root(left_deep_join_root);
201  if (!old_root) {
202  return {nullptr, nullptr};
203  }
204  std::deque<std::shared_ptr<const RelAlgNode>> inputs_deque;
205  const auto left_deep_join_filter =
206  std::dynamic_pointer_cast<RelFilter>(left_deep_join_root);
207  const auto join =
208  std::dynamic_pointer_cast<const RelJoin>(left_deep_join_root->getAndOwnInput(0));
209  CHECK(join);
210  std::vector<std::shared_ptr<const RelJoin>> original_joins;
211  collect_left_deep_join_inputs(inputs_deque, original_joins, join);
212  std::vector<std::shared_ptr<const RelAlgNode>> inputs(inputs_deque.begin(),
213  inputs_deque.end());
214  return {std::make_shared<RelLeftDeepInnerJoin>(
215  left_deep_join_filter, inputs, original_joins),
216  old_root};
217 }
218 
219 class RebindRexInputsFromLeftDeepJoin : public RexVisitor<void*> {
220  public:
222  : left_deep_join_(left_deep_join) {
223  std::vector<size_t> input_sizes;
224  CHECK_GT(left_deep_join->inputCount(), size_t(1));
225  for (size_t i = 0; i < left_deep_join->inputCount(); ++i) {
226  input_sizes.push_back(left_deep_join->getInput(i)->size());
227  }
228  input_size_prefix_sums_.resize(input_sizes.size());
230  input_sizes.begin(), input_sizes.end(), input_size_prefix_sums_.begin());
231  }
232 
233  void* visitInput(const RexInput* rex_input) const override {
234  const auto source_node = rex_input->getSourceNode();
235  if (left_deep_join_->coversOriginalNode(source_node)) {
236  const auto it = std::lower_bound(input_size_prefix_sums_.begin(),
237  input_size_prefix_sums_.end(),
238  rex_input->getIndex(),
239  std::less_equal<size_t>());
240  CHECK(it != input_size_prefix_sums_.end());
241  const auto input_node =
242  left_deep_join_->getInput(std::distance(input_size_prefix_sums_.begin(), it));
243  if (it != input_size_prefix_sums_.begin()) {
244  const auto prev_input_count = *(it - 1);
245  CHECK_LE(prev_input_count, rex_input->getIndex());
246  const auto input_index = rex_input->getIndex() - prev_input_count;
247  rex_input->setIndex(input_index);
248  }
249  rex_input->setSourceNode(input_node);
250  }
251  return nullptr;
252  };
253 
254  private:
255  std::vector<size_t> input_size_prefix_sums_;
257 };
258 
259 } // namespace
260 
261 // Recognize the left-deep join tree pattern with an optional filter as root
262 // with `node` as the parent of the join sub-tree. On match, return the root
263 // of the recognized tree (either the filter node or the outermost join).
264 std::shared_ptr<const RelAlgNode> get_left_deep_join_root(
265  const std::shared_ptr<RelAlgNode>& node) {
266  const auto left_deep_join_filter = dynamic_cast<const RelFilter*>(node.get());
267  if (left_deep_join_filter) {
268  const auto join = dynamic_cast<const RelJoin*>(left_deep_join_filter->getInput(0));
269  if (!join) {
270  return nullptr;
271  }
272  if (join->getJoinType() == JoinType::INNER || join->getJoinType() == JoinType::SEMI ||
273  join->getJoinType() == JoinType::ANTI) {
274  return node;
275  }
276  }
277  if (!node || node->inputCount() != 1) {
278  return nullptr;
279  }
280  const auto join = dynamic_cast<const RelJoin*>(node->getInput(0));
281  if (!join) {
282  return nullptr;
283  }
284  return node->getAndOwnInput(0);
285 }
286 
288  const RelLeftDeepInnerJoin* left_deep_join) {
289  RebindRexInputsFromLeftDeepJoin rebind_rex_inputs_from_left_deep_join(left_deep_join);
290  rebind_rex_inputs_from_left_deep_join.visit(rex);
291 }
292 
293 void create_left_deep_join(std::vector<std::shared_ptr<RelAlgNode>>& nodes) {
294  std::list<std::shared_ptr<RelAlgNode>> new_nodes;
295  for (auto& left_deep_join_candidate : nodes) {
296  std::shared_ptr<RelLeftDeepInnerJoin> left_deep_join;
297  std::shared_ptr<const RelAlgNode> old_root;
298  std::tie(left_deep_join, old_root) = create_left_deep_join(left_deep_join_candidate);
299  if (!left_deep_join) {
300  continue;
301  }
302  CHECK_GE(left_deep_join->inputCount(), size_t(2));
303  for (size_t nesting_level = 1; nesting_level <= left_deep_join->inputCount() - 1;
304  ++nesting_level) {
305  const auto outer_condition = left_deep_join->getOuterCondition(nesting_level);
306  if (outer_condition) {
307  rebind_inputs_from_left_deep_join(outer_condition, left_deep_join.get());
308  }
309  }
310  rebind_inputs_from_left_deep_join(left_deep_join->getInnerCondition(),
311  left_deep_join.get());
312  for (auto& node : nodes) {
313  if (node && node->hasInput(old_root.get())) {
314  node->replaceInput(left_deep_join_candidate, left_deep_join);
315  std::shared_ptr<const RelJoin> old_join;
316  if (std::dynamic_pointer_cast<const RelJoin>(left_deep_join_candidate)) {
317  old_join = std::static_pointer_cast<const RelJoin>(left_deep_join_candidate);
318  } else {
319  CHECK_EQ(size_t(1), left_deep_join_candidate->inputCount());
320  old_join = std::dynamic_pointer_cast<const RelJoin>(
321  left_deep_join_candidate->getAndOwnInput(0));
322  }
323  while (old_join) {
324  node->replaceInput(old_join, left_deep_join);
325  old_join =
326  std::dynamic_pointer_cast<const RelJoin>(old_join->getAndOwnInput(0));
327  }
328  }
329  }
330 
331  new_nodes.emplace_back(std::move(left_deep_join));
332  }
333 
334  // insert the new left join nodes to the front of the owned RelAlgNode list.
335  // This is done to ensure all created RelAlgNodes exist in this list for later
336  // visitation, such as RelAlgDag::resetQueryExecutionState.
337  nodes.insert(nodes.begin(), new_nodes.begin(), new_nodes.end());
338 }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
std::vector< std::unique_ptr< const RexScalar > > outer_conditions_per_level_
Definition: RelAlgDag.h:1783
std::string toString(RelRexToStringConfig config=RelRexToStringConfig::defaults()) const override
JoinType
Definition: sqldefs.h:165
size_t size() const override
std::shared_ptr< const RelAlgNode > get_left_deep_join_root(const std::shared_ptr< RelAlgNode > &node)
bool coversOriginalNode(const RelAlgNode *node) const
std::pair< std::shared_ptr< RelLeftDeepInnerJoin >, std::shared_ptr< const RelAlgNode > > create_left_deep_join(const std::shared_ptr< RelAlgNode > &left_deep_join_root)
const RexScalar * getOuterCondition(const size_t nesting_level) const
std::shared_ptr< RelFilter > original_filter_
Definition: RelAlgDag.h:1784
std::string join(T const &container, std::string const &delim)
#define CHECK_GE(x, y)
Definition: Logger.h:306
std::vector< std::shared_ptr< const RelJoin > > original_joins_
Definition: RelAlgDag.h:1785
#define CHECK_GT(x, y)
Definition: Logger.h:305
std::shared_ptr< const RelAlgNode > getAndOwnInput(const size_t idx) const
Definition: RelAlgDag.h:897
std::string to_string(char const *&&v)
unsigned getIndex() const
Definition: RelAlgDag.h:77
static auto const HASH_N
Definition: RelAlgDag.h:44
RelLeftDeepInnerJoin()=default
void setIndex(const unsigned in_index) const
Definition: RelAlgDag.h:79
std::shared_ptr< RelAlgNode > deepCopy() const override
DEVICE void partial_sum(ARGS &&...args)
Definition: gpu_enabled.h:87
const RelAlgNode * getInput(const size_t idx) const
Definition: RelAlgDag.h:892
Definition: sqldefs.h:36
std::vector< std::shared_ptr< const RelJoin > > getOriginalJoins() const
std::optional< size_t > hash_
Definition: RelAlgDag.h:955
DEVICE auto lower_bound(ARGS &&...args)
Definition: gpu_enabled.h:78
#define CHECK_LE(x, y)
Definition: Logger.h:304
void setSourceNode(const RelAlgNode *node) const
Definition: RelAlgDag.h:394
virtual size_t size() const =0
const RelAlgNode * getSourceNode() const
Definition: RelAlgDag.h:389
const RelFilter * getOriginalFilter() const
std::string typeName(const T *v)
Definition: toString.h:103
std::unique_ptr< const RexScalar > condition_
Definition: RelAlgDag.h:1782
const JoinType getJoinType(const size_t nesting_level) const
const RexScalar * getInnerCondition() const
#define CHECK(condition)
Definition: Logger.h:291
void collect_left_deep_join_inputs(std::deque< std::shared_ptr< const RelAlgNode >> &inputs, std::vector< std::shared_ptr< const RelJoin >> &original_joins, const std::shared_ptr< const RelJoin > &join)
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
Definition: RelAlgDag.cpp:527
const size_t inputCount() const
Definition: RelAlgDag.h:890
void rebind_inputs_from_left_deep_join(const RexScalar *rex, const RelLeftDeepInnerJoin *left_deep_join)
size_t toHash() const override
RelAlgInputs inputs_
Definition: RelAlgDag.h:952