OmniSciDB  1dac507f6e
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
JoinLoop.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "JoinLoop.h"
18 #include "Shared/Logger.h"
19 
20 #include <llvm/IR/Type.h>
21 
22 #include <stack>
23 
25  const JoinType type,
26  const std::function<JoinLoopDomain(const std::vector<llvm::Value*>&)>&
27  iteration_domain_codegen,
28  const std::function<llvm::Value*(const std::vector<llvm::Value*>&)>&
29  outer_condition_match,
30  const std::function<void(llvm::Value*)>& found_outer_matches,
31  const std::function<llvm::Value*(const std::vector<llvm::Value*>&,
32  llvm::Value*)>& is_deleted,
33  const std::string& name)
34  : kind_(kind)
35  , type_(type)
36  , iteration_domain_codegen_(iteration_domain_codegen)
37  , outer_condition_match_(outer_condition_match)
38  , found_outer_matches_(found_outer_matches)
39  , is_deleted_(is_deleted)
40  , name_(name) {
41  CHECK(outer_condition_match == nullptr || type == JoinType::LEFT);
42  CHECK_EQ(static_cast<bool>(found_outer_matches), (type == JoinType::LEFT));
43 }
44 
45 llvm::BasicBlock* JoinLoop::codegen(
46  const std::vector<JoinLoop>& join_loops,
47  const std::function<llvm::BasicBlock*(const std::vector<llvm::Value*>&)>&
48  body_codegen,
49  llvm::Value* outer_iter,
50  llvm::BasicBlock* exit_bb,
51  llvm::IRBuilder<>& builder) {
52  llvm::BasicBlock* prev_exit_bb{exit_bb};
53  llvm::BasicBlock* prev_iter_advance_bb{nullptr};
54  llvm::BasicBlock* last_head_bb{nullptr};
55  auto& context = builder.getContext();
56  const auto parent_func = builder.GetInsertBlock()->getParent();
57  llvm::Value* prev_comparison_result{nullptr};
58  llvm::BasicBlock* entry{nullptr};
59  std::vector<llvm::Value*> iterators;
60  iterators.push_back(outer_iter);
61  JoinType prev_join_type{JoinType::INVALID};
62  for (const auto& join_loop : join_loops) {
63  switch (join_loop.kind_) {
65  case JoinLoopKind::Set: {
66  const auto preheader_bb = llvm::BasicBlock::Create(
67  context, "ub_iter_preheader_" + join_loop.name_, parent_func);
68  if (!entry) {
69  entry = preheader_bb;
70  }
71  if (prev_comparison_result) {
72  builder.CreateCondBr(
73  prev_comparison_result,
74  preheader_bb,
75  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
76  }
77  prev_exit_bb = prev_iter_advance_bb ? prev_iter_advance_bb : exit_bb;
78  builder.SetInsertPoint(preheader_bb);
79  const auto iteration_counter_ptr = builder.CreateAlloca(
80  get_int_type(64, context), nullptr, "ub_iter_counter_ptr_" + join_loop.name_);
81  llvm::Value* found_an_outer_match_ptr{nullptr};
82  if (join_loop.type_ == JoinType::LEFT) {
83  found_an_outer_match_ptr = builder.CreateAlloca(
84  get_int_type(1, context), nullptr, "found_an_outer_match");
85  builder.CreateStore(ll_bool(false, context), found_an_outer_match_ptr);
86  }
87  builder.CreateStore(ll_int(int64_t(0), context), iteration_counter_ptr);
88  const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);
89  const auto head_bb = llvm::BasicBlock::Create(
90  context, "ub_iter_head_" + join_loop.name_, parent_func);
91  builder.CreateBr(head_bb);
92  builder.SetInsertPoint(head_bb);
93  llvm::Value* iteration_counter = builder.CreateLoad(
94  iteration_counter_ptr, "ub_iter_counter_val_" + join_loop.name_);
95  auto iteration_val = iteration_counter;
96  CHECK(join_loop.kind_ == JoinLoopKind::Set || !iteration_domain.values_buffer);
97  if (join_loop.kind_ == JoinLoopKind::Set) {
98  iteration_val =
99  builder.CreateGEP(iteration_domain.values_buffer, iteration_counter);
100  }
101  iterators.push_back(iteration_val);
102  const auto have_more_inner_rows = builder.CreateICmpSLT(
103  iteration_counter,
104  join_loop.kind_ == JoinLoopKind::UpperBound ? iteration_domain.upper_bound
105  : iteration_domain.element_count);
106  const auto iter_advance_bb = llvm::BasicBlock::Create(
107  context, "ub_iter_advance_" + join_loop.name_, parent_func);
108  llvm::BasicBlock* row_not_deleted_bb{nullptr};
109  if (join_loop.is_deleted_) {
110  row_not_deleted_bb = llvm::BasicBlock::Create(
111  context, "row_not_deleted_" + join_loop.name_, parent_func);
112  const auto row_is_deleted =
113  join_loop.is_deleted_(iterators, have_more_inner_rows);
114  builder.CreateCondBr(row_is_deleted, iter_advance_bb, row_not_deleted_bb);
115  builder.SetInsertPoint(row_not_deleted_bb);
116  }
117  if (join_loop.type_ == JoinType::LEFT) {
118  std::tie(last_head_bb, prev_comparison_result) =
119  evaluateOuterJoinCondition(join_loop,
120  iteration_domain,
121  iterators,
122  iteration_counter,
123  have_more_inner_rows,
124  found_an_outer_match_ptr,
125  builder);
126  } else {
127  prev_comparison_result = have_more_inner_rows;
128  last_head_bb = row_not_deleted_bb ? row_not_deleted_bb : head_bb;
129  }
130  builder.SetInsertPoint(iter_advance_bb);
131  const auto iteration_counter_next_val =
132  builder.CreateAdd(iteration_counter, ll_int(int64_t(1), context));
133  builder.CreateStore(iteration_counter_next_val, iteration_counter_ptr);
134  if (join_loop.type_ == JoinType::LEFT) {
135  const auto no_more_inner_rows =
136  builder.CreateICmpSGT(iteration_counter_next_val,
137  join_loop.kind_ == JoinLoopKind::UpperBound
138  ? iteration_domain.upper_bound
139  : iteration_domain.element_count);
140  builder.CreateCondBr(no_more_inner_rows, prev_exit_bb, head_bb);
141  } else {
142  builder.CreateBr(head_bb);
143  }
144  builder.SetInsertPoint(last_head_bb);
145  prev_iter_advance_bb = iter_advance_bb;
146  break;
147  }
149  const auto true_bb = llvm::BasicBlock::Create(
150  context, "singleton_true_" + join_loop.name_, parent_func);
151  if (!entry) {
152  entry = true_bb;
153  }
154  if (prev_comparison_result) {
155  builder.CreateCondBr(
156  prev_comparison_result,
157  true_bb,
158  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
159  }
160  prev_exit_bb = prev_iter_advance_bb ? prev_iter_advance_bb : exit_bb;
161  builder.SetInsertPoint(true_bb);
162  const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);
163  CHECK(!iteration_domain.values_buffer);
164  iterators.push_back(iteration_domain.slot_lookup_result);
165  auto match_found = builder.CreateICmpSGE(iteration_domain.slot_lookup_result,
166  ll_int<int64_t>(0, context));
167  if (join_loop.is_deleted_) {
168  match_found = builder.CreateAnd(
169  match_found, builder.CreateNot(join_loop.is_deleted_(iterators, nullptr)));
170  }
171  auto match_found_bb = builder.GetInsertBlock();
172  switch (join_loop.type_) {
173  case JoinType::INNER: {
174  prev_comparison_result = match_found;
175  break;
176  }
177  case JoinType::LEFT: {
178  join_loop.found_outer_matches_(match_found);
179  // For outer joins, do the iteration regardless of the result of the match.
180  prev_comparison_result = ll_bool(true, context);
181  break;
182  }
183  default:
184  CHECK(false);
185  }
186  if (!prev_iter_advance_bb) {
187  prev_iter_advance_bb = prev_exit_bb;
188  }
189  last_head_bb = match_found_bb;
190  break;
191  }
192  default:
193  CHECK(false);
194  }
195  prev_join_type = join_loop.type_;
196  }
197  const auto body_bb = body_codegen(iterators);
198  builder.CreateBr(prev_iter_advance_bb);
199  builder.SetInsertPoint(last_head_bb);
200  builder.CreateCondBr(
201  prev_comparison_result,
202  body_bb,
203  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
204  return entry;
205 }
206 
207 std::pair<llvm::BasicBlock*, llvm::Value*> JoinLoop::evaluateOuterJoinCondition(
208  const JoinLoop& join_loop,
209  const JoinLoopDomain& iteration_domain,
210  const std::vector<llvm::Value*>& iterators,
211  llvm::Value* iteration_counter,
212  llvm::Value* have_more_inner_rows,
213  llvm::Value* found_an_outer_match_ptr,
214  llvm::IRBuilder<>& builder) {
215  auto& context = builder.getContext();
216  const auto parent_func = builder.GetInsertBlock()->getParent();
217  const auto current_condition_match_ptr = builder.CreateAlloca(
218  get_int_type(1, context), nullptr, "outer_condition_current_match");
219  builder.CreateStore(ll_bool(false, context), current_condition_match_ptr);
220  const auto evaluate_outer_condition_bb = llvm::BasicBlock::Create(
221  context, "eval_outer_cond_" + join_loop.name_, parent_func);
222  const auto after_evaluate_outer_condition_bb = llvm::BasicBlock::Create(
223  context, "after_eval_outer_cond_" + join_loop.name_, parent_func);
224  builder.CreateCondBr(have_more_inner_rows,
225  evaluate_outer_condition_bb,
226  after_evaluate_outer_condition_bb);
227  builder.SetInsertPoint(evaluate_outer_condition_bb);
228  const auto current_condition_match = join_loop.outer_condition_match_
229  ? join_loop.outer_condition_match_(iterators)
230  : ll_bool(true, context);
231  builder.CreateStore(current_condition_match, current_condition_match_ptr);
232  const auto updated_condition_match = builder.CreateOr(
233  current_condition_match, builder.CreateLoad(found_an_outer_match_ptr));
234  builder.CreateStore(updated_condition_match, found_an_outer_match_ptr);
235  builder.CreateBr(after_evaluate_outer_condition_bb);
236  builder.SetInsertPoint(after_evaluate_outer_condition_bb);
237  const auto no_matches_found =
238  builder.CreateNot(builder.CreateLoad(found_an_outer_match_ptr));
239  const auto no_more_inner_rows = builder.CreateICmpEQ(
240  iteration_counter,
241  join_loop.kind_ == JoinLoopKind::UpperBound ? iteration_domain.upper_bound
242  : iteration_domain.element_count);
243  // Do the iteration if the outer condition is true or it's the last iteration and no
244  // matches have been found.
245  const auto do_iteration =
246  builder.CreateOr(builder.CreateLoad(current_condition_match_ptr),
247  builder.CreateAnd(no_matches_found, no_more_inner_rows));
248  join_loop.found_outer_matches_(builder.CreateLoad(current_condition_match_ptr));
249  return {after_evaluate_outer_condition_bb, do_iteration};
250 }
#define CHECK_EQ(x, y)
Definition: Logger.h:198
JoinType
Definition: sqldefs.h:98
llvm::Value * element_count
Definition: JoinLoop.h:44
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
const std::function< llvm::Value *(const std::vector< llvm::Value * > &)> outer_condition_match_
Definition: JoinLoop.h:92
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
const JoinLoopKind kind_
Definition: JoinLoop.h:82
CHECK(cgen_state)
const std::function< void(llvm::Value *)> found_outer_matches_
Definition: JoinLoop.h:95
const std::string name_
Definition: JoinLoop.h:103
llvm::Value * upper_bound
Definition: JoinLoop.h:43
static std::pair< llvm::BasicBlock *, llvm::Value * > evaluateOuterJoinCondition(const JoinLoop &join_loop, const JoinLoopDomain &iteration_domain, const std::vector< llvm::Value * > &iterators, llvm::Value *iteration_counter, llvm::Value *have_more_inner_rows, llvm::Value *found_an_outer_match_ptr, llvm::IRBuilder<> &builder)
Definition: JoinLoop.cpp:207
static llvm::BasicBlock * codegen(const std::vector< JoinLoop > &join_loops, const std::function< llvm::BasicBlock *(const std::vector< llvm::Value * > &)> &body_codegen, llvm::Value *outer_iter, llvm::BasicBlock *exit_bb, llvm::IRBuilder<> &builder)
Definition: JoinLoop.cpp:45
JoinLoop(const JoinLoopKind, const JoinType, const std::function< JoinLoopDomain(const std::vector< llvm::Value * > &)> &, const std::function< llvm::Value *(const std::vector< llvm::Value * > &)> &, const std::function< void(llvm::Value *)> &, const std::function< llvm::Value *(const std::vector< llvm::Value * > &prev_iters, llvm::Value *)> &, const std::string &name="")
Definition: JoinLoop.cpp:24
llvm::ConstantInt * ll_bool(const bool v, llvm::LLVMContext &context)
JoinLoopKind
Definition: JoinLoop.h:30