OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
QueryTemplateGenerator.cpp File Reference
#include "QueryTemplateGenerator.h"
#include "IRCodegenUtils.h"
#include "Logger/Logger.h"
#include <llvm/IR/Constants.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Verifier.h>
+ Include dependency graph for QueryTemplateGenerator.cpp:

Go to the source code of this file.

Classes

class  anonymous_namespace{QueryTemplateGenerator.cpp}::Params< NTYPES >
 

Namespaces

 anonymous_namespace{QueryTemplateGenerator.cpp}
 

Functions

template<typename... ATTRS>
llvm::AttributeList anonymous_namespace{QueryTemplateGenerator.cpp}::make_attribute_list (llvm::Module const *const mod, unsigned const index, ATTRS const ...attrs)
 
template<bool IS_GROUP_BY, size_t NTYPES = 13u>
Params< NTYPES > anonymous_namespace{QueryTemplateGenerator.cpp}::make_params (llvm::Module const *const mod, bool const hoist_literals)
 
llvm::Typeanonymous_namespace{QueryTemplateGenerator.cpp}::get_pointer_element_type (llvm::Value *value)
 
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::default_func_builder (llvm::Module *mod, const std::string &name)
 
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::pos_start (llvm::Module *mod)
 
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::group_buff_idx (llvm::Module *mod)
 
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::pos_step (llvm::Module *mod)
 
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::row_process (llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)
 
std::tuple< llvm::Function
*, llvm::CallInst * > 
query_template (llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
 
std::tuple< llvm::Function
*, llvm::CallInst * > 
query_group_by_template (llvm::Module *mod, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
 

Function Documentation

std::tuple<llvm::Function*, llvm::CallInst*> query_group_by_template ( llvm::Module *  mod,
const bool  hoist_literals,
const QueryMemoryDescriptor query_mem_desc,
const ExecutorDeviceType  device_type,
const bool  check_scan_limit,
const GpuSharedMemoryContext gpu_smem_context 
)

Definition at line 553 of file QueryTemplateGenerator.cpp.

References CHECK, logger::FATAL, get_arg_by_name(), anonymous_namespace{QueryTemplateGenerator.cpp}::get_pointer_element_type(), GpuSharedMemoryContext::getSharedMemorySize(), GPU, anonymous_namespace{QueryTemplateGenerator.cpp}::group_buff_idx(), QueryMemoryDescriptor::hasVarlenOutput(), GpuSharedMemoryContext::isSharedMemoryUsed(), QueryMemoryDescriptor::isWarpSyncRequired(), LLVM_ALIGN, LOG, anonymous_namespace{QueryTemplateGenerator.cpp}::pos_start(), anonymous_namespace{QueryTemplateGenerator.cpp}::pos_step(), and anonymous_namespace{QueryTemplateGenerator.cpp}::row_process().

559  {
560  if (gpu_smem_context.isSharedMemoryUsed()) {
561  CHECK(device_type == ExecutorDeviceType::GPU);
562  }
563  using namespace llvm;
564 
565  auto* const i32_type = llvm::IntegerType::get(mod->getContext(), 32);
566  auto* const i64_type = llvm::IntegerType::get(mod->getContext(), 64);
567 
568  llvm::Function* const func_pos_start = pos_start(mod);
569  CHECK(func_pos_start);
570  llvm::Function* const func_pos_step = pos_step(mod);
571  CHECK(func_pos_step);
572  llvm::Function* const func_group_buff_idx = group_buff_idx(mod);
573  CHECK(func_group_buff_idx);
574  llvm::Function* const func_row_process = row_process(mod, 0, hoist_literals);
575  CHECK(func_row_process);
576  llvm::Function* const func_init_shared_mem =
577  gpu_smem_context.isSharedMemoryUsed() ? mod->getFunction("init_shared_mem")
578  : mod->getFunction("init_shared_mem_nop");
579  CHECK(func_init_shared_mem);
580 
581  auto func_write_back = mod->getFunction("write_back_nop");
582  CHECK(func_write_back);
583 
584  constexpr bool IS_GROUP_BY = true;
585  Params query_func_params = make_params<IS_GROUP_BY>(mod, hoist_literals);
586 
587  FunctionType* query_func_type = FunctionType::get(
588  /*Result=*/Type::getVoidTy(mod->getContext()),
589  /*Params=*/query_func_params.types(),
590  /*isVarArg=*/false);
591 
592  std::string query_name{"query_group_by_template"};
593  auto query_func_ptr = mod->getFunction(query_name);
594  CHECK(!query_func_ptr);
595 
596  query_func_ptr = Function::Create(
597  /*Type=*/query_func_type,
598  /*Linkage=*/GlobalValue::ExternalLinkage,
599  /*Name=*/"query_group_by_template",
600  mod);
601  query_func_ptr->setCallingConv(CallingConv::C);
602  query_func_ptr->setAttributes(query_func_params.attributeList());
603  query_func_params.setNames(query_func_ptr->arg_begin());
604 
605  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
606  auto bb_preheader =
607  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
608  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".forbody", query_func_ptr, 0);
609  auto bb_crit_edge =
610  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
611  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
612 
613  // Block .entry
614  llvm::Value* const row_count_ptr = get_arg_by_name(query_func_ptr, "row_count_ptr");
615  LoadInst* row_count = new LoadInst(
616  get_pointer_element_type(row_count_ptr), row_count_ptr, "", false, bb_entry);
617  row_count->setAlignment(LLVM_ALIGN(8));
618  row_count->setName("row_count");
619 
620  llvm::Value* const max_matched_ptr = get_arg_by_name(query_func_ptr, "max_matched_ptr");
621  LoadInst* max_matched = new LoadInst(
622  get_pointer_element_type(max_matched_ptr), max_matched_ptr, "", false, bb_entry);
623  max_matched->setAlignment(LLVM_ALIGN(8));
624 
625  auto crt_matched_ptr = new AllocaInst(i32_type, 0, "crt_matched", bb_entry);
626  auto old_total_matched_ptr = new AllocaInst(i32_type, 0, "old_total_matched", bb_entry);
627  CallInst* pos_start = CallInst::Create(func_pos_start, "", bb_entry);
628  pos_start->setCallingConv(CallingConv::C);
629  pos_start->setTailCall(true);
630  llvm::AttributeList pos_start_pal;
631  pos_start->setAttributes(pos_start_pal);
632 
633  CallInst* pos_step = CallInst::Create(func_pos_step, "", bb_entry);
634  pos_step->setCallingConv(CallingConv::C);
635  pos_step->setTailCall(true);
636  llvm::AttributeList pos_step_pal;
637  pos_step->setAttributes(pos_step_pal);
638 
639  CallInst* group_buff_idx_call = CallInst::Create(func_group_buff_idx, "", bb_entry);
640  group_buff_idx_call->setCallingConv(CallingConv::C);
641  group_buff_idx_call->setTailCall(true);
642  llvm::AttributeList group_buff_idx_pal;
643  group_buff_idx_call->setAttributes(group_buff_idx_pal);
644  Value* group_buff_idx = group_buff_idx_call;
645 
646  auto* const group_by_buffers = get_arg_by_name(query_func_ptr, "group_by_buffers");
647  const PointerType* Ty = dyn_cast<PointerType>(group_by_buffers->getType());
648  CHECK(Ty);
649 
650  Value* varlen_output_buffer{nullptr};
651  if (query_mem_desc.hasVarlenOutput()) {
652  // make the varlen buffer the _first_ 8 byte value in the group by buffers double ptr,
653  // and offset the group by buffers index by 8 bytes
654  auto varlen_output_buffer_gep = GetElementPtrInst::Create(
655  Ty->getPointerElementType(),
656  group_by_buffers,
657  llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 0),
658  "",
659  bb_entry);
660  varlen_output_buffer =
661  new LoadInst(get_pointer_element_type(varlen_output_buffer_gep),
662  varlen_output_buffer_gep,
663  "varlen_output_buffer",
664  false,
665  bb_entry);
666 
667  group_buff_idx = BinaryOperator::Create(
668  Instruction::Add,
670  llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 1),
671  "group_buff_idx_varlen_offset",
672  bb_entry);
673  } else {
674  varlen_output_buffer =
675  ConstantPointerNull::get(Type::getInt64PtrTy(mod->getContext()));
676  }
677  CHECK(varlen_output_buffer);
678 
679  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
680  GetElementPtrInst* group_by_buffers_gep = GetElementPtrInst::Create(
681  Ty->getPointerElementType(), group_by_buffers, group_buff_idx, "", bb_entry);
682  LoadInst* col_buffer = new LoadInst(get_pointer_element_type(group_by_buffers_gep),
683  group_by_buffers_gep,
684  "",
685  false,
686  bb_entry);
687  col_buffer->setName("col_buffer");
688  col_buffer->setAlignment(LLVM_ALIGN(8));
689 
690  llvm::ConstantInt* shared_mem_bytes_lv =
691  ConstantInt::get(i32_type, gpu_smem_context.getSharedMemorySize());
692  // TODO(Saman): change this further, normal path should not go through this
693  llvm::CallInst* result_buffer =
694  CallInst::Create(func_init_shared_mem,
695  std::vector<llvm::Value*>{col_buffer, shared_mem_bytes_lv},
696  "result_buffer",
697  bb_entry);
698 
699  ICmpInst* enter_or_not =
700  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
701  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
702 
703  // Block .loop.preheader
704  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
705  BranchInst::Create(bb_forbody, bb_preheader);
706 
707  // Block .forbody
708  Argument* pos_pre = new Argument(i64_type);
709  PHINode* pos = PHINode::Create(i64_type, check_scan_limit ? 3 : 2, "pos", bb_forbody);
710 
711  std::vector<Value*> row_process_params;
712  row_process_params.push_back(result_buffer);
713  row_process_params.push_back(varlen_output_buffer);
714  row_process_params.push_back(crt_matched_ptr);
715  row_process_params.push_back(get_arg_by_name(query_func_ptr, "total_matched"));
716  row_process_params.push_back(old_total_matched_ptr);
717  row_process_params.push_back(max_matched_ptr);
718  row_process_params.push_back(get_arg_by_name(query_func_ptr, "agg_init_val"));
719  row_process_params.push_back(pos);
720  row_process_params.push_back(get_arg_by_name(query_func_ptr, "frag_row_off_ptr"));
721  row_process_params.push_back(row_count_ptr);
722  if (hoist_literals) {
723  row_process_params.push_back(get_arg_by_name(query_func_ptr, "literals"));
724  }
725  if (check_scan_limit) {
726  new StoreInst(ConstantInt::get(IntegerType::get(mod->getContext(), 32), 0),
727  crt_matched_ptr,
728  bb_forbody);
729  }
730  CallInst* row_process =
731  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
732  row_process->setCallingConv(CallingConv::C);
733  row_process->setTailCall(true);
734  llvm::AttributeList row_process_pal;
735  row_process->setAttributes(row_process_pal);
736 
737  // Forcing all threads within a warp to be synchronized (Compute >= 7.x)
738  if (query_mem_desc.isWarpSyncRequired(device_type)) {
739  auto func_sync_warp_protected = mod->getFunction("sync_warp_protected");
740  CHECK(func_sync_warp_protected);
741  CallInst::Create(func_sync_warp_protected,
742  std::vector<llvm::Value*>{pos, row_count},
743  "",
744  bb_forbody);
745  }
746 
747  BinaryOperator* pos_inc =
748  BinaryOperator::Create(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
749  ICmpInst* loop_or_exit =
750  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
751  if (check_scan_limit) {
752  auto crt_matched = new LoadInst(get_pointer_element_type(crt_matched_ptr),
753  crt_matched_ptr,
754  "crt_matched",
755  false,
756  bb_forbody);
757  auto filter_match = BasicBlock::Create(
758  mod->getContext(), "filter_match", query_func_ptr, bb_crit_edge);
759  llvm::Value* new_total_matched =
760  new LoadInst(get_pointer_element_type(old_total_matched_ptr),
761  old_total_matched_ptr,
762  "",
763  false,
764  filter_match);
765  new_total_matched =
766  BinaryOperator::CreateAdd(new_total_matched, crt_matched, "", filter_match);
767  CHECK(new_total_matched);
768  ICmpInst* limit_not_reached = new ICmpInst(*filter_match,
769  ICmpInst::ICMP_SLT,
770  new_total_matched,
771  max_matched,
772  "limit_not_reached");
773  BranchInst::Create(
774  bb_forbody,
775  bb_crit_edge,
776  BinaryOperator::Create(
777  BinaryOperator::And, loop_or_exit, limit_not_reached, "", filter_match),
778  filter_match);
779  auto filter_nomatch = BasicBlock::Create(
780  mod->getContext(), "filter_nomatch", query_func_ptr, bb_crit_edge);
781  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, filter_nomatch);
782  ICmpInst* crt_matched_nz = new ICmpInst(
783  *bb_forbody, ICmpInst::ICMP_NE, crt_matched, ConstantInt::get(i32_type, 0), "");
784  BranchInst::Create(filter_match, filter_nomatch, crt_matched_nz, bb_forbody);
785  pos->addIncoming(pos_start_i64, bb_preheader);
786  pos->addIncoming(pos_pre, filter_match);
787  pos->addIncoming(pos_pre, filter_nomatch);
788  } else {
789  pos->addIncoming(pos_start_i64, bb_preheader);
790  pos->addIncoming(pos_pre, bb_forbody);
791  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
792  }
793 
794  // Block ._crit_edge
795  BranchInst::Create(bb_exit, bb_crit_edge);
796 
797  // Block .exit
798  CallInst::Create(func_write_back,
799  std::vector<Value*>{col_buffer, result_buffer, shared_mem_bytes_lv},
800  "",
801  bb_exit);
802 
803  ReturnInst::Create(mod->getContext(), bb_exit);
804 
805  // Resolve Forward References
806  pos_pre->replaceAllUsesWith(pos_inc);
807  delete pos_pre;
808 
809  if (verifyFunction(*query_func_ptr, &llvm::errs())) {
810  LOG(FATAL) << "Generated invalid code. ";
811  }
812 
813  return {query_func_ptr, row_process};
814 }
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)
#define LOG(tag)
Definition: Logger.h:285
size_t getSharedMemorySize() const
llvm::Function * group_buff_idx(llvm::Module *mod)
#define LLVM_ALIGN(alignment)
llvm::Value * get_arg_by_name(llvm::Function *func, const std::string &name)
Definition: Execute.h:168
bool isWarpSyncRequired(const ExecutorDeviceType) const
#define CHECK(condition)
Definition: Logger.h:291
llvm::Type * get_pointer_element_type(llvm::Value *value)

+ Here is the call graph for this function:

std::tuple<llvm::Function*, llvm::CallInst*> query_template ( llvm::Module *  mod,
const size_t  aggr_col_count,
const bool  hoist_literals,
const bool  is_estimate_query,
const GpuSharedMemoryContext gpu_smem_context 
)

If GPU shared memory optimization is disabled, for each aggregate target, threads copy back their aggregate results (stored in registers) back into memory. This process is performed per processed fragment. In the host the final results are reduced (per target, for all threads and all fragments).

If GPU Shared memory optimization is enabled, we properly (atomically) aggregate all thread's results into memory, which makes the final reduction on host much cheaper. Here, we call a noop dummy write back function which will be properly replaced at runtime depending on the target expressions.

Definition at line 266 of file QueryTemplateGenerator.cpp.

References anonymous_namespace{RuntimeFunctions.cpp}::agg_func(), CHECK, logger::FATAL, get_arg_by_name(), anonymous_namespace{QueryTemplateGenerator.cpp}::get_pointer_element_type(), anonymous_namespace{QueryTemplateGenerator.cpp}::group_buff_idx(), GpuSharedMemoryContext::isSharedMemoryUsed(), LLVM_ALIGN, LOG, anonymous_namespace{QueryTemplateGenerator.cpp}::pos_start(), anonymous_namespace{QueryTemplateGenerator.cpp}::pos_step(), run_benchmark_import::result, anonymous_namespace{QueryTemplateGenerator.cpp}::row_process(), and to_string().

271  {
272  using namespace llvm;
273 
274  auto* const i32_type = llvm::IntegerType::get(mod->getContext(), 32);
275  auto* const i64_type = llvm::IntegerType::get(mod->getContext(), 64);
276 
277  llvm::Function* const func_pos_start = pos_start(mod);
278  CHECK(func_pos_start);
279  llvm::Function* const func_pos_step = pos_step(mod);
280  CHECK(func_pos_step);
281  llvm::Function* const func_group_buff_idx = group_buff_idx(mod);
282  CHECK(func_group_buff_idx);
283  llvm::Function* const func_row_process =
284  row_process(mod, is_estimate_query ? 1 : aggr_col_count, hoist_literals);
285  CHECK(func_row_process);
286 
287  constexpr bool IS_GROUP_BY = false;
288  Params query_func_params = make_params<IS_GROUP_BY>(mod, hoist_literals);
289 
290  FunctionType* query_func_type = FunctionType::get(
291  /*Result=*/Type::getVoidTy(mod->getContext()),
292  /*Params=*/query_func_params.types(),
293  /*isVarArg=*/false);
294 
295  std::string query_template_name{"query_template"};
296  auto query_func_ptr = mod->getFunction(query_template_name);
297  CHECK(!query_func_ptr);
298 
299  query_func_ptr = Function::Create(
300  /*Type=*/query_func_type,
301  /*Linkage=*/GlobalValue::ExternalLinkage,
302  /*Name=*/query_template_name,
303  mod);
304  query_func_ptr->setCallingConv(CallingConv::C);
305  query_func_ptr->setAttributes(query_func_params.attributeList());
306  query_func_params.setNames(query_func_ptr->arg_begin());
307 
308  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
309  auto bb_preheader =
310  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
311  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".for.body", query_func_ptr, 0);
312  auto bb_crit_edge =
313  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
314  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
315 
316  // Block (.entry)
317  llvm::Value* const agg_init_val = get_arg_by_name(query_func_ptr, "agg_init_val");
318  std::vector<Value*> result_ptr_vec;
319  llvm::CallInst* smem_output_buffer{nullptr};
320  if (!is_estimate_query) {
321  for (size_t i = 0; i < aggr_col_count; ++i) {
322  auto result_ptr = new AllocaInst(i64_type, 0, "result", bb_entry);
323  result_ptr->setAlignment(LLVM_ALIGN(8));
324  result_ptr_vec.push_back(result_ptr);
325  }
326  if (gpu_smem_context.isSharedMemoryUsed()) {
327  auto init_smem_func = mod->getFunction("init_shared_mem");
328  CHECK(init_smem_func);
329  // only one slot per aggregate column is needed, and so we can initialize shared
330  // memory buffer for intermediate results to be exactly like the agg_init_val array
331  smem_output_buffer = CallInst::Create(
332  init_smem_func,
333  std::vector<llvm::Value*>{
334  agg_init_val,
335  llvm::ConstantInt::get(i32_type, aggr_col_count * sizeof(int64_t))},
336  "smem_buffer",
337  bb_entry);
338  }
339  }
340 
341  llvm::Value* const row_count_ptr = get_arg_by_name(query_func_ptr, "row_count_ptr");
342  LoadInst* row_count = new LoadInst(get_pointer_element_type(row_count_ptr),
343  row_count_ptr,
344  "row_count",
345  false,
346  bb_entry);
347  row_count->setAlignment(LLVM_ALIGN(8));
348  row_count->setName("row_count");
349  std::vector<Value*> agg_init_val_vec;
350  if (!is_estimate_query) {
351  for (size_t i = 0; i < aggr_col_count; ++i) {
352  auto idx_lv = ConstantInt::get(i32_type, i);
353  auto agg_init_gep = GetElementPtrInst::CreateInBounds(
354  agg_init_val->getType()->getPointerElementType(),
355  agg_init_val,
356  idx_lv,
357  "",
358  bb_entry);
359  auto agg_init_val = new LoadInst(
360  get_pointer_element_type(agg_init_gep), agg_init_gep, "", false, bb_entry);
361  agg_init_val->setAlignment(LLVM_ALIGN(8));
362  agg_init_val_vec.push_back(agg_init_val);
363  auto init_val_st = new StoreInst(agg_init_val, result_ptr_vec[i], false, bb_entry);
364  init_val_st->setAlignment(LLVM_ALIGN(8));
365  }
366  }
367 
368  CallInst* pos_start = CallInst::Create(func_pos_start, "pos_start", bb_entry);
369  pos_start->setCallingConv(CallingConv::C);
370  pos_start->setTailCall(true);
371  llvm::AttributeList pos_start_pal;
372  pos_start->setAttributes(pos_start_pal);
373 
374  CallInst* pos_step = CallInst::Create(func_pos_step, "pos_step", bb_entry);
375  pos_step->setCallingConv(CallingConv::C);
376  pos_step->setTailCall(true);
377  llvm::AttributeList pos_step_pal;
378  pos_step->setAttributes(pos_step_pal);
379 
380  CallInst* group_buff_idx = nullptr;
381  if (!is_estimate_query) {
382  group_buff_idx = CallInst::Create(func_group_buff_idx, "group_buff_idx", bb_entry);
383  group_buff_idx->setCallingConv(CallingConv::C);
384  group_buff_idx->setTailCall(true);
385  llvm::AttributeList group_buff_idx_pal;
386  group_buff_idx->setAttributes(group_buff_idx_pal);
387  }
388 
389  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
390  ICmpInst* enter_or_not =
391  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
392  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
393 
394  // Block .loop.preheader
395  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
396  BranchInst::Create(bb_forbody, bb_preheader);
397 
398  // Block .forbody
399  Argument* pos_inc_pre = new Argument(i64_type);
400  PHINode* pos = PHINode::Create(i64_type, 2, "pos", bb_forbody);
401  pos->addIncoming(pos_start_i64, bb_preheader);
402  pos->addIncoming(pos_inc_pre, bb_forbody);
403 
404  std::vector<Value*> row_process_params;
405  llvm::Value* const out = get_arg_by_name(query_func_ptr, "out");
406  row_process_params.insert(
407  row_process_params.end(), result_ptr_vec.begin(), result_ptr_vec.end());
408  if (is_estimate_query) {
409  row_process_params.push_back(
410  new LoadInst(get_pointer_element_type(out), out, "", false, bb_forbody));
411  }
412  row_process_params.push_back(agg_init_val);
413  row_process_params.push_back(pos);
414  row_process_params.push_back(get_arg_by_name(query_func_ptr, "frag_row_off_ptr"));
415  row_process_params.push_back(row_count_ptr);
416  if (hoist_literals) {
417  row_process_params.push_back(get_arg_by_name(query_func_ptr, "literals"));
418  }
419  CallInst* row_process =
420  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
421  row_process->setCallingConv(CallingConv::C);
422  row_process->setTailCall(false);
423  llvm::AttributeList row_process_pal;
424  row_process->setAttributes(row_process_pal);
425 
426  BinaryOperator* pos_inc =
427  BinaryOperator::CreateNSW(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
428  ICmpInst* loop_or_exit =
429  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
430  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
431 
432  // Block ._crit_edge
433  std::vector<Instruction*> result_vec_pre;
434  if (!is_estimate_query) {
435  for (size_t i = 0; i < aggr_col_count; ++i) {
436  auto result = new LoadInst(get_pointer_element_type(result_ptr_vec[i]),
437  result_ptr_vec[i],
438  ".pre.result",
439  false,
440  bb_crit_edge);
441  result->setAlignment(LLVM_ALIGN(8));
442  result_vec_pre.push_back(result);
443  }
444  }
445 
446  BranchInst::Create(bb_exit, bb_crit_edge);
447 
448  // Block .exit
460  if (!is_estimate_query) {
461  std::vector<PHINode*> result_vec;
462  for (int64_t i = aggr_col_count - 1; i >= 0; --i) {
463  auto result =
464  PHINode::Create(IntegerType::get(mod->getContext(), 64), 2, "", bb_exit);
465  result->addIncoming(result_vec_pre[i], bb_crit_edge);
466  result->addIncoming(agg_init_val_vec[i], bb_entry);
467  result_vec.insert(result_vec.begin(), result);
468  }
469 
470  llvm::Value* const frag_idx = get_arg_by_name(query_func_ptr, "frag_idx");
471  for (size_t i = 0; i < aggr_col_count; ++i) {
472  auto col_idx = ConstantInt::get(i32_type, i);
473  if (gpu_smem_context.isSharedMemoryUsed()) {
474  auto target_addr = GetElementPtrInst::CreateInBounds(
475  smem_output_buffer->getType()->getPointerElementType(),
476  smem_output_buffer,
477  col_idx,
478  "",
479  bb_exit);
480  // TODO: generalize this once we want to support other types of aggregate
481  // functions besides COUNT.
482  auto agg_func = mod->getFunction("agg_sum_shared");
483  CHECK(agg_func);
484  CallInst::Create(
485  agg_func, std::vector<llvm::Value*>{target_addr, result_vec[i]}, "", bb_exit);
486  } else {
487  auto out_gep = GetElementPtrInst::CreateInBounds(
488  out->getType()->getPointerElementType(), out, col_idx, "", bb_exit);
489  auto col_buffer =
490  new LoadInst(get_pointer_element_type(out_gep), out_gep, "", false, bb_exit);
491  col_buffer->setAlignment(LLVM_ALIGN(8));
492  auto slot_idx = BinaryOperator::CreateAdd(
494  BinaryOperator::CreateMul(frag_idx, pos_step, "", bb_exit),
495  "",
496  bb_exit);
497  auto target_addr = GetElementPtrInst::CreateInBounds(
498  col_buffer->getType()->getPointerElementType(),
499  col_buffer,
500  slot_idx,
501  "",
502  bb_exit);
503  StoreInst* result_st = new StoreInst(result_vec[i], target_addr, false, bb_exit);
504  result_st->setAlignment(LLVM_ALIGN(8));
505  }
506  }
507  if (gpu_smem_context.isSharedMemoryUsed()) {
508  // final reduction of results from shared memory buffer back into global memory.
509  auto sync_thread_func = mod->getFunction("sync_threadblock");
510  CHECK(sync_thread_func);
511  CallInst::Create(sync_thread_func, std::vector<llvm::Value*>{}, "", bb_exit);
512  auto reduce_smem_to_gmem_func = mod->getFunction("write_back_non_grouped_agg");
513  CHECK(reduce_smem_to_gmem_func);
514  // each thread reduce the aggregate target corresponding to its own thread ID.
515  // If there are more targets than threads we do not currently use shared memory
516  // optimization. This can be relaxed if necessary
517  for (size_t i = 0; i < aggr_col_count; i++) {
518  auto out_gep =
519  GetElementPtrInst::CreateInBounds(out->getType()->getPointerElementType(),
520  out,
521  ConstantInt::get(i32_type, i),
522  "",
523  bb_exit);
524  auto gmem_output_buffer = new LoadInst(get_pointer_element_type(out_gep),
525  out_gep,
526  "gmem_output_buffer_" + std::to_string(i),
527  false,
528  bb_exit);
529  CallInst::Create(
530  reduce_smem_to_gmem_func,
531  std::vector<llvm::Value*>{
532  smem_output_buffer, gmem_output_buffer, ConstantInt::get(i32_type, i)},
533  "",
534  bb_exit);
535  }
536  }
537  }
538 
539  ReturnInst::Create(mod->getContext(), bb_exit);
540 
541  // Resolve Forward References
542  pos_inc_pre->replaceAllUsesWith(pos_inc);
543  delete pos_inc_pre;
544 
545  if (verifyFunction(*query_func_ptr, &llvm::errs())) {
546  LOG(FATAL) << "Generated invalid code.";
547  }
548 
549  return {query_func_ptr, row_process};
550 }
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)
#define LOG(tag)
Definition: Logger.h:285
llvm::Function * group_buff_idx(llvm::Module *mod)
AGG_TYPE agg_func(AGG_TYPE const lhs, AGG_TYPE const rhs)
#define LLVM_ALIGN(alignment)
std::string to_string(char const *&&v)
llvm::Value * get_arg_by_name(llvm::Function *func, const std::string &name)
Definition: Execute.h:168
#define CHECK(condition)
Definition: Logger.h:291
llvm::Type * get_pointer_element_type(llvm::Value *value)

+ Here is the call graph for this function: