OmniSciDB  2b310ab3b2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
QueryTemplateGenerator.cpp File Reference
#include "QueryTemplateGenerator.h"
#include "IRCodegenUtils.h"
#include "Logger/Logger.h"
#include <llvm/IR/Constants.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Verifier.h>
+ Include dependency graph for QueryTemplateGenerator.cpp:

Go to the source code of this file.

Namespaces

 anonymous_namespace{QueryTemplateGenerator.cpp}
 

Functions

template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::default_func_builder (llvm::Module *mod, const std::string &name)
 
template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::pos_start (llvm::Module *mod)
 
template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::group_buff_idx (llvm::Module *mod)
 
template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::pos_step (llvm::Module *mod)
 
template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::row_process (llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)
 
template<class Attributes >
std::tuple< llvm::Function
*, llvm::CallInst * > 
query_template_impl (llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
 
template<class Attributes >
std::tuple< llvm::Function
*, llvm::CallInst * > 
query_group_by_template_impl (llvm::Module *mod, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
 
std::tuple< llvm::Function
*, llvm::CallInst * > 
query_template (llvm::Module *module, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
 
std::tuple< llvm::Function
*, llvm::CallInst * > 
query_group_by_template (llvm::Module *module, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
 

Function Documentation

std::tuple<llvm::Function*, llvm::CallInst*> query_group_by_template ( llvm::Module *  module,
const bool  hoist_literals,
const QueryMemoryDescriptor query_mem_desc,
const ExecutorDeviceType  device_type,
const bool  check_scan_limit,
const GpuSharedMemoryContext gpu_smem_context 
)

Definition at line 862 of file QueryTemplateGenerator.cpp.

References module(), and query_mem_desc.

868  {
869  return query_group_by_template_impl<llvm::AttributeSet>(module,
870  hoist_literals,
872  device_type,
873  check_scan_limit,
874  gpu_smem_context);
875 }
std::unique_ptr< llvm::Module > module(runtime_module_shallow_copy(cgen_state))

+ Here is the call graph for this function:

template<class Attributes >
std::tuple<llvm::Function*, llvm::CallInst*> query_group_by_template_impl ( llvm::Module *  mod,
const bool  hoist_literals,
const QueryMemoryDescriptor query_mem_desc,
const ExecutorDeviceType  device_type,
const bool  check_scan_limit,
const GpuSharedMemoryContext gpu_smem_context 
)

Definition at line 520 of file QueryTemplateGenerator.cpp.

References CHECK(), error_code, logger::FATAL, frag_idx, GpuSharedMemoryContext::getSharedMemorySize(), GPU, anonymous_namespace{QueryTemplateGenerator.cpp}::group_buff_idx(), GpuSharedMemoryContext::isSharedMemoryUsed(), QueryMemoryDescriptor::isWarpSyncRequired(), join_hash_tables, literals, LLVM_ALIGN, LOG, max_matched, anonymous_namespace{QueryTemplateGenerator.cpp}::pos_start(), anonymous_namespace{QueryTemplateGenerator.cpp}::pos_step(), and anonymous_namespace{QueryTemplateGenerator.cpp}::row_process().

526  {
527  if (gpu_smem_context.isSharedMemoryUsed()) {
528  CHECK(device_type == ExecutorDeviceType::GPU);
529  }
530  using namespace llvm;
531 
532  auto func_pos_start = pos_start<Attributes>(mod);
533  CHECK(func_pos_start);
534  auto func_pos_step = pos_step<Attributes>(mod);
535  CHECK(func_pos_step);
536  auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
537  CHECK(func_group_buff_idx);
538  auto func_row_process = row_process<Attributes>(mod, 0, hoist_literals);
539  CHECK(func_row_process);
540  auto func_init_shared_mem = gpu_smem_context.isSharedMemoryUsed()
541  ? mod->getFunction("init_shared_mem")
542  : mod->getFunction("init_shared_mem_nop");
543  CHECK(func_init_shared_mem);
544 
545  auto func_write_back = mod->getFunction("write_back_nop");
546  CHECK(func_write_back);
547 
548  auto i32_type = IntegerType::get(mod->getContext(), 32);
549  auto i64_type = IntegerType::get(mod->getContext(), 64);
550  auto pi8_type = PointerType::get(IntegerType::get(mod->getContext(), 8), 0);
551  auto pi32_type = PointerType::get(i32_type, 0);
552  auto pi64_type = PointerType::get(i64_type, 0);
553  auto ppi64_type = PointerType::get(pi64_type, 0);
554  auto ppi8_type = PointerType::get(pi8_type, 0);
555 
556  std::vector<Type*> query_args;
557  query_args.push_back(ppi8_type);
558  if (hoist_literals) {
559  query_args.push_back(pi8_type);
560  }
561  query_args.push_back(pi64_type);
562  query_args.push_back(pi64_type);
563  query_args.push_back(pi32_type);
564  query_args.push_back(pi64_type);
565 
566  query_args.push_back(ppi64_type);
567  query_args.push_back(i32_type);
568  query_args.push_back(pi64_type);
569  query_args.push_back(pi32_type);
570  query_args.push_back(pi32_type);
571 
572  FunctionType* query_func_type = FunctionType::get(
573  /*Result=*/Type::getVoidTy(mod->getContext()),
574  /*Params=*/query_args,
575  /*isVarArg=*/false);
576 
577  std::string query_name{"query_group_by_template"};
578  auto query_func_ptr = mod->getFunction(query_name);
579  CHECK(!query_func_ptr);
580 
581  query_func_ptr = Function::Create(
582  /*Type=*/query_func_type,
583  /*Linkage=*/GlobalValue::ExternalLinkage,
584  /*Name=*/"query_group_by_template",
585  mod);
586 
587  query_func_ptr->setCallingConv(CallingConv::C);
588 
589  Attributes query_func_pal;
590  {
591  SmallVector<Attributes, 4> Attrs;
592  Attributes PAS;
593  {
594  AttrBuilder B;
595  B.addAttribute(Attribute::ReadNone);
596  B.addAttribute(Attribute::NoCapture);
597  PAS = Attributes::get(mod->getContext(), 1U, B);
598  }
599 
600  Attrs.push_back(PAS);
601  {
602  AttrBuilder B;
603  B.addAttribute(Attribute::ReadOnly);
604  B.addAttribute(Attribute::NoCapture);
605  PAS = Attributes::get(mod->getContext(), 2U, B);
606  }
607 
608  Attrs.push_back(PAS);
609  {
610  AttrBuilder B;
611  B.addAttribute(Attribute::ReadNone);
612  B.addAttribute(Attribute::NoCapture);
613  PAS = Attributes::get(mod->getContext(), 3U, B);
614  }
615 
616  Attrs.push_back(PAS);
617  {
618  AttrBuilder B;
619  B.addAttribute(Attribute::ReadOnly);
620  B.addAttribute(Attribute::NoCapture);
621  PAS = Attributes::get(mod->getContext(), 4U, B);
622  }
623 
624  Attrs.push_back(PAS);
625  {
626  AttrBuilder B;
627  B.addAttribute(Attribute::UWTable);
628  PAS = Attributes::get(mod->getContext(), ~0U, B);
629  }
630 
631  Attrs.push_back(PAS);
632 
633  query_func_pal = Attributes::get(mod->getContext(), Attrs);
634  }
635  query_func_ptr->setAttributes(query_func_pal);
636 
637  Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
638  Value* byte_stream = &*query_arg_it;
639  byte_stream->setName("byte_stream");
640  Value* literals{nullptr};
641  if (hoist_literals) {
642  literals = &*(++query_arg_it);
643  ;
644  literals->setName("literals");
645  }
646  Value* row_count_ptr = &*(++query_arg_it);
647  row_count_ptr->setName("row_count_ptr");
648  Value* frag_row_off_ptr = &*(++query_arg_it);
649  frag_row_off_ptr->setName("frag_row_off_ptr");
650  Value* max_matched_ptr = &*(++query_arg_it);
651  max_matched_ptr->setName("max_matched_ptr");
652  Value* agg_init_val = &*(++query_arg_it);
653  agg_init_val->setName("agg_init_val");
654  Value* group_by_buffers = &*(++query_arg_it);
655  group_by_buffers->setName("group_by_buffers");
656  Value* frag_idx = &*(++query_arg_it);
657  frag_idx->setName("frag_idx");
658  Value* join_hash_tables = &*(++query_arg_it);
659  join_hash_tables->setName("join_hash_tables");
660  Value* total_matched = &*(++query_arg_it);
661  total_matched->setName("total_matched");
662  Value* error_code = &*(++query_arg_it);
663  error_code->setName("error_code");
664 
665  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
666  auto bb_preheader =
667  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
668  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".forbody", query_func_ptr, 0);
669  auto bb_crit_edge =
670  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
671  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
672 
673  // Block .entry
674  LoadInst* row_count = new LoadInst(row_count_ptr, "", false, bb_entry);
675  row_count->setAlignment(LLVM_ALIGN(8));
676  row_count->setName("row_count");
677 
678  LoadInst* max_matched = new LoadInst(max_matched_ptr, "", false, bb_entry);
679  max_matched->setAlignment(LLVM_ALIGN(8));
680 
681  auto crt_matched_ptr = new AllocaInst(i32_type, 0, "crt_matched", bb_entry);
682  auto old_total_matched_ptr = new AllocaInst(i32_type, 0, "old_total_matched", bb_entry);
683  CallInst* pos_start = CallInst::Create(func_pos_start, "", bb_entry);
684  pos_start->setCallingConv(CallingConv::C);
685  pos_start->setTailCall(true);
686  Attributes pos_start_pal;
687  pos_start->setAttributes(pos_start_pal);
688 
689  CallInst* pos_step = CallInst::Create(func_pos_step, "", bb_entry);
690  pos_step->setCallingConv(CallingConv::C);
691  pos_step->setTailCall(true);
692  Attributes pos_step_pal;
693  pos_step->setAttributes(pos_step_pal);
694 
695  CallInst* group_buff_idx = CallInst::Create(func_group_buff_idx, "", bb_entry);
696  group_buff_idx->setCallingConv(CallingConv::C);
697  group_buff_idx->setTailCall(true);
698  Attributes group_buff_idx_pal;
699  group_buff_idx->setAttributes(group_buff_idx_pal);
700 
701  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
702  const PointerType* Ty = dyn_cast<PointerType>(group_by_buffers->getType());
703  CHECK(Ty);
704  GetElementPtrInst* group_by_buffers_gep = GetElementPtrInst::Create(
705  Ty->getElementType(), group_by_buffers, group_buff_idx, "", bb_entry);
706  LoadInst* col_buffer = new LoadInst(group_by_buffers_gep, "", false, bb_entry);
707  col_buffer->setName("col_buffer");
708  col_buffer->setAlignment(LLVM_ALIGN(8));
709 
710  llvm::ConstantInt* shared_mem_bytes_lv =
711  ConstantInt::get(i32_type, gpu_smem_context.getSharedMemorySize());
712  llvm::CallInst* result_buffer =
713  CallInst::Create(func_init_shared_mem,
714  std::vector<llvm::Value*>{col_buffer, shared_mem_bytes_lv},
715  "result_buffer",
716  bb_entry);
717  // TODO(Saman): change this further, normal path should not go through this
718 
719  ICmpInst* enter_or_not =
720  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
721  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
722 
723  // Block .loop.preheader
724  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
725  BranchInst::Create(bb_forbody, bb_preheader);
726 
727  // Block .forbody
728  Argument* pos_pre = new Argument(i64_type);
729  PHINode* pos = PHINode::Create(i64_type, check_scan_limit ? 3 : 2, "pos", bb_forbody);
730 
731  std::vector<Value*> row_process_params;
732  row_process_params.push_back(result_buffer);
733  row_process_params.push_back(crt_matched_ptr);
734  row_process_params.push_back(total_matched);
735  row_process_params.push_back(old_total_matched_ptr);
736  row_process_params.push_back(max_matched_ptr);
737  row_process_params.push_back(agg_init_val);
738  row_process_params.push_back(pos);
739  row_process_params.push_back(frag_row_off_ptr);
740  row_process_params.push_back(row_count_ptr);
741  if (hoist_literals) {
742  CHECK(literals);
743  row_process_params.push_back(literals);
744  }
745  if (check_scan_limit) {
746  new StoreInst(ConstantInt::get(IntegerType::get(mod->getContext(), 32), 0),
747  crt_matched_ptr,
748  bb_forbody);
749  }
750  CallInst* row_process =
751  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
752  row_process->setCallingConv(CallingConv::C);
753  row_process->setTailCall(true);
754  Attributes row_process_pal;
755  row_process->setAttributes(row_process_pal);
756 
757  // Forcing all threads within a warp to be synchronized (Compute >= 7.x)
758  if (query_mem_desc.isWarpSyncRequired(device_type)) {
759  auto func_sync_warp_protected = mod->getFunction("sync_warp_protected");
760  CHECK(func_sync_warp_protected);
761  CallInst::Create(func_sync_warp_protected,
762  std::vector<llvm::Value*>{pos, row_count},
763  "",
764  bb_forbody);
765  }
766 
767  BinaryOperator* pos_inc =
768  BinaryOperator::Create(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
769  ICmpInst* loop_or_exit =
770  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
771  if (check_scan_limit) {
772  auto crt_matched = new LoadInst(crt_matched_ptr, "crt_matched", false, bb_forbody);
773  auto filter_match = BasicBlock::Create(
774  mod->getContext(), "filter_match", query_func_ptr, bb_crit_edge);
775  llvm::Value* new_total_matched =
776  new LoadInst(old_total_matched_ptr, "", false, filter_match);
777  new_total_matched =
778  BinaryOperator::CreateAdd(new_total_matched, crt_matched, "", filter_match);
779  CHECK(new_total_matched);
780  ICmpInst* limit_not_reached = new ICmpInst(*filter_match,
781  ICmpInst::ICMP_SLT,
782  new_total_matched,
783  max_matched,
784  "limit_not_reached");
785  BranchInst::Create(
786  bb_forbody,
787  bb_crit_edge,
788  BinaryOperator::Create(
789  BinaryOperator::And, loop_or_exit, limit_not_reached, "", filter_match),
790  filter_match);
791  auto filter_nomatch = BasicBlock::Create(
792  mod->getContext(), "filter_nomatch", query_func_ptr, bb_crit_edge);
793  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, filter_nomatch);
794  ICmpInst* crt_matched_nz = new ICmpInst(
795  *bb_forbody, ICmpInst::ICMP_NE, crt_matched, ConstantInt::get(i32_type, 0), "");
796  BranchInst::Create(filter_match, filter_nomatch, crt_matched_nz, bb_forbody);
797  pos->addIncoming(pos_start_i64, bb_preheader);
798  pos->addIncoming(pos_pre, filter_match);
799  pos->addIncoming(pos_pre, filter_nomatch);
800  } else {
801  pos->addIncoming(pos_start_i64, bb_preheader);
802  pos->addIncoming(pos_pre, bb_forbody);
803  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
804  }
805 
806  // Block ._crit_edge
807  BranchInst::Create(bb_exit, bb_crit_edge);
808 
809  // Block .exit
810  CallInst::Create(func_write_back,
811  std::vector<Value*>{col_buffer, result_buffer, shared_mem_bytes_lv},
812  "",
813  bb_exit);
814 
815  ReturnInst::Create(mod->getContext(), bb_exit);
816 
817  // Resolve Forward References
818  pos_pre->replaceAllUsesWith(pos_inc);
819  delete pos_pre;
820 
821  if (verifyFunction(*query_func_ptr, &llvm::errs())) {
822  LOG(FATAL) << "Generated invalid code. ";
823  }
824 
825  return {query_func_ptr, row_process};
826 }
const int8_t const int64_t const uint64_t const int32_t const int64_t int64_t uint32_t const int64_t * join_hash_tables
#define LOG(tag)
Definition: Logger.h:188
size_t getSharedMemorySize() const
#define LLVM_ALIGN(alignment)
llvm::Function * group_buff_idx(llvm::Module *mod)
CHECK(cgen_state)
const int8_t const int64_t const uint64_t const int32_t const int64_t int64_t uint32_t const int64_t int32_t * error_code
const int8_t const int64_t const uint64_t const int32_t * max_matched
const int8_t const int64_t const uint64_t const int32_t const int64_t int64_t uint32_t frag_idx
bool isWarpSyncRequired(const ExecutorDeviceType) const
const int8_t * literals
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)

+ Here is the call graph for this function:

std::tuple<llvm::Function*, llvm::CallInst*> query_template ( llvm::Module *  module,
const size_t  aggr_col_count,
const bool  hoist_literals,
const bool  is_estimate_query,
const GpuSharedMemoryContext gpu_smem_context 
)

Definition at line 853 of file QueryTemplateGenerator.cpp.

References module().

858  {
859  return query_template_impl<llvm::AttributeSet>(
860  module, aggr_col_count, hoist_literals, is_estimate_query, gpu_smem_context);
861 }
std::unique_ptr< llvm::Module > module(runtime_module_shallow_copy(cgen_state))

+ Here is the call graph for this function:

template<class Attributes >
std::tuple<llvm::Function*, llvm::CallInst*> query_template_impl ( llvm::Module *  mod,
const size_t  aggr_col_count,
const bool  hoist_literals,
const bool  is_estimate_query,
const GpuSharedMemoryContext gpu_smem_context 
)

If GPU shared memory optimization is disabled, for each aggregate target, threads copy back their aggregate results (stored in registers) back into memory. This process is performed per processed fragment. In the host the final results are reduced (per target, for all threads and all fragments).

If GPU Shared memory optimization is enabled, we properly (atomically) aggregate all thread's results into memory, which makes the final reduction on host much cheaper. Here, we call a noop dummy write back function which will be properly replaced at runtime depending on the target expressions.

Definition at line 186 of file QueryTemplateGenerator.cpp.

References CHECK(), error_code, logger::FATAL, frag_idx, anonymous_namespace{QueryTemplateGenerator.cpp}::group_buff_idx(), GpuSharedMemoryContext::isSharedMemoryUsed(), join_hash_tables, literals, LLVM_ALIGN, LOG, out, anonymous_namespace{QueryTemplateGenerator.cpp}::pos_start(), anonymous_namespace{QueryTemplateGenerator.cpp}::pos_step(), run_benchmark_import::result, anonymous_namespace{QueryTemplateGenerator.cpp}::row_process(), and to_string().

191  {
192  using namespace llvm;
193 
194  auto func_pos_start = pos_start<Attributes>(mod);
195  CHECK(func_pos_start);
196  auto func_pos_step = pos_step<Attributes>(mod);
197  CHECK(func_pos_step);
198  auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
199  CHECK(func_group_buff_idx);
200  auto func_row_process = row_process<Attributes>(
201  mod, is_estimate_query ? 1 : aggr_col_count, hoist_literals);
202  CHECK(func_row_process);
203 
204  auto i8_type = IntegerType::get(mod->getContext(), 8);
205  auto i32_type = IntegerType::get(mod->getContext(), 32);
206  auto i64_type = IntegerType::get(mod->getContext(), 64);
207  auto pi8_type = PointerType::get(i8_type, 0);
208  auto ppi8_type = PointerType::get(pi8_type, 0);
209  auto pi32_type = PointerType::get(i32_type, 0);
210  auto pi64_type = PointerType::get(i64_type, 0);
211  auto ppi64_type = PointerType::get(pi64_type, 0);
212 
213  std::vector<Type*> query_args;
214  query_args.push_back(ppi8_type);
215  if (hoist_literals) {
216  query_args.push_back(pi8_type);
217  }
218  query_args.push_back(pi64_type);
219  query_args.push_back(pi64_type);
220  query_args.push_back(pi32_type);
221 
222  query_args.push_back(pi64_type);
223  query_args.push_back(ppi64_type);
224  query_args.push_back(i32_type);
225  query_args.push_back(pi64_type);
226  query_args.push_back(pi32_type);
227  query_args.push_back(pi32_type);
228 
229  FunctionType* query_func_type = FunctionType::get(
230  /*Result=*/Type::getVoidTy(mod->getContext()),
231  /*Params=*/query_args,
232  /*isVarArg=*/false);
233 
234  std::string query_template_name{"query_template"};
235  auto query_func_ptr = mod->getFunction(query_template_name);
236  CHECK(!query_func_ptr);
237 
238  query_func_ptr = Function::Create(
239  /*Type=*/query_func_type,
240  /*Linkage=*/GlobalValue::ExternalLinkage,
241  /*Name=*/query_template_name,
242  mod);
243  query_func_ptr->setCallingConv(CallingConv::C);
244 
245  Attributes query_func_pal;
246  {
247  SmallVector<Attributes, 4> Attrs;
248  Attributes PAS;
249  {
250  AttrBuilder B;
251  B.addAttribute(Attribute::NoCapture);
252  PAS = Attributes::get(mod->getContext(), 1U, B);
253  }
254 
255  Attrs.push_back(PAS);
256  {
257  AttrBuilder B;
258  B.addAttribute(Attribute::NoCapture);
259  PAS = Attributes::get(mod->getContext(), 2U, B);
260  }
261 
262  Attrs.push_back(PAS);
263 
264  {
265  AttrBuilder B;
266  B.addAttribute(Attribute::NoCapture);
267  Attrs.push_back(Attributes::get(mod->getContext(), 3U, B));
268  }
269 
270  {
271  AttrBuilder B;
272  B.addAttribute(Attribute::NoCapture);
273  Attrs.push_back(Attributes::get(mod->getContext(), 4U, B));
274  }
275 
276  Attrs.push_back(PAS);
277 
278  query_func_pal = Attributes::get(mod->getContext(), Attrs);
279  }
280  query_func_ptr->setAttributes(query_func_pal);
281 
282  Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
283  Value* byte_stream = &*query_arg_it;
284  byte_stream->setName("byte_stream");
285  Value* literals{nullptr};
286  if (hoist_literals) {
287  literals = &*(++query_arg_it);
288  literals->setName("literals");
289  }
290  Value* row_count_ptr = &*(++query_arg_it);
291  row_count_ptr->setName("row_count_ptr");
292  Value* frag_row_off_ptr = &*(++query_arg_it);
293  frag_row_off_ptr->setName("frag_row_off_ptr");
294  Value* max_matched_ptr = &*(++query_arg_it);
295  max_matched_ptr->setName("max_matched_ptr");
296  Value* agg_init_val = &*(++query_arg_it);
297  agg_init_val->setName("agg_init_val");
298  Value* out = &*(++query_arg_it);
299  out->setName("out");
300  Value* frag_idx = &*(++query_arg_it);
301  frag_idx->setName("frag_idx");
302  Value* join_hash_tables = &*(++query_arg_it);
303  join_hash_tables->setName("join_hash_tables");
304  Value* total_matched = &*(++query_arg_it);
305  total_matched->setName("total_matched");
306  Value* error_code = &*(++query_arg_it);
307  error_code->setName("error_code");
308 
309  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
310  auto bb_preheader =
311  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
312  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".for.body", query_func_ptr, 0);
313  auto bb_crit_edge =
314  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
315  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
316 
317  // Block (.entry)
318  std::vector<Value*> result_ptr_vec;
319  llvm::CallInst* smem_output_buffer{nullptr};
320  if (!is_estimate_query) {
321  for (size_t i = 0; i < aggr_col_count; ++i) {
322  auto result_ptr = new AllocaInst(i64_type, 0, "result", bb_entry);
323  result_ptr->setAlignment(LLVM_ALIGN(8));
324  result_ptr_vec.push_back(result_ptr);
325  }
326  if (gpu_smem_context.isSharedMemoryUsed()) {
327  auto init_smem_func = mod->getFunction("init_shared_mem");
328  CHECK(init_smem_func);
329  // only one slot per aggregate column is needed, and so we can initialize shared
330  // memory buffer for intermediate results to be exactly like the agg_init_val array
331  smem_output_buffer = CallInst::Create(
332  init_smem_func,
333  std::vector<llvm::Value*>{
334  agg_init_val,
335  llvm::ConstantInt::get(i32_type, aggr_col_count * sizeof(int64_t))},
336  "smem_buffer",
337  bb_entry);
338  }
339  }
340 
341  LoadInst* row_count = new LoadInst(row_count_ptr, "row_count", false, bb_entry);
342  row_count->setAlignment(LLVM_ALIGN(8));
343  row_count->setName("row_count");
344  std::vector<Value*> agg_init_val_vec;
345  if (!is_estimate_query) {
346  for (size_t i = 0; i < aggr_col_count; ++i) {
347  auto idx_lv = ConstantInt::get(i32_type, i);
348  auto agg_init_gep =
349  GetElementPtrInst::CreateInBounds(agg_init_val, idx_lv, "", bb_entry);
350  auto agg_init_val = new LoadInst(agg_init_gep, "", false, bb_entry);
351  agg_init_val->setAlignment(LLVM_ALIGN(8));
352  agg_init_val_vec.push_back(agg_init_val);
353  auto init_val_st = new StoreInst(agg_init_val, result_ptr_vec[i], false, bb_entry);
354  init_val_st->setAlignment(LLVM_ALIGN(8));
355  }
356  }
357 
358  CallInst* pos_start = CallInst::Create(func_pos_start, "pos_start", bb_entry);
359  pos_start->setCallingConv(CallingConv::C);
360  pos_start->setTailCall(true);
361  Attributes pos_start_pal;
362  pos_start->setAttributes(pos_start_pal);
363 
364  CallInst* pos_step = CallInst::Create(func_pos_step, "pos_step", bb_entry);
365  pos_step->setCallingConv(CallingConv::C);
366  pos_step->setTailCall(true);
367  Attributes pos_step_pal;
368  pos_step->setAttributes(pos_step_pal);
369 
370  CallInst* group_buff_idx = nullptr;
371  if (!is_estimate_query) {
372  group_buff_idx = CallInst::Create(func_group_buff_idx, "group_buff_idx", bb_entry);
373  group_buff_idx->setCallingConv(CallingConv::C);
374  group_buff_idx->setTailCall(true);
375  Attributes group_buff_idx_pal;
376  group_buff_idx->setAttributes(group_buff_idx_pal);
377  }
378 
379  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
380  ICmpInst* enter_or_not =
381  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
382  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
383 
384  // Block .loop.preheader
385  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
386  BranchInst::Create(bb_forbody, bb_preheader);
387 
388  // Block .forbody
389  Argument* pos_inc_pre = new Argument(i64_type);
390  PHINode* pos = PHINode::Create(i64_type, 2, "pos", bb_forbody);
391  pos->addIncoming(pos_start_i64, bb_preheader);
392  pos->addIncoming(pos_inc_pre, bb_forbody);
393 
394  std::vector<Value*> row_process_params;
395  row_process_params.insert(
396  row_process_params.end(), result_ptr_vec.begin(), result_ptr_vec.end());
397  if (is_estimate_query) {
398  row_process_params.push_back(new LoadInst(out, "", false, bb_forbody));
399  }
400  row_process_params.push_back(agg_init_val);
401  row_process_params.push_back(pos);
402  row_process_params.push_back(frag_row_off_ptr);
403  row_process_params.push_back(row_count_ptr);
404  if (hoist_literals) {
405  CHECK(literals);
406  row_process_params.push_back(literals);
407  }
408  CallInst* row_process =
409  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
410  row_process->setCallingConv(CallingConv::C);
411  row_process->setTailCall(false);
412  Attributes row_process_pal;
413  row_process->setAttributes(row_process_pal);
414 
415  BinaryOperator* pos_inc =
416  BinaryOperator::CreateNSW(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
417  ICmpInst* loop_or_exit =
418  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
419  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
420 
421  // Block ._crit_edge
422  std::vector<Instruction*> result_vec_pre;
423  if (!is_estimate_query) {
424  for (size_t i = 0; i < aggr_col_count; ++i) {
425  auto result = new LoadInst(result_ptr_vec[i], ".pre.result", false, bb_crit_edge);
426  result->setAlignment(LLVM_ALIGN(8));
427  result_vec_pre.push_back(result);
428  }
429  }
430 
431  BranchInst::Create(bb_exit, bb_crit_edge);
432 
433  // Block .exit
445  if (!is_estimate_query) {
446  std::vector<PHINode*> result_vec;
447  for (int64_t i = aggr_col_count - 1; i >= 0; --i) {
448  auto result =
449  PHINode::Create(IntegerType::get(mod->getContext(), 64), 2, "", bb_exit);
450  result->addIncoming(result_vec_pre[i], bb_crit_edge);
451  result->addIncoming(agg_init_val_vec[i], bb_entry);
452  result_vec.insert(result_vec.begin(), result);
453  }
454 
455  for (size_t i = 0; i < aggr_col_count; ++i) {
456  auto col_idx = ConstantInt::get(i32_type, i);
457  if (gpu_smem_context.isSharedMemoryUsed()) {
458  auto target_addr =
459  GetElementPtrInst::CreateInBounds(smem_output_buffer, col_idx, "", bb_exit);
460  // TODO: generalize this once we want to support other types of aggregate
461  // functions besides COUNT.
462  auto agg_func = mod->getFunction("agg_sum_shared");
463  CHECK(agg_func);
464  CallInst::Create(
465  agg_func, std::vector<llvm::Value*>{target_addr, result_vec[i]}, "", bb_exit);
466  } else {
467  auto out_gep = GetElementPtrInst::CreateInBounds(out, col_idx, "", bb_exit);
468  auto col_buffer = new LoadInst(out_gep, "", false, bb_exit);
469  col_buffer->setAlignment(LLVM_ALIGN(8));
470  auto slot_idx = BinaryOperator::CreateAdd(
472  BinaryOperator::CreateMul(frag_idx, pos_step, "", bb_exit),
473  "",
474  bb_exit);
475  auto target_addr =
476  GetElementPtrInst::CreateInBounds(col_buffer, slot_idx, "", bb_exit);
477  StoreInst* result_st = new StoreInst(result_vec[i], target_addr, false, bb_exit);
478  result_st->setAlignment(LLVM_ALIGN(8));
479  }
480  }
481  if (gpu_smem_context.isSharedMemoryUsed()) {
482  // final reduction of results from shared memory buffer back into global memory.
483  auto sync_thread_func = mod->getFunction("sync_threadblock");
484  CHECK(sync_thread_func);
485  CallInst::Create(sync_thread_func, std::vector<llvm::Value*>{}, "", bb_exit);
486  auto reduce_smem_to_gmem_func = mod->getFunction("write_back_non_grouped_agg");
487  CHECK(reduce_smem_to_gmem_func);
488  // each thread reduce the aggregate target corresponding to its own thread ID.
489  // If there are more targets than threads we do not currently use shared memory
490  // optimization. This can be relaxed if necessary
491  for (size_t i = 0; i < aggr_col_count; i++) {
492  auto out_gep = GetElementPtrInst::CreateInBounds(
493  out, ConstantInt::get(i32_type, i), "", bb_exit);
494  auto gmem_output_buffer = new LoadInst(
495  out_gep, "gmem_output_buffer_" + std::to_string(i), false, bb_exit);
496  CallInst::Create(
497  reduce_smem_to_gmem_func,
498  std::vector<llvm::Value*>{
499  smem_output_buffer, gmem_output_buffer, ConstantInt::get(i32_type, i)},
500  "",
501  bb_exit);
502  }
503  }
504  }
505 
506  ReturnInst::Create(mod->getContext(), bb_exit);
507 
508  // Resolve Forward References
509  pos_inc_pre->replaceAllUsesWith(pos_inc);
510  delete pos_inc_pre;
511 
512  if (verifyFunction(*query_func_ptr)) {
513  LOG(FATAL) << "Generated invalid code. ";
514  }
515 
516  return {query_func_ptr, row_process};
517 }
const int8_t const int64_t const uint64_t const int32_t const int64_t int64_t uint32_t const int64_t * join_hash_tables
#define LOG(tag)
Definition: Logger.h:188
#define LLVM_ALIGN(alignment)
std::string to_string(char const *&&v)
llvm::Function * group_buff_idx(llvm::Module *mod)
CHECK(cgen_state)
const int8_t const int64_t const uint64_t const int32_t const int64_t int64_t uint32_t const int64_t int32_t * error_code
const int8_t const int64_t const uint64_t const int32_t const int64_t int64_t uint32_t frag_idx
const int8_t * literals
const int8_t const int64_t const uint64_t const int32_t const int64_t int64_t ** out
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)

+ Here is the call graph for this function: