OmniSciDB  91042dcc5b
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
QueryTemplateGenerator.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "QueryTemplateGenerator.h"
18 #include "IRCodegenUtils.h"
19 #include "Logger/Logger.h"
20 
21 #include <llvm/IR/Constants.h>
22 #include <llvm/IR/IRBuilder.h>
23 #include <llvm/IR/Instructions.h>
24 #include <llvm/IR/Verifier.h>
25 
26 // This file was pretty much auto-generated by running:
27 // llc -march=cpp RuntimeFunctions.ll
28 // and formatting the results to be more readable.
29 
30 namespace {
31 
32 inline llvm::Type* get_pointer_element_type(llvm::Value* value) {
33  CHECK(value);
34  auto type = value->getType();
35  CHECK(type && type->isPointerTy());
36  auto pointer_type = llvm::dyn_cast<llvm::PointerType>(type);
37  CHECK(pointer_type);
38  return pointer_type->getElementType();
39 }
40 
41 template <class Attributes>
42 llvm::Function* default_func_builder(llvm::Module* mod, const std::string& name) {
43  using namespace llvm;
44 
45  std::vector<Type*> func_args;
46  FunctionType* func_type = FunctionType::get(
47  /*Result=*/IntegerType::get(mod->getContext(), 32),
48  /*Params=*/func_args,
49  /*isVarArg=*/false);
50 
51  auto func_ptr = mod->getFunction(name);
52  if (!func_ptr) {
53  func_ptr = Function::Create(
54  /*Type=*/func_type,
55  /*Linkage=*/GlobalValue::ExternalLinkage,
56  /*Name=*/name,
57  mod); // (external, no body)
58  func_ptr->setCallingConv(CallingConv::C);
59  }
60 
61  Attributes func_pal;
62  {
63  SmallVector<Attributes, 4> Attrs;
64  Attributes PAS;
65  {
66  AttrBuilder B;
67  PAS = Attributes::get(mod->getContext(), ~0U, B);
68  }
69 
70  Attrs.push_back(PAS);
71  func_pal = Attributes::get(mod->getContext(), Attrs);
72  }
73  func_ptr->setAttributes(func_pal);
74 
75  return func_ptr;
76 }
77 
78 template <class Attributes>
79 llvm::Function* pos_start(llvm::Module* mod) {
80  return default_func_builder<Attributes>(mod, "pos_start");
81 }
82 
83 template <class Attributes>
84 llvm::Function* group_buff_idx(llvm::Module* mod) {
85  return default_func_builder<Attributes>(mod, "group_buff_idx");
86 }
87 
88 template <class Attributes>
89 llvm::Function* pos_step(llvm::Module* mod) {
90  using namespace llvm;
91 
92  std::vector<Type*> func_args;
93  FunctionType* func_type = FunctionType::get(
94  /*Result=*/IntegerType::get(mod->getContext(), 32),
95  /*Params=*/func_args,
96  /*isVarArg=*/false);
97 
98  auto func_ptr = mod->getFunction("pos_step");
99  if (!func_ptr) {
100  func_ptr = Function::Create(
101  /*Type=*/func_type,
102  /*Linkage=*/GlobalValue::ExternalLinkage,
103  /*Name=*/"pos_step",
104  mod); // (external, no body)
105  func_ptr->setCallingConv(CallingConv::C);
106  }
107 
108  Attributes func_pal;
109  {
110  SmallVector<Attributes, 4> Attrs;
111  Attributes PAS;
112  {
113  AttrBuilder B;
114  PAS = Attributes::get(mod->getContext(), ~0U, B);
115  }
116 
117  Attrs.push_back(PAS);
118  func_pal = Attributes::get(mod->getContext(), Attrs);
119  }
120  func_ptr->setAttributes(func_pal);
121 
122  return func_ptr;
123 }
124 
125 template <class Attributes>
126 llvm::Function* row_process(llvm::Module* mod,
127  const size_t aggr_col_count,
128  const bool hoist_literals) {
129  using namespace llvm;
130 
131  std::vector<Type*> func_args;
132  auto i8_type = IntegerType::get(mod->getContext(), 8);
133  auto i32_type = IntegerType::get(mod->getContext(), 32);
134  auto i64_type = IntegerType::get(mod->getContext(), 64);
135  auto pi32_type = PointerType::get(i32_type, 0);
136  auto pi64_type = PointerType::get(i64_type, 0);
137 
138  if (aggr_col_count) {
139  for (size_t i = 0; i < aggr_col_count; ++i) {
140  func_args.push_back(pi64_type);
141  }
142  } else { // group by query
143  func_args.push_back(pi64_type); // groups buffer
144  func_args.push_back(pi64_type); // varlen output buffer
145  func_args.push_back(pi32_type); // 1 iff current row matched, else 0
146  func_args.push_back(pi32_type); // total rows matched from the caller
147  func_args.push_back(pi32_type); // total rows matched before atomic increment
148  func_args.push_back(pi32_type); // max number of slots in the output buffer
149  }
150 
151  func_args.push_back(pi64_type); // aggregate init values
152 
153  func_args.push_back(i64_type);
154  func_args.push_back(pi64_type);
155  func_args.push_back(pi64_type);
156  if (hoist_literals) {
157  func_args.push_back(PointerType::get(i8_type, 0));
158  }
159  FunctionType* func_type = FunctionType::get(
160  /*Result=*/i32_type,
161  /*Params=*/func_args,
162  /*isVarArg=*/false);
163 
164  std::string func_name{"row_process"};
165  auto func_ptr = mod->getFunction(func_name);
166 
167  if (!func_ptr) {
168  func_ptr = Function::Create(
169  /*Type=*/func_type,
170  /*Linkage=*/GlobalValue::ExternalLinkage,
171  /*Name=*/func_name,
172  mod); // (external, no body)
173  func_ptr->setCallingConv(CallingConv::C);
174 
175  Attributes func_pal;
176  {
177  SmallVector<Attributes, 4> Attrs;
178  Attributes PAS;
179  {
180  AttrBuilder B;
181  PAS = Attributes::get(mod->getContext(), ~0U, B);
182  }
183 
184  Attrs.push_back(PAS);
185  func_pal = Attributes::get(mod->getContext(), Attrs);
186  }
187  func_ptr->setAttributes(func_pal);
188  }
189 
190  return func_ptr;
191 }
192 
193 } // namespace
194 
195 template <class Attributes>
196 std::tuple<llvm::Function*, llvm::CallInst*> query_template_impl(
197  llvm::Module* mod,
198  const size_t aggr_col_count,
199  const bool hoist_literals,
200  const bool is_estimate_query,
201  const GpuSharedMemoryContext& gpu_smem_context) {
202  using namespace llvm;
203 
204  auto func_pos_start = pos_start<Attributes>(mod);
205  CHECK(func_pos_start);
206  auto func_pos_step = pos_step<Attributes>(mod);
207  CHECK(func_pos_step);
208  auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
209  CHECK(func_group_buff_idx);
210  auto func_row_process = row_process<Attributes>(
211  mod, is_estimate_query ? 1 : aggr_col_count, hoist_literals);
212  CHECK(func_row_process);
213 
214  auto i8_type = IntegerType::get(mod->getContext(), 8);
215  auto i32_type = IntegerType::get(mod->getContext(), 32);
216  auto i64_type = IntegerType::get(mod->getContext(), 64);
217  auto pi8_type = PointerType::get(i8_type, 0);
218  auto ppi8_type = PointerType::get(pi8_type, 0);
219  auto pi32_type = PointerType::get(i32_type, 0);
220  auto pi64_type = PointerType::get(i64_type, 0);
221  auto ppi64_type = PointerType::get(pi64_type, 0);
222 
223  std::vector<Type*> query_args;
224  query_args.push_back(ppi8_type);
225  if (hoist_literals) {
226  query_args.push_back(pi8_type);
227  }
228  query_args.push_back(pi64_type);
229  query_args.push_back(pi64_type);
230  query_args.push_back(pi32_type);
231 
232  query_args.push_back(pi64_type);
233  query_args.push_back(ppi64_type);
234  query_args.push_back(i32_type);
235  query_args.push_back(pi64_type);
236  query_args.push_back(pi32_type);
237  query_args.push_back(pi32_type);
238 
239  FunctionType* query_func_type = FunctionType::get(
240  /*Result=*/Type::getVoidTy(mod->getContext()),
241  /*Params=*/query_args,
242  /*isVarArg=*/false);
243 
244  std::string query_template_name{"query_template"};
245  auto query_func_ptr = mod->getFunction(query_template_name);
246  CHECK(!query_func_ptr);
247 
248  query_func_ptr = Function::Create(
249  /*Type=*/query_func_type,
250  /*Linkage=*/GlobalValue::ExternalLinkage,
251  /*Name=*/query_template_name,
252  mod);
253  query_func_ptr->setCallingConv(CallingConv::C);
254 
255  Attributes query_func_pal;
256  {
257  SmallVector<Attributes, 4> Attrs;
258  Attributes PAS;
259  {
260  AttrBuilder B;
261  B.addAttribute(Attribute::NoCapture);
262  PAS = Attributes::get(mod->getContext(), 1U, B);
263  }
264 
265  Attrs.push_back(PAS);
266  {
267  AttrBuilder B;
268  B.addAttribute(Attribute::NoCapture);
269  PAS = Attributes::get(mod->getContext(), 2U, B);
270  }
271 
272  Attrs.push_back(PAS);
273 
274  {
275  AttrBuilder B;
276  B.addAttribute(Attribute::NoCapture);
277  Attrs.push_back(Attributes::get(mod->getContext(), 3U, B));
278  }
279 
280  {
281  AttrBuilder B;
282  B.addAttribute(Attribute::NoCapture);
283  Attrs.push_back(Attributes::get(mod->getContext(), 4U, B));
284  }
285 
286  Attrs.push_back(PAS);
287 
288  query_func_pal = Attributes::get(mod->getContext(), Attrs);
289  }
290  query_func_ptr->setAttributes(query_func_pal);
291 
292  Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
293  Value* byte_stream = &*query_arg_it;
294  byte_stream->setName("byte_stream");
295  Value* literals{nullptr};
296  if (hoist_literals) {
297  literals = &*(++query_arg_it);
298  literals->setName("literals");
299  }
300  Value* row_count_ptr = &*(++query_arg_it);
301  row_count_ptr->setName("row_count_ptr");
302  Value* frag_row_off_ptr = &*(++query_arg_it);
303  frag_row_off_ptr->setName("frag_row_off_ptr");
304  Value* max_matched_ptr = &*(++query_arg_it);
305  max_matched_ptr->setName("max_matched_ptr");
306  Value* agg_init_val = &*(++query_arg_it);
307  agg_init_val->setName("agg_init_val");
308  Value* out = &*(++query_arg_it);
309  out->setName("out");
310  Value* frag_idx = &*(++query_arg_it);
311  frag_idx->setName("frag_idx");
312  Value* join_hash_tables = &*(++query_arg_it);
313  join_hash_tables->setName("join_hash_tables");
314  Value* total_matched = &*(++query_arg_it);
315  total_matched->setName("total_matched");
316  Value* error_code = &*(++query_arg_it);
317  error_code->setName("error_code");
318 
319  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
320  auto bb_preheader =
321  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
322  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".for.body", query_func_ptr, 0);
323  auto bb_crit_edge =
324  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
325  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
326 
327  // Block (.entry)
328  std::vector<Value*> result_ptr_vec;
329  llvm::CallInst* smem_output_buffer{nullptr};
330  if (!is_estimate_query) {
331  for (size_t i = 0; i < aggr_col_count; ++i) {
332  auto result_ptr = new AllocaInst(i64_type, 0, "result", bb_entry);
333  result_ptr->setAlignment(LLVM_ALIGN(8));
334  result_ptr_vec.push_back(result_ptr);
335  }
336  if (gpu_smem_context.isSharedMemoryUsed()) {
337  auto init_smem_func = mod->getFunction("init_shared_mem");
338  CHECK(init_smem_func);
339  // only one slot per aggregate column is needed, and so we can initialize shared
340  // memory buffer for intermediate results to be exactly like the agg_init_val array
341  smem_output_buffer = CallInst::Create(
342  init_smem_func,
343  std::vector<llvm::Value*>{
344  agg_init_val,
345  llvm::ConstantInt::get(i32_type, aggr_col_count * sizeof(int64_t))},
346  "smem_buffer",
347  bb_entry);
348  }
349  }
350 
351  LoadInst* row_count = new LoadInst(get_pointer_element_type(row_count_ptr),
352  row_count_ptr,
353  "row_count",
354  false,
355  bb_entry);
356  row_count->setAlignment(LLVM_ALIGN(8));
357  row_count->setName("row_count");
358  std::vector<Value*> agg_init_val_vec;
359  if (!is_estimate_query) {
360  for (size_t i = 0; i < aggr_col_count; ++i) {
361  auto idx_lv = ConstantInt::get(i32_type, i);
362  auto agg_init_gep = GetElementPtrInst::CreateInBounds(
363  agg_init_val->getType()->getPointerElementType(),
364  agg_init_val,
365  idx_lv,
366  "",
367  bb_entry);
368  auto agg_init_val = new LoadInst(
369  get_pointer_element_type(agg_init_gep), agg_init_gep, "", false, bb_entry);
370  agg_init_val->setAlignment(LLVM_ALIGN(8));
371  agg_init_val_vec.push_back(agg_init_val);
372  auto init_val_st = new StoreInst(agg_init_val, result_ptr_vec[i], false, bb_entry);
373  init_val_st->setAlignment(LLVM_ALIGN(8));
374  }
375  }
376 
377  CallInst* pos_start = CallInst::Create(func_pos_start, "pos_start", bb_entry);
378  pos_start->setCallingConv(CallingConv::C);
379  pos_start->setTailCall(true);
380  Attributes pos_start_pal;
381  pos_start->setAttributes(pos_start_pal);
382 
383  CallInst* pos_step = CallInst::Create(func_pos_step, "pos_step", bb_entry);
384  pos_step->setCallingConv(CallingConv::C);
385  pos_step->setTailCall(true);
386  Attributes pos_step_pal;
387  pos_step->setAttributes(pos_step_pal);
388 
389  CallInst* group_buff_idx = nullptr;
390  if (!is_estimate_query) {
391  group_buff_idx = CallInst::Create(func_group_buff_idx, "group_buff_idx", bb_entry);
392  group_buff_idx->setCallingConv(CallingConv::C);
393  group_buff_idx->setTailCall(true);
394  Attributes group_buff_idx_pal;
395  group_buff_idx->setAttributes(group_buff_idx_pal);
396  }
397 
398  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
399  ICmpInst* enter_or_not =
400  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
401  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
402 
403  // Block .loop.preheader
404  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
405  BranchInst::Create(bb_forbody, bb_preheader);
406 
407  // Block .forbody
408  Argument* pos_inc_pre = new Argument(i64_type);
409  PHINode* pos = PHINode::Create(i64_type, 2, "pos", bb_forbody);
410  pos->addIncoming(pos_start_i64, bb_preheader);
411  pos->addIncoming(pos_inc_pre, bb_forbody);
412 
413  std::vector<Value*> row_process_params;
414  row_process_params.insert(
415  row_process_params.end(), result_ptr_vec.begin(), result_ptr_vec.end());
416  if (is_estimate_query) {
417  row_process_params.push_back(
418  new LoadInst(get_pointer_element_type(out), out, "", false, bb_forbody));
419  }
420  row_process_params.push_back(agg_init_val);
421  row_process_params.push_back(pos);
422  row_process_params.push_back(frag_row_off_ptr);
423  row_process_params.push_back(row_count_ptr);
424  if (hoist_literals) {
425  CHECK(literals);
426  row_process_params.push_back(literals);
427  }
428  CallInst* row_process =
429  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
430  row_process->setCallingConv(CallingConv::C);
431  row_process->setTailCall(false);
432  Attributes row_process_pal;
433  row_process->setAttributes(row_process_pal);
434 
435  BinaryOperator* pos_inc =
436  BinaryOperator::CreateNSW(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
437  ICmpInst* loop_or_exit =
438  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
439  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
440 
441  // Block ._crit_edge
442  std::vector<Instruction*> result_vec_pre;
443  if (!is_estimate_query) {
444  for (size_t i = 0; i < aggr_col_count; ++i) {
445  auto result = new LoadInst(get_pointer_element_type(result_ptr_vec[i]),
446  result_ptr_vec[i],
447  ".pre.result",
448  false,
449  bb_crit_edge);
450  result->setAlignment(LLVM_ALIGN(8));
451  result_vec_pre.push_back(result);
452  }
453  }
454 
455  BranchInst::Create(bb_exit, bb_crit_edge);
456 
457  // Block .exit
469  if (!is_estimate_query) {
470  std::vector<PHINode*> result_vec;
471  for (int64_t i = aggr_col_count - 1; i >= 0; --i) {
472  auto result =
473  PHINode::Create(IntegerType::get(mod->getContext(), 64), 2, "", bb_exit);
474  result->addIncoming(result_vec_pre[i], bb_crit_edge);
475  result->addIncoming(agg_init_val_vec[i], bb_entry);
476  result_vec.insert(result_vec.begin(), result);
477  }
478 
479  for (size_t i = 0; i < aggr_col_count; ++i) {
480  auto col_idx = ConstantInt::get(i32_type, i);
481  if (gpu_smem_context.isSharedMemoryUsed()) {
482  auto target_addr = GetElementPtrInst::CreateInBounds(
483  smem_output_buffer->getType()->getPointerElementType(),
484  smem_output_buffer,
485  col_idx,
486  "",
487  bb_exit);
488  // TODO: generalize this once we want to support other types of aggregate
489  // functions besides COUNT.
490  auto agg_func = mod->getFunction("agg_sum_shared");
491  CHECK(agg_func);
492  CallInst::Create(
493  agg_func, std::vector<llvm::Value*>{target_addr, result_vec[i]}, "", bb_exit);
494  } else {
495  auto out_gep = GetElementPtrInst::CreateInBounds(
496  out->getType()->getPointerElementType(), out, col_idx, "", bb_exit);
497  auto col_buffer =
498  new LoadInst(get_pointer_element_type(out_gep), out_gep, "", false, bb_exit);
499  col_buffer->setAlignment(LLVM_ALIGN(8));
500  auto slot_idx = BinaryOperator::CreateAdd(
502  BinaryOperator::CreateMul(frag_idx, pos_step, "", bb_exit),
503  "",
504  bb_exit);
505  auto target_addr = GetElementPtrInst::CreateInBounds(
506  col_buffer->getType()->getPointerElementType(),
507  col_buffer,
508  slot_idx,
509  "",
510  bb_exit);
511  StoreInst* result_st = new StoreInst(result_vec[i], target_addr, false, bb_exit);
512  result_st->setAlignment(LLVM_ALIGN(8));
513  }
514  }
515  if (gpu_smem_context.isSharedMemoryUsed()) {
516  // final reduction of results from shared memory buffer back into global memory.
517  auto sync_thread_func = mod->getFunction("sync_threadblock");
518  CHECK(sync_thread_func);
519  CallInst::Create(sync_thread_func, std::vector<llvm::Value*>{}, "", bb_exit);
520  auto reduce_smem_to_gmem_func = mod->getFunction("write_back_non_grouped_agg");
521  CHECK(reduce_smem_to_gmem_func);
522  // each thread reduce the aggregate target corresponding to its own thread ID.
523  // If there are more targets than threads we do not currently use shared memory
524  // optimization. This can be relaxed if necessary
525  for (size_t i = 0; i < aggr_col_count; i++) {
526  auto out_gep =
527  GetElementPtrInst::CreateInBounds(out->getType()->getPointerElementType(),
528  out,
529  ConstantInt::get(i32_type, i),
530  "",
531  bb_exit);
532  auto gmem_output_buffer = new LoadInst(get_pointer_element_type(out_gep),
533  out_gep,
534  "gmem_output_buffer_" + std::to_string(i),
535  false,
536  bb_exit);
537  CallInst::Create(
538  reduce_smem_to_gmem_func,
539  std::vector<llvm::Value*>{
540  smem_output_buffer, gmem_output_buffer, ConstantInt::get(i32_type, i)},
541  "",
542  bb_exit);
543  }
544  }
545  }
546 
547  ReturnInst::Create(mod->getContext(), bb_exit);
548 
549  // Resolve Forward References
550  pos_inc_pre->replaceAllUsesWith(pos_inc);
551  delete pos_inc_pre;
552 
553  if (verifyFunction(*query_func_ptr)) {
554  LOG(FATAL) << "Generated invalid code. ";
555  }
556 
557  return {query_func_ptr, row_process};
558 }
559 
560 template <class Attributes>
561 std::tuple<llvm::Function*, llvm::CallInst*> query_group_by_template_impl(
562  llvm::Module* mod,
563  const bool hoist_literals,
565  const ExecutorDeviceType device_type,
566  const bool check_scan_limit,
567  const GpuSharedMemoryContext& gpu_smem_context) {
568  if (gpu_smem_context.isSharedMemoryUsed()) {
569  CHECK(device_type == ExecutorDeviceType::GPU);
570  }
571  using namespace llvm;
572 
573  auto func_pos_start = pos_start<Attributes>(mod);
574  CHECK(func_pos_start);
575  auto func_pos_step = pos_step<Attributes>(mod);
576  CHECK(func_pos_step);
577  auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
578  CHECK(func_group_buff_idx);
579  auto func_row_process = row_process<Attributes>(mod, 0, hoist_literals);
580  CHECK(func_row_process);
581  auto func_init_shared_mem = gpu_smem_context.isSharedMemoryUsed()
582  ? mod->getFunction("init_shared_mem")
583  : mod->getFunction("init_shared_mem_nop");
584  CHECK(func_init_shared_mem);
585 
586  auto func_write_back = mod->getFunction("write_back_nop");
587  CHECK(func_write_back);
588 
589  auto i32_type = IntegerType::get(mod->getContext(), 32);
590  auto i64_type = IntegerType::get(mod->getContext(), 64);
591  auto pi8_type = PointerType::get(IntegerType::get(mod->getContext(), 8), 0);
592  auto pi32_type = PointerType::get(i32_type, 0);
593  auto pi64_type = PointerType::get(i64_type, 0);
594  auto ppi64_type = PointerType::get(pi64_type, 0);
595  auto ppi8_type = PointerType::get(pi8_type, 0);
596 
597  std::vector<Type*> query_args;
598  query_args.push_back(ppi8_type);
599  if (hoist_literals) {
600  query_args.push_back(pi8_type);
601  }
602  query_args.push_back(pi64_type);
603  query_args.push_back(pi64_type);
604  query_args.push_back(pi32_type);
605  query_args.push_back(pi64_type);
606 
607  query_args.push_back(ppi64_type);
608  query_args.push_back(i32_type);
609  query_args.push_back(pi64_type);
610  query_args.push_back(pi32_type);
611  query_args.push_back(pi32_type);
612 
613  FunctionType* query_func_type = FunctionType::get(
614  /*Result=*/Type::getVoidTy(mod->getContext()),
615  /*Params=*/query_args,
616  /*isVarArg=*/false);
617 
618  std::string query_name{"query_group_by_template"};
619  auto query_func_ptr = mod->getFunction(query_name);
620  CHECK(!query_func_ptr);
621 
622  query_func_ptr = Function::Create(
623  /*Type=*/query_func_type,
624  /*Linkage=*/GlobalValue::ExternalLinkage,
625  /*Name=*/"query_group_by_template",
626  mod);
627 
628  query_func_ptr->setCallingConv(CallingConv::C);
629 
630  Attributes query_func_pal;
631  {
632  SmallVector<Attributes, 4> Attrs;
633  Attributes PAS;
634  {
635  AttrBuilder B;
636  B.addAttribute(Attribute::ReadNone);
637  B.addAttribute(Attribute::NoCapture);
638  PAS = Attributes::get(mod->getContext(), 1U, B);
639  }
640 
641  Attrs.push_back(PAS);
642  {
643  AttrBuilder B;
644  B.addAttribute(Attribute::ReadOnly);
645  B.addAttribute(Attribute::NoCapture);
646  PAS = Attributes::get(mod->getContext(), 2U, B);
647  }
648 
649  Attrs.push_back(PAS);
650  {
651  AttrBuilder B;
652  B.addAttribute(Attribute::ReadNone);
653  B.addAttribute(Attribute::NoCapture);
654  PAS = Attributes::get(mod->getContext(), 3U, B);
655  }
656 
657  Attrs.push_back(PAS);
658  {
659  AttrBuilder B;
660  B.addAttribute(Attribute::ReadOnly);
661  B.addAttribute(Attribute::NoCapture);
662  PAS = Attributes::get(mod->getContext(), 4U, B);
663  }
664 
665  Attrs.push_back(PAS);
666  {
667  AttrBuilder B;
668  B.addAttribute(Attribute::UWTable);
669  PAS = Attributes::get(mod->getContext(), ~0U, B);
670  }
671 
672  Attrs.push_back(PAS);
673 
674  query_func_pal = Attributes::get(mod->getContext(), Attrs);
675  }
676  query_func_ptr->setAttributes(query_func_pal);
677 
678  Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
679  Value* byte_stream = &*query_arg_it;
680  byte_stream->setName("byte_stream");
681  Value* literals{nullptr};
682  if (hoist_literals) {
683  literals = &*(++query_arg_it);
684  ;
685  literals->setName("literals");
686  }
687  Value* row_count_ptr = &*(++query_arg_it);
688  row_count_ptr->setName("row_count_ptr");
689  Value* frag_row_off_ptr = &*(++query_arg_it);
690  frag_row_off_ptr->setName("frag_row_off_ptr");
691  Value* max_matched_ptr = &*(++query_arg_it);
692  max_matched_ptr->setName("max_matched_ptr");
693  Value* agg_init_val = &*(++query_arg_it);
694  agg_init_val->setName("agg_init_val");
695  Value* group_by_buffers = &*(++query_arg_it);
696  group_by_buffers->setName("group_by_buffers");
697  Value* frag_idx = &*(++query_arg_it);
698  frag_idx->setName("frag_idx");
699  Value* join_hash_tables = &*(++query_arg_it);
700  join_hash_tables->setName("join_hash_tables");
701  Value* total_matched = &*(++query_arg_it);
702  total_matched->setName("total_matched");
703  Value* error_code = &*(++query_arg_it);
704  error_code->setName("error_code");
705 
706  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
707  auto bb_preheader =
708  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
709  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".forbody", query_func_ptr, 0);
710  auto bb_crit_edge =
711  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
712  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
713 
714  // Block .entry
715  LoadInst* row_count = new LoadInst(
716  get_pointer_element_type(row_count_ptr), row_count_ptr, "", false, bb_entry);
717  row_count->setAlignment(LLVM_ALIGN(8));
718  row_count->setName("row_count");
719 
720  LoadInst* max_matched = new LoadInst(
721  get_pointer_element_type(max_matched_ptr), max_matched_ptr, "", false, bb_entry);
722  max_matched->setAlignment(LLVM_ALIGN(8));
723 
724  auto crt_matched_ptr = new AllocaInst(i32_type, 0, "crt_matched", bb_entry);
725  auto old_total_matched_ptr = new AllocaInst(i32_type, 0, "old_total_matched", bb_entry);
726  CallInst* pos_start = CallInst::Create(func_pos_start, "", bb_entry);
727  pos_start->setCallingConv(CallingConv::C);
728  pos_start->setTailCall(true);
729  Attributes pos_start_pal;
730  pos_start->setAttributes(pos_start_pal);
731 
732  CallInst* pos_step = CallInst::Create(func_pos_step, "", bb_entry);
733  pos_step->setCallingConv(CallingConv::C);
734  pos_step->setTailCall(true);
735  Attributes pos_step_pal;
736  pos_step->setAttributes(pos_step_pal);
737 
738  CallInst* group_buff_idx_call = CallInst::Create(func_group_buff_idx, "", bb_entry);
739  group_buff_idx_call->setCallingConv(CallingConv::C);
740  group_buff_idx_call->setTailCall(true);
741  Attributes group_buff_idx_pal;
742  group_buff_idx_call->setAttributes(group_buff_idx_pal);
743  Value* group_buff_idx = group_buff_idx_call;
744 
745  const PointerType* Ty = dyn_cast<PointerType>(group_by_buffers->getType());
746  CHECK(Ty);
747 
748  Value* varlen_output_buffer{nullptr};
749  if (query_mem_desc.hasVarlenOutput()) {
750  // make the varlen buffer the _first_ 8 byte value in the group by buffers double ptr,
751  // and offset the group by buffers index by 8 bytes
752  auto varlen_output_buffer_gep = GetElementPtrInst::Create(
753  Ty->getElementType(),
754  group_by_buffers,
755  llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 0),
756  "",
757  bb_entry);
758  varlen_output_buffer =
759  new LoadInst(get_pointer_element_type(varlen_output_buffer_gep),
760  varlen_output_buffer_gep,
761  "varlen_output_buffer",
762  false,
763  bb_entry);
764 
765  group_buff_idx = BinaryOperator::Create(
766  Instruction::Add,
768  llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 1),
769  "group_buff_idx_varlen_offset",
770  bb_entry);
771  } else {
772  varlen_output_buffer =
773  ConstantPointerNull::get(Type::getInt64PtrTy(mod->getContext()));
774  }
775  CHECK(varlen_output_buffer);
776 
777  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
778  GetElementPtrInst* group_by_buffers_gep = GetElementPtrInst::Create(
779  Ty->getElementType(), group_by_buffers, group_buff_idx, "", bb_entry);
780  LoadInst* col_buffer = new LoadInst(get_pointer_element_type(group_by_buffers_gep),
781  group_by_buffers_gep,
782  "",
783  false,
784  bb_entry);
785  col_buffer->setName("col_buffer");
786  col_buffer->setAlignment(LLVM_ALIGN(8));
787 
788  llvm::ConstantInt* shared_mem_bytes_lv =
789  ConstantInt::get(i32_type, gpu_smem_context.getSharedMemorySize());
790  // TODO(Saman): change this further, normal path should not go through this
791  llvm::CallInst* result_buffer =
792  CallInst::Create(func_init_shared_mem,
793  std::vector<llvm::Value*>{col_buffer, shared_mem_bytes_lv},
794  "result_buffer",
795  bb_entry);
796 
797  ICmpInst* enter_or_not =
798  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
799  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
800 
801  // Block .loop.preheader
802  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
803  BranchInst::Create(bb_forbody, bb_preheader);
804 
805  // Block .forbody
806  Argument* pos_pre = new Argument(i64_type);
807  PHINode* pos = PHINode::Create(i64_type, check_scan_limit ? 3 : 2, "pos", bb_forbody);
808 
809  std::vector<Value*> row_process_params;
810  row_process_params.push_back(result_buffer);
811  row_process_params.push_back(varlen_output_buffer);
812  row_process_params.push_back(crt_matched_ptr);
813  row_process_params.push_back(total_matched);
814  row_process_params.push_back(old_total_matched_ptr);
815  row_process_params.push_back(max_matched_ptr);
816  row_process_params.push_back(agg_init_val);
817  row_process_params.push_back(pos);
818  row_process_params.push_back(frag_row_off_ptr);
819  row_process_params.push_back(row_count_ptr);
820  if (hoist_literals) {
821  CHECK(literals);
822  row_process_params.push_back(literals);
823  }
824  if (check_scan_limit) {
825  new StoreInst(ConstantInt::get(IntegerType::get(mod->getContext(), 32), 0),
826  crt_matched_ptr,
827  bb_forbody);
828  }
829  CallInst* row_process =
830  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
831  row_process->setCallingConv(CallingConv::C);
832  row_process->setTailCall(true);
833  Attributes row_process_pal;
834  row_process->setAttributes(row_process_pal);
835 
836  // Forcing all threads within a warp to be synchronized (Compute >= 7.x)
837  if (query_mem_desc.isWarpSyncRequired(device_type)) {
838  auto func_sync_warp_protected = mod->getFunction("sync_warp_protected");
839  CHECK(func_sync_warp_protected);
840  CallInst::Create(func_sync_warp_protected,
841  std::vector<llvm::Value*>{pos, row_count},
842  "",
843  bb_forbody);
844  }
845 
846  BinaryOperator* pos_inc =
847  BinaryOperator::Create(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
848  ICmpInst* loop_or_exit =
849  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
850  if (check_scan_limit) {
851  auto crt_matched = new LoadInst(get_pointer_element_type(crt_matched_ptr),
852  crt_matched_ptr,
853  "crt_matched",
854  false,
855  bb_forbody);
856  auto filter_match = BasicBlock::Create(
857  mod->getContext(), "filter_match", query_func_ptr, bb_crit_edge);
858  llvm::Value* new_total_matched =
859  new LoadInst(get_pointer_element_type(old_total_matched_ptr),
860  old_total_matched_ptr,
861  "",
862  false,
863  filter_match);
864  new_total_matched =
865  BinaryOperator::CreateAdd(new_total_matched, crt_matched, "", filter_match);
866  CHECK(new_total_matched);
867  ICmpInst* limit_not_reached = new ICmpInst(*filter_match,
868  ICmpInst::ICMP_SLT,
869  new_total_matched,
870  max_matched,
871  "limit_not_reached");
872  BranchInst::Create(
873  bb_forbody,
874  bb_crit_edge,
875  BinaryOperator::Create(
876  BinaryOperator::And, loop_or_exit, limit_not_reached, "", filter_match),
877  filter_match);
878  auto filter_nomatch = BasicBlock::Create(
879  mod->getContext(), "filter_nomatch", query_func_ptr, bb_crit_edge);
880  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, filter_nomatch);
881  ICmpInst* crt_matched_nz = new ICmpInst(
882  *bb_forbody, ICmpInst::ICMP_NE, crt_matched, ConstantInt::get(i32_type, 0), "");
883  BranchInst::Create(filter_match, filter_nomatch, crt_matched_nz, bb_forbody);
884  pos->addIncoming(pos_start_i64, bb_preheader);
885  pos->addIncoming(pos_pre, filter_match);
886  pos->addIncoming(pos_pre, filter_nomatch);
887  } else {
888  pos->addIncoming(pos_start_i64, bb_preheader);
889  pos->addIncoming(pos_pre, bb_forbody);
890  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
891  }
892 
893  // Block ._crit_edge
894  BranchInst::Create(bb_exit, bb_crit_edge);
895 
896  // Block .exit
897  CallInst::Create(func_write_back,
898  std::vector<Value*>{col_buffer, result_buffer, shared_mem_bytes_lv},
899  "",
900  bb_exit);
901 
902  ReturnInst::Create(mod->getContext(), bb_exit);
903 
904  // Resolve Forward References
905  pos_pre->replaceAllUsesWith(pos_inc);
906  delete pos_pre;
907 
908  if (verifyFunction(*query_func_ptr, &llvm::errs())) {
909  LOG(FATAL) << "Generated invalid code. ";
910  }
911 
912  return {query_func_ptr, row_process};
913 }
914 
915 std::tuple<llvm::Function*, llvm::CallInst*> query_template(
916  llvm::Module* module,
917  const size_t aggr_col_count,
918  const bool hoist_literals,
919  const bool is_estimate_query,
920  const GpuSharedMemoryContext& gpu_smem_context) {
921  return query_template_impl<llvm::AttributeList>(
922  module, aggr_col_count, hoist_literals, is_estimate_query, gpu_smem_context);
923 }
924 std::tuple<llvm::Function*, llvm::CallInst*> query_group_by_template(
925  llvm::Module* module,
926  const bool hoist_literals,
928  const ExecutorDeviceType device_type,
929  const bool check_scan_limit,
930  const GpuSharedMemoryContext& gpu_smem_context) {
931  return query_group_by_template_impl<llvm::AttributeList>(module,
932  hoist_literals,
934  device_type,
935  check_scan_limit,
936  gpu_smem_context);
937 }
std::tuple< llvm::Function *, llvm::CallInst * > query_group_by_template(llvm::Module *module, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
std::tuple< llvm::Function *, llvm::CallInst * > query_group_by_template_impl(llvm::Module *mod, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
ExecutorDeviceType
#define LOG(tag)
Definition: Logger.h:205
size_t getSharedMemorySize() const
string name
Definition: setup.in.py:72
Type pointer_type(const Type pointee)
#define LLVM_ALIGN(alignment)
std::string to_string(char const *&&v)
llvm::Function * group_buff_idx(llvm::Module *mod)
llvm::Function * default_func_builder(llvm::Module *mod, const std::string &name)
std::tuple< llvm::Function *, llvm::CallInst * > query_template(llvm::Module *module, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
bool isWarpSyncRequired(const ExecutorDeviceType) const
#define CHECK(condition)
Definition: Logger.h:211
std::tuple< llvm::Function *, llvm::CallInst * > query_template_impl(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
llvm::Type * get_pointer_element_type(llvm::Value *value)
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)