OmniSciDB  0264ff685a
QueryTemplateGenerator.cpp File Reference
#include "QueryTemplateGenerator.h"
#include "IRCodegenUtils.h"
#include "Logger/Logger.h"
#include <llvm/IR/Constants.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Verifier.h>
+ Include dependency graph for QueryTemplateGenerator.cpp:

Go to the source code of this file.

Namespaces

 anonymous_namespace{QueryTemplateGenerator.cpp}
 

Functions

llvm::Typeanonymous_namespace{QueryTemplateGenerator.cpp}::get_pointer_element_type (llvm::Value *value)
 
template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::default_func_builder (llvm::Module *mod, const std::string &name)
 
template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::pos_start (llvm::Module *mod)
 
template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::group_buff_idx (llvm::Module *mod)
 
template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::pos_step (llvm::Module *mod)
 
template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::row_process (llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)
 
template<class Attributes >
std::tuple< llvm::Function *, llvm::CallInst * > query_template_impl (llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
 
template<class Attributes >
std::tuple< llvm::Function *, llvm::CallInst * > query_group_by_template_impl (llvm::Module *mod, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
 
std::tuple< llvm::Function *, llvm::CallInst * > query_template (llvm::Module *module, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
 
std::tuple< llvm::Function *, llvm::CallInst * > query_group_by_template (llvm::Module *module, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
 

Function Documentation

◆ query_group_by_template()

std::tuple<llvm::Function*, llvm::CallInst*> query_group_by_template ( llvm::Module *  module,
const bool  hoist_literals,
const QueryMemoryDescriptor query_mem_desc,
const ExecutorDeviceType  device_type,
const bool  check_scan_limit,
const GpuSharedMemoryContext gpu_smem_context 
)

Definition at line 874 of file QueryTemplateGenerator.cpp.

Referenced by Executor::compileWorkUnit().

880  {
881  return query_group_by_template_impl<llvm::AttributeList>(module,
882  hoist_literals,
883  query_mem_desc,
884  device_type,
885  check_scan_limit,
886  gpu_smem_context);
887 }
+ Here is the caller graph for this function:

◆ query_group_by_template_impl()

template<class Attributes >
std::tuple<llvm::Function*, llvm::CallInst*> query_group_by_template_impl ( llvm::Module *  mod,
const bool  hoist_literals,
const QueryMemoryDescriptor query_mem_desc,
const ExecutorDeviceType  device_type,
const bool  check_scan_limit,
const GpuSharedMemoryContext gpu_smem_context 
)

Definition at line 543 of file QueryTemplateGenerator.cpp.

References CHECK, logger::FATAL, anonymous_namespace{QueryTemplateGenerator.cpp}::get_pointer_element_type(), GpuSharedMemoryContext::getSharedMemorySize(), GPU, anonymous_namespace{QueryTemplateGenerator.cpp}::group_buff_idx(), GpuSharedMemoryContext::isSharedMemoryUsed(), QueryMemoryDescriptor::isWarpSyncRequired(), LLVM_ALIGN, LOG, anonymous_namespace{QueryTemplateGenerator.cpp}::pos_start(), anonymous_namespace{QueryTemplateGenerator.cpp}::pos_step(), and anonymous_namespace{QueryTemplateGenerator.cpp}::row_process().

549  {
550  if (gpu_smem_context.isSharedMemoryUsed()) {
551  CHECK(device_type == ExecutorDeviceType::GPU);
552  }
553  using namespace llvm;
554 
555  auto func_pos_start = pos_start<Attributes>(mod);
556  CHECK(func_pos_start);
557  auto func_pos_step = pos_step<Attributes>(mod);
558  CHECK(func_pos_step);
559  auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
560  CHECK(func_group_buff_idx);
561  auto func_row_process = row_process<Attributes>(mod, 0, hoist_literals);
562  CHECK(func_row_process);
563  auto func_init_shared_mem = gpu_smem_context.isSharedMemoryUsed()
564  ? mod->getFunction("init_shared_mem")
565  : mod->getFunction("init_shared_mem_nop");
566  CHECK(func_init_shared_mem);
567 
568  auto func_write_back = mod->getFunction("write_back_nop");
569  CHECK(func_write_back);
570 
571  auto i32_type = IntegerType::get(mod->getContext(), 32);
572  auto i64_type = IntegerType::get(mod->getContext(), 64);
573  auto pi8_type = PointerType::get(IntegerType::get(mod->getContext(), 8), 0);
574  auto pi32_type = PointerType::get(i32_type, 0);
575  auto pi64_type = PointerType::get(i64_type, 0);
576  auto ppi64_type = PointerType::get(pi64_type, 0);
577  auto ppi8_type = PointerType::get(pi8_type, 0);
578 
579  std::vector<Type*> query_args;
580  query_args.push_back(ppi8_type);
581  if (hoist_literals) {
582  query_args.push_back(pi8_type);
583  }
584  query_args.push_back(pi64_type);
585  query_args.push_back(pi64_type);
586  query_args.push_back(pi32_type);
587  query_args.push_back(pi64_type);
588 
589  query_args.push_back(ppi64_type);
590  query_args.push_back(i32_type);
591  query_args.push_back(pi64_type);
592  query_args.push_back(pi32_type);
593  query_args.push_back(pi32_type);
594 
595  FunctionType* query_func_type = FunctionType::get(
596  /*Result=*/Type::getVoidTy(mod->getContext()),
597  /*Params=*/query_args,
598  /*isVarArg=*/false);
599 
600  std::string query_name{"query_group_by_template"};
601  auto query_func_ptr = mod->getFunction(query_name);
602  CHECK(!query_func_ptr);
603 
604  query_func_ptr = Function::Create(
605  /*Type=*/query_func_type,
606  /*Linkage=*/GlobalValue::ExternalLinkage,
607  /*Name=*/"query_group_by_template",
608  mod);
609 
610  query_func_ptr->setCallingConv(CallingConv::C);
611 
612  Attributes query_func_pal;
613  {
614  SmallVector<Attributes, 4> Attrs;
615  Attributes PAS;
616  {
617  AttrBuilder B;
618  B.addAttribute(Attribute::ReadNone);
619  B.addAttribute(Attribute::NoCapture);
620  PAS = Attributes::get(mod->getContext(), 1U, B);
621  }
622 
623  Attrs.push_back(PAS);
624  {
625  AttrBuilder B;
626  B.addAttribute(Attribute::ReadOnly);
627  B.addAttribute(Attribute::NoCapture);
628  PAS = Attributes::get(mod->getContext(), 2U, B);
629  }
630 
631  Attrs.push_back(PAS);
632  {
633  AttrBuilder B;
634  B.addAttribute(Attribute::ReadNone);
635  B.addAttribute(Attribute::NoCapture);
636  PAS = Attributes::get(mod->getContext(), 3U, B);
637  }
638 
639  Attrs.push_back(PAS);
640  {
641  AttrBuilder B;
642  B.addAttribute(Attribute::ReadOnly);
643  B.addAttribute(Attribute::NoCapture);
644  PAS = Attributes::get(mod->getContext(), 4U, B);
645  }
646 
647  Attrs.push_back(PAS);
648  {
649  AttrBuilder B;
650  B.addAttribute(Attribute::UWTable);
651  PAS = Attributes::get(mod->getContext(), ~0U, B);
652  }
653 
654  Attrs.push_back(PAS);
655 
656  query_func_pal = Attributes::get(mod->getContext(), Attrs);
657  }
658  query_func_ptr->setAttributes(query_func_pal);
659 
660  Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
661  Value* byte_stream = &*query_arg_it;
662  byte_stream->setName("byte_stream");
663  Value* literals{nullptr};
664  if (hoist_literals) {
665  literals = &*(++query_arg_it);
666  ;
667  literals->setName("literals");
668  }
669  Value* row_count_ptr = &*(++query_arg_it);
670  row_count_ptr->setName("row_count_ptr");
671  Value* frag_row_off_ptr = &*(++query_arg_it);
672  frag_row_off_ptr->setName("frag_row_off_ptr");
673  Value* max_matched_ptr = &*(++query_arg_it);
674  max_matched_ptr->setName("max_matched_ptr");
675  Value* agg_init_val = &*(++query_arg_it);
676  agg_init_val->setName("agg_init_val");
677  Value* group_by_buffers = &*(++query_arg_it);
678  group_by_buffers->setName("group_by_buffers");
679  Value* frag_idx = &*(++query_arg_it);
680  frag_idx->setName("frag_idx");
681  Value* join_hash_tables = &*(++query_arg_it);
682  join_hash_tables->setName("join_hash_tables");
683  Value* total_matched = &*(++query_arg_it);
684  total_matched->setName("total_matched");
685  Value* error_code = &*(++query_arg_it);
686  error_code->setName("error_code");
687 
688  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
689  auto bb_preheader =
690  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
691  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".forbody", query_func_ptr, 0);
692  auto bb_crit_edge =
693  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
694  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
695 
696  // Block .entry
697  LoadInst* row_count = new LoadInst(
698  get_pointer_element_type(row_count_ptr), row_count_ptr, "", false, bb_entry);
699  row_count->setAlignment(LLVM_ALIGN(8));
700  row_count->setName("row_count");
701 
702  LoadInst* max_matched = new LoadInst(
703  get_pointer_element_type(max_matched_ptr), max_matched_ptr, "", false, bb_entry);
704  max_matched->setAlignment(LLVM_ALIGN(8));
705 
706  auto crt_matched_ptr = new AllocaInst(i32_type, 0, "crt_matched", bb_entry);
707  auto old_total_matched_ptr = new AllocaInst(i32_type, 0, "old_total_matched", bb_entry);
708  CallInst* pos_start = CallInst::Create(func_pos_start, "", bb_entry);
709  pos_start->setCallingConv(CallingConv::C);
710  pos_start->setTailCall(true);
711  Attributes pos_start_pal;
712  pos_start->setAttributes(pos_start_pal);
713 
714  CallInst* pos_step = CallInst::Create(func_pos_step, "", bb_entry);
715  pos_step->setCallingConv(CallingConv::C);
716  pos_step->setTailCall(true);
717  Attributes pos_step_pal;
718  pos_step->setAttributes(pos_step_pal);
719 
720  CallInst* group_buff_idx = CallInst::Create(func_group_buff_idx, "", bb_entry);
721  group_buff_idx->setCallingConv(CallingConv::C);
722  group_buff_idx->setTailCall(true);
723  Attributes group_buff_idx_pal;
724  group_buff_idx->setAttributes(group_buff_idx_pal);
725 
726  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
727  const PointerType* Ty = dyn_cast<PointerType>(group_by_buffers->getType());
728  CHECK(Ty);
729  GetElementPtrInst* group_by_buffers_gep = GetElementPtrInst::Create(
730  Ty->getElementType(), group_by_buffers, group_buff_idx, "", bb_entry);
731  LoadInst* col_buffer = new LoadInst(get_pointer_element_type(group_by_buffers_gep),
732  group_by_buffers_gep,
733  "",
734  false,
735  bb_entry);
736  col_buffer->setName("col_buffer");
737  col_buffer->setAlignment(LLVM_ALIGN(8));
738 
739  llvm::ConstantInt* shared_mem_bytes_lv =
740  ConstantInt::get(i32_type, gpu_smem_context.getSharedMemorySize());
741  llvm::CallInst* result_buffer =
742  CallInst::Create(func_init_shared_mem,
743  std::vector<llvm::Value*>{col_buffer, shared_mem_bytes_lv},
744  "result_buffer",
745  bb_entry);
746  // TODO(Saman): change this further, normal path should not go through this
747 
748  ICmpInst* enter_or_not =
749  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
750  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
751 
752  // Block .loop.preheader
753  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
754  BranchInst::Create(bb_forbody, bb_preheader);
755 
756  // Block .forbody
757  Argument* pos_pre = new Argument(i64_type);
758  PHINode* pos = PHINode::Create(i64_type, check_scan_limit ? 3 : 2, "pos", bb_forbody);
759 
760  std::vector<Value*> row_process_params;
761  row_process_params.push_back(result_buffer);
762  row_process_params.push_back(crt_matched_ptr);
763  row_process_params.push_back(total_matched);
764  row_process_params.push_back(old_total_matched_ptr);
765  row_process_params.push_back(max_matched_ptr);
766  row_process_params.push_back(agg_init_val);
767  row_process_params.push_back(pos);
768  row_process_params.push_back(frag_row_off_ptr);
769  row_process_params.push_back(row_count_ptr);
770  if (hoist_literals) {
771  CHECK(literals);
772  row_process_params.push_back(literals);
773  }
774  if (check_scan_limit) {
775  new StoreInst(ConstantInt::get(IntegerType::get(mod->getContext(), 32), 0),
776  crt_matched_ptr,
777  bb_forbody);
778  }
779  CallInst* row_process =
780  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
781  row_process->setCallingConv(CallingConv::C);
782  row_process->setTailCall(true);
783  Attributes row_process_pal;
784  row_process->setAttributes(row_process_pal);
785 
786  // Forcing all threads within a warp to be synchronized (Compute >= 7.x)
787  if (query_mem_desc.isWarpSyncRequired(device_type)) {
788  auto func_sync_warp_protected = mod->getFunction("sync_warp_protected");
789  CHECK(func_sync_warp_protected);
790  CallInst::Create(func_sync_warp_protected,
791  std::vector<llvm::Value*>{pos, row_count},
792  "",
793  bb_forbody);
794  }
795 
796  BinaryOperator* pos_inc =
797  BinaryOperator::Create(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
798  ICmpInst* loop_or_exit =
799  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
800  if (check_scan_limit) {
801  auto crt_matched = new LoadInst(get_pointer_element_type(crt_matched_ptr),
802  crt_matched_ptr,
803  "crt_matched",
804  false,
805  bb_forbody);
806  auto filter_match = BasicBlock::Create(
807  mod->getContext(), "filter_match", query_func_ptr, bb_crit_edge);
808  llvm::Value* new_total_matched =
809  new LoadInst(get_pointer_element_type(old_total_matched_ptr),
810  old_total_matched_ptr,
811  "",
812  false,
813  filter_match);
814  new_total_matched =
815  BinaryOperator::CreateAdd(new_total_matched, crt_matched, "", filter_match);
816  CHECK(new_total_matched);
817  ICmpInst* limit_not_reached = new ICmpInst(*filter_match,
818  ICmpInst::ICMP_SLT,
819  new_total_matched,
820  max_matched,
821  "limit_not_reached");
822  BranchInst::Create(
823  bb_forbody,
824  bb_crit_edge,
825  BinaryOperator::Create(
826  BinaryOperator::And, loop_or_exit, limit_not_reached, "", filter_match),
827  filter_match);
828  auto filter_nomatch = BasicBlock::Create(
829  mod->getContext(), "filter_nomatch", query_func_ptr, bb_crit_edge);
830  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, filter_nomatch);
831  ICmpInst* crt_matched_nz = new ICmpInst(
832  *bb_forbody, ICmpInst::ICMP_NE, crt_matched, ConstantInt::get(i32_type, 0), "");
833  BranchInst::Create(filter_match, filter_nomatch, crt_matched_nz, bb_forbody);
834  pos->addIncoming(pos_start_i64, bb_preheader);
835  pos->addIncoming(pos_pre, filter_match);
836  pos->addIncoming(pos_pre, filter_nomatch);
837  } else {
838  pos->addIncoming(pos_start_i64, bb_preheader);
839  pos->addIncoming(pos_pre, bb_forbody);
840  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
841  }
842 
843  // Block ._crit_edge
844  BranchInst::Create(bb_exit, bb_crit_edge);
845 
846  // Block .exit
847  CallInst::Create(func_write_back,
848  std::vector<Value*>{col_buffer, result_buffer, shared_mem_bytes_lv},
849  "",
850  bb_exit);
851 
852  ReturnInst::Create(mod->getContext(), bb_exit);
853 
854  // Resolve Forward References
855  pos_pre->replaceAllUsesWith(pos_inc);
856  delete pos_pre;
857 
858  if (verifyFunction(*query_func_ptr, &llvm::errs())) {
859  LOG(FATAL) << "Generated invalid code. ";
860  }
861 
862  return {query_func_ptr, row_process};
863 }
#define LOG(tag)
Definition: Logger.h:188
#define LLVM_ALIGN(alignment)
llvm::Function * group_buff_idx(llvm::Module *mod)
bool isWarpSyncRequired(const ExecutorDeviceType) const
#define CHECK(condition)
Definition: Logger.h:197
llvm::Type * get_pointer_element_type(llvm::Value *value)
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)
+ Here is the call graph for this function:

◆ query_template()

std::tuple<llvm::Function*, llvm::CallInst*> query_template ( llvm::Module *  module,
const size_t  aggr_col_count,
const bool  hoist_literals,
const bool  is_estimate_query,
const GpuSharedMemoryContext gpu_smem_context 
)

Definition at line 865 of file QueryTemplateGenerator.cpp.

Referenced by Executor::compileWorkUnit().

870  {
871  return query_template_impl<llvm::AttributeList>(
872  module, aggr_col_count, hoist_literals, is_estimate_query, gpu_smem_context);
873 }
+ Here is the caller graph for this function:

◆ query_template_impl()

template<class Attributes >
std::tuple<llvm::Function*, llvm::CallInst*> query_template_impl ( llvm::Module *  mod,
const size_t  aggr_col_count,
const bool  hoist_literals,
const bool  is_estimate_query,
const GpuSharedMemoryContext gpu_smem_context 
)

If GPU shared memory optimization is disabled, for each aggregate target, threads copy back their aggregate results (stored in registers) back into memory. This process is performed per processed fragment. In the host the final results are reduced (per target, for all threads and all fragments).

If GPU Shared memory optimization is enabled, we properly (atomically) aggregate all thread's results into memory, which makes the final reduction on host much cheaper. Here, we call a noop dummy write back function which will be properly replaced at runtime depending on the target expressions.

Definition at line 195 of file QueryTemplateGenerator.cpp.

References CHECK, logger::FATAL, anonymous_namespace{QueryTemplateGenerator.cpp}::get_pointer_element_type(), anonymous_namespace{QueryTemplateGenerator.cpp}::group_buff_idx(), GpuSharedMemoryContext::isSharedMemoryUsed(), LLVM_ALIGN, LOG, anonymous_namespace{QueryTemplateGenerator.cpp}::pos_start(), anonymous_namespace{QueryTemplateGenerator.cpp}::pos_step(), run_benchmark_import::result, anonymous_namespace{QueryTemplateGenerator.cpp}::row_process(), and to_string().

200  {
201  using namespace llvm;
202 
203  auto func_pos_start = pos_start<Attributes>(mod);
204  CHECK(func_pos_start);
205  auto func_pos_step = pos_step<Attributes>(mod);
206  CHECK(func_pos_step);
207  auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
208  CHECK(func_group_buff_idx);
209  auto func_row_process = row_process<Attributes>(
210  mod, is_estimate_query ? 1 : aggr_col_count, hoist_literals);
211  CHECK(func_row_process);
212 
213  auto i8_type = IntegerType::get(mod->getContext(), 8);
214  auto i32_type = IntegerType::get(mod->getContext(), 32);
215  auto i64_type = IntegerType::get(mod->getContext(), 64);
216  auto pi8_type = PointerType::get(i8_type, 0);
217  auto ppi8_type = PointerType::get(pi8_type, 0);
218  auto pi32_type = PointerType::get(i32_type, 0);
219  auto pi64_type = PointerType::get(i64_type, 0);
220  auto ppi64_type = PointerType::get(pi64_type, 0);
221 
222  std::vector<Type*> query_args;
223  query_args.push_back(ppi8_type);
224  if (hoist_literals) {
225  query_args.push_back(pi8_type);
226  }
227  query_args.push_back(pi64_type);
228  query_args.push_back(pi64_type);
229  query_args.push_back(pi32_type);
230 
231  query_args.push_back(pi64_type);
232  query_args.push_back(ppi64_type);
233  query_args.push_back(i32_type);
234  query_args.push_back(pi64_type);
235  query_args.push_back(pi32_type);
236  query_args.push_back(pi32_type);
237 
238  FunctionType* query_func_type = FunctionType::get(
239  /*Result=*/Type::getVoidTy(mod->getContext()),
240  /*Params=*/query_args,
241  /*isVarArg=*/false);
242 
243  std::string query_template_name{"query_template"};
244  auto query_func_ptr = mod->getFunction(query_template_name);
245  CHECK(!query_func_ptr);
246 
247  query_func_ptr = Function::Create(
248  /*Type=*/query_func_type,
249  /*Linkage=*/GlobalValue::ExternalLinkage,
250  /*Name=*/query_template_name,
251  mod);
252  query_func_ptr->setCallingConv(CallingConv::C);
253 
254  Attributes query_func_pal;
255  {
256  SmallVector<Attributes, 4> Attrs;
257  Attributes PAS;
258  {
259  AttrBuilder B;
260  B.addAttribute(Attribute::NoCapture);
261  PAS = Attributes::get(mod->getContext(), 1U, B);
262  }
263 
264  Attrs.push_back(PAS);
265  {
266  AttrBuilder B;
267  B.addAttribute(Attribute::NoCapture);
268  PAS = Attributes::get(mod->getContext(), 2U, B);
269  }
270 
271  Attrs.push_back(PAS);
272 
273  {
274  AttrBuilder B;
275  B.addAttribute(Attribute::NoCapture);
276  Attrs.push_back(Attributes::get(mod->getContext(), 3U, B));
277  }
278 
279  {
280  AttrBuilder B;
281  B.addAttribute(Attribute::NoCapture);
282  Attrs.push_back(Attributes::get(mod->getContext(), 4U, B));
283  }
284 
285  Attrs.push_back(PAS);
286 
287  query_func_pal = Attributes::get(mod->getContext(), Attrs);
288  }
289  query_func_ptr->setAttributes(query_func_pal);
290 
291  Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
292  Value* byte_stream = &*query_arg_it;
293  byte_stream->setName("byte_stream");
294  Value* literals{nullptr};
295  if (hoist_literals) {
296  literals = &*(++query_arg_it);
297  literals->setName("literals");
298  }
299  Value* row_count_ptr = &*(++query_arg_it);
300  row_count_ptr->setName("row_count_ptr");
301  Value* frag_row_off_ptr = &*(++query_arg_it);
302  frag_row_off_ptr->setName("frag_row_off_ptr");
303  Value* max_matched_ptr = &*(++query_arg_it);
304  max_matched_ptr->setName("max_matched_ptr");
305  Value* agg_init_val = &*(++query_arg_it);
306  agg_init_val->setName("agg_init_val");
307  Value* out = &*(++query_arg_it);
308  out->setName("out");
309  Value* frag_idx = &*(++query_arg_it);
310  frag_idx->setName("frag_idx");
311  Value* join_hash_tables = &*(++query_arg_it);
312  join_hash_tables->setName("join_hash_tables");
313  Value* total_matched = &*(++query_arg_it);
314  total_matched->setName("total_matched");
315  Value* error_code = &*(++query_arg_it);
316  error_code->setName("error_code");
317 
318  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
319  auto bb_preheader =
320  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
321  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".for.body", query_func_ptr, 0);
322  auto bb_crit_edge =
323  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
324  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
325 
326  // Block (.entry)
327  std::vector<Value*> result_ptr_vec;
328  llvm::CallInst* smem_output_buffer{nullptr};
329  if (!is_estimate_query) {
330  for (size_t i = 0; i < aggr_col_count; ++i) {
331  auto result_ptr = new AllocaInst(i64_type, 0, "result", bb_entry);
332  result_ptr->setAlignment(LLVM_ALIGN(8));
333  result_ptr_vec.push_back(result_ptr);
334  }
335  if (gpu_smem_context.isSharedMemoryUsed()) {
336  auto init_smem_func = mod->getFunction("init_shared_mem");
337  CHECK(init_smem_func);
338  // only one slot per aggregate column is needed, and so we can initialize shared
339  // memory buffer for intermediate results to be exactly like the agg_init_val array
340  smem_output_buffer = CallInst::Create(
341  init_smem_func,
342  std::vector<llvm::Value*>{
343  agg_init_val,
344  llvm::ConstantInt::get(i32_type, aggr_col_count * sizeof(int64_t))},
345  "smem_buffer",
346  bb_entry);
347  }
348  }
349 
350  LoadInst* row_count = new LoadInst(get_pointer_element_type(row_count_ptr),
351  row_count_ptr,
352  "row_count",
353  false,
354  bb_entry);
355  row_count->setAlignment(LLVM_ALIGN(8));
356  row_count->setName("row_count");
357  std::vector<Value*> agg_init_val_vec;
358  if (!is_estimate_query) {
359  for (size_t i = 0; i < aggr_col_count; ++i) {
360  auto idx_lv = ConstantInt::get(i32_type, i);
361  auto agg_init_gep =
362  GetElementPtrInst::CreateInBounds(agg_init_val, idx_lv, "", bb_entry);
363  auto agg_init_val = new LoadInst(
364  get_pointer_element_type(agg_init_gep), agg_init_gep, "", false, bb_entry);
365  agg_init_val->setAlignment(LLVM_ALIGN(8));
366  agg_init_val_vec.push_back(agg_init_val);
367  auto init_val_st = new StoreInst(agg_init_val, result_ptr_vec[i], false, bb_entry);
368  init_val_st->setAlignment(LLVM_ALIGN(8));
369  }
370  }
371 
372  CallInst* pos_start = CallInst::Create(func_pos_start, "pos_start", bb_entry);
373  pos_start->setCallingConv(CallingConv::C);
374  pos_start->setTailCall(true);
375  Attributes pos_start_pal;
376  pos_start->setAttributes(pos_start_pal);
377 
378  CallInst* pos_step = CallInst::Create(func_pos_step, "pos_step", bb_entry);
379  pos_step->setCallingConv(CallingConv::C);
380  pos_step->setTailCall(true);
381  Attributes pos_step_pal;
382  pos_step->setAttributes(pos_step_pal);
383 
384  CallInst* group_buff_idx = nullptr;
385  if (!is_estimate_query) {
386  group_buff_idx = CallInst::Create(func_group_buff_idx, "group_buff_idx", bb_entry);
387  group_buff_idx->setCallingConv(CallingConv::C);
388  group_buff_idx->setTailCall(true);
389  Attributes group_buff_idx_pal;
390  group_buff_idx->setAttributes(group_buff_idx_pal);
391  }
392 
393  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
394  ICmpInst* enter_or_not =
395  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
396  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
397 
398  // Block .loop.preheader
399  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
400  BranchInst::Create(bb_forbody, bb_preheader);
401 
402  // Block .forbody
403  Argument* pos_inc_pre = new Argument(i64_type);
404  PHINode* pos = PHINode::Create(i64_type, 2, "pos", bb_forbody);
405  pos->addIncoming(pos_start_i64, bb_preheader);
406  pos->addIncoming(pos_inc_pre, bb_forbody);
407 
408  std::vector<Value*> row_process_params;
409  row_process_params.insert(
410  row_process_params.end(), result_ptr_vec.begin(), result_ptr_vec.end());
411  if (is_estimate_query) {
412  row_process_params.push_back(
413  new LoadInst(get_pointer_element_type(out), out, "", false, bb_forbody));
414  }
415  row_process_params.push_back(agg_init_val);
416  row_process_params.push_back(pos);
417  row_process_params.push_back(frag_row_off_ptr);
418  row_process_params.push_back(row_count_ptr);
419  if (hoist_literals) {
420  CHECK(literals);
421  row_process_params.push_back(literals);
422  }
423  CallInst* row_process =
424  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
425  row_process->setCallingConv(CallingConv::C);
426  row_process->setTailCall(false);
427  Attributes row_process_pal;
428  row_process->setAttributes(row_process_pal);
429 
430  BinaryOperator* pos_inc =
431  BinaryOperator::CreateNSW(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
432  ICmpInst* loop_or_exit =
433  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
434  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
435 
436  // Block ._crit_edge
437  std::vector<Instruction*> result_vec_pre;
438  if (!is_estimate_query) {
439  for (size_t i = 0; i < aggr_col_count; ++i) {
440  auto result = new LoadInst(get_pointer_element_type(result_ptr_vec[i]),
441  result_ptr_vec[i],
442  ".pre.result",
443  false,
444  bb_crit_edge);
445  result->setAlignment(LLVM_ALIGN(8));
446  result_vec_pre.push_back(result);
447  }
448  }
449 
450  BranchInst::Create(bb_exit, bb_crit_edge);
451 
452  // Block .exit
464  if (!is_estimate_query) {
465  std::vector<PHINode*> result_vec;
466  for (int64_t i = aggr_col_count - 1; i >= 0; --i) {
467  auto result =
468  PHINode::Create(IntegerType::get(mod->getContext(), 64), 2, "", bb_exit);
469  result->addIncoming(result_vec_pre[i], bb_crit_edge);
470  result->addIncoming(agg_init_val_vec[i], bb_entry);
471  result_vec.insert(result_vec.begin(), result);
472  }
473 
474  for (size_t i = 0; i < aggr_col_count; ++i) {
475  auto col_idx = ConstantInt::get(i32_type, i);
476  if (gpu_smem_context.isSharedMemoryUsed()) {
477  auto target_addr =
478  GetElementPtrInst::CreateInBounds(smem_output_buffer, col_idx, "", bb_exit);
479  // TODO: generalize this once we want to support other types of aggregate
480  // functions besides COUNT.
481  auto agg_func = mod->getFunction("agg_sum_shared");
482  CHECK(agg_func);
483  CallInst::Create(
484  agg_func, std::vector<llvm::Value*>{target_addr, result_vec[i]}, "", bb_exit);
485  } else {
486  auto out_gep = GetElementPtrInst::CreateInBounds(out, col_idx, "", bb_exit);
487  auto col_buffer =
488  new LoadInst(get_pointer_element_type(out_gep), out_gep, "", false, bb_exit);
489  col_buffer->setAlignment(LLVM_ALIGN(8));
490  auto slot_idx = BinaryOperator::CreateAdd(
491  group_buff_idx,
492  BinaryOperator::CreateMul(frag_idx, pos_step, "", bb_exit),
493  "",
494  bb_exit);
495  auto target_addr =
496  GetElementPtrInst::CreateInBounds(col_buffer, slot_idx, "", bb_exit);
497  StoreInst* result_st = new StoreInst(result_vec[i], target_addr, false, bb_exit);
498  result_st->setAlignment(LLVM_ALIGN(8));
499  }
500  }
501  if (gpu_smem_context.isSharedMemoryUsed()) {
502  // final reduction of results from shared memory buffer back into global memory.
503  auto sync_thread_func = mod->getFunction("sync_threadblock");
504  CHECK(sync_thread_func);
505  CallInst::Create(sync_thread_func, std::vector<llvm::Value*>{}, "", bb_exit);
506  auto reduce_smem_to_gmem_func = mod->getFunction("write_back_non_grouped_agg");
507  CHECK(reduce_smem_to_gmem_func);
508  // each thread reduce the aggregate target corresponding to its own thread ID.
509  // If there are more targets than threads we do not currently use shared memory
510  // optimization. This can be relaxed if necessary
511  for (size_t i = 0; i < aggr_col_count; i++) {
512  auto out_gep = GetElementPtrInst::CreateInBounds(
513  out, ConstantInt::get(i32_type, i), "", bb_exit);
514  auto gmem_output_buffer = new LoadInst(get_pointer_element_type(out_gep),
515  out_gep,
516  "gmem_output_buffer_" + std::to_string(i),
517  false,
518  bb_exit);
519  CallInst::Create(
520  reduce_smem_to_gmem_func,
521  std::vector<llvm::Value*>{
522  smem_output_buffer, gmem_output_buffer, ConstantInt::get(i32_type, i)},
523  "",
524  bb_exit);
525  }
526  }
527  }
528 
529  ReturnInst::Create(mod->getContext(), bb_exit);
530 
531  // Resolve Forward References
532  pos_inc_pre->replaceAllUsesWith(pos_inc);
533  delete pos_inc_pre;
534 
535  if (verifyFunction(*query_func_ptr)) {
536  LOG(FATAL) << "Generated invalid code. ";
537  }
538 
539  return {query_func_ptr, row_process};
540 }
#define LOG(tag)
Definition: Logger.h:188
#define LLVM_ALIGN(alignment)
std::string to_string(char const *&&v)
llvm::Function * group_buff_idx(llvm::Module *mod)
#define CHECK(condition)
Definition: Logger.h:197
llvm::Type * get_pointer_element_type(llvm::Value *value)
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)
+ Here is the call graph for this function: