OmniSciDB  8a228a1076
JoinLoop.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "JoinLoop.h"
18 #include "../CgenState.h"
19 #include "Logger/Logger.h"
20 
21 #include <llvm/IR/Type.h>
22 
23 #include <stack>
24 
26  const JoinType type,
27  const std::function<JoinLoopDomain(const std::vector<llvm::Value*>&)>&
28  iteration_domain_codegen,
29  const std::function<llvm::Value*(const std::vector<llvm::Value*>&)>&
30  outer_condition_match,
31  const std::function<void(llvm::Value*)>& found_outer_matches,
32  const std::function<llvm::Value*(const std::vector<llvm::Value*>&,
33  llvm::Value*)>& is_deleted,
34  const std::string& name)
35  : kind_(kind)
36  , type_(type)
37  , iteration_domain_codegen_(iteration_domain_codegen)
38  , outer_condition_match_(outer_condition_match)
39  , found_outer_matches_(found_outer_matches)
40  , is_deleted_(is_deleted)
41  , name_(name) {
42  CHECK(outer_condition_match == nullptr || type == JoinType::LEFT);
43  CHECK_EQ(static_cast<bool>(found_outer_matches), (type == JoinType::LEFT));
44 }
45 
46 llvm::BasicBlock* JoinLoop::codegen(
47  const std::vector<JoinLoop>& join_loops,
48  const std::function<llvm::BasicBlock*(const std::vector<llvm::Value*>&)>&
49  body_codegen,
50  llvm::Value* outer_iter,
51  llvm::BasicBlock* exit_bb,
52  CgenState* cgen_state) {
53  AUTOMATIC_IR_METADATA(cgen_state);
54  llvm::IRBuilder<>& builder = cgen_state->ir_builder_;
55  llvm::BasicBlock* prev_exit_bb{exit_bb};
56  llvm::BasicBlock* prev_iter_advance_bb{nullptr};
57  llvm::BasicBlock* last_head_bb{nullptr};
58  auto& context = builder.getContext();
59  const auto parent_func = builder.GetInsertBlock()->getParent();
60  llvm::Value* prev_comparison_result{nullptr};
61  llvm::BasicBlock* entry{nullptr};
62  std::vector<llvm::Value*> iterators;
63  iterators.push_back(outer_iter);
64  JoinType prev_join_type{JoinType::INVALID};
65  for (const auto& join_loop : join_loops) {
66  switch (join_loop.kind_) {
68  case JoinLoopKind::Set: {
69  const auto preheader_bb = llvm::BasicBlock::Create(
70  context, "ub_iter_preheader_" + join_loop.name_, parent_func);
71  if (!entry) {
72  entry = preheader_bb;
73  }
74  if (prev_comparison_result) {
75  builder.CreateCondBr(
76  prev_comparison_result,
77  preheader_bb,
78  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
79  }
80  prev_exit_bb = prev_iter_advance_bb ? prev_iter_advance_bb : exit_bb;
81  builder.SetInsertPoint(preheader_bb);
82  const auto iteration_counter_ptr = builder.CreateAlloca(
83  get_int_type(64, context), nullptr, "ub_iter_counter_ptr_" + join_loop.name_);
84  llvm::Value* found_an_outer_match_ptr{nullptr};
85  llvm::Value* current_condition_match_ptr{nullptr};
86  if (join_loop.type_ == JoinType::LEFT) {
87  found_an_outer_match_ptr = builder.CreateAlloca(
88  get_int_type(1, context), nullptr, "found_an_outer_match");
89  builder.CreateStore(ll_bool(false, context), found_an_outer_match_ptr);
90  current_condition_match_ptr = builder.CreateAlloca(
91  get_int_type(1, context), nullptr, "outer_condition_current_match");
92  }
93  builder.CreateStore(ll_int(int64_t(0), context), iteration_counter_ptr);
94  const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);
95  const auto head_bb = llvm::BasicBlock::Create(
96  context, "ub_iter_head_" + join_loop.name_, parent_func);
97  builder.CreateBr(head_bb);
98  builder.SetInsertPoint(head_bb);
99  llvm::Value* iteration_counter = builder.CreateLoad(
100  iteration_counter_ptr, "ub_iter_counter_val_" + join_loop.name_);
101  auto iteration_val = iteration_counter;
102  CHECK(join_loop.kind_ == JoinLoopKind::Set || !iteration_domain.values_buffer);
103  if (join_loop.kind_ == JoinLoopKind::Set) {
104  iteration_val =
105  builder.CreateGEP(iteration_domain.values_buffer, iteration_counter);
106  }
107  iterators.push_back(iteration_val);
108  const auto have_more_inner_rows = builder.CreateICmpSLT(
109  iteration_counter,
110  join_loop.kind_ == JoinLoopKind::UpperBound ? iteration_domain.upper_bound
111  : iteration_domain.element_count,
112  "have_more_inner_rows");
113  const auto iter_advance_bb = llvm::BasicBlock::Create(
114  context, "ub_iter_advance_" + join_loop.name_, parent_func);
115  llvm::BasicBlock* row_not_deleted_bb{nullptr};
116  if (join_loop.is_deleted_) {
117  row_not_deleted_bb = llvm::BasicBlock::Create(
118  context, "row_not_deleted_" + join_loop.name_, parent_func);
119  const auto row_is_deleted =
120  join_loop.is_deleted_(iterators, have_more_inner_rows);
121  builder.CreateCondBr(row_is_deleted, iter_advance_bb, row_not_deleted_bb);
122  builder.SetInsertPoint(row_not_deleted_bb);
123  }
124  if (join_loop.type_ == JoinType::LEFT) {
125  std::tie(last_head_bb, prev_comparison_result) =
126  evaluateOuterJoinCondition(join_loop,
127  iteration_domain,
128  iterators,
129  iteration_counter,
130  have_more_inner_rows,
131  found_an_outer_match_ptr,
132  current_condition_match_ptr,
133  cgen_state);
134  } else {
135  prev_comparison_result = have_more_inner_rows;
136  last_head_bb = row_not_deleted_bb ? row_not_deleted_bb : head_bb;
137  }
138  builder.SetInsertPoint(iter_advance_bb);
139  const auto iteration_counter_next_val =
140  builder.CreateAdd(iteration_counter, ll_int(int64_t(1), context));
141  builder.CreateStore(iteration_counter_next_val, iteration_counter_ptr);
142  if (join_loop.type_ == JoinType::LEFT) {
143  const auto no_more_inner_rows =
144  builder.CreateICmpSGT(iteration_counter_next_val,
145  join_loop.kind_ == JoinLoopKind::UpperBound
146  ? iteration_domain.upper_bound
147  : iteration_domain.element_count,
148  "no_more_inner_rows");
149  builder.CreateCondBr(no_more_inner_rows, prev_exit_bb, head_bb);
150  } else {
151  builder.CreateBr(head_bb);
152  }
153  builder.SetInsertPoint(last_head_bb);
154  prev_iter_advance_bb = iter_advance_bb;
155  break;
156  }
158  const auto true_bb = llvm::BasicBlock::Create(
159  context, "singleton_true_" + join_loop.name_, parent_func);
160  if (!entry) {
161  entry = true_bb;
162  }
163  if (prev_comparison_result) {
164  builder.CreateCondBr(
165  prev_comparison_result,
166  true_bb,
167  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
168  }
169  prev_exit_bb = prev_iter_advance_bb ? prev_iter_advance_bb : exit_bb;
170  builder.SetInsertPoint(true_bb);
171  const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);
172  CHECK(!iteration_domain.values_buffer);
173  iterators.push_back(iteration_domain.slot_lookup_result);
174  auto match_found = builder.CreateICmpSGE(iteration_domain.slot_lookup_result,
175  ll_int<int64_t>(0, context));
176  if (join_loop.is_deleted_) {
177  match_found = builder.CreateAnd(
178  match_found, builder.CreateNot(join_loop.is_deleted_(iterators, nullptr)));
179  }
180  auto match_found_bb = builder.GetInsertBlock();
181  switch (join_loop.type_) {
182  case JoinType::INNER: {
183  prev_comparison_result = match_found;
184  break;
185  }
186  case JoinType::LEFT: {
187  join_loop.found_outer_matches_(match_found);
188  // For outer joins, do the iteration regardless of the result of the match.
189  prev_comparison_result = ll_bool(true, context);
190  break;
191  }
192  default:
193  CHECK(false);
194  }
195  if (!prev_iter_advance_bb) {
196  prev_iter_advance_bb = prev_exit_bb;
197  }
198  last_head_bb = match_found_bb;
199  break;
200  }
201  default:
202  CHECK(false);
203  }
204  prev_join_type = join_loop.type_;
205  }
206  const auto body_bb = body_codegen(iterators);
207  builder.CreateBr(prev_iter_advance_bb);
208  builder.SetInsertPoint(last_head_bb);
209  builder.CreateCondBr(
210  prev_comparison_result,
211  body_bb,
212  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
213  return entry;
214 }
215 
216 std::pair<llvm::BasicBlock*, llvm::Value*> JoinLoop::evaluateOuterJoinCondition(
217  const JoinLoop& join_loop,
218  const JoinLoopDomain& iteration_domain,
219  const std::vector<llvm::Value*>& iterators,
220  llvm::Value* iteration_counter,
221  llvm::Value* have_more_inner_rows,
222  llvm::Value* found_an_outer_match_ptr,
223  llvm::Value* current_condition_match_ptr,
224  CgenState* cgen_state) {
225  AUTOMATIC_IR_METADATA(cgen_state);
226  llvm::IRBuilder<>& builder = cgen_state->ir_builder_;
227  auto& context = builder.getContext();
228  const auto parent_func = builder.GetInsertBlock()->getParent();
229  builder.CreateStore(ll_bool(false, context), current_condition_match_ptr);
230  const auto evaluate_outer_condition_bb = llvm::BasicBlock::Create(
231  context, "eval_outer_cond_" + join_loop.name_, parent_func);
232  const auto after_evaluate_outer_condition_bb = llvm::BasicBlock::Create(
233  context, "after_eval_outer_cond_" + join_loop.name_, parent_func);
234  builder.CreateCondBr(have_more_inner_rows,
235  evaluate_outer_condition_bb,
236  after_evaluate_outer_condition_bb);
237  builder.SetInsertPoint(evaluate_outer_condition_bb);
238  const auto current_condition_match = join_loop.outer_condition_match_
239  ? join_loop.outer_condition_match_(iterators)
240  : ll_bool(true, context);
241  builder.CreateStore(current_condition_match, current_condition_match_ptr);
242  const auto updated_condition_match = builder.CreateOr(
243  current_condition_match, builder.CreateLoad(found_an_outer_match_ptr));
244  builder.CreateStore(updated_condition_match, found_an_outer_match_ptr);
245  builder.CreateBr(after_evaluate_outer_condition_bb);
246  builder.SetInsertPoint(after_evaluate_outer_condition_bb);
247  const auto no_matches_found =
248  builder.CreateNot(builder.CreateLoad(found_an_outer_match_ptr));
249  const auto no_more_inner_rows = builder.CreateICmpEQ(
250  iteration_counter,
251  join_loop.kind_ == JoinLoopKind::UpperBound ? iteration_domain.upper_bound
252  : iteration_domain.element_count);
253  // Do the iteration if the outer condition is true or it's the last iteration and no
254  // matches have been found.
255  const auto do_iteration =
256  builder.CreateOr(builder.CreateLoad(current_condition_match_ptr),
257  builder.CreateAnd(no_matches_found, no_more_inner_rows));
258  join_loop.found_outer_matches_(builder.CreateLoad(current_condition_match_ptr));
259  return {after_evaluate_outer_condition_bb, do_iteration};
260 }
#define CHECK_EQ(x, y)
Definition: Logger.h:205
JoinType
Definition: sqldefs.h:107
llvm::Value * element_count
Definition: JoinLoop.h:45
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
llvm::IRBuilder ir_builder_
Definition: CgenState.h:319
const std::function< llvm::Value *(const std::vector< llvm::Value * > &)> outer_condition_match_
Definition: JoinLoop.h:94
static std::pair< llvm::BasicBlock *, llvm::Value * > evaluateOuterJoinCondition(const JoinLoop &join_loop, const JoinLoopDomain &iteration_domain, const std::vector< llvm::Value *> &iterators, llvm::Value *iteration_counter, llvm::Value *have_more_inner_rows, llvm::Value *found_an_outer_match_ptr, llvm::Value *current_condition_match_ptr, CgenState *cgen_state)
Definition: JoinLoop.cpp:216
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
const JoinLoopKind kind_
Definition: JoinLoop.h:84
static llvm::BasicBlock * codegen(const std::vector< JoinLoop > &join_loops, const std::function< llvm::BasicBlock *(const std::vector< llvm::Value *> &)> &body_codegen, llvm::Value *outer_iter, llvm::BasicBlock *exit_bb, CgenState *cgen_state)
Definition: JoinLoop.cpp:46
const std::function< void(llvm::Value *)> found_outer_matches_
Definition: JoinLoop.h:97
const std::string name_
Definition: JoinLoop.h:105
JoinLoop(const JoinLoopKind, const JoinType, const std::function< JoinLoopDomain(const std::vector< llvm::Value *> &)> &, const std::function< llvm::Value *(const std::vector< llvm::Value *> &)> &, const std::function< void(llvm::Value *)> &, const std::function< llvm::Value *(const std::vector< llvm::Value *> &prev_iters, llvm::Value *)> &, const std::string &name="")
Definition: JoinLoop.cpp:25
#define AUTOMATIC_IR_METADATA(CGENSTATE)
llvm::Value * upper_bound
Definition: JoinLoop.h:44
#define CHECK(condition)
Definition: Logger.h:197
llvm::ConstantInt * ll_bool(const bool v, llvm::LLVMContext &context)
JoinLoopKind
Definition: JoinLoop.h:31