OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
JoinLoop.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "JoinLoop.h"
18 #include "../CgenState.h"
19 #include "Logger/Logger.h"
20 
21 #include <llvm/IR/Type.h>
22 
23 #include <stack>
24 
26  const JoinType type,
27  const std::function<JoinLoopDomain(const std::vector<llvm::Value*>&)>&
28  iteration_domain_codegen,
29  const std::function<llvm::Value*(const std::vector<llvm::Value*>&)>&
30  outer_condition_match,
31  const std::function<void(llvm::Value*)>& found_outer_matches,
32  const HoistedFiltersCallback& hoisted_filters,
33  const std::function<llvm::Value*(const std::vector<llvm::Value*>&,
34  llvm::Value*)>& is_deleted,
35  const bool nested_loop_join,
36  const std::string& name)
37  : kind_(kind)
38  , type_(type)
39  , iteration_domain_codegen_(iteration_domain_codegen)
40  , outer_condition_match_(outer_condition_match)
41  , found_outer_matches_(found_outer_matches)
42  , hoisted_filters_(hoisted_filters)
43  , is_deleted_(is_deleted)
44  , nested_loop_join_(nested_loop_join)
45  , name_(name) {
46  CHECK(outer_condition_match == nullptr || type == JoinType::LEFT);
47  CHECK_EQ(static_cast<bool>(found_outer_matches), (type == JoinType::LEFT));
48 }
49 
50 llvm::BasicBlock* JoinLoop::codegen(
51  const std::vector<JoinLoop>& join_loops,
52  const std::function<llvm::BasicBlock*(const std::vector<llvm::Value*>&)>&
53  body_codegen,
54  llvm::Value* outer_iter,
55  llvm::BasicBlock* exit_bb,
56  CgenState* cgen_state) {
57  AUTOMATIC_IR_METADATA(cgen_state);
58  llvm::IRBuilder<>& builder = cgen_state->ir_builder_;
59  llvm::BasicBlock* prev_exit_bb{exit_bb};
60  llvm::BasicBlock* prev_iter_advance_bb{nullptr};
61  llvm::BasicBlock* last_head_bb{nullptr};
62  auto& context = builder.getContext();
63  const auto parent_func = builder.GetInsertBlock()->getParent();
64  llvm::Value* prev_comparison_result{nullptr};
65  llvm::BasicBlock* entry{nullptr};
66  std::vector<llvm::Value*> iterators;
67  iterators.push_back(outer_iter);
68  JoinType prev_join_type{JoinType::INVALID};
69  for (const auto& join_loop : join_loops) {
70  switch (join_loop.kind_) {
72  case JoinLoopKind::Set:
74  const auto preheader_bb = llvm::BasicBlock::Create(
75  context, "ub_iter_preheader_" + join_loop.name_, parent_func);
76 
77  llvm::BasicBlock* filter_bb{nullptr};
78  if (join_loop.hoisted_filters_) {
79  filter_bb = join_loop.hoisted_filters_(
80  preheader_bb, prev_exit_bb, join_loop.name_, parent_func, cgen_state);
81  }
82 
83  if (!entry) {
84  entry = filter_bb ? filter_bb : preheader_bb;
85  }
86 
87  if (prev_comparison_result) {
88  builder.CreateCondBr(
89  prev_comparison_result,
90  filter_bb ? filter_bb : preheader_bb,
91  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
92  }
93  prev_exit_bb = prev_iter_advance_bb ? prev_iter_advance_bb : exit_bb;
94  builder.SetInsertPoint(preheader_bb);
95 
96  const auto iteration_counter_ptr = builder.CreateAlloca(
97  get_int_type(64, context), nullptr, "ub_iter_counter_ptr_" + join_loop.name_);
98  llvm::Value* found_an_outer_match_ptr{nullptr};
99  llvm::Value* current_condition_match_ptr{nullptr};
100  if (join_loop.type_ == JoinType::LEFT) {
101  found_an_outer_match_ptr = builder.CreateAlloca(
102  get_int_type(1, context), nullptr, "found_an_outer_match");
103  builder.CreateStore(ll_bool(false, context), found_an_outer_match_ptr);
104  current_condition_match_ptr = builder.CreateAlloca(
105  get_int_type(1, context), nullptr, "outer_condition_current_match");
106  }
107  builder.CreateStore(ll_int(int64_t(0), context), iteration_counter_ptr);
108  const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);
109  const auto head_bb = llvm::BasicBlock::Create(
110  context, "ub_iter_head_" + join_loop.name_, parent_func);
111  builder.CreateBr(head_bb);
112  builder.SetInsertPoint(head_bb);
113  llvm::Value* iteration_counter =
114  builder.CreateLoad(iteration_counter_ptr->getType()->getPointerElementType(),
115  iteration_counter_ptr,
116  "ub_iter_counter_val_" + join_loop.name_);
117  auto iteration_val = iteration_counter;
118  CHECK(join_loop.kind_ == JoinLoopKind::Set ||
119  join_loop.kind_ == JoinLoopKind::MultiSet ||
120  !iteration_domain.values_buffer);
121  if (join_loop.kind_ == JoinLoopKind::Set ||
122  join_loop.kind_ == JoinLoopKind::MultiSet) {
123  CHECK(iteration_domain.values_buffer->getType()->isPointerTy());
124  const auto ptr_type =
125  static_cast<llvm::PointerType*>(iteration_domain.values_buffer->getType());
126  if (ptr_type->getPointerElementType()->isArrayTy()) {
127  iteration_val = builder.CreateGEP(
128  iteration_domain.values_buffer->getType()
129  ->getScalarType()
130  ->getPointerElementType(),
131  iteration_domain.values_buffer,
132  std::vector<llvm::Value*>{
133  llvm::ConstantInt::get(get_int_type(64, context), 0),
134  iteration_counter},
135  "ub_iter_counter_" + join_loop.name_);
136  } else {
137  iteration_val = builder.CreateGEP(iteration_domain.values_buffer->getType()
138  ->getScalarType()
139  ->getPointerElementType(),
140  iteration_domain.values_buffer,
141  iteration_counter,
142  "ub_iter_counter_" + join_loop.name_);
143  }
144  }
145  iterators.push_back(iteration_val);
146  const auto have_more_inner_rows = builder.CreateICmpSLT(
147  iteration_counter,
148  join_loop.kind_ == JoinLoopKind::UpperBound ? iteration_domain.upper_bound
149  : iteration_domain.element_count,
150  "have_more_inner_rows");
151  const auto iter_advance_bb = llvm::BasicBlock::Create(
152  context, "ub_iter_advance_" + join_loop.name_, parent_func);
153  llvm::BasicBlock* row_not_deleted_bb{nullptr};
154  if (join_loop.is_deleted_) {
155  row_not_deleted_bb = llvm::BasicBlock::Create(
156  context, "row_not_deleted_" + join_loop.name_, parent_func);
157  const auto row_is_deleted =
158  join_loop.is_deleted_(iterators, have_more_inner_rows);
159  builder.CreateCondBr(row_is_deleted, iter_advance_bb, row_not_deleted_bb);
160  builder.SetInsertPoint(row_not_deleted_bb);
161  }
162  if (join_loop.type_ == JoinType::LEFT) {
163  std::tie(last_head_bb, prev_comparison_result) =
164  evaluateOuterJoinCondition(join_loop,
165  iteration_domain,
166  iterators,
167  iteration_counter,
168  have_more_inner_rows,
169  found_an_outer_match_ptr,
170  current_condition_match_ptr,
171  cgen_state);
172  } else {
173  prev_comparison_result = have_more_inner_rows;
174  last_head_bb = row_not_deleted_bb ? row_not_deleted_bb : head_bb;
175  }
176  builder.SetInsertPoint(iter_advance_bb);
177  const auto iteration_counter_next_val =
178  builder.CreateAdd(iteration_counter, ll_int(int64_t(1), context));
179  builder.CreateStore(iteration_counter_next_val, iteration_counter_ptr);
180  if (join_loop.type_ == JoinType::LEFT) {
181  const auto no_more_inner_rows =
182  builder.CreateICmpSGT(iteration_counter_next_val,
183  join_loop.kind_ == JoinLoopKind::UpperBound
184  ? iteration_domain.upper_bound
185  : iteration_domain.element_count,
186  "no_more_inner_rows");
187  builder.CreateCondBr(no_more_inner_rows, prev_exit_bb, head_bb);
188  } else {
189  builder.CreateBr(head_bb);
190  }
191  builder.SetInsertPoint(last_head_bb);
192  prev_iter_advance_bb = iter_advance_bb;
193  break;
194  }
196  const auto true_bb = llvm::BasicBlock::Create(
197  context, "singleton_true_" + join_loop.name_, parent_func);
198 
199  llvm::BasicBlock* filter_bb{nullptr};
200  if (join_loop.hoisted_filters_) {
201  filter_bb = join_loop.hoisted_filters_(
202  true_bb, prev_exit_bb, join_loop.name_, parent_func, cgen_state);
203  }
204 
205  if (!entry) {
206  entry = filter_bb ? filter_bb : true_bb;
207  }
208 
209  if (prev_comparison_result) {
210  builder.CreateCondBr(
211  prev_comparison_result,
212  filter_bb ? filter_bb : true_bb,
213  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
214  }
215  prev_exit_bb = prev_iter_advance_bb ? prev_iter_advance_bb : exit_bb;
216 
217  builder.SetInsertPoint(true_bb);
218  const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);
219  CHECK(!iteration_domain.values_buffer);
220  iterators.push_back(iteration_domain.slot_lookup_result);
221  auto join_cond_match = builder.CreateICmpSGE(iteration_domain.slot_lookup_result,
222  ll_int<int64_t>(0, context));
223  llvm::Value* remaining_cond_match = builder.CreateAlloca(
224  get_int_type(1, context), nullptr, "remaining_outer_cond_match");
225  builder.CreateStore(ll_bool(true, context), remaining_cond_match);
226 
227  if (join_loop.type_ == JoinType::LEFT && join_loop.outer_condition_match_) {
228  const auto parent_func = builder.GetInsertBlock()->getParent();
229  const auto evaluate_remaining_outer_cond_bb = llvm::BasicBlock::Create(
230  context, "eval_remaining_outer_cond_" + join_loop.name_, parent_func);
231  const auto after_evaluate_outer_cond_bb = llvm::BasicBlock::Create(
232  context, "after_eval_outer_cond_" + join_loop.name_, parent_func);
233  builder.CreateCondBr(join_cond_match,
234  evaluate_remaining_outer_cond_bb,
235  after_evaluate_outer_cond_bb);
236  builder.SetInsertPoint(evaluate_remaining_outer_cond_bb);
237  const auto outer_cond_match = join_loop.outer_condition_match_(iterators);
238  const auto true_left_cond_match =
239  builder.CreateAnd(outer_cond_match, join_cond_match);
240  builder.CreateStore(true_left_cond_match, remaining_cond_match);
241  builder.CreateBr(after_evaluate_outer_cond_bb);
242  builder.SetInsertPoint(after_evaluate_outer_cond_bb);
243  }
244  auto match_found = builder.CreateAnd(
245  join_cond_match,
246  builder.CreateLoad(remaining_cond_match->getType()->getPointerElementType(),
247  remaining_cond_match));
248  CHECK(match_found);
249  if (join_loop.is_deleted_) {
250  match_found = builder.CreateAnd(
251  match_found, builder.CreateNot(join_loop.is_deleted_(iterators, nullptr)));
252  }
253  auto match_found_bb = builder.GetInsertBlock();
254  switch (join_loop.type_) {
255  case JoinType::INNER:
256  case JoinType::SEMI: {
257  prev_comparison_result = match_found;
258  break;
259  }
260  case JoinType::ANTI: {
261  auto match_found_for_anti_join = builder.CreateICmpSLT(
262  iteration_domain.slot_lookup_result, ll_int<int64_t>(0, context));
263  prev_comparison_result = match_found_for_anti_join;
264  break;
265  }
266  case JoinType::LEFT: {
267  join_loop.found_outer_matches_(match_found);
268  // For outer joins, do the iteration regardless of the result of the match.
269  prev_comparison_result = ll_bool(true, context);
270  break;
271  }
272  default:
273  CHECK(false);
274  }
275  if (!prev_iter_advance_bb) {
276  prev_iter_advance_bb = prev_exit_bb;
277  }
278  last_head_bb = match_found_bb;
279  break;
280  }
281  default:
282  CHECK(false);
283  }
284  prev_join_type = join_loop.type_;
285  }
286 
287  const auto body_bb = body_codegen(iterators);
288  builder.CreateBr(prev_iter_advance_bb);
289  builder.SetInsertPoint(last_head_bb);
290  builder.CreateCondBr(
291  prev_comparison_result,
292  body_bb,
293  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
294  return entry;
295 }
296 
297 std::pair<llvm::BasicBlock*, llvm::Value*> JoinLoop::evaluateOuterJoinCondition(
298  const JoinLoop& join_loop,
299  const JoinLoopDomain& iteration_domain,
300  const std::vector<llvm::Value*>& iterators,
301  llvm::Value* iteration_counter,
302  llvm::Value* have_more_inner_rows,
303  llvm::Value* found_an_outer_match_ptr,
304  llvm::Value* current_condition_match_ptr,
305  CgenState* cgen_state) {
306  AUTOMATIC_IR_METADATA(cgen_state);
307  llvm::IRBuilder<>& builder = cgen_state->ir_builder_;
308  auto& context = builder.getContext();
309  const auto parent_func = builder.GetInsertBlock()->getParent();
310  builder.CreateStore(ll_bool(false, context), current_condition_match_ptr);
311  const auto evaluate_outer_condition_bb = llvm::BasicBlock::Create(
312  context, "eval_outer_cond_" + join_loop.name_, parent_func);
313  const auto after_evaluate_outer_condition_bb = llvm::BasicBlock::Create(
314  context, "after_eval_outer_cond_" + join_loop.name_, parent_func);
315  builder.CreateCondBr(have_more_inner_rows,
316  evaluate_outer_condition_bb,
317  after_evaluate_outer_condition_bb);
318  builder.SetInsertPoint(evaluate_outer_condition_bb);
319  const auto current_condition_match = join_loop.outer_condition_match_
320  ? join_loop.outer_condition_match_(iterators)
321  : ll_bool(true, context);
322  builder.CreateStore(current_condition_match, current_condition_match_ptr);
323  const auto updated_condition_match = builder.CreateOr(
324  current_condition_match,
325  builder.CreateLoad(found_an_outer_match_ptr->getType()->getPointerElementType(),
326  found_an_outer_match_ptr));
327  builder.CreateStore(updated_condition_match, found_an_outer_match_ptr);
328  builder.CreateBr(after_evaluate_outer_condition_bb);
329  builder.SetInsertPoint(after_evaluate_outer_condition_bb);
330  const auto no_matches_found = builder.CreateNot(
331  builder.CreateLoad(found_an_outer_match_ptr->getType()->getPointerElementType(),
332  found_an_outer_match_ptr));
333  const auto no_more_inner_rows = builder.CreateICmpEQ(
334  iteration_counter,
335  join_loop.kind_ == JoinLoopKind::UpperBound ? iteration_domain.upper_bound
336  : iteration_domain.element_count);
337  // Do the iteration if the outer condition is true or it's the last iteration and no
338  // matches have been found.
339  const auto do_iteration = builder.CreateOr(
340  builder.CreateLoad(current_condition_match_ptr->getType()->getPointerElementType(),
341  current_condition_match_ptr),
342  builder.CreateAnd(no_matches_found, no_more_inner_rows));
343  join_loop.found_outer_matches_(
344  builder.CreateLoad(current_condition_match_ptr->getType()->getPointerElementType(),
345  current_condition_match_ptr));
346  return {after_evaluate_outer_condition_bb, do_iteration};
347 }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
JoinType
Definition: sqldefs.h:174
llvm::Value * element_count
Definition: JoinLoop.h:46
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
llvm::IRBuilder ir_builder_
Definition: CgenState.h:384
std::function< llvm::BasicBlock *(llvm::BasicBlock *, llvm::BasicBlock *, const std::string &, llvm::Function *, CgenState *)> HoistedFiltersCallback
Definition: JoinLoop.h:61
const std::function< llvm::Value *(const std::vector< llvm::Value * > &)> outer_condition_match_
Definition: JoinLoop.h:109
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
const JoinLoopKind kind_
Definition: JoinLoop.h:99
const std::function< void(llvm::Value *)> found_outer_matches_
Definition: JoinLoop.h:112
const std::string name_
Definition: JoinLoop.h:125
static llvm::BasicBlock * codegen(const std::vector< JoinLoop > &join_loops, const std::function< llvm::BasicBlock *(const std::vector< llvm::Value * > &)> &body_codegen, llvm::Value *outer_iter, llvm::BasicBlock *exit_bb, CgenState *cgen_state)
Definition: JoinLoop.cpp:50
#define AUTOMATIC_IR_METADATA(CGENSTATE)
llvm::Value * upper_bound
Definition: JoinLoop.h:45
JoinLoop(const JoinLoopKind, const JoinType, const std::function< JoinLoopDomain(const std::vector< llvm::Value * > &)> &iteration_domain_codegen, const std::function< llvm::Value *(const std::vector< llvm::Value * > &)> &outer_condition_match, const std::function< void(llvm::Value *)> &found_outer_matches, const HoistedFiltersCallback &hoisted_filters, const std::function< llvm::Value *(const std::vector< llvm::Value * > &prev_iters, llvm::Value *)> &is_deleted, const bool nested_loop_join=false, const std::string &name="")
Definition: JoinLoop.cpp:25
#define CHECK(condition)
Definition: Logger.h:291
llvm::ConstantInt * ll_bool(const bool v, llvm::LLVMContext &context)
JoinLoopKind
Definition: JoinLoop.h:31
string name
Definition: setup.in.py:72
static std::pair< llvm::BasicBlock *, llvm::Value * > evaluateOuterJoinCondition(const JoinLoop &join_loop, const JoinLoopDomain &iteration_domain, const std::vector< llvm::Value * > &iterators, llvm::Value *iteration_counter, llvm::Value *have_more_inner_rows, llvm::Value *found_an_outer_match_ptr, llvm::Value *current_condition_match_ptr, CgenState *cgen_state)
Definition: JoinLoop.cpp:297