OmniSciDB  17c254d2f8
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
JoinLoop.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "JoinLoop.h"
18 #include "Shared/Logger.h"
19 
20 #include <llvm/IR/Type.h>
21 
22 #include <stack>
23 
25  const JoinType type,
26  const std::function<JoinLoopDomain(const std::vector<llvm::Value*>&)>&
27  iteration_domain_codegen,
28  const std::function<llvm::Value*(const std::vector<llvm::Value*>&)>&
29  outer_condition_match,
30  const std::function<void(llvm::Value*)>& found_outer_matches,
31  const std::function<llvm::Value*(const std::vector<llvm::Value*>&,
32  llvm::Value*)>& is_deleted,
33  const std::string& name)
34  : kind_(kind)
35  , type_(type)
36  , iteration_domain_codegen_(iteration_domain_codegen)
37  , outer_condition_match_(outer_condition_match)
38  , found_outer_matches_(found_outer_matches)
39  , is_deleted_(is_deleted)
40  , name_(name) {
41  CHECK(outer_condition_match == nullptr || type == JoinType::LEFT);
42  CHECK_EQ(static_cast<bool>(found_outer_matches), (type == JoinType::LEFT));
43 }
44 
45 llvm::BasicBlock* JoinLoop::codegen(
46  const std::vector<JoinLoop>& join_loops,
47  const std::function<llvm::BasicBlock*(const std::vector<llvm::Value*>&)>&
48  body_codegen,
49  llvm::Value* outer_iter,
50  llvm::BasicBlock* exit_bb,
51  llvm::IRBuilder<>& builder) {
52  llvm::BasicBlock* prev_exit_bb{exit_bb};
53  llvm::BasicBlock* prev_iter_advance_bb{nullptr};
54  llvm::BasicBlock* last_head_bb{nullptr};
55  auto& context = builder.getContext();
56  const auto parent_func = builder.GetInsertBlock()->getParent();
57  llvm::Value* prev_comparison_result{nullptr};
58  llvm::BasicBlock* entry{nullptr};
59  std::vector<llvm::Value*> iterators;
60  iterators.push_back(outer_iter);
61  JoinType prev_join_type{JoinType::INVALID};
62  for (const auto& join_loop : join_loops) {
63  switch (join_loop.kind_) {
65  case JoinLoopKind::Set: {
66  const auto preheader_bb = llvm::BasicBlock::Create(
67  context, "ub_iter_preheader_" + join_loop.name_, parent_func);
68  if (!entry) {
69  entry = preheader_bb;
70  }
71  if (prev_comparison_result) {
72  builder.CreateCondBr(
73  prev_comparison_result,
74  preheader_bb,
75  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
76  }
77  prev_exit_bb = prev_iter_advance_bb ? prev_iter_advance_bb : exit_bb;
78  builder.SetInsertPoint(preheader_bb);
79  const auto iteration_counter_ptr = builder.CreateAlloca(
80  get_int_type(64, context), nullptr, "ub_iter_counter_ptr_" + join_loop.name_);
81  llvm::Value* found_an_outer_match_ptr{nullptr};
82  llvm::Value* current_condition_match_ptr{nullptr};
83  if (join_loop.type_ == JoinType::LEFT) {
84  found_an_outer_match_ptr = builder.CreateAlloca(
85  get_int_type(1, context), nullptr, "found_an_outer_match");
86  builder.CreateStore(ll_bool(false, context), found_an_outer_match_ptr);
87  current_condition_match_ptr = builder.CreateAlloca(
88  get_int_type(1, context), nullptr, "outer_condition_current_match");
89  }
90  builder.CreateStore(ll_int(int64_t(0), context), iteration_counter_ptr);
91  const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);
92  const auto head_bb = llvm::BasicBlock::Create(
93  context, "ub_iter_head_" + join_loop.name_, parent_func);
94  builder.CreateBr(head_bb);
95  builder.SetInsertPoint(head_bb);
96  llvm::Value* iteration_counter = builder.CreateLoad(
97  iteration_counter_ptr, "ub_iter_counter_val_" + join_loop.name_);
98  auto iteration_val = iteration_counter;
99  CHECK(join_loop.kind_ == JoinLoopKind::Set || !iteration_domain.values_buffer);
100  if (join_loop.kind_ == JoinLoopKind::Set) {
101  iteration_val =
102  builder.CreateGEP(iteration_domain.values_buffer, iteration_counter);
103  }
104  iterators.push_back(iteration_val);
105  const auto have_more_inner_rows = builder.CreateICmpSLT(
106  iteration_counter,
107  join_loop.kind_ == JoinLoopKind::UpperBound ? iteration_domain.upper_bound
108  : iteration_domain.element_count,
109  "have_more_inner_rows");
110  const auto iter_advance_bb = llvm::BasicBlock::Create(
111  context, "ub_iter_advance_" + join_loop.name_, parent_func);
112  llvm::BasicBlock* row_not_deleted_bb{nullptr};
113  if (join_loop.is_deleted_) {
114  row_not_deleted_bb = llvm::BasicBlock::Create(
115  context, "row_not_deleted_" + join_loop.name_, parent_func);
116  const auto row_is_deleted =
117  join_loop.is_deleted_(iterators, have_more_inner_rows);
118  builder.CreateCondBr(row_is_deleted, iter_advance_bb, row_not_deleted_bb);
119  builder.SetInsertPoint(row_not_deleted_bb);
120  }
121  if (join_loop.type_ == JoinType::LEFT) {
122  std::tie(last_head_bb, prev_comparison_result) =
123  evaluateOuterJoinCondition(join_loop,
124  iteration_domain,
125  iterators,
126  iteration_counter,
127  have_more_inner_rows,
128  found_an_outer_match_ptr,
129  current_condition_match_ptr,
130  builder);
131  } else {
132  prev_comparison_result = have_more_inner_rows;
133  last_head_bb = row_not_deleted_bb ? row_not_deleted_bb : head_bb;
134  }
135  builder.SetInsertPoint(iter_advance_bb);
136  const auto iteration_counter_next_val =
137  builder.CreateAdd(iteration_counter, ll_int(int64_t(1), context));
138  builder.CreateStore(iteration_counter_next_val, iteration_counter_ptr);
139  if (join_loop.type_ == JoinType::LEFT) {
140  const auto no_more_inner_rows =
141  builder.CreateICmpSGT(iteration_counter_next_val,
142  join_loop.kind_ == JoinLoopKind::UpperBound
143  ? iteration_domain.upper_bound
144  : iteration_domain.element_count,
145  "no_more_inner_rows");
146  builder.CreateCondBr(no_more_inner_rows, prev_exit_bb, head_bb);
147  } else {
148  builder.CreateBr(head_bb);
149  }
150  builder.SetInsertPoint(last_head_bb);
151  prev_iter_advance_bb = iter_advance_bb;
152  break;
153  }
155  const auto true_bb = llvm::BasicBlock::Create(
156  context, "singleton_true_" + join_loop.name_, parent_func);
157  if (!entry) {
158  entry = true_bb;
159  }
160  if (prev_comparison_result) {
161  builder.CreateCondBr(
162  prev_comparison_result,
163  true_bb,
164  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
165  }
166  prev_exit_bb = prev_iter_advance_bb ? prev_iter_advance_bb : exit_bb;
167  builder.SetInsertPoint(true_bb);
168  const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);
169  CHECK(!iteration_domain.values_buffer);
170  iterators.push_back(iteration_domain.slot_lookup_result);
171  auto match_found = builder.CreateICmpSGE(iteration_domain.slot_lookup_result,
172  ll_int<int64_t>(0, context));
173  if (join_loop.is_deleted_) {
174  match_found = builder.CreateAnd(
175  match_found, builder.CreateNot(join_loop.is_deleted_(iterators, nullptr)));
176  }
177  auto match_found_bb = builder.GetInsertBlock();
178  switch (join_loop.type_) {
179  case JoinType::INNER: {
180  prev_comparison_result = match_found;
181  break;
182  }
183  case JoinType::LEFT: {
184  join_loop.found_outer_matches_(match_found);
185  // For outer joins, do the iteration regardless of the result of the match.
186  prev_comparison_result = ll_bool(true, context);
187  break;
188  }
189  default:
190  CHECK(false);
191  }
192  if (!prev_iter_advance_bb) {
193  prev_iter_advance_bb = prev_exit_bb;
194  }
195  last_head_bb = match_found_bb;
196  break;
197  }
198  default:
199  CHECK(false);
200  }
201  prev_join_type = join_loop.type_;
202  }
203  const auto body_bb = body_codegen(iterators);
204  builder.CreateBr(prev_iter_advance_bb);
205  builder.SetInsertPoint(last_head_bb);
206  builder.CreateCondBr(
207  prev_comparison_result,
208  body_bb,
209  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
210  return entry;
211 }
212 
213 std::pair<llvm::BasicBlock*, llvm::Value*> JoinLoop::evaluateOuterJoinCondition(
214  const JoinLoop& join_loop,
215  const JoinLoopDomain& iteration_domain,
216  const std::vector<llvm::Value*>& iterators,
217  llvm::Value* iteration_counter,
218  llvm::Value* have_more_inner_rows,
219  llvm::Value* found_an_outer_match_ptr,
220  llvm::Value* current_condition_match_ptr,
221  llvm::IRBuilder<>& builder) {
222  auto& context = builder.getContext();
223  const auto parent_func = builder.GetInsertBlock()->getParent();
224  builder.CreateStore(ll_bool(false, context), current_condition_match_ptr);
225  const auto evaluate_outer_condition_bb = llvm::BasicBlock::Create(
226  context, "eval_outer_cond_" + join_loop.name_, parent_func);
227  const auto after_evaluate_outer_condition_bb = llvm::BasicBlock::Create(
228  context, "after_eval_outer_cond_" + join_loop.name_, parent_func);
229  builder.CreateCondBr(have_more_inner_rows,
230  evaluate_outer_condition_bb,
231  after_evaluate_outer_condition_bb);
232  builder.SetInsertPoint(evaluate_outer_condition_bb);
233  const auto current_condition_match = join_loop.outer_condition_match_
234  ? join_loop.outer_condition_match_(iterators)
235  : ll_bool(true, context);
236  builder.CreateStore(current_condition_match, current_condition_match_ptr);
237  const auto updated_condition_match = builder.CreateOr(
238  current_condition_match, builder.CreateLoad(found_an_outer_match_ptr));
239  builder.CreateStore(updated_condition_match, found_an_outer_match_ptr);
240  builder.CreateBr(after_evaluate_outer_condition_bb);
241  builder.SetInsertPoint(after_evaluate_outer_condition_bb);
242  const auto no_matches_found =
243  builder.CreateNot(builder.CreateLoad(found_an_outer_match_ptr));
244  const auto no_more_inner_rows = builder.CreateICmpEQ(
245  iteration_counter,
246  join_loop.kind_ == JoinLoopKind::UpperBound ? iteration_domain.upper_bound
247  : iteration_domain.element_count);
248  // Do the iteration if the outer condition is true or it's the last iteration and no
249  // matches have been found.
250  const auto do_iteration =
251  builder.CreateOr(builder.CreateLoad(current_condition_match_ptr),
252  builder.CreateAnd(no_matches_found, no_more_inner_rows));
253  join_loop.found_outer_matches_(builder.CreateLoad(current_condition_match_ptr));
254  return {after_evaluate_outer_condition_bb, do_iteration};
255 }
#define CHECK_EQ(x, y)
Definition: Logger.h:205
JoinType
Definition: sqldefs.h:107
llvm::Value * element_count
Definition: JoinLoop.h:44
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
const std::function< llvm::Value *(const std::vector< llvm::Value * > &)> outer_condition_match_
Definition: JoinLoop.h:93
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
const JoinLoopKind kind_
Definition: JoinLoop.h:83
CHECK(cgen_state)
const std::function< void(llvm::Value *)> found_outer_matches_
Definition: JoinLoop.h:96
const std::string name_
Definition: JoinLoop.h:104
llvm::Value * upper_bound
Definition: JoinLoop.h:43
static llvm::BasicBlock * codegen(const std::vector< JoinLoop > &join_loops, const std::function< llvm::BasicBlock *(const std::vector< llvm::Value * > &)> &body_codegen, llvm::Value *outer_iter, llvm::BasicBlock *exit_bb, llvm::IRBuilder<> &builder)
Definition: JoinLoop.cpp:45
JoinLoop(const JoinLoopKind, const JoinType, const std::function< JoinLoopDomain(const std::vector< llvm::Value * > &)> &, const std::function< llvm::Value *(const std::vector< llvm::Value * > &)> &, const std::function< void(llvm::Value *)> &, const std::function< llvm::Value *(const std::vector< llvm::Value * > &prev_iters, llvm::Value *)> &, const std::string &name="")
Definition: JoinLoop.cpp:24
llvm::ConstantInt * ll_bool(const bool v, llvm::LLVMContext &context)
JoinLoopKind
Definition: JoinLoop.h:30
static std::pair< llvm::BasicBlock *, llvm::Value * > evaluateOuterJoinCondition(const JoinLoop &join_loop, const JoinLoopDomain &iteration_domain, const std::vector< llvm::Value * > &iterators, llvm::Value *iteration_counter, llvm::Value *have_more_inner_rows, llvm::Value *found_an_outer_match_ptr, llvm::Value *current_condition_match_ptr, llvm::IRBuilder<> &builder)
Definition: JoinLoop.cpp:213