OmniSciDB  c07336695a
QueryTemplateGenerator.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "QueryTemplateGenerator.h"
18 #include "Shared/Logger.h"
19 
20 #include <llvm/IR/Constants.h>
21 #include <llvm/IR/Instructions.h>
22 #include <llvm/IR/Verifier.h>
23 
24 // This file was pretty much auto-generated by running:
25 // llc -march=cpp RuntimeFunctions.ll
26 // and formatting the results to be more readable.
27 
28 namespace {
29 
30 template <class Attributes>
31 llvm::Function* default_func_builder(llvm::Module* mod, const std::string& name) {
32  using namespace llvm;
33 
34  std::vector<Type*> func_args;
35  FunctionType* func_type = FunctionType::get(
36  /*Result=*/IntegerType::get(mod->getContext(), 32),
37  /*Params=*/func_args,
38  /*isVarArg=*/false);
39 
40  auto func_ptr = mod->getFunction(name);
41  if (!func_ptr) {
42  func_ptr = Function::Create(
43  /*Type=*/func_type,
44  /*Linkage=*/GlobalValue::ExternalLinkage,
45  /*Name=*/name,
46  mod); // (external, no body)
47  func_ptr->setCallingConv(CallingConv::C);
48  }
49 
50  Attributes func_pal;
51  {
52  SmallVector<Attributes, 4> Attrs;
53  Attributes PAS;
54  {
55  AttrBuilder B;
56  PAS = Attributes::get(mod->getContext(), ~0U, B);
57  }
58 
59  Attrs.push_back(PAS);
60  func_pal = Attributes::get(mod->getContext(), Attrs);
61  }
62  func_ptr->setAttributes(func_pal);
63 
64  return func_ptr;
65 }
66 
67 template <class Attributes>
68 llvm::Function* pos_start(llvm::Module* mod) {
69  return default_func_builder<Attributes>(mod, "pos_start");
70 }
71 
72 template <class Attributes>
73 llvm::Function* group_buff_idx(llvm::Module* mod) {
74  return default_func_builder<Attributes>(mod, "group_buff_idx");
75 }
76 
77 template <class Attributes>
78 llvm::Function* pos_step(llvm::Module* mod) {
79  using namespace llvm;
80 
81  std::vector<Type*> func_args;
82  FunctionType* func_type = FunctionType::get(
83  /*Result=*/IntegerType::get(mod->getContext(), 32),
84  /*Params=*/func_args,
85  /*isVarArg=*/false);
86 
87  auto func_ptr = mod->getFunction("pos_step");
88  if (!func_ptr) {
89  func_ptr = Function::Create(
90  /*Type=*/func_type,
91  /*Linkage=*/GlobalValue::ExternalLinkage,
92  /*Name=*/"pos_step",
93  mod); // (external, no body)
94  func_ptr->setCallingConv(CallingConv::C);
95  }
96 
97  Attributes func_pal;
98  {
99  SmallVector<Attributes, 4> Attrs;
100  Attributes PAS;
101  {
102  AttrBuilder B;
103  PAS = Attributes::get(mod->getContext(), ~0U, B);
104  }
105 
106  Attrs.push_back(PAS);
107  func_pal = Attributes::get(mod->getContext(), Attrs);
108  }
109  func_ptr->setAttributes(func_pal);
110 
111  return func_ptr;
112 }
113 
114 template <class Attributes>
115 llvm::Function* row_process(llvm::Module* mod,
116  const size_t aggr_col_count,
117  const bool hoist_literals) {
118  using namespace llvm;
119 
120  std::vector<Type*> func_args;
121  auto i8_type = IntegerType::get(mod->getContext(), 8);
122  auto i32_type = IntegerType::get(mod->getContext(), 32);
123  auto i64_type = IntegerType::get(mod->getContext(), 64);
124  auto pi32_type = PointerType::get(i32_type, 0);
125  auto pi64_type = PointerType::get(i64_type, 0);
126 
127  if (aggr_col_count) {
128  for (size_t i = 0; i < aggr_col_count; ++i) {
129  func_args.push_back(pi64_type);
130  }
131  } else { // group by query
132  func_args.push_back(pi64_type); // groups buffer
133  func_args.push_back(pi32_type); // 1 iff current row matched, else 0
134  func_args.push_back(pi32_type); // total rows matched from the caller
135  func_args.push_back(pi32_type); // total rows matched before atomic increment
136  func_args.push_back(pi32_type); // max number of slots in the output buffer
137  }
138 
139  func_args.push_back(pi64_type); // aggregate init values
140 
141  func_args.push_back(i64_type);
142  func_args.push_back(pi64_type);
143  func_args.push_back(pi64_type);
144  if (hoist_literals) {
145  func_args.push_back(PointerType::get(i8_type, 0));
146  }
147  FunctionType* func_type = FunctionType::get(
148  /*Result=*/i32_type,
149  /*Params=*/func_args,
150  /*isVarArg=*/false);
151 
152  std::string func_name{"row_process"};
153  auto func_ptr = mod->getFunction(func_name);
154 
155  if (!func_ptr) {
156  func_ptr = Function::Create(
157  /*Type=*/func_type,
158  /*Linkage=*/GlobalValue::ExternalLinkage,
159  /*Name=*/func_name,
160  mod); // (external, no body)
161  func_ptr->setCallingConv(CallingConv::C);
162 
163  Attributes func_pal;
164  {
165  SmallVector<Attributes, 4> Attrs;
166  Attributes PAS;
167  {
168  AttrBuilder B;
169  PAS = Attributes::get(mod->getContext(), ~0U, B);
170  }
171 
172  Attrs.push_back(PAS);
173  func_pal = Attributes::get(mod->getContext(), Attrs);
174  }
175  func_ptr->setAttributes(func_pal);
176  }
177 
178  return func_ptr;
179 }
180 
181 } // namespace
182 
183 template <class Attributes>
184 llvm::Function* query_template_impl(llvm::Module* mod,
185  const size_t aggr_col_count,
186  const bool hoist_literals,
187  const bool is_estimate_query) {
188  using namespace llvm;
189 
190  auto func_pos_start = pos_start<Attributes>(mod);
191  CHECK(func_pos_start);
192  auto func_pos_step = pos_step<Attributes>(mod);
193  CHECK(func_pos_step);
194  auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
195  CHECK(func_group_buff_idx);
196  auto func_row_process = row_process<Attributes>(
197  mod, is_estimate_query ? 1 : aggr_col_count, hoist_literals);
198  CHECK(func_row_process);
199 
200  auto i8_type = IntegerType::get(mod->getContext(), 8);
201  auto i32_type = IntegerType::get(mod->getContext(), 32);
202  auto i64_type = IntegerType::get(mod->getContext(), 64);
203  auto pi8_type = PointerType::get(i8_type, 0);
204  auto ppi8_type = PointerType::get(pi8_type, 0);
205  auto pi32_type = PointerType::get(i32_type, 0);
206  auto pi64_type = PointerType::get(i64_type, 0);
207  auto ppi64_type = PointerType::get(pi64_type, 0);
208 
209  std::vector<Type*> query_args;
210  query_args.push_back(ppi8_type);
211  if (hoist_literals) {
212  query_args.push_back(pi8_type);
213  }
214  query_args.push_back(pi64_type);
215  query_args.push_back(pi64_type);
216  query_args.push_back(pi32_type);
217 
218  query_args.push_back(pi64_type);
219  query_args.push_back(ppi64_type);
220  query_args.push_back(i32_type);
221  query_args.push_back(pi64_type);
222  query_args.push_back(pi32_type);
223  query_args.push_back(pi32_type);
224 
225  FunctionType* query_func_type = FunctionType::get(
226  /*Result=*/Type::getVoidTy(mod->getContext()),
227  /*Params=*/query_args,
228  /*isVarArg=*/false);
229 
230  std::string query_template_name{"query_template"};
231  auto query_func_ptr = mod->getFunction(query_template_name);
232  CHECK(!query_func_ptr);
233 
234  query_func_ptr = Function::Create(
235  /*Type=*/query_func_type,
236  /*Linkage=*/GlobalValue::ExternalLinkage,
237  /*Name=*/query_template_name,
238  mod);
239  query_func_ptr->setCallingConv(CallingConv::C);
240 
241  Attributes query_func_pal;
242  {
243  SmallVector<Attributes, 4> Attrs;
244  Attributes PAS;
245  {
246  AttrBuilder B;
247  B.addAttribute(Attribute::NoCapture);
248  PAS = Attributes::get(mod->getContext(), 1U, B);
249  }
250 
251  Attrs.push_back(PAS);
252  {
253  AttrBuilder B;
254  B.addAttribute(Attribute::NoCapture);
255  PAS = Attributes::get(mod->getContext(), 2U, B);
256  }
257 
258  Attrs.push_back(PAS);
259 
260  {
261  AttrBuilder B;
262  B.addAttribute(Attribute::NoCapture);
263  Attrs.push_back(Attributes::get(mod->getContext(), 3U, B));
264  }
265 
266  {
267  AttrBuilder B;
268  B.addAttribute(Attribute::NoCapture);
269  Attrs.push_back(Attributes::get(mod->getContext(), 4U, B));
270  }
271 
272  Attrs.push_back(PAS);
273 
274  query_func_pal = Attributes::get(mod->getContext(), Attrs);
275  }
276  query_func_ptr->setAttributes(query_func_pal);
277 
278  Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
279  Value* byte_stream = &*query_arg_it;
280  byte_stream->setName("byte_stream");
281  Value* literals{nullptr};
282  if (hoist_literals) {
283  literals = &*(++query_arg_it);
284  literals->setName("literals");
285  }
286  Value* row_count_ptr = &*(++query_arg_it);
287  row_count_ptr->setName("row_count_ptr");
288  Value* frag_row_off_ptr = &*(++query_arg_it);
289  frag_row_off_ptr->setName("frag_row_off_ptr");
290  Value* max_matched_ptr = &*(++query_arg_it);
291  max_matched_ptr->setName("max_matched_ptr");
292  Value* agg_init_val = &*(++query_arg_it);
293  agg_init_val->setName("agg_init_val");
294  Value* out = &*(++query_arg_it);
295  out->setName("out");
296  Value* frag_idx = &*(++query_arg_it);
297  frag_idx->setName("frag_idx");
298  Value* join_hash_tables = &*(++query_arg_it);
299  join_hash_tables->setName("join_hash_tables");
300  Value* total_matched = &*(++query_arg_it);
301  total_matched->setName("total_matched");
302  Value* error_code = &*(++query_arg_it);
303  error_code->setName("error_code");
304 
305  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
306  auto bb_preheader =
307  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
308  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".for.body", query_func_ptr, 0);
309  auto bb_crit_edge =
310  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
311  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
312 
313  // Block (.entry)
314  std::vector<Value*> result_ptr_vec;
315  if (!is_estimate_query) {
316  for (size_t i = 0; i < aggr_col_count; ++i) {
317  auto result_ptr = new AllocaInst(i64_type, 0, "result", bb_entry);
318  result_ptr->setAlignment(8);
319  result_ptr_vec.push_back(result_ptr);
320  }
321  }
322 
323  LoadInst* row_count = new LoadInst(row_count_ptr, "row_count", false, bb_entry);
324  row_count->setAlignment(8);
325  row_count->setName("row_count");
326  std::vector<Value*> agg_init_val_vec;
327  if (!is_estimate_query) {
328  for (size_t i = 0; i < aggr_col_count; ++i) {
329  auto idx_lv = ConstantInt::get(i32_type, i);
330  auto agg_init_gep =
331  GetElementPtrInst::CreateInBounds(agg_init_val, idx_lv, "", bb_entry);
332  auto agg_init_val = new LoadInst(agg_init_gep, "", false, bb_entry);
333  agg_init_val->setAlignment(8);
334  agg_init_val_vec.push_back(agg_init_val);
335  auto init_val_st = new StoreInst(agg_init_val, result_ptr_vec[i], false, bb_entry);
336  init_val_st->setAlignment(8);
337  }
338  }
339 
340  CallInst* pos_start = CallInst::Create(func_pos_start, "pos_start", bb_entry);
341  pos_start->setCallingConv(CallingConv::C);
342  pos_start->setTailCall(true);
343  Attributes pos_start_pal;
344  pos_start->setAttributes(pos_start_pal);
345 
346  CallInst* pos_step = CallInst::Create(func_pos_step, "pos_step", bb_entry);
347  pos_step->setCallingConv(CallingConv::C);
348  pos_step->setTailCall(true);
349  Attributes pos_step_pal;
350  pos_step->setAttributes(pos_step_pal);
351 
352  CallInst* group_buff_idx = nullptr;
353  if (!is_estimate_query) {
354  group_buff_idx = CallInst::Create(func_group_buff_idx, "group_buff_idx", bb_entry);
355  group_buff_idx->setCallingConv(CallingConv::C);
356  group_buff_idx->setTailCall(true);
357  Attributes group_buff_idx_pal;
358  group_buff_idx->setAttributes(group_buff_idx_pal);
359  }
360 
361  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
362  ICmpInst* enter_or_not =
363  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
364  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
365 
366  // Block .loop.preheader
367  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
368  BranchInst::Create(bb_forbody, bb_preheader);
369 
370  // Block .forbody
371  Argument* pos_inc_pre = new Argument(i64_type);
372  PHINode* pos = PHINode::Create(i64_type, 2, "pos", bb_forbody);
373  pos->addIncoming(pos_start_i64, bb_preheader);
374  pos->addIncoming(pos_inc_pre, bb_forbody);
375 
376  std::vector<Value*> row_process_params;
377  row_process_params.insert(
378  row_process_params.end(), result_ptr_vec.begin(), result_ptr_vec.end());
379  if (is_estimate_query) {
380  row_process_params.push_back(new LoadInst(out, "", false, bb_forbody));
381  }
382  row_process_params.push_back(agg_init_val);
383  row_process_params.push_back(pos);
384  row_process_params.push_back(frag_row_off_ptr);
385  row_process_params.push_back(row_count_ptr);
386  if (hoist_literals) {
387  CHECK(literals);
388  row_process_params.push_back(literals);
389  }
390  CallInst* row_process =
391  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
392  row_process->setCallingConv(CallingConv::C);
393  row_process->setTailCall(false);
394  Attributes row_process_pal;
395  row_process->setAttributes(row_process_pal);
396 
397  BinaryOperator* pos_inc =
398  BinaryOperator::CreateNSW(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
399  ICmpInst* loop_or_exit =
400  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
401  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
402 
403  // Block ._crit_edge
404  std::vector<Instruction*> result_vec_pre;
405  if (!is_estimate_query) {
406  for (size_t i = 0; i < aggr_col_count; ++i) {
407  auto result = new LoadInst(result_ptr_vec[i], ".pre.result", false, bb_crit_edge);
408  result->setAlignment(8);
409  result_vec_pre.push_back(result);
410  }
411  }
412 
413  BranchInst::Create(bb_exit, bb_crit_edge);
414 
415  // Block .exit
416  std::vector<PHINode*> result_vec;
417  if (!is_estimate_query) {
418  for (int64_t i = aggr_col_count - 1; i >= 0; --i) {
419  auto result =
420  PHINode::Create(IntegerType::get(mod->getContext(), 64), 2, "", bb_exit);
421  result->addIncoming(result_vec_pre[i], bb_crit_edge);
422  result->addIncoming(agg_init_val_vec[i], bb_entry);
423  result_vec.insert(result_vec.begin(), result);
424  }
425  }
426 
427  if (!is_estimate_query) {
428  for (size_t i = 0; i < aggr_col_count; ++i) {
429  auto col_idx = ConstantInt::get(i32_type, i);
430  auto out_gep = GetElementPtrInst::CreateInBounds(out, col_idx, "", bb_exit);
431  auto col_buffer = new LoadInst(out_gep, "", false, bb_exit);
432  col_buffer->setAlignment(8);
433  auto slot_idx = BinaryOperator::CreateAdd(
434  group_buff_idx,
435  BinaryOperator::CreateMul(frag_idx, pos_step, "", bb_exit),
436  "",
437  bb_exit);
438  auto target_addr =
439  GetElementPtrInst::CreateInBounds(col_buffer, slot_idx, "", bb_exit);
440  StoreInst* result_st = new StoreInst(result_vec[i], target_addr, false, bb_exit);
441  result_st->setAlignment(8);
442  }
443  }
444 
445  ReturnInst::Create(mod->getContext(), bb_exit);
446 
447  // Resolve Forward References
448  pos_inc_pre->replaceAllUsesWith(pos_inc);
449  delete pos_inc_pre;
450 
451  if (verifyFunction(*query_func_ptr)) {
452  LOG(FATAL) << "Generated invalid code. ";
453  }
454 
455  return query_func_ptr;
456 }
457 
458 template <class Attributes>
459 llvm::Function* query_group_by_template_impl(llvm::Module* mod,
460  const bool hoist_literals,
461  const QueryMemoryDescriptor& query_mem_desc,
462  const ExecutorDeviceType device_type,
463  const bool check_scan_limit) {
464  using namespace llvm;
465 
466  auto func_pos_start = pos_start<Attributes>(mod);
467  CHECK(func_pos_start);
468  auto func_pos_step = pos_step<Attributes>(mod);
469  CHECK(func_pos_step);
470  auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
471  CHECK(func_group_buff_idx);
472  auto func_row_process = row_process<Attributes>(mod, 0, hoist_literals);
473  CHECK(func_row_process);
474  auto func_init_shared_mem = query_mem_desc.sharedMemBytes(device_type)
475  ? mod->getFunction("init_shared_mem")
476  : mod->getFunction("init_shared_mem_nop");
477  if (query_mem_desc.getGpuMemSharing() ==
479  func_init_shared_mem = mod->getFunction("init_shared_mem_dynamic");
480  }
481  CHECK(func_init_shared_mem);
482 
483  auto func_write_back = query_mem_desc.sharedMemBytes(device_type)
484  ? mod->getFunction("write_back")
485  : mod->getFunction("write_back_nop");
486  if (query_mem_desc.getGpuMemSharing() ==
488  func_write_back = mod->getFunction("write_back_smem_nop");
489  }
490  CHECK(func_write_back);
491 
492  auto i32_type = IntegerType::get(mod->getContext(), 32);
493  auto i64_type = IntegerType::get(mod->getContext(), 64);
494  auto pi8_type = PointerType::get(IntegerType::get(mod->getContext(), 8), 0);
495  auto pi32_type = PointerType::get(i32_type, 0);
496  auto pi64_type = PointerType::get(i64_type, 0);
497  auto ppi64_type = PointerType::get(pi64_type, 0);
498  auto ppi8_type = PointerType::get(pi8_type, 0);
499 
500  std::vector<Type*> query_args;
501  query_args.push_back(ppi8_type);
502  if (hoist_literals) {
503  query_args.push_back(pi8_type);
504  }
505  query_args.push_back(pi64_type);
506  query_args.push_back(pi64_type);
507  query_args.push_back(pi32_type);
508  query_args.push_back(pi64_type);
509 
510  query_args.push_back(ppi64_type);
511  query_args.push_back(i32_type);
512  query_args.push_back(pi64_type);
513  query_args.push_back(pi32_type);
514  query_args.push_back(pi32_type);
515 
516  FunctionType* query_func_type = FunctionType::get(
517  /*Result=*/Type::getVoidTy(mod->getContext()),
518  /*Params=*/query_args,
519  /*isVarArg=*/false);
520 
521  std::string query_name{"query_group_by_template"};
522  auto query_func_ptr = mod->getFunction(query_name);
523  CHECK(!query_func_ptr);
524 
525  query_func_ptr = Function::Create(
526  /*Type=*/query_func_type,
527  /*Linkage=*/GlobalValue::ExternalLinkage,
528  /*Name=*/"query_group_by_template",
529  mod);
530 
531  query_func_ptr->setCallingConv(CallingConv::C);
532 
533  Attributes query_func_pal;
534  {
535  SmallVector<Attributes, 4> Attrs;
536  Attributes PAS;
537  {
538  AttrBuilder B;
539  B.addAttribute(Attribute::ReadNone);
540  B.addAttribute(Attribute::NoCapture);
541  PAS = Attributes::get(mod->getContext(), 1U, B);
542  }
543 
544  Attrs.push_back(PAS);
545  {
546  AttrBuilder B;
547  B.addAttribute(Attribute::ReadOnly);
548  B.addAttribute(Attribute::NoCapture);
549  PAS = Attributes::get(mod->getContext(), 2U, B);
550  }
551 
552  Attrs.push_back(PAS);
553  {
554  AttrBuilder B;
555  B.addAttribute(Attribute::ReadNone);
556  B.addAttribute(Attribute::NoCapture);
557  PAS = Attributes::get(mod->getContext(), 3U, B);
558  }
559 
560  Attrs.push_back(PAS);
561  {
562  AttrBuilder B;
563  B.addAttribute(Attribute::ReadOnly);
564  B.addAttribute(Attribute::NoCapture);
565  PAS = Attributes::get(mod->getContext(), 4U, B);
566  }
567 
568  Attrs.push_back(PAS);
569  {
570  AttrBuilder B;
571  B.addAttribute(Attribute::UWTable);
572  PAS = Attributes::get(mod->getContext(), ~0U, B);
573  }
574 
575  Attrs.push_back(PAS);
576 
577  query_func_pal = Attributes::get(mod->getContext(), Attrs);
578  }
579  query_func_ptr->setAttributes(query_func_pal);
580 
581  Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
582  Value* byte_stream = &*query_arg_it;
583  byte_stream->setName("byte_stream");
584  Value* literals{nullptr};
585  if (hoist_literals) {
586  literals = &*(++query_arg_it);
587  ;
588  literals->setName("literals");
589  }
590  Value* row_count_ptr = &*(++query_arg_it);
591  row_count_ptr->setName("row_count_ptr");
592  Value* frag_row_off_ptr = &*(++query_arg_it);
593  frag_row_off_ptr->setName("frag_row_off_ptr");
594  Value* max_matched_ptr = &*(++query_arg_it);
595  max_matched_ptr->setName("max_matched_ptr");
596  Value* agg_init_val = &*(++query_arg_it);
597  agg_init_val->setName("agg_init_val");
598  Value* group_by_buffers = &*(++query_arg_it);
599  group_by_buffers->setName("group_by_buffers");
600  Value* frag_idx = &*(++query_arg_it);
601  frag_idx->setName("frag_idx");
602  Value* join_hash_tables = &*(++query_arg_it);
603  join_hash_tables->setName("join_hash_tables");
604  Value* total_matched = &*(++query_arg_it);
605  total_matched->setName("total_matched");
606  Value* error_code = &*(++query_arg_it);
607  error_code->setName("error_code");
608 
609  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
610  auto bb_preheader =
611  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
612  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".forbody", query_func_ptr, 0);
613  auto bb_crit_edge =
614  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
615  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
616 
617  // Block .entry
618  LoadInst* row_count = new LoadInst(row_count_ptr, "", false, bb_entry);
619  row_count->setAlignment(8);
620  row_count->setName("row_count");
621 
622  LoadInst* max_matched = new LoadInst(max_matched_ptr, "", false, bb_entry);
623  max_matched->setAlignment(4);
624 
625  auto crt_matched_ptr = new AllocaInst(i32_type, 0, "crt_matched", bb_entry);
626  auto old_total_matched_ptr = new AllocaInst(i32_type, 0, "old_total_matched", bb_entry);
627  CallInst* pos_start = CallInst::Create(func_pos_start, "", bb_entry);
628  pos_start->setCallingConv(CallingConv::C);
629  pos_start->setTailCall(true);
630  Attributes pos_start_pal;
631  pos_start->setAttributes(pos_start_pal);
632 
633  CallInst* pos_step = CallInst::Create(func_pos_step, "", bb_entry);
634  pos_step->setCallingConv(CallingConv::C);
635  pos_step->setTailCall(true);
636  Attributes pos_step_pal;
637  pos_step->setAttributes(pos_step_pal);
638 
639  CallInst* group_buff_idx = CallInst::Create(func_group_buff_idx, "", bb_entry);
640  group_buff_idx->setCallingConv(CallingConv::C);
641  group_buff_idx->setTailCall(true);
642  Attributes group_buff_idx_pal;
643  group_buff_idx->setAttributes(group_buff_idx_pal);
644 
645  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
646  const PointerType* Ty = dyn_cast<PointerType>(group_by_buffers->getType());
647  CHECK(Ty);
648  GetElementPtrInst* group_by_buffers_gep = GetElementPtrInst::Create(
649  Ty->getElementType(), group_by_buffers, group_buff_idx, "", bb_entry);
650  LoadInst* col_buffer = new LoadInst(group_by_buffers_gep, "", false, bb_entry);
651  col_buffer->setName("col_buffer");
652  col_buffer->setAlignment(8);
653 
654  llvm::ConstantInt* shared_mem_num_elements_lv = nullptr;
655  llvm::ConstantInt* shared_mem_bytes_lv = nullptr;
656  llvm::CallInst* result_buffer = nullptr;
657  if (query_mem_desc.getGpuMemSharing() ==
659  int32_t num_shared_mem_buckets = query_mem_desc.getEntryCount() + 1;
660  shared_mem_bytes_lv =
661  ConstantInt::get(i32_type, query_mem_desc.sharedMemBytes(device_type));
662  shared_mem_num_elements_lv = ConstantInt::get(i32_type, num_shared_mem_buckets);
663  result_buffer = CallInst::Create(
664  func_init_shared_mem,
665  std::vector<llvm::Value*>{col_buffer, shared_mem_num_elements_lv},
666  "",
667  bb_entry);
668  } else {
669  shared_mem_bytes_lv =
670  ConstantInt::get(i32_type, query_mem_desc.sharedMemBytes(device_type));
671  result_buffer =
672  CallInst::Create(func_init_shared_mem,
673  std::vector<llvm::Value*>{col_buffer, shared_mem_bytes_lv},
674  "",
675  bb_entry);
676  }
677  result_buffer->setName("result_buffer");
678 
679  ICmpInst* enter_or_not =
680  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
681  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
682 
683  // Block .loop.preheader
684  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
685  BranchInst::Create(bb_forbody, bb_preheader);
686 
687  // Block .forbody
688  Argument* pos_pre = new Argument(i64_type);
689  PHINode* pos = PHINode::Create(i64_type, check_scan_limit ? 3 : 2, "pos", bb_forbody);
690 
691  std::vector<Value*> row_process_params;
692  row_process_params.push_back(result_buffer);
693  row_process_params.push_back(crt_matched_ptr);
694  row_process_params.push_back(total_matched);
695  row_process_params.push_back(old_total_matched_ptr);
696  row_process_params.push_back(max_matched_ptr);
697  row_process_params.push_back(agg_init_val);
698  row_process_params.push_back(pos);
699  row_process_params.push_back(frag_row_off_ptr);
700  row_process_params.push_back(row_count_ptr);
701  if (hoist_literals) {
702  CHECK(literals);
703  row_process_params.push_back(literals);
704  }
705  if (check_scan_limit) {
706  new StoreInst(ConstantInt::get(IntegerType::get(mod->getContext(), 32), 0),
707  crt_matched_ptr,
708  bb_forbody);
709  }
710  CallInst* row_process =
711  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
712  row_process->setCallingConv(CallingConv::C);
713  row_process->setTailCall(true);
714  Attributes row_process_pal;
715  row_process->setAttributes(row_process_pal);
716 
717  // Forcing all threads within a warp to be synchronized (Compute >= 7.x)
718  if (query_mem_desc.isWarpSyncRequired(device_type)) {
719  auto func_sync_warp_protected = mod->getFunction("sync_warp_protected");
720  CHECK(func_sync_warp_protected);
721  CallInst::Create(func_sync_warp_protected,
722  std::vector<llvm::Value*>{pos, row_count},
723  "",
724  bb_forbody);
725  }
726 
727  BinaryOperator* pos_inc =
728  BinaryOperator::Create(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
729  ICmpInst* loop_or_exit =
730  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
731  if (check_scan_limit) {
732  auto crt_matched = new LoadInst(crt_matched_ptr, "crt_matched", false, bb_forbody);
733  auto filter_match = BasicBlock::Create(
734  mod->getContext(), "filter_match", query_func_ptr, bb_crit_edge);
735  llvm::Value* new_total_matched =
736  new LoadInst(old_total_matched_ptr, "", false, filter_match);
737  new_total_matched =
738  BinaryOperator::CreateAdd(new_total_matched, crt_matched, "", filter_match);
739  CHECK(new_total_matched);
740  ICmpInst* limit_not_reached = new ICmpInst(*filter_match,
741  ICmpInst::ICMP_SLT,
742  new_total_matched,
743  max_matched,
744  "limit_not_reached");
745  BranchInst::Create(
746  bb_forbody,
747  bb_crit_edge,
748  BinaryOperator::Create(
749  BinaryOperator::And, loop_or_exit, limit_not_reached, "", filter_match),
750  filter_match);
751  auto filter_nomatch = BasicBlock::Create(
752  mod->getContext(), "filter_nomatch", query_func_ptr, bb_crit_edge);
753  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, filter_nomatch);
754  ICmpInst* crt_matched_nz = new ICmpInst(
755  *bb_forbody, ICmpInst::ICMP_NE, crt_matched, ConstantInt::get(i32_type, 0), "");
756  BranchInst::Create(filter_match, filter_nomatch, crt_matched_nz, bb_forbody);
757  pos->addIncoming(pos_start_i64, bb_preheader);
758  pos->addIncoming(pos_pre, filter_match);
759  pos->addIncoming(pos_pre, filter_nomatch);
760  } else {
761  pos->addIncoming(pos_start_i64, bb_preheader);
762  pos->addIncoming(pos_pre, bb_forbody);
763  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
764  }
765 
766  // Block ._crit_edge
767  BranchInst::Create(bb_exit, bb_crit_edge);
768 
769  // Block .exit
770  if (query_mem_desc.getGpuMemSharing() ==
772  result_buffer->setName("shared_mem_result");
773  col_buffer->setName("col_buffer_global");
774  CHECK_LT(query_mem_desc.getTargetIdxForKey(),
775  2); // Saman: not expected for the shared memory design if more than 1
776  // Depending on the aggregate's target expression index, we choose different memory
777  // layout for the shared memory
778  auto func_agg_from_smem_to_gmem =
779  (query_mem_desc.getTargetIdxForKey() == 0)
780  ? mod->getFunction("agg_from_smem_to_gmem_count_binId")
781  : mod->getFunction("agg_from_smem_to_gmem_binId_count");
782  CHECK(func_agg_from_smem_to_gmem);
783  CallInst::Create(
784  func_agg_from_smem_to_gmem,
785  std::vector<Value*>{col_buffer, result_buffer, (shared_mem_num_elements_lv)},
786  "",
787  bb_exit);
788  }
789 
790  CallInst::Create(func_write_back,
791  std::vector<Value*>{col_buffer, result_buffer, shared_mem_bytes_lv},
792  "",
793  bb_exit);
794  ReturnInst::Create(mod->getContext(), bb_exit);
795 
796  // Resolve Forward References
797  pos_pre->replaceAllUsesWith(pos_inc);
798  delete pos_pre;
799 
800  if (verifyFunction(*query_func_ptr, &llvm::errs())) {
801  LOG(FATAL) << "Generated invalid code. ";
802  }
803 
804  return query_func_ptr;
805 }
806 
807 #if LLVM_VERSION_MAJOR >= 6
808 llvm::Function* query_template(llvm::Module* module,
809  const size_t aggr_col_count,
810  const bool hoist_literals,
811  const bool is_estimate_query) {
812  return query_template_impl<llvm::AttributeList>(
813  module, aggr_col_count, hoist_literals, is_estimate_query);
814 }
815 llvm::Function* query_group_by_template(llvm::Module* module,
816  const bool hoist_literals,
817  const QueryMemoryDescriptor& query_mem_desc,
818  const ExecutorDeviceType device_type,
819  const bool check_scan_limit) {
820  return query_group_by_template_impl<llvm::AttributeList>(
821  module, hoist_literals, query_mem_desc, device_type, check_scan_limit);
822 }
823 #else
824 llvm::Function* query_template(llvm::Module* module,
825  const size_t aggr_col_count,
826  const bool hoist_literals,
827  const bool is_estimate_query) {
828  return query_template_impl<llvm::AttributeSet>(
829  module, aggr_col_count, hoist_literals, is_estimate_query);
830 }
831 llvm::Function* query_group_by_template(llvm::Module* module,
832  const bool hoist_literals,
833  const QueryMemoryDescriptor& query_mem_desc,
834  const ExecutorDeviceType device_type,
835  const bool check_scan_limit) {
836  return query_group_by_template_impl<llvm::AttributeSet>(
837  module, hoist_literals, query_mem_desc, device_type, check_scan_limit);
838 }
839 #endif
llvm::Function * query_template_impl(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query)
const int8_t const int64_t const uint64_t const int32_t const int64_t int64_t uint32_t const int64_t * join_hash_tables
size_t sharedMemBytes(const ExecutorDeviceType) const
ExecutorDeviceType
#define LOG(tag)
Definition: Logger.h:182
GroupByMemSharing getGpuMemSharing() const
llvm::Function * query_template(llvm::Module *module, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query)
const int8_t const int64_t const uint64_t const int32_t const int64_t int64_t uint32_t const int64_t int32_t int32_t * total_matched
llvm::Function * group_buff_idx(llvm::Module *mod)
llvm::Function * default_func_builder(llvm::Module *mod, const std::string &name)
llvm::Function * query_group_by_template_impl(llvm::Module *mod, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit)
const int8_t const int64_t const uint64_t const int32_t const int64_t int64_t uint32_t const int64_t int32_t * error_code
bool isWarpSyncRequired(const ExecutorDeviceType) const
llvm::Function * query_group_by_template(llvm::Module *module, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit)
const int8_t const int64_t const uint64_t const int32_t * max_matched
int32_t getTargetIdxForKey() const
#define CHECK_LT(x, y)
Definition: Logger.h:197
const int64_t const uint32_t const uint32_t const uint32_t const bool const bool const int32_t frag_idx
#define CHECK(condition)
Definition: Logger.h:187
const int8_t * literals
const int8_t const int64_t const uint64_t const int32_t const int64_t int64_t ** out
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)