OmniSciDB  c1a53651b2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
QueryTemplateGenerator.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "QueryTemplateGenerator.h"
18 #include "IRCodegenUtils.h"
19 #include "Logger/Logger.h"
20 
21 #include <llvm/IR/Constants.h>
22 #include <llvm/IR/IRBuilder.h>
23 #include <llvm/IR/Instructions.h>
24 #include <llvm/IR/Verifier.h>
25 
26 // This file was pretty much auto-generated by running:
27 // llc -march=cpp RuntimeFunctions.ll
28 // and formatting the results to be more readable.
29 
30 namespace {
31 
32 inline llvm::Type* get_pointer_element_type(llvm::Value* value) {
33  CHECK(value);
34  auto type = value->getType();
35  CHECK(type && type->isPointerTy());
36  auto pointer_type = llvm::dyn_cast<llvm::PointerType>(type);
37  CHECK(pointer_type);
38  return pointer_type->getPointerElementType();
39 }
40 
41 template <class Attributes>
42 llvm::Function* default_func_builder(llvm::Module* mod, const std::string& name) {
43  using namespace llvm;
44 
45  std::vector<Type*> func_args;
46  FunctionType* func_type = FunctionType::get(
47  /*Result=*/IntegerType::get(mod->getContext(), 32),
48  /*Params=*/func_args,
49  /*isVarArg=*/false);
50 
51  auto func_ptr = mod->getFunction(name);
52  if (!func_ptr) {
53  func_ptr = Function::Create(
54  /*Type=*/func_type,
55  /*Linkage=*/GlobalValue::ExternalLinkage,
56  /*Name=*/name,
57  mod); // (external, no body)
58  func_ptr->setCallingConv(CallingConv::C);
59  }
60 
61  Attributes func_pal;
62  {
63  SmallVector<Attributes, 4> Attrs;
64  Attributes PAS;
65  {
66 #if 14 <= LLVM_VERSION_MAJOR
67  AttrBuilder B(mod->getContext());
68 #else
69  AttrBuilder B;
70 #endif
71  PAS = Attributes::get(mod->getContext(), ~0U, B);
72  }
73 
74  Attrs.push_back(PAS);
75  func_pal = Attributes::get(mod->getContext(), Attrs);
76  }
77  func_ptr->setAttributes(func_pal);
78 
79  return func_ptr;
80 }
81 
82 template <class Attributes>
83 llvm::Function* pos_start(llvm::Module* mod) {
84  return default_func_builder<Attributes>(mod, "pos_start");
85 }
86 
87 template <class Attributes>
88 llvm::Function* group_buff_idx(llvm::Module* mod) {
89  return default_func_builder<Attributes>(mod, "group_buff_idx");
90 }
91 
92 template <class Attributes>
93 llvm::Function* pos_step(llvm::Module* mod) {
94  using namespace llvm;
95 
96  std::vector<Type*> func_args;
97  FunctionType* func_type = FunctionType::get(
98  /*Result=*/IntegerType::get(mod->getContext(), 32),
99  /*Params=*/func_args,
100  /*isVarArg=*/false);
101 
102  auto func_ptr = mod->getFunction("pos_step");
103  if (!func_ptr) {
104  func_ptr = Function::Create(
105  /*Type=*/func_type,
106  /*Linkage=*/GlobalValue::ExternalLinkage,
107  /*Name=*/"pos_step",
108  mod); // (external, no body)
109  func_ptr->setCallingConv(CallingConv::C);
110  }
111 
112  Attributes func_pal;
113  {
114  SmallVector<Attributes, 4> Attrs;
115  Attributes PAS;
116  {
117 #if 14 <= LLVM_VERSION_MAJOR
118  AttrBuilder B(mod->getContext());
119 #else
120  AttrBuilder B;
121 #endif
122  PAS = Attributes::get(mod->getContext(), ~0U, B);
123  }
124 
125  Attrs.push_back(PAS);
126  func_pal = Attributes::get(mod->getContext(), Attrs);
127  }
128  func_ptr->setAttributes(func_pal);
129 
130  return func_ptr;
131 }
132 
133 template <class Attributes>
134 llvm::Function* row_process(llvm::Module* mod,
135  const size_t aggr_col_count,
136  const bool hoist_literals) {
137  using namespace llvm;
138 
139  std::vector<Type*> func_args;
140  auto i8_type = IntegerType::get(mod->getContext(), 8);
141  auto i32_type = IntegerType::get(mod->getContext(), 32);
142  auto i64_type = IntegerType::get(mod->getContext(), 64);
143  auto pi32_type = PointerType::get(i32_type, 0);
144  auto pi64_type = PointerType::get(i64_type, 0);
145 
146  if (aggr_col_count) {
147  for (size_t i = 0; i < aggr_col_count; ++i) {
148  func_args.push_back(pi64_type);
149  }
150  } else { // group by query
151  func_args.push_back(pi64_type); // groups buffer
152  func_args.push_back(pi64_type); // varlen output buffer
153  func_args.push_back(pi32_type); // 1 iff current row matched, else 0
154  func_args.push_back(pi32_type); // total rows matched from the caller
155  func_args.push_back(pi32_type); // total rows matched before atomic increment
156  func_args.push_back(pi32_type); // max number of slots in the output buffer
157  }
158 
159  func_args.push_back(pi64_type); // aggregate init values
160 
161  func_args.push_back(i64_type); // pos
162  func_args.push_back(pi64_type); // frag_row_off_ptr
163  func_args.push_back(pi64_type); // row_count_ptr
164  if (hoist_literals) {
165  func_args.push_back(PointerType::get(i8_type, 0)); // literals
166  }
167  FunctionType* func_type = FunctionType::get(
168  /*Result=*/i32_type,
169  /*Params=*/func_args,
170  /*isVarArg=*/false);
171 
172  std::string func_name{"row_process"};
173  auto func_ptr = mod->getFunction(func_name);
174 
175  if (!func_ptr) {
176  func_ptr = Function::Create(
177  /*Type=*/func_type,
178  /*Linkage=*/GlobalValue::ExternalLinkage,
179  /*Name=*/func_name,
180  mod); // (external, no body)
181  func_ptr->setCallingConv(CallingConv::C);
182 
183  Attributes func_pal;
184  {
185  SmallVector<Attributes, 4> Attrs;
186  Attributes PAS;
187  {
188 #if 14 <= LLVM_VERSION_MAJOR
189  AttrBuilder B(mod->getContext());
190 #else
191  AttrBuilder B;
192 #endif
193  PAS = Attributes::get(mod->getContext(), ~0U, B);
194  }
195 
196  Attrs.push_back(PAS);
197  func_pal = Attributes::get(mod->getContext(), Attrs);
198  }
199  func_ptr->setAttributes(func_pal);
200  }
201 
202  return func_ptr;
203 }
204 
205 } // namespace
206 
207 template <class Attributes>
208 std::tuple<llvm::Function*, llvm::CallInst*> query_template_impl(
209  llvm::Module* mod,
210  const size_t aggr_col_count,
211  const bool hoist_literals,
212  const bool is_estimate_query,
213  const GpuSharedMemoryContext& gpu_smem_context) {
214  using namespace llvm;
215 
216  auto func_pos_start = pos_start<Attributes>(mod);
217  CHECK(func_pos_start);
218  auto func_pos_step = pos_step<Attributes>(mod);
219  CHECK(func_pos_step);
220  auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
221  CHECK(func_group_buff_idx);
222  auto func_row_process = row_process<Attributes>(
223  mod, is_estimate_query ? 1 : aggr_col_count, hoist_literals);
224  CHECK(func_row_process);
225 
226  auto i8_type = IntegerType::get(mod->getContext(), 8);
227  auto i32_type = IntegerType::get(mod->getContext(), 32);
228  auto i64_type = IntegerType::get(mod->getContext(), 64);
229  auto pi8_type = PointerType::get(i8_type, 0);
230  auto ppi8_type = PointerType::get(pi8_type, 0);
231  auto pi32_type = PointerType::get(i32_type, 0);
232  auto pi64_type = PointerType::get(i64_type, 0);
233  auto ppi64_type = PointerType::get(pi64_type, 0);
234 
235  std::vector<Type*> query_args;
236  query_args.push_back(ppi8_type); // byte_stream
237  if (hoist_literals) {
238  query_args.push_back(pi8_type); // literals
239  }
240  query_args.push_back(pi64_type); // row_count_ptr
241  query_args.push_back(pi64_type); // frag_row_off_ptr
242  query_args.push_back(pi32_type); // max_matched_ptr
243 
244  query_args.push_back(pi64_type); // agg_init_val
245  query_args.push_back(ppi64_type); // group_by_buffers
246  query_args.push_back(i32_type); // frag_idx
247  query_args.push_back(pi64_type); // join_hash_tables
248  query_args.push_back(pi32_type); // total_matched
249  query_args.push_back(pi32_type); // error_code
250  query_args.push_back(pi8_type); // row_func_mgr
251 
252  FunctionType* query_func_type = FunctionType::get(
253  /*Result=*/Type::getVoidTy(mod->getContext()),
254  /*Params=*/query_args,
255  /*isVarArg=*/false);
256 
257  std::string query_template_name{"query_template"};
258  auto query_func_ptr = mod->getFunction(query_template_name);
259  CHECK(!query_func_ptr);
260 
261  query_func_ptr = Function::Create(
262  /*Type=*/query_func_type,
263  /*Linkage=*/GlobalValue::ExternalLinkage,
264  /*Name=*/query_template_name,
265  mod);
266  query_func_ptr->setCallingConv(CallingConv::C);
267 
268  Attributes query_func_pal;
269  {
270  SmallVector<Attributes, 4> Attrs;
271  Attributes PAS;
272  {
273 #if 14 <= LLVM_VERSION_MAJOR
274  AttrBuilder B(mod->getContext());
275 #else
276  AttrBuilder B;
277 #endif
278  B.addAttribute(Attribute::NoCapture);
279  PAS = Attributes::get(mod->getContext(), 1U, B);
280  }
281 
282  Attrs.push_back(PAS);
283  {
284 #if 14 <= LLVM_VERSION_MAJOR
285  AttrBuilder B(mod->getContext());
286 #else
287  AttrBuilder B;
288 #endif
289  B.addAttribute(Attribute::NoCapture);
290  PAS = Attributes::get(mod->getContext(), 2U, B);
291  }
292 
293  Attrs.push_back(PAS);
294 
295  {
296 #if 14 <= LLVM_VERSION_MAJOR
297  AttrBuilder B(mod->getContext());
298 #else
299  AttrBuilder B;
300 #endif
301  B.addAttribute(Attribute::NoCapture);
302  Attrs.push_back(Attributes::get(mod->getContext(), 3U, B));
303  }
304 
305  {
306 #if 14 <= LLVM_VERSION_MAJOR
307  AttrBuilder B(mod->getContext());
308 #else
309  AttrBuilder B;
310 #endif
311  B.addAttribute(Attribute::NoCapture);
312  Attrs.push_back(Attributes::get(mod->getContext(), 4U, B));
313  }
314 
315  Attrs.push_back(PAS);
316 
317  query_func_pal = Attributes::get(mod->getContext(), Attrs);
318  }
319  query_func_ptr->setAttributes(query_func_pal);
320 
321  Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
322  Value* byte_stream = &*query_arg_it;
323  byte_stream->setName("byte_stream");
324  Value* literals{nullptr};
325  if (hoist_literals) {
326  literals = &*(++query_arg_it);
327  literals->setName("literals");
328  }
329  Value* row_count_ptr = &*(++query_arg_it);
330  row_count_ptr->setName("row_count_ptr");
331  Value* frag_row_off_ptr = &*(++query_arg_it);
332  frag_row_off_ptr->setName("frag_row_off_ptr");
333  Value* max_matched_ptr = &*(++query_arg_it);
334  max_matched_ptr->setName("max_matched_ptr");
335  Value* agg_init_val = &*(++query_arg_it);
336  agg_init_val->setName("agg_init_val");
337  Value* out = &*(++query_arg_it);
338  out->setName("out");
339  Value* frag_idx = &*(++query_arg_it);
340  frag_idx->setName("frag_idx");
341  Value* join_hash_tables = &*(++query_arg_it);
342  join_hash_tables->setName("join_hash_tables");
343  Value* total_matched = &*(++query_arg_it);
344  total_matched->setName("total_matched");
345  Value* error_code = &*(++query_arg_it);
346  error_code->setName("error_code");
347  Value* row_func_mgr = &*(++query_arg_it);
348  row_func_mgr->setName("row_func_mgr");
349 
350  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
351  auto bb_preheader =
352  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
353  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".for.body", query_func_ptr, 0);
354  auto bb_crit_edge =
355  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
356  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
357 
358  // Block (.entry)
359  std::vector<Value*> result_ptr_vec;
360  llvm::CallInst* smem_output_buffer{nullptr};
361  if (!is_estimate_query) {
362  for (size_t i = 0; i < aggr_col_count; ++i) {
363  auto result_ptr = new AllocaInst(i64_type, 0, "result", bb_entry);
364  result_ptr->setAlignment(LLVM_ALIGN(8));
365  result_ptr_vec.push_back(result_ptr);
366  }
367  if (gpu_smem_context.isSharedMemoryUsed()) {
368  auto init_smem_func = mod->getFunction("init_shared_mem");
369  CHECK(init_smem_func);
370  // only one slot per aggregate column is needed, and so we can initialize shared
371  // memory buffer for intermediate results to be exactly like the agg_init_val array
372  smem_output_buffer = CallInst::Create(
373  init_smem_func,
374  std::vector<llvm::Value*>{
375  agg_init_val,
376  llvm::ConstantInt::get(i32_type, aggr_col_count * sizeof(int64_t))},
377  "smem_buffer",
378  bb_entry);
379  }
380  }
381 
382  LoadInst* row_count = new LoadInst(get_pointer_element_type(row_count_ptr),
383  row_count_ptr,
384  "row_count",
385  false,
386  bb_entry);
387  row_count->setAlignment(LLVM_ALIGN(8));
388  row_count->setName("row_count");
389  std::vector<Value*> agg_init_val_vec;
390  if (!is_estimate_query) {
391  for (size_t i = 0; i < aggr_col_count; ++i) {
392  auto idx_lv = ConstantInt::get(i32_type, i);
393  auto agg_init_gep = GetElementPtrInst::CreateInBounds(
394  agg_init_val->getType()->getPointerElementType(),
395  agg_init_val,
396  idx_lv,
397  "",
398  bb_entry);
399  auto agg_init_val = new LoadInst(
400  get_pointer_element_type(agg_init_gep), agg_init_gep, "", false, bb_entry);
401  agg_init_val->setAlignment(LLVM_ALIGN(8));
402  agg_init_val_vec.push_back(agg_init_val);
403  auto init_val_st = new StoreInst(agg_init_val, result_ptr_vec[i], false, bb_entry);
404  init_val_st->setAlignment(LLVM_ALIGN(8));
405  }
406  }
407 
408  CallInst* pos_start = CallInst::Create(func_pos_start, "pos_start", bb_entry);
409  pos_start->setCallingConv(CallingConv::C);
410  pos_start->setTailCall(true);
411  Attributes pos_start_pal;
412  pos_start->setAttributes(pos_start_pal);
413 
414  CallInst* pos_step = CallInst::Create(func_pos_step, "pos_step", bb_entry);
415  pos_step->setCallingConv(CallingConv::C);
416  pos_step->setTailCall(true);
417  Attributes pos_step_pal;
418  pos_step->setAttributes(pos_step_pal);
419 
420  CallInst* group_buff_idx = nullptr;
421  if (!is_estimate_query) {
422  group_buff_idx = CallInst::Create(func_group_buff_idx, "group_buff_idx", bb_entry);
423  group_buff_idx->setCallingConv(CallingConv::C);
424  group_buff_idx->setTailCall(true);
425  Attributes group_buff_idx_pal;
426  group_buff_idx->setAttributes(group_buff_idx_pal);
427  }
428 
429  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
430  ICmpInst* enter_or_not =
431  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
432  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
433 
434  // Block .loop.preheader
435  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
436  BranchInst::Create(bb_forbody, bb_preheader);
437 
438  // Block .forbody
439  Argument* pos_inc_pre = new Argument(i64_type);
440  PHINode* pos = PHINode::Create(i64_type, 2, "pos", bb_forbody);
441  pos->addIncoming(pos_start_i64, bb_preheader);
442  pos->addIncoming(pos_inc_pre, bb_forbody);
443 
444  std::vector<Value*> row_process_params;
445  row_process_params.insert(
446  row_process_params.end(), result_ptr_vec.begin(), result_ptr_vec.end());
447  if (is_estimate_query) {
448  row_process_params.push_back(
449  new LoadInst(get_pointer_element_type(out), out, "", false, bb_forbody));
450  }
451  row_process_params.push_back(agg_init_val);
452  row_process_params.push_back(pos);
453  row_process_params.push_back(frag_row_off_ptr);
454  row_process_params.push_back(row_count_ptr);
455  if (hoist_literals) {
456  CHECK(literals);
457  row_process_params.push_back(literals);
458  }
459  CallInst* row_process =
460  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
461  row_process->setCallingConv(CallingConv::C);
462  row_process->setTailCall(false);
463  Attributes row_process_pal;
464  row_process->setAttributes(row_process_pal);
465 
466  BinaryOperator* pos_inc =
467  BinaryOperator::CreateNSW(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
468  ICmpInst* loop_or_exit =
469  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
470  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
471 
472  // Block ._crit_edge
473  std::vector<Instruction*> result_vec_pre;
474  if (!is_estimate_query) {
475  for (size_t i = 0; i < aggr_col_count; ++i) {
476  auto result = new LoadInst(get_pointer_element_type(result_ptr_vec[i]),
477  result_ptr_vec[i],
478  ".pre.result",
479  false,
480  bb_crit_edge);
481  result->setAlignment(LLVM_ALIGN(8));
482  result_vec_pre.push_back(result);
483  }
484  }
485 
486  BranchInst::Create(bb_exit, bb_crit_edge);
487 
488  // Block .exit
500  if (!is_estimate_query) {
501  std::vector<PHINode*> result_vec;
502  for (int64_t i = aggr_col_count - 1; i >= 0; --i) {
503  auto result =
504  PHINode::Create(IntegerType::get(mod->getContext(), 64), 2, "", bb_exit);
505  result->addIncoming(result_vec_pre[i], bb_crit_edge);
506  result->addIncoming(agg_init_val_vec[i], bb_entry);
507  result_vec.insert(result_vec.begin(), result);
508  }
509 
510  for (size_t i = 0; i < aggr_col_count; ++i) {
511  auto col_idx = ConstantInt::get(i32_type, i);
512  if (gpu_smem_context.isSharedMemoryUsed()) {
513  auto target_addr = GetElementPtrInst::CreateInBounds(
514  smem_output_buffer->getType()->getPointerElementType(),
515  smem_output_buffer,
516  col_idx,
517  "",
518  bb_exit);
519  // TODO: generalize this once we want to support other types of aggregate
520  // functions besides COUNT.
521  auto agg_func = mod->getFunction("agg_sum_shared");
522  CHECK(agg_func);
523  CallInst::Create(
524  agg_func, std::vector<llvm::Value*>{target_addr, result_vec[i]}, "", bb_exit);
525  } else {
526  auto out_gep = GetElementPtrInst::CreateInBounds(
527  out->getType()->getPointerElementType(), out, col_idx, "", bb_exit);
528  auto col_buffer =
529  new LoadInst(get_pointer_element_type(out_gep), out_gep, "", false, bb_exit);
530  col_buffer->setAlignment(LLVM_ALIGN(8));
531  auto slot_idx = BinaryOperator::CreateAdd(
533  BinaryOperator::CreateMul(frag_idx, pos_step, "", bb_exit),
534  "",
535  bb_exit);
536  auto target_addr = GetElementPtrInst::CreateInBounds(
537  col_buffer->getType()->getPointerElementType(),
538  col_buffer,
539  slot_idx,
540  "",
541  bb_exit);
542  StoreInst* result_st = new StoreInst(result_vec[i], target_addr, false, bb_exit);
543  result_st->setAlignment(LLVM_ALIGN(8));
544  }
545  }
546  if (gpu_smem_context.isSharedMemoryUsed()) {
547  // final reduction of results from shared memory buffer back into global memory.
548  auto sync_thread_func = mod->getFunction("sync_threadblock");
549  CHECK(sync_thread_func);
550  CallInst::Create(sync_thread_func, std::vector<llvm::Value*>{}, "", bb_exit);
551  auto reduce_smem_to_gmem_func = mod->getFunction("write_back_non_grouped_agg");
552  CHECK(reduce_smem_to_gmem_func);
553  // each thread reduce the aggregate target corresponding to its own thread ID.
554  // If there are more targets than threads we do not currently use shared memory
555  // optimization. This can be relaxed if necessary
556  for (size_t i = 0; i < aggr_col_count; i++) {
557  auto out_gep =
558  GetElementPtrInst::CreateInBounds(out->getType()->getPointerElementType(),
559  out,
560  ConstantInt::get(i32_type, i),
561  "",
562  bb_exit);
563  auto gmem_output_buffer = new LoadInst(get_pointer_element_type(out_gep),
564  out_gep,
565  "gmem_output_buffer_" + std::to_string(i),
566  false,
567  bb_exit);
568  CallInst::Create(
569  reduce_smem_to_gmem_func,
570  std::vector<llvm::Value*>{
571  smem_output_buffer, gmem_output_buffer, ConstantInt::get(i32_type, i)},
572  "",
573  bb_exit);
574  }
575  }
576  }
577 
578  ReturnInst::Create(mod->getContext(), bb_exit);
579 
580  // Resolve Forward References
581  pos_inc_pre->replaceAllUsesWith(pos_inc);
582  delete pos_inc_pre;
583 
584  if (verifyFunction(*query_func_ptr)) {
585  LOG(FATAL) << "Generated invalid code. ";
586  }
587 
588  return {query_func_ptr, row_process};
589 }
590 
591 template <class Attributes>
592 std::tuple<llvm::Function*, llvm::CallInst*> query_group_by_template_impl(
593  llvm::Module* mod,
594  const bool hoist_literals,
596  const ExecutorDeviceType device_type,
597  const bool check_scan_limit,
598  const GpuSharedMemoryContext& gpu_smem_context) {
599  if (gpu_smem_context.isSharedMemoryUsed()) {
600  CHECK(device_type == ExecutorDeviceType::GPU);
601  }
602  using namespace llvm;
603 
604  auto func_pos_start = pos_start<Attributes>(mod);
605  CHECK(func_pos_start);
606  auto func_pos_step = pos_step<Attributes>(mod);
607  CHECK(func_pos_step);
608  auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
609  CHECK(func_group_buff_idx);
610  auto func_row_process = row_process<Attributes>(mod, 0, hoist_literals);
611  CHECK(func_row_process);
612  auto func_init_shared_mem = gpu_smem_context.isSharedMemoryUsed()
613  ? mod->getFunction("init_shared_mem")
614  : mod->getFunction("init_shared_mem_nop");
615  CHECK(func_init_shared_mem);
616 
617  auto func_write_back = mod->getFunction("write_back_nop");
618  CHECK(func_write_back);
619 
620  auto i32_type = IntegerType::get(mod->getContext(), 32);
621  auto i64_type = IntegerType::get(mod->getContext(), 64);
622  auto pi8_type = PointerType::get(IntegerType::get(mod->getContext(), 8), 0);
623  auto pi32_type = PointerType::get(i32_type, 0);
624  auto pi64_type = PointerType::get(i64_type, 0);
625  auto ppi64_type = PointerType::get(pi64_type, 0);
626  auto ppi8_type = PointerType::get(pi8_type, 0);
627 
628  std::vector<Type*> query_args;
629  query_args.push_back(ppi8_type); // col_buffers
630  if (hoist_literals) {
631  query_args.push_back(pi8_type); // literals
632  }
633  query_args.push_back(pi64_type); // num_rows
634  query_args.push_back(pi64_type); // frag_row_offsets
635  query_args.push_back(pi32_type); // max_matched
636  query_args.push_back(pi64_type); // init_agg_value
637 
638  query_args.push_back(ppi64_type); // out
639  query_args.push_back(i32_type); // frag_idx
640  query_args.push_back(pi64_type); // join_hash_tables
641  query_args.push_back(pi32_type); // total_matched
642  query_args.push_back(pi32_type); // error_code
643  query_args.push_back(pi8_type); // row_func_mgr
644 
645  FunctionType* query_func_type = FunctionType::get(
646  /*Result=*/Type::getVoidTy(mod->getContext()),
647  /*Params=*/query_args,
648  /*isVarArg=*/false);
649 
650  std::string query_name{"query_group_by_template"};
651  auto query_func_ptr = mod->getFunction(query_name);
652  CHECK(!query_func_ptr);
653 
654  query_func_ptr = Function::Create(
655  /*Type=*/query_func_type,
656  /*Linkage=*/GlobalValue::ExternalLinkage,
657  /*Name=*/"query_group_by_template",
658  mod);
659 
660  query_func_ptr->setCallingConv(CallingConv::C);
661 
662  Attributes query_func_pal;
663  {
664  SmallVector<Attributes, 4> Attrs;
665  Attributes PAS;
666  {
667 #if 14 <= LLVM_VERSION_MAJOR
668  AttrBuilder B(mod->getContext());
669 #else
670  AttrBuilder B;
671 #endif
672  B.addAttribute(Attribute::ReadNone);
673  B.addAttribute(Attribute::NoCapture);
674  PAS = Attributes::get(mod->getContext(), 1U, B);
675  }
676 
677  Attrs.push_back(PAS);
678  {
679 #if 14 <= LLVM_VERSION_MAJOR
680  AttrBuilder B(mod->getContext());
681 #else
682  AttrBuilder B;
683 #endif
684  B.addAttribute(Attribute::ReadOnly);
685  B.addAttribute(Attribute::NoCapture);
686  PAS = Attributes::get(mod->getContext(), 2U, B);
687  }
688 
689  Attrs.push_back(PAS);
690  {
691 #if 14 <= LLVM_VERSION_MAJOR
692  AttrBuilder B(mod->getContext());
693 #else
694  AttrBuilder B;
695 #endif
696  B.addAttribute(Attribute::ReadNone);
697  B.addAttribute(Attribute::NoCapture);
698  PAS = Attributes::get(mod->getContext(), 3U, B);
699  }
700 
701  Attrs.push_back(PAS);
702  {
703 #if 14 <= LLVM_VERSION_MAJOR
704  AttrBuilder B(mod->getContext());
705 #else
706  AttrBuilder B;
707 #endif
708  B.addAttribute(Attribute::ReadOnly);
709  B.addAttribute(Attribute::NoCapture);
710  PAS = Attributes::get(mod->getContext(), 4U, B);
711  }
712 
713  Attrs.push_back(PAS);
714  {
715 #if 14 <= LLVM_VERSION_MAJOR
716  AttrBuilder B(mod->getContext());
717 #else
718  AttrBuilder B;
719 #endif
720  B.addAttribute(Attribute::UWTable);
721  PAS = Attributes::get(mod->getContext(), ~0U, B);
722  }
723 
724  Attrs.push_back(PAS);
725 
726  query_func_pal = Attributes::get(mod->getContext(), Attrs);
727  }
728  query_func_ptr->setAttributes(query_func_pal);
729 
730  Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
731  Value* byte_stream = &*query_arg_it;
732  byte_stream->setName("byte_stream");
733  Value* literals{nullptr};
734  if (hoist_literals) {
735  literals = &*(++query_arg_it);
736  ;
737  literals->setName("literals");
738  }
739  Value* row_count_ptr = &*(++query_arg_it);
740  row_count_ptr->setName("row_count_ptr");
741  Value* frag_row_off_ptr = &*(++query_arg_it);
742  frag_row_off_ptr->setName("frag_row_off_ptr");
743  Value* max_matched_ptr = &*(++query_arg_it);
744  max_matched_ptr->setName("max_matched_ptr");
745  Value* agg_init_val = &*(++query_arg_it);
746  agg_init_val->setName("agg_init_val");
747  Value* group_by_buffers = &*(++query_arg_it);
748  group_by_buffers->setName("group_by_buffers");
749  Value* frag_idx = &*(++query_arg_it);
750  frag_idx->setName("frag_idx");
751  Value* join_hash_tables = &*(++query_arg_it);
752  join_hash_tables->setName("join_hash_tables");
753  Value* total_matched = &*(++query_arg_it);
754  total_matched->setName("total_matched");
755  Value* error_code = &*(++query_arg_it);
756  error_code->setName("error_code");
757  Value* row_func_mgr = &*(++query_arg_it);
758  row_func_mgr->setName("row_func_mgr");
759 
760  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
761  auto bb_preheader =
762  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
763  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".forbody", query_func_ptr, 0);
764  auto bb_crit_edge =
765  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
766  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
767 
768  // Block .entry
769  LoadInst* row_count = new LoadInst(
770  get_pointer_element_type(row_count_ptr), row_count_ptr, "", false, bb_entry);
771  row_count->setAlignment(LLVM_ALIGN(8));
772  row_count->setName("row_count");
773 
774  LoadInst* max_matched = new LoadInst(
775  get_pointer_element_type(max_matched_ptr), max_matched_ptr, "", false, bb_entry);
776  max_matched->setAlignment(LLVM_ALIGN(8));
777 
778  auto crt_matched_ptr = new AllocaInst(i32_type, 0, "crt_matched", bb_entry);
779  auto old_total_matched_ptr = new AllocaInst(i32_type, 0, "old_total_matched", bb_entry);
780  CallInst* pos_start = CallInst::Create(func_pos_start, "", bb_entry);
781  pos_start->setCallingConv(CallingConv::C);
782  pos_start->setTailCall(true);
783  Attributes pos_start_pal;
784  pos_start->setAttributes(pos_start_pal);
785 
786  CallInst* pos_step = CallInst::Create(func_pos_step, "", bb_entry);
787  pos_step->setCallingConv(CallingConv::C);
788  pos_step->setTailCall(true);
789  Attributes pos_step_pal;
790  pos_step->setAttributes(pos_step_pal);
791 
792  CallInst* group_buff_idx_call = CallInst::Create(func_group_buff_idx, "", bb_entry);
793  group_buff_idx_call->setCallingConv(CallingConv::C);
794  group_buff_idx_call->setTailCall(true);
795  Attributes group_buff_idx_pal;
796  group_buff_idx_call->setAttributes(group_buff_idx_pal);
797  Value* group_buff_idx = group_buff_idx_call;
798 
799  const PointerType* Ty = dyn_cast<PointerType>(group_by_buffers->getType());
800  CHECK(Ty);
801 
802  Value* varlen_output_buffer{nullptr};
803  if (query_mem_desc.hasVarlenOutput()) {
804  // make the varlen buffer the _first_ 8 byte value in the group by buffers double ptr,
805  // and offset the group by buffers index by 8 bytes
806  auto varlen_output_buffer_gep = GetElementPtrInst::Create(
807  Ty->getPointerElementType(),
808  group_by_buffers,
809  llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 0),
810  "",
811  bb_entry);
812  varlen_output_buffer =
813  new LoadInst(get_pointer_element_type(varlen_output_buffer_gep),
814  varlen_output_buffer_gep,
815  "varlen_output_buffer",
816  false,
817  bb_entry);
818 
819  group_buff_idx = BinaryOperator::Create(
820  Instruction::Add,
822  llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 1),
823  "group_buff_idx_varlen_offset",
824  bb_entry);
825  } else {
826  varlen_output_buffer =
827  ConstantPointerNull::get(Type::getInt64PtrTy(mod->getContext()));
828  }
829  CHECK(varlen_output_buffer);
830 
831  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
832  GetElementPtrInst* group_by_buffers_gep = GetElementPtrInst::Create(
833  Ty->getPointerElementType(), group_by_buffers, group_buff_idx, "", bb_entry);
834  LoadInst* col_buffer = new LoadInst(get_pointer_element_type(group_by_buffers_gep),
835  group_by_buffers_gep,
836  "",
837  false,
838  bb_entry);
839  col_buffer->setName("col_buffer");
840  col_buffer->setAlignment(LLVM_ALIGN(8));
841 
842  llvm::ConstantInt* shared_mem_bytes_lv =
843  ConstantInt::get(i32_type, gpu_smem_context.getSharedMemorySize());
844  // TODO(Saman): change this further, normal path should not go through this
845  llvm::CallInst* result_buffer =
846  CallInst::Create(func_init_shared_mem,
847  std::vector<llvm::Value*>{col_buffer, shared_mem_bytes_lv},
848  "result_buffer",
849  bb_entry);
850 
851  ICmpInst* enter_or_not =
852  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
853  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
854 
855  // Block .loop.preheader
856  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
857  BranchInst::Create(bb_forbody, bb_preheader);
858 
859  // Block .forbody
860  Argument* pos_pre = new Argument(i64_type);
861  PHINode* pos = PHINode::Create(i64_type, check_scan_limit ? 3 : 2, "pos", bb_forbody);
862 
863  std::vector<Value*> row_process_params;
864  row_process_params.push_back(result_buffer);
865  row_process_params.push_back(varlen_output_buffer);
866  row_process_params.push_back(crt_matched_ptr);
867  row_process_params.push_back(total_matched);
868  row_process_params.push_back(old_total_matched_ptr);
869  row_process_params.push_back(max_matched_ptr);
870  row_process_params.push_back(agg_init_val);
871  row_process_params.push_back(pos);
872  row_process_params.push_back(frag_row_off_ptr);
873  row_process_params.push_back(row_count_ptr);
874  if (hoist_literals) {
875  CHECK(literals);
876  row_process_params.push_back(literals);
877  }
878  if (check_scan_limit) {
879  new StoreInst(ConstantInt::get(IntegerType::get(mod->getContext(), 32), 0),
880  crt_matched_ptr,
881  bb_forbody);
882  }
883  CallInst* row_process =
884  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
885  row_process->setCallingConv(CallingConv::C);
886  row_process->setTailCall(true);
887  Attributes row_process_pal;
888  row_process->setAttributes(row_process_pal);
889 
890  // Forcing all threads within a warp to be synchronized (Compute >= 7.x)
891  if (query_mem_desc.isWarpSyncRequired(device_type)) {
892  auto func_sync_warp_protected = mod->getFunction("sync_warp_protected");
893  CHECK(func_sync_warp_protected);
894  CallInst::Create(func_sync_warp_protected,
895  std::vector<llvm::Value*>{pos, row_count},
896  "",
897  bb_forbody);
898  }
899 
900  BinaryOperator* pos_inc =
901  BinaryOperator::Create(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
902  ICmpInst* loop_or_exit =
903  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
904  if (check_scan_limit) {
905  auto crt_matched = new LoadInst(get_pointer_element_type(crt_matched_ptr),
906  crt_matched_ptr,
907  "crt_matched",
908  false,
909  bb_forbody);
910  auto filter_match = BasicBlock::Create(
911  mod->getContext(), "filter_match", query_func_ptr, bb_crit_edge);
912  llvm::Value* new_total_matched =
913  new LoadInst(get_pointer_element_type(old_total_matched_ptr),
914  old_total_matched_ptr,
915  "",
916  false,
917  filter_match);
918  new_total_matched =
919  BinaryOperator::CreateAdd(new_total_matched, crt_matched, "", filter_match);
920  CHECK(new_total_matched);
921  ICmpInst* limit_not_reached = new ICmpInst(*filter_match,
922  ICmpInst::ICMP_SLT,
923  new_total_matched,
924  max_matched,
925  "limit_not_reached");
926  BranchInst::Create(
927  bb_forbody,
928  bb_crit_edge,
929  BinaryOperator::Create(
930  BinaryOperator::And, loop_or_exit, limit_not_reached, "", filter_match),
931  filter_match);
932  auto filter_nomatch = BasicBlock::Create(
933  mod->getContext(), "filter_nomatch", query_func_ptr, bb_crit_edge);
934  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, filter_nomatch);
935  ICmpInst* crt_matched_nz = new ICmpInst(
936  *bb_forbody, ICmpInst::ICMP_NE, crt_matched, ConstantInt::get(i32_type, 0), "");
937  BranchInst::Create(filter_match, filter_nomatch, crt_matched_nz, bb_forbody);
938  pos->addIncoming(pos_start_i64, bb_preheader);
939  pos->addIncoming(pos_pre, filter_match);
940  pos->addIncoming(pos_pre, filter_nomatch);
941  } else {
942  pos->addIncoming(pos_start_i64, bb_preheader);
943  pos->addIncoming(pos_pre, bb_forbody);
944  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
945  }
946 
947  // Block ._crit_edge
948  BranchInst::Create(bb_exit, bb_crit_edge);
949 
950  // Block .exit
951  CallInst::Create(func_write_back,
952  std::vector<Value*>{col_buffer, result_buffer, shared_mem_bytes_lv},
953  "",
954  bb_exit);
955 
956  ReturnInst::Create(mod->getContext(), bb_exit);
957 
958  // Resolve Forward References
959  pos_pre->replaceAllUsesWith(pos_inc);
960  delete pos_pre;
961 
962  if (verifyFunction(*query_func_ptr, &llvm::errs())) {
963  LOG(FATAL) << "Generated invalid code. ";
964  }
965 
966  return {query_func_ptr, row_process};
967 }
968 
969 std::tuple<llvm::Function*, llvm::CallInst*> query_template(
970  llvm::Module* mod,
971  const size_t aggr_col_count,
972  const bool hoist_literals,
973  const bool is_estimate_query,
974  const GpuSharedMemoryContext& gpu_smem_context) {
975  return query_template_impl<llvm::AttributeList>(
976  mod, aggr_col_count, hoist_literals, is_estimate_query, gpu_smem_context);
977 }
978 std::tuple<llvm::Function*, llvm::CallInst*> query_group_by_template(
979  llvm::Module* mod,
980  const bool hoist_literals,
982  const ExecutorDeviceType device_type,
983  const bool check_scan_limit,
984  const GpuSharedMemoryContext& gpu_smem_context) {
985  return query_group_by_template_impl<llvm::AttributeList>(mod,
986  hoist_literals,
988  device_type,
989  check_scan_limit,
990  gpu_smem_context);
991 }
std::tuple< llvm::Function *, llvm::CallInst * > query_group_by_template_impl(llvm::Module *mod, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
ExecutorDeviceType
#define LOG(tag)
Definition: Logger.h:285
size_t getSharedMemorySize() const
std::tuple< llvm::Function *, llvm::CallInst * > query_template(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
AGG_TYPE agg_func(AGG_TYPE const lhs, AGG_TYPE const rhs)
Type pointer_type(const Type pointee)
#define LLVM_ALIGN(alignment)
std::tuple< llvm::Function *, llvm::CallInst * > query_group_by_template(llvm::Module *mod, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
std::string to_string(char const *&&v)
llvm::Function * group_buff_idx(llvm::Module *mod)
llvm::Function * default_func_builder(llvm::Module *mod, const std::string &name)
bool isWarpSyncRequired(const ExecutorDeviceType) const
#define CHECK(condition)
Definition: Logger.h:291
std::tuple< llvm::Function *, llvm::CallInst * > query_template_impl(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
string name
Definition: setup.in.py:72
llvm::Type * get_pointer_element_type(llvm::Value *value)
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)