OmniSciDB  6686921089
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
JoinLoop.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "JoinLoop.h"
18 #include "../CgenState.h"
19 #include "Logger/Logger.h"
20 
21 #include <llvm/IR/Type.h>
22 
23 #include <stack>
24 
26  const JoinType type,
27  const std::function<JoinLoopDomain(const std::vector<llvm::Value*>&)>&
28  iteration_domain_codegen,
29  const std::function<llvm::Value*(const std::vector<llvm::Value*>&)>&
30  outer_condition_match,
31  const std::function<void(llvm::Value*)>& found_outer_matches,
32  const HoistedFiltersCallback& hoisted_filters,
33  const std::function<llvm::Value*(const std::vector<llvm::Value*>&,
34  llvm::Value*)>& is_deleted,
35  const std::string& name)
36  : kind_(kind)
37  , type_(type)
38  , iteration_domain_codegen_(iteration_domain_codegen)
39  , outer_condition_match_(outer_condition_match)
40  , found_outer_matches_(found_outer_matches)
41  , hoisted_filters_(hoisted_filters)
42  , is_deleted_(is_deleted)
43  , name_(name) {
44  CHECK(outer_condition_match == nullptr || type == JoinType::LEFT);
45  CHECK_EQ(static_cast<bool>(found_outer_matches), (type == JoinType::LEFT));
46 }
47 
48 llvm::BasicBlock* JoinLoop::codegen(
49  const std::vector<JoinLoop>& join_loops,
50  const std::function<llvm::BasicBlock*(const std::vector<llvm::Value*>&)>&
51  body_codegen,
52  llvm::Value* outer_iter,
53  llvm::BasicBlock* exit_bb,
54  CgenState* cgen_state) {
55  AUTOMATIC_IR_METADATA(cgen_state);
56  llvm::IRBuilder<>& builder = cgen_state->ir_builder_;
57  llvm::BasicBlock* prev_exit_bb{exit_bb};
58  llvm::BasicBlock* prev_iter_advance_bb{nullptr};
59  llvm::BasicBlock* last_head_bb{nullptr};
60  auto& context = builder.getContext();
61  const auto parent_func = builder.GetInsertBlock()->getParent();
62  llvm::Value* prev_comparison_result{nullptr};
63  llvm::BasicBlock* entry{nullptr};
64  std::vector<llvm::Value*> iterators;
65  iterators.push_back(outer_iter);
66  JoinType prev_join_type{JoinType::INVALID};
67  for (const auto& join_loop : join_loops) {
68  switch (join_loop.kind_) {
70  case JoinLoopKind::Set:
72  const auto preheader_bb = llvm::BasicBlock::Create(
73  context, "ub_iter_preheader_" + join_loop.name_, parent_func);
74 
75  llvm::BasicBlock* filter_bb{nullptr};
76  if (join_loop.hoisted_filters_) {
77  filter_bb = join_loop.hoisted_filters_(
78  preheader_bb, prev_exit_bb, join_loop.name_, parent_func, cgen_state);
79  }
80 
81  if (!entry) {
82  entry = filter_bb ? filter_bb : preheader_bb;
83  }
84 
85  if (prev_comparison_result) {
86  builder.CreateCondBr(
87  prev_comparison_result,
88  filter_bb ? filter_bb : preheader_bb,
89  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
90  }
91  prev_exit_bb = prev_iter_advance_bb ? prev_iter_advance_bb : exit_bb;
92  builder.SetInsertPoint(preheader_bb);
93 
94  const auto iteration_counter_ptr = builder.CreateAlloca(
95  get_int_type(64, context), nullptr, "ub_iter_counter_ptr_" + join_loop.name_);
96  llvm::Value* found_an_outer_match_ptr{nullptr};
97  llvm::Value* current_condition_match_ptr{nullptr};
98  if (join_loop.type_ == JoinType::LEFT) {
99  found_an_outer_match_ptr = builder.CreateAlloca(
100  get_int_type(1, context), nullptr, "found_an_outer_match");
101  builder.CreateStore(ll_bool(false, context), found_an_outer_match_ptr);
102  current_condition_match_ptr = builder.CreateAlloca(
103  get_int_type(1, context), nullptr, "outer_condition_current_match");
104  }
105  builder.CreateStore(ll_int(int64_t(0), context), iteration_counter_ptr);
106  const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);
107  const auto head_bb = llvm::BasicBlock::Create(
108  context, "ub_iter_head_" + join_loop.name_, parent_func);
109  builder.CreateBr(head_bb);
110  builder.SetInsertPoint(head_bb);
111  llvm::Value* iteration_counter = builder.CreateLoad(
112  iteration_counter_ptr, "ub_iter_counter_val_" + join_loop.name_);
113  auto iteration_val = iteration_counter;
114  CHECK(join_loop.kind_ == JoinLoopKind::Set ||
115  join_loop.kind_ == JoinLoopKind::MultiSet ||
116  !iteration_domain.values_buffer);
117  if (join_loop.kind_ == JoinLoopKind::Set ||
118  join_loop.kind_ == JoinLoopKind::MultiSet) {
119  CHECK(iteration_domain.values_buffer->getType()->isPointerTy());
120  const auto ptr_type =
121  static_cast<llvm::PointerType*>(iteration_domain.values_buffer->getType());
122  if (ptr_type->getElementType()->isArrayTy()) {
123  iteration_val = builder.CreateGEP(
124  iteration_domain.values_buffer,
125  std::vector<llvm::Value*>{
126  llvm::ConstantInt::get(get_int_type(64, context), 0),
127  iteration_counter},
128  "ub_iter_counter_" + join_loop.name_);
129  } else {
130  iteration_val = builder.CreateGEP(iteration_domain.values_buffer,
131  iteration_counter,
132  "ub_iter_counter_" + join_loop.name_);
133  }
134  }
135  iterators.push_back(iteration_val);
136  const auto have_more_inner_rows = builder.CreateICmpSLT(
137  iteration_counter,
138  join_loop.kind_ == JoinLoopKind::UpperBound ? iteration_domain.upper_bound
139  : iteration_domain.element_count,
140  "have_more_inner_rows");
141  const auto iter_advance_bb = llvm::BasicBlock::Create(
142  context, "ub_iter_advance_" + join_loop.name_, parent_func);
143  llvm::BasicBlock* row_not_deleted_bb{nullptr};
144  if (join_loop.is_deleted_) {
145  row_not_deleted_bb = llvm::BasicBlock::Create(
146  context, "row_not_deleted_" + join_loop.name_, parent_func);
147  const auto row_is_deleted =
148  join_loop.is_deleted_(iterators, have_more_inner_rows);
149  builder.CreateCondBr(row_is_deleted, iter_advance_bb, row_not_deleted_bb);
150  builder.SetInsertPoint(row_not_deleted_bb);
151  }
152  if (join_loop.type_ == JoinType::LEFT) {
153  std::tie(last_head_bb, prev_comparison_result) =
154  evaluateOuterJoinCondition(join_loop,
155  iteration_domain,
156  iterators,
157  iteration_counter,
158  have_more_inner_rows,
159  found_an_outer_match_ptr,
160  current_condition_match_ptr,
161  cgen_state);
162  } else {
163  prev_comparison_result = have_more_inner_rows;
164  last_head_bb = row_not_deleted_bb ? row_not_deleted_bb : head_bb;
165  }
166  builder.SetInsertPoint(iter_advance_bb);
167  const auto iteration_counter_next_val =
168  builder.CreateAdd(iteration_counter, ll_int(int64_t(1), context));
169  builder.CreateStore(iteration_counter_next_val, iteration_counter_ptr);
170  if (join_loop.type_ == JoinType::LEFT) {
171  const auto no_more_inner_rows =
172  builder.CreateICmpSGT(iteration_counter_next_val,
173  join_loop.kind_ == JoinLoopKind::UpperBound
174  ? iteration_domain.upper_bound
175  : iteration_domain.element_count,
176  "no_more_inner_rows");
177  builder.CreateCondBr(no_more_inner_rows, prev_exit_bb, head_bb);
178  } else {
179  builder.CreateBr(head_bb);
180  }
181  builder.SetInsertPoint(last_head_bb);
182  prev_iter_advance_bb = iter_advance_bb;
183  break;
184  }
186  const auto true_bb = llvm::BasicBlock::Create(
187  context, "singleton_true_" + join_loop.name_, parent_func);
188 
189  llvm::BasicBlock* filter_bb{nullptr};
190  if (join_loop.hoisted_filters_) {
191  filter_bb = join_loop.hoisted_filters_(
192  true_bb, prev_exit_bb, join_loop.name_, parent_func, cgen_state);
193  }
194 
195  if (!entry) {
196  entry = filter_bb ? filter_bb : true_bb;
197  }
198 
199  if (prev_comparison_result) {
200  builder.CreateCondBr(
201  prev_comparison_result,
202  filter_bb ? filter_bb : true_bb,
203  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
204  }
205  prev_exit_bb = prev_iter_advance_bb ? prev_iter_advance_bb : exit_bb;
206 
207  builder.SetInsertPoint(true_bb);
208  const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);
209  CHECK(!iteration_domain.values_buffer);
210  iterators.push_back(iteration_domain.slot_lookup_result);
211  auto join_cond_match = builder.CreateICmpSGE(iteration_domain.slot_lookup_result,
212  ll_int<int64_t>(0, context));
213  llvm::Value* remaining_cond_match = builder.CreateAlloca(
214  get_int_type(1, context), nullptr, "remaining_outer_cond_match");
215  builder.CreateStore(ll_bool(true, context), remaining_cond_match);
216 
217  if (join_loop.type_ == JoinType::LEFT && join_loop.outer_condition_match_) {
218  const auto parent_func = builder.GetInsertBlock()->getParent();
219  const auto evaluate_remaining_outer_cond_bb = llvm::BasicBlock::Create(
220  context, "eval_remaining_outer_cond_" + join_loop.name_, parent_func);
221  const auto after_evaluate_outer_cond_bb = llvm::BasicBlock::Create(
222  context, "after_eval_outer_cond_" + join_loop.name_, parent_func);
223  builder.CreateCondBr(join_cond_match,
224  evaluate_remaining_outer_cond_bb,
225  after_evaluate_outer_cond_bb);
226  builder.SetInsertPoint(evaluate_remaining_outer_cond_bb);
227  const auto outer_cond_match = join_loop.outer_condition_match_(iterators);
228  const auto true_left_cond_match =
229  builder.CreateAnd(outer_cond_match, join_cond_match);
230  builder.CreateStore(true_left_cond_match, remaining_cond_match);
231  builder.CreateBr(after_evaluate_outer_cond_bb);
232  builder.SetInsertPoint(after_evaluate_outer_cond_bb);
233  }
234  auto match_found =
235  builder.CreateAnd(join_cond_match, builder.CreateLoad(remaining_cond_match));
236  CHECK(match_found);
237  if (join_loop.is_deleted_) {
238  match_found = builder.CreateAnd(
239  match_found, builder.CreateNot(join_loop.is_deleted_(iterators, nullptr)));
240  }
241  auto match_found_bb = builder.GetInsertBlock();
242  switch (join_loop.type_) {
243  case JoinType::INNER:
244  case JoinType::SEMI: {
245  prev_comparison_result = match_found;
246  break;
247  }
248  case JoinType::ANTI: {
249  auto match_found_for_anti_join = builder.CreateICmpSLT(
250  iteration_domain.slot_lookup_result, ll_int<int64_t>(0, context));
251  prev_comparison_result = match_found_for_anti_join;
252  break;
253  }
254  case JoinType::LEFT: {
255  join_loop.found_outer_matches_(match_found);
256  // For outer joins, do the iteration regardless of the result of the match.
257  prev_comparison_result = ll_bool(true, context);
258  break;
259  }
260  default:
261  CHECK(false);
262  }
263  if (!prev_iter_advance_bb) {
264  prev_iter_advance_bb = prev_exit_bb;
265  }
266  last_head_bb = match_found_bb;
267  break;
268  }
269  default:
270  CHECK(false);
271  }
272  prev_join_type = join_loop.type_;
273  }
274 
275  const auto body_bb = body_codegen(iterators);
276  builder.CreateBr(prev_iter_advance_bb);
277  builder.SetInsertPoint(last_head_bb);
278  builder.CreateCondBr(
279  prev_comparison_result,
280  body_bb,
281  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
282  return entry;
283 }
284 
285 std::pair<llvm::BasicBlock*, llvm::Value*> JoinLoop::evaluateOuterJoinCondition(
286  const JoinLoop& join_loop,
287  const JoinLoopDomain& iteration_domain,
288  const std::vector<llvm::Value*>& iterators,
289  llvm::Value* iteration_counter,
290  llvm::Value* have_more_inner_rows,
291  llvm::Value* found_an_outer_match_ptr,
292  llvm::Value* current_condition_match_ptr,
293  CgenState* cgen_state) {
294  AUTOMATIC_IR_METADATA(cgen_state);
295  llvm::IRBuilder<>& builder = cgen_state->ir_builder_;
296  auto& context = builder.getContext();
297  const auto parent_func = builder.GetInsertBlock()->getParent();
298  builder.CreateStore(ll_bool(false, context), current_condition_match_ptr);
299  const auto evaluate_outer_condition_bb = llvm::BasicBlock::Create(
300  context, "eval_outer_cond_" + join_loop.name_, parent_func);
301  const auto after_evaluate_outer_condition_bb = llvm::BasicBlock::Create(
302  context, "after_eval_outer_cond_" + join_loop.name_, parent_func);
303  builder.CreateCondBr(have_more_inner_rows,
304  evaluate_outer_condition_bb,
305  after_evaluate_outer_condition_bb);
306  builder.SetInsertPoint(evaluate_outer_condition_bb);
307  const auto current_condition_match = join_loop.outer_condition_match_
308  ? join_loop.outer_condition_match_(iterators)
309  : ll_bool(true, context);
310  builder.CreateStore(current_condition_match, current_condition_match_ptr);
311  const auto updated_condition_match = builder.CreateOr(
312  current_condition_match, builder.CreateLoad(found_an_outer_match_ptr));
313  builder.CreateStore(updated_condition_match, found_an_outer_match_ptr);
314  builder.CreateBr(after_evaluate_outer_condition_bb);
315  builder.SetInsertPoint(after_evaluate_outer_condition_bb);
316  const auto no_matches_found =
317  builder.CreateNot(builder.CreateLoad(found_an_outer_match_ptr));
318  const auto no_more_inner_rows = builder.CreateICmpEQ(
319  iteration_counter,
320  join_loop.kind_ == JoinLoopKind::UpperBound ? iteration_domain.upper_bound
321  : iteration_domain.element_count);
322  // Do the iteration if the outer condition is true or it's the last iteration and no
323  // matches have been found.
324  const auto do_iteration =
325  builder.CreateOr(builder.CreateLoad(current_condition_match_ptr),
326  builder.CreateAnd(no_matches_found, no_more_inner_rows));
327  join_loop.found_outer_matches_(builder.CreateLoad(current_condition_match_ptr));
328  return {after_evaluate_outer_condition_bb, do_iteration};
329 }
#define CHECK_EQ(x, y)
Definition: Logger.h:217
JoinType
Definition: sqldefs.h:108
llvm::Value * element_count
Definition: JoinLoop.h:46
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
llvm::IRBuilder ir_builder_
Definition: CgenState.h:340
std::function< llvm::BasicBlock *(llvm::BasicBlock *, llvm::BasicBlock *, const std::string &, llvm::Function *, CgenState *)> HoistedFiltersCallback
Definition: JoinLoop.h:61
string name
Definition: setup.in.py:72
const std::function< llvm::Value *(const std::vector< llvm::Value * > &)> outer_condition_match_
Definition: JoinLoop.h:106
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
const JoinLoopKind kind_
Definition: JoinLoop.h:96
const std::function< void(llvm::Value *)> found_outer_matches_
Definition: JoinLoop.h:109
const std::string name_
Definition: JoinLoop.h:120
static llvm::BasicBlock * codegen(const std::vector< JoinLoop > &join_loops, const std::function< llvm::BasicBlock *(const std::vector< llvm::Value * > &)> &body_codegen, llvm::Value *outer_iter, llvm::BasicBlock *exit_bb, CgenState *cgen_state)
Definition: JoinLoop.cpp:48
#define AUTOMATIC_IR_METADATA(CGENSTATE)
llvm::Value * upper_bound
Definition: JoinLoop.h:45
JoinLoop(const JoinLoopKind, const JoinType, const std::function< JoinLoopDomain(const std::vector< llvm::Value * > &)> &iteration_domain_codegen, const std::function< llvm::Value *(const std::vector< llvm::Value * > &)> &outer_condition_match, const std::function< void(llvm::Value *)> &found_outer_matches, const HoistedFiltersCallback &hoisted_filters, const std::function< llvm::Value *(const std::vector< llvm::Value * > &prev_iters, llvm::Value *)> &is_deleted, const std::string &name="")
Definition: JoinLoop.cpp:25
#define CHECK(condition)
Definition: Logger.h:209
llvm::ConstantInt * ll_bool(const bool v, llvm::LLVMContext &context)
JoinLoopKind
Definition: JoinLoop.h:31
static std::pair< llvm::BasicBlock *, llvm::Value * > evaluateOuterJoinCondition(const JoinLoop &join_loop, const JoinLoopDomain &iteration_domain, const std::vector< llvm::Value * > &iterators, llvm::Value *iteration_counter, llvm::Value *have_more_inner_rows, llvm::Value *found_an_outer_match_ptr, llvm::Value *current_condition_match_ptr, CgenState *cgen_state)
Definition: JoinLoop.cpp:285