OmniSciDB  c1a53651b2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
QueryTemplateGenerator.cpp File Reference
#include "QueryTemplateGenerator.h"
#include "IRCodegenUtils.h"
#include "Logger/Logger.h"
#include <llvm/IR/Constants.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Verifier.h>
+ Include dependency graph for QueryTemplateGenerator.cpp:

Go to the source code of this file.

Namespaces

 anonymous_namespace{QueryTemplateGenerator.cpp}
 

Functions

llvm::Typeanonymous_namespace{QueryTemplateGenerator.cpp}::get_pointer_element_type (llvm::Value *value)
 
template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::default_func_builder (llvm::Module *mod, const std::string &name)
 
template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::pos_start (llvm::Module *mod)
 
template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::group_buff_idx (llvm::Module *mod)
 
template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::pos_step (llvm::Module *mod)
 
template<class Attributes >
llvm::Function * anonymous_namespace{QueryTemplateGenerator.cpp}::row_process (llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)
 
template<class Attributes >
std::tuple< llvm::Function
*, llvm::CallInst * > 
query_template_impl (llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
 
template<class Attributes >
std::tuple< llvm::Function
*, llvm::CallInst * > 
query_group_by_template_impl (llvm::Module *mod, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
 
std::tuple< llvm::Function
*, llvm::CallInst * > 
query_template (llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
 
std::tuple< llvm::Function
*, llvm::CallInst * > 
query_group_by_template (llvm::Module *mod, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
 

Function Documentation

std::tuple<llvm::Function*, llvm::CallInst*> query_group_by_template ( llvm::Module *  mod,
const bool  hoist_literals,
const QueryMemoryDescriptor query_mem_desc,
const ExecutorDeviceType  device_type,
const bool  check_scan_limit,
const GpuSharedMemoryContext gpu_smem_context 
)

Definition at line 978 of file QueryTemplateGenerator.cpp.

References query_mem_desc.

984  {
985  return query_group_by_template_impl<llvm::AttributeList>(mod,
986  hoist_literals,
988  device_type,
989  check_scan_limit,
990  gpu_smem_context);
991 }
template<class Attributes >
std::tuple<llvm::Function*, llvm::CallInst*> query_group_by_template_impl ( llvm::Module *  mod,
const bool  hoist_literals,
const QueryMemoryDescriptor query_mem_desc,
const ExecutorDeviceType  device_type,
const bool  check_scan_limit,
const GpuSharedMemoryContext gpu_smem_context 
)

Definition at line 592 of file QueryTemplateGenerator.cpp.

References CHECK, logger::FATAL, anonymous_namespace{QueryTemplateGenerator.cpp}::get_pointer_element_type(), GpuSharedMemoryContext::getSharedMemorySize(), GPU, anonymous_namespace{QueryTemplateGenerator.cpp}::group_buff_idx(), QueryMemoryDescriptor::hasVarlenOutput(), GpuSharedMemoryContext::isSharedMemoryUsed(), QueryMemoryDescriptor::isWarpSyncRequired(), LLVM_ALIGN, LOG, anonymous_namespace{QueryTemplateGenerator.cpp}::pos_start(), anonymous_namespace{QueryTemplateGenerator.cpp}::pos_step(), and anonymous_namespace{QueryTemplateGenerator.cpp}::row_process().

598  {
599  if (gpu_smem_context.isSharedMemoryUsed()) {
600  CHECK(device_type == ExecutorDeviceType::GPU);
601  }
602  using namespace llvm;
603 
604  auto func_pos_start = pos_start<Attributes>(mod);
605  CHECK(func_pos_start);
606  auto func_pos_step = pos_step<Attributes>(mod);
607  CHECK(func_pos_step);
608  auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
609  CHECK(func_group_buff_idx);
610  auto func_row_process = row_process<Attributes>(mod, 0, hoist_literals);
611  CHECK(func_row_process);
612  auto func_init_shared_mem = gpu_smem_context.isSharedMemoryUsed()
613  ? mod->getFunction("init_shared_mem")
614  : mod->getFunction("init_shared_mem_nop");
615  CHECK(func_init_shared_mem);
616 
617  auto func_write_back = mod->getFunction("write_back_nop");
618  CHECK(func_write_back);
619 
620  auto i32_type = IntegerType::get(mod->getContext(), 32);
621  auto i64_type = IntegerType::get(mod->getContext(), 64);
622  auto pi8_type = PointerType::get(IntegerType::get(mod->getContext(), 8), 0);
623  auto pi32_type = PointerType::get(i32_type, 0);
624  auto pi64_type = PointerType::get(i64_type, 0);
625  auto ppi64_type = PointerType::get(pi64_type, 0);
626  auto ppi8_type = PointerType::get(pi8_type, 0);
627 
628  std::vector<Type*> query_args;
629  query_args.push_back(ppi8_type); // col_buffers
630  if (hoist_literals) {
631  query_args.push_back(pi8_type); // literals
632  }
633  query_args.push_back(pi64_type); // num_rows
634  query_args.push_back(pi64_type); // frag_row_offsets
635  query_args.push_back(pi32_type); // max_matched
636  query_args.push_back(pi64_type); // init_agg_value
637 
638  query_args.push_back(ppi64_type); // out
639  query_args.push_back(i32_type); // frag_idx
640  query_args.push_back(pi64_type); // join_hash_tables
641  query_args.push_back(pi32_type); // total_matched
642  query_args.push_back(pi32_type); // error_code
643  query_args.push_back(pi8_type); // row_func_mgr
644 
645  FunctionType* query_func_type = FunctionType::get(
646  /*Result=*/Type::getVoidTy(mod->getContext()),
647  /*Params=*/query_args,
648  /*isVarArg=*/false);
649 
650  std::string query_name{"query_group_by_template"};
651  auto query_func_ptr = mod->getFunction(query_name);
652  CHECK(!query_func_ptr);
653 
654  query_func_ptr = Function::Create(
655  /*Type=*/query_func_type,
656  /*Linkage=*/GlobalValue::ExternalLinkage,
657  /*Name=*/"query_group_by_template",
658  mod);
659 
660  query_func_ptr->setCallingConv(CallingConv::C);
661 
662  Attributes query_func_pal;
663  {
664  SmallVector<Attributes, 4> Attrs;
665  Attributes PAS;
666  {
667 #if 14 <= LLVM_VERSION_MAJOR
668  AttrBuilder B(mod->getContext());
669 #else
670  AttrBuilder B;
671 #endif
672  B.addAttribute(Attribute::ReadNone);
673  B.addAttribute(Attribute::NoCapture);
674  PAS = Attributes::get(mod->getContext(), 1U, B);
675  }
676 
677  Attrs.push_back(PAS);
678  {
679 #if 14 <= LLVM_VERSION_MAJOR
680  AttrBuilder B(mod->getContext());
681 #else
682  AttrBuilder B;
683 #endif
684  B.addAttribute(Attribute::ReadOnly);
685  B.addAttribute(Attribute::NoCapture);
686  PAS = Attributes::get(mod->getContext(), 2U, B);
687  }
688 
689  Attrs.push_back(PAS);
690  {
691 #if 14 <= LLVM_VERSION_MAJOR
692  AttrBuilder B(mod->getContext());
693 #else
694  AttrBuilder B;
695 #endif
696  B.addAttribute(Attribute::ReadNone);
697  B.addAttribute(Attribute::NoCapture);
698  PAS = Attributes::get(mod->getContext(), 3U, B);
699  }
700 
701  Attrs.push_back(PAS);
702  {
703 #if 14 <= LLVM_VERSION_MAJOR
704  AttrBuilder B(mod->getContext());
705 #else
706  AttrBuilder B;
707 #endif
708  B.addAttribute(Attribute::ReadOnly);
709  B.addAttribute(Attribute::NoCapture);
710  PAS = Attributes::get(mod->getContext(), 4U, B);
711  }
712 
713  Attrs.push_back(PAS);
714  {
715 #if 14 <= LLVM_VERSION_MAJOR
716  AttrBuilder B(mod->getContext());
717 #else
718  AttrBuilder B;
719 #endif
720  B.addAttribute(Attribute::UWTable);
721  PAS = Attributes::get(mod->getContext(), ~0U, B);
722  }
723 
724  Attrs.push_back(PAS);
725 
726  query_func_pal = Attributes::get(mod->getContext(), Attrs);
727  }
728  query_func_ptr->setAttributes(query_func_pal);
729 
730  Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
731  Value* byte_stream = &*query_arg_it;
732  byte_stream->setName("byte_stream");
733  Value* literals{nullptr};
734  if (hoist_literals) {
735  literals = &*(++query_arg_it);
736  ;
737  literals->setName("literals");
738  }
739  Value* row_count_ptr = &*(++query_arg_it);
740  row_count_ptr->setName("row_count_ptr");
741  Value* frag_row_off_ptr = &*(++query_arg_it);
742  frag_row_off_ptr->setName("frag_row_off_ptr");
743  Value* max_matched_ptr = &*(++query_arg_it);
744  max_matched_ptr->setName("max_matched_ptr");
745  Value* agg_init_val = &*(++query_arg_it);
746  agg_init_val->setName("agg_init_val");
747  Value* group_by_buffers = &*(++query_arg_it);
748  group_by_buffers->setName("group_by_buffers");
749  Value* frag_idx = &*(++query_arg_it);
750  frag_idx->setName("frag_idx");
751  Value* join_hash_tables = &*(++query_arg_it);
752  join_hash_tables->setName("join_hash_tables");
753  Value* total_matched = &*(++query_arg_it);
754  total_matched->setName("total_matched");
755  Value* error_code = &*(++query_arg_it);
756  error_code->setName("error_code");
757  Value* row_func_mgr = &*(++query_arg_it);
758  row_func_mgr->setName("row_func_mgr");
759 
760  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
761  auto bb_preheader =
762  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
763  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".forbody", query_func_ptr, 0);
764  auto bb_crit_edge =
765  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
766  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
767 
768  // Block .entry
769  LoadInst* row_count = new LoadInst(
770  get_pointer_element_type(row_count_ptr), row_count_ptr, "", false, bb_entry);
771  row_count->setAlignment(LLVM_ALIGN(8));
772  row_count->setName("row_count");
773 
774  LoadInst* max_matched = new LoadInst(
775  get_pointer_element_type(max_matched_ptr), max_matched_ptr, "", false, bb_entry);
776  max_matched->setAlignment(LLVM_ALIGN(8));
777 
778  auto crt_matched_ptr = new AllocaInst(i32_type, 0, "crt_matched", bb_entry);
779  auto old_total_matched_ptr = new AllocaInst(i32_type, 0, "old_total_matched", bb_entry);
780  CallInst* pos_start = CallInst::Create(func_pos_start, "", bb_entry);
781  pos_start->setCallingConv(CallingConv::C);
782  pos_start->setTailCall(true);
783  Attributes pos_start_pal;
784  pos_start->setAttributes(pos_start_pal);
785 
786  CallInst* pos_step = CallInst::Create(func_pos_step, "", bb_entry);
787  pos_step->setCallingConv(CallingConv::C);
788  pos_step->setTailCall(true);
789  Attributes pos_step_pal;
790  pos_step->setAttributes(pos_step_pal);
791 
792  CallInst* group_buff_idx_call = CallInst::Create(func_group_buff_idx, "", bb_entry);
793  group_buff_idx_call->setCallingConv(CallingConv::C);
794  group_buff_idx_call->setTailCall(true);
795  Attributes group_buff_idx_pal;
796  group_buff_idx_call->setAttributes(group_buff_idx_pal);
797  Value* group_buff_idx = group_buff_idx_call;
798 
799  const PointerType* Ty = dyn_cast<PointerType>(group_by_buffers->getType());
800  CHECK(Ty);
801 
802  Value* varlen_output_buffer{nullptr};
803  if (query_mem_desc.hasVarlenOutput()) {
804  // make the varlen buffer the _first_ 8 byte value in the group by buffers double ptr,
805  // and offset the group by buffers index by 8 bytes
806  auto varlen_output_buffer_gep = GetElementPtrInst::Create(
807  Ty->getPointerElementType(),
808  group_by_buffers,
809  llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 0),
810  "",
811  bb_entry);
812  varlen_output_buffer =
813  new LoadInst(get_pointer_element_type(varlen_output_buffer_gep),
814  varlen_output_buffer_gep,
815  "varlen_output_buffer",
816  false,
817  bb_entry);
818 
819  group_buff_idx = BinaryOperator::Create(
820  Instruction::Add,
822  llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 1),
823  "group_buff_idx_varlen_offset",
824  bb_entry);
825  } else {
826  varlen_output_buffer =
827  ConstantPointerNull::get(Type::getInt64PtrTy(mod->getContext()));
828  }
829  CHECK(varlen_output_buffer);
830 
831  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
832  GetElementPtrInst* group_by_buffers_gep = GetElementPtrInst::Create(
833  Ty->getPointerElementType(), group_by_buffers, group_buff_idx, "", bb_entry);
834  LoadInst* col_buffer = new LoadInst(get_pointer_element_type(group_by_buffers_gep),
835  group_by_buffers_gep,
836  "",
837  false,
838  bb_entry);
839  col_buffer->setName("col_buffer");
840  col_buffer->setAlignment(LLVM_ALIGN(8));
841 
842  llvm::ConstantInt* shared_mem_bytes_lv =
843  ConstantInt::get(i32_type, gpu_smem_context.getSharedMemorySize());
844  // TODO(Saman): change this further, normal path should not go through this
845  llvm::CallInst* result_buffer =
846  CallInst::Create(func_init_shared_mem,
847  std::vector<llvm::Value*>{col_buffer, shared_mem_bytes_lv},
848  "result_buffer",
849  bb_entry);
850 
851  ICmpInst* enter_or_not =
852  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
853  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
854 
855  // Block .loop.preheader
856  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
857  BranchInst::Create(bb_forbody, bb_preheader);
858 
859  // Block .forbody
860  Argument* pos_pre = new Argument(i64_type);
861  PHINode* pos = PHINode::Create(i64_type, check_scan_limit ? 3 : 2, "pos", bb_forbody);
862 
863  std::vector<Value*> row_process_params;
864  row_process_params.push_back(result_buffer);
865  row_process_params.push_back(varlen_output_buffer);
866  row_process_params.push_back(crt_matched_ptr);
867  row_process_params.push_back(total_matched);
868  row_process_params.push_back(old_total_matched_ptr);
869  row_process_params.push_back(max_matched_ptr);
870  row_process_params.push_back(agg_init_val);
871  row_process_params.push_back(pos);
872  row_process_params.push_back(frag_row_off_ptr);
873  row_process_params.push_back(row_count_ptr);
874  if (hoist_literals) {
875  CHECK(literals);
876  row_process_params.push_back(literals);
877  }
878  if (check_scan_limit) {
879  new StoreInst(ConstantInt::get(IntegerType::get(mod->getContext(), 32), 0),
880  crt_matched_ptr,
881  bb_forbody);
882  }
883  CallInst* row_process =
884  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
885  row_process->setCallingConv(CallingConv::C);
886  row_process->setTailCall(true);
887  Attributes row_process_pal;
888  row_process->setAttributes(row_process_pal);
889 
890  // Forcing all threads within a warp to be synchronized (Compute >= 7.x)
891  if (query_mem_desc.isWarpSyncRequired(device_type)) {
892  auto func_sync_warp_protected = mod->getFunction("sync_warp_protected");
893  CHECK(func_sync_warp_protected);
894  CallInst::Create(func_sync_warp_protected,
895  std::vector<llvm::Value*>{pos, row_count},
896  "",
897  bb_forbody);
898  }
899 
900  BinaryOperator* pos_inc =
901  BinaryOperator::Create(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
902  ICmpInst* loop_or_exit =
903  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
904  if (check_scan_limit) {
905  auto crt_matched = new LoadInst(get_pointer_element_type(crt_matched_ptr),
906  crt_matched_ptr,
907  "crt_matched",
908  false,
909  bb_forbody);
910  auto filter_match = BasicBlock::Create(
911  mod->getContext(), "filter_match", query_func_ptr, bb_crit_edge);
912  llvm::Value* new_total_matched =
913  new LoadInst(get_pointer_element_type(old_total_matched_ptr),
914  old_total_matched_ptr,
915  "",
916  false,
917  filter_match);
918  new_total_matched =
919  BinaryOperator::CreateAdd(new_total_matched, crt_matched, "", filter_match);
920  CHECK(new_total_matched);
921  ICmpInst* limit_not_reached = new ICmpInst(*filter_match,
922  ICmpInst::ICMP_SLT,
923  new_total_matched,
924  max_matched,
925  "limit_not_reached");
926  BranchInst::Create(
927  bb_forbody,
928  bb_crit_edge,
929  BinaryOperator::Create(
930  BinaryOperator::And, loop_or_exit, limit_not_reached, "", filter_match),
931  filter_match);
932  auto filter_nomatch = BasicBlock::Create(
933  mod->getContext(), "filter_nomatch", query_func_ptr, bb_crit_edge);
934  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, filter_nomatch);
935  ICmpInst* crt_matched_nz = new ICmpInst(
936  *bb_forbody, ICmpInst::ICMP_NE, crt_matched, ConstantInt::get(i32_type, 0), "");
937  BranchInst::Create(filter_match, filter_nomatch, crt_matched_nz, bb_forbody);
938  pos->addIncoming(pos_start_i64, bb_preheader);
939  pos->addIncoming(pos_pre, filter_match);
940  pos->addIncoming(pos_pre, filter_nomatch);
941  } else {
942  pos->addIncoming(pos_start_i64, bb_preheader);
943  pos->addIncoming(pos_pre, bb_forbody);
944  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
945  }
946 
947  // Block ._crit_edge
948  BranchInst::Create(bb_exit, bb_crit_edge);
949 
950  // Block .exit
951  CallInst::Create(func_write_back,
952  std::vector<Value*>{col_buffer, result_buffer, shared_mem_bytes_lv},
953  "",
954  bb_exit);
955 
956  ReturnInst::Create(mod->getContext(), bb_exit);
957 
958  // Resolve Forward References
959  pos_pre->replaceAllUsesWith(pos_inc);
960  delete pos_pre;
961 
962  if (verifyFunction(*query_func_ptr, &llvm::errs())) {
963  LOG(FATAL) << "Generated invalid code. ";
964  }
965 
966  return {query_func_ptr, row_process};
967 }
#define LOG(tag)
Definition: Logger.h:285
size_t getSharedMemorySize() const
#define LLVM_ALIGN(alignment)
llvm::Function * group_buff_idx(llvm::Module *mod)
bool isWarpSyncRequired(const ExecutorDeviceType) const
#define CHECK(condition)
Definition: Logger.h:291
llvm::Type * get_pointer_element_type(llvm::Value *value)
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)

+ Here is the call graph for this function:

std::tuple<llvm::Function*, llvm::CallInst*> query_template ( llvm::Module *  mod,
const size_t  aggr_col_count,
const bool  hoist_literals,
const bool  is_estimate_query,
const GpuSharedMemoryContext gpu_smem_context 
)

Definition at line 969 of file QueryTemplateGenerator.cpp.

974  {
975  return query_template_impl<llvm::AttributeList>(
976  mod, aggr_col_count, hoist_literals, is_estimate_query, gpu_smem_context);
977 }
template<class Attributes >
std::tuple<llvm::Function*, llvm::CallInst*> query_template_impl ( llvm::Module *  mod,
const size_t  aggr_col_count,
const bool  hoist_literals,
const bool  is_estimate_query,
const GpuSharedMemoryContext gpu_smem_context 
)

If GPU shared memory optimization is disabled, for each aggregate target, threads copy back their aggregate results (stored in registers) back into memory. This process is performed per processed fragment. In the host the final results are reduced (per target, for all threads and all fragments).

If GPU Shared memory optimization is enabled, we properly (atomically) aggregate all thread's results into memory, which makes the final reduction on host much cheaper. Here, we call a noop dummy write back function which will be properly replaced at runtime depending on the target expressions.

Definition at line 208 of file QueryTemplateGenerator.cpp.

References anonymous_namespace{RuntimeFunctions.cpp}::agg_func(), CHECK, logger::FATAL, anonymous_namespace{QueryTemplateGenerator.cpp}::get_pointer_element_type(), anonymous_namespace{QueryTemplateGenerator.cpp}::group_buff_idx(), GpuSharedMemoryContext::isSharedMemoryUsed(), LLVM_ALIGN, LOG, anonymous_namespace{QueryTemplateGenerator.cpp}::pos_start(), anonymous_namespace{QueryTemplateGenerator.cpp}::pos_step(), run_benchmark_import::result, anonymous_namespace{QueryTemplateGenerator.cpp}::row_process(), and to_string().

213  {
214  using namespace llvm;
215 
216  auto func_pos_start = pos_start<Attributes>(mod);
217  CHECK(func_pos_start);
218  auto func_pos_step = pos_step<Attributes>(mod);
219  CHECK(func_pos_step);
220  auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
221  CHECK(func_group_buff_idx);
222  auto func_row_process = row_process<Attributes>(
223  mod, is_estimate_query ? 1 : aggr_col_count, hoist_literals);
224  CHECK(func_row_process);
225 
226  auto i8_type = IntegerType::get(mod->getContext(), 8);
227  auto i32_type = IntegerType::get(mod->getContext(), 32);
228  auto i64_type = IntegerType::get(mod->getContext(), 64);
229  auto pi8_type = PointerType::get(i8_type, 0);
230  auto ppi8_type = PointerType::get(pi8_type, 0);
231  auto pi32_type = PointerType::get(i32_type, 0);
232  auto pi64_type = PointerType::get(i64_type, 0);
233  auto ppi64_type = PointerType::get(pi64_type, 0);
234 
235  std::vector<Type*> query_args;
236  query_args.push_back(ppi8_type); // byte_stream
237  if (hoist_literals) {
238  query_args.push_back(pi8_type); // literals
239  }
240  query_args.push_back(pi64_type); // row_count_ptr
241  query_args.push_back(pi64_type); // frag_row_off_ptr
242  query_args.push_back(pi32_type); // max_matched_ptr
243 
244  query_args.push_back(pi64_type); // agg_init_val
245  query_args.push_back(ppi64_type); // group_by_buffers
246  query_args.push_back(i32_type); // frag_idx
247  query_args.push_back(pi64_type); // join_hash_tables
248  query_args.push_back(pi32_type); // total_matched
249  query_args.push_back(pi32_type); // error_code
250  query_args.push_back(pi8_type); // row_func_mgr
251 
252  FunctionType* query_func_type = FunctionType::get(
253  /*Result=*/Type::getVoidTy(mod->getContext()),
254  /*Params=*/query_args,
255  /*isVarArg=*/false);
256 
257  std::string query_template_name{"query_template"};
258  auto query_func_ptr = mod->getFunction(query_template_name);
259  CHECK(!query_func_ptr);
260 
261  query_func_ptr = Function::Create(
262  /*Type=*/query_func_type,
263  /*Linkage=*/GlobalValue::ExternalLinkage,
264  /*Name=*/query_template_name,
265  mod);
266  query_func_ptr->setCallingConv(CallingConv::C);
267 
268  Attributes query_func_pal;
269  {
270  SmallVector<Attributes, 4> Attrs;
271  Attributes PAS;
272  {
273 #if 14 <= LLVM_VERSION_MAJOR
274  AttrBuilder B(mod->getContext());
275 #else
276  AttrBuilder B;
277 #endif
278  B.addAttribute(Attribute::NoCapture);
279  PAS = Attributes::get(mod->getContext(), 1U, B);
280  }
281 
282  Attrs.push_back(PAS);
283  {
284 #if 14 <= LLVM_VERSION_MAJOR
285  AttrBuilder B(mod->getContext());
286 #else
287  AttrBuilder B;
288 #endif
289  B.addAttribute(Attribute::NoCapture);
290  PAS = Attributes::get(mod->getContext(), 2U, B);
291  }
292 
293  Attrs.push_back(PAS);
294 
295  {
296 #if 14 <= LLVM_VERSION_MAJOR
297  AttrBuilder B(mod->getContext());
298 #else
299  AttrBuilder B;
300 #endif
301  B.addAttribute(Attribute::NoCapture);
302  Attrs.push_back(Attributes::get(mod->getContext(), 3U, B));
303  }
304 
305  {
306 #if 14 <= LLVM_VERSION_MAJOR
307  AttrBuilder B(mod->getContext());
308 #else
309  AttrBuilder B;
310 #endif
311  B.addAttribute(Attribute::NoCapture);
312  Attrs.push_back(Attributes::get(mod->getContext(), 4U, B));
313  }
314 
315  Attrs.push_back(PAS);
316 
317  query_func_pal = Attributes::get(mod->getContext(), Attrs);
318  }
319  query_func_ptr->setAttributes(query_func_pal);
320 
321  Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
322  Value* byte_stream = &*query_arg_it;
323  byte_stream->setName("byte_stream");
324  Value* literals{nullptr};
325  if (hoist_literals) {
326  literals = &*(++query_arg_it);
327  literals->setName("literals");
328  }
329  Value* row_count_ptr = &*(++query_arg_it);
330  row_count_ptr->setName("row_count_ptr");
331  Value* frag_row_off_ptr = &*(++query_arg_it);
332  frag_row_off_ptr->setName("frag_row_off_ptr");
333  Value* max_matched_ptr = &*(++query_arg_it);
334  max_matched_ptr->setName("max_matched_ptr");
335  Value* agg_init_val = &*(++query_arg_it);
336  agg_init_val->setName("agg_init_val");
337  Value* out = &*(++query_arg_it);
338  out->setName("out");
339  Value* frag_idx = &*(++query_arg_it);
340  frag_idx->setName("frag_idx");
341  Value* join_hash_tables = &*(++query_arg_it);
342  join_hash_tables->setName("join_hash_tables");
343  Value* total_matched = &*(++query_arg_it);
344  total_matched->setName("total_matched");
345  Value* error_code = &*(++query_arg_it);
346  error_code->setName("error_code");
347  Value* row_func_mgr = &*(++query_arg_it);
348  row_func_mgr->setName("row_func_mgr");
349 
350  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
351  auto bb_preheader =
352  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
353  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".for.body", query_func_ptr, 0);
354  auto bb_crit_edge =
355  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
356  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
357 
358  // Block (.entry)
359  std::vector<Value*> result_ptr_vec;
360  llvm::CallInst* smem_output_buffer{nullptr};
361  if (!is_estimate_query) {
362  for (size_t i = 0; i < aggr_col_count; ++i) {
363  auto result_ptr = new AllocaInst(i64_type, 0, "result", bb_entry);
364  result_ptr->setAlignment(LLVM_ALIGN(8));
365  result_ptr_vec.push_back(result_ptr);
366  }
367  if (gpu_smem_context.isSharedMemoryUsed()) {
368  auto init_smem_func = mod->getFunction("init_shared_mem");
369  CHECK(init_smem_func);
370  // only one slot per aggregate column is needed, and so we can initialize shared
371  // memory buffer for intermediate results to be exactly like the agg_init_val array
372  smem_output_buffer = CallInst::Create(
373  init_smem_func,
374  std::vector<llvm::Value*>{
375  agg_init_val,
376  llvm::ConstantInt::get(i32_type, aggr_col_count * sizeof(int64_t))},
377  "smem_buffer",
378  bb_entry);
379  }
380  }
381 
382  LoadInst* row_count = new LoadInst(get_pointer_element_type(row_count_ptr),
383  row_count_ptr,
384  "row_count",
385  false,
386  bb_entry);
387  row_count->setAlignment(LLVM_ALIGN(8));
388  row_count->setName("row_count");
389  std::vector<Value*> agg_init_val_vec;
390  if (!is_estimate_query) {
391  for (size_t i = 0; i < aggr_col_count; ++i) {
392  auto idx_lv = ConstantInt::get(i32_type, i);
393  auto agg_init_gep = GetElementPtrInst::CreateInBounds(
394  agg_init_val->getType()->getPointerElementType(),
395  agg_init_val,
396  idx_lv,
397  "",
398  bb_entry);
399  auto agg_init_val = new LoadInst(
400  get_pointer_element_type(agg_init_gep), agg_init_gep, "", false, bb_entry);
401  agg_init_val->setAlignment(LLVM_ALIGN(8));
402  agg_init_val_vec.push_back(agg_init_val);
403  auto init_val_st = new StoreInst(agg_init_val, result_ptr_vec[i], false, bb_entry);
404  init_val_st->setAlignment(LLVM_ALIGN(8));
405  }
406  }
407 
408  CallInst* pos_start = CallInst::Create(func_pos_start, "pos_start", bb_entry);
409  pos_start->setCallingConv(CallingConv::C);
410  pos_start->setTailCall(true);
411  Attributes pos_start_pal;
412  pos_start->setAttributes(pos_start_pal);
413 
414  CallInst* pos_step = CallInst::Create(func_pos_step, "pos_step", bb_entry);
415  pos_step->setCallingConv(CallingConv::C);
416  pos_step->setTailCall(true);
417  Attributes pos_step_pal;
418  pos_step->setAttributes(pos_step_pal);
419 
420  CallInst* group_buff_idx = nullptr;
421  if (!is_estimate_query) {
422  group_buff_idx = CallInst::Create(func_group_buff_idx, "group_buff_idx", bb_entry);
423  group_buff_idx->setCallingConv(CallingConv::C);
424  group_buff_idx->setTailCall(true);
425  Attributes group_buff_idx_pal;
426  group_buff_idx->setAttributes(group_buff_idx_pal);
427  }
428 
429  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
430  ICmpInst* enter_or_not =
431  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
432  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
433 
434  // Block .loop.preheader
435  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
436  BranchInst::Create(bb_forbody, bb_preheader);
437 
438  // Block .forbody
439  Argument* pos_inc_pre = new Argument(i64_type);
440  PHINode* pos = PHINode::Create(i64_type, 2, "pos", bb_forbody);
441  pos->addIncoming(pos_start_i64, bb_preheader);
442  pos->addIncoming(pos_inc_pre, bb_forbody);
443 
444  std::vector<Value*> row_process_params;
445  row_process_params.insert(
446  row_process_params.end(), result_ptr_vec.begin(), result_ptr_vec.end());
447  if (is_estimate_query) {
448  row_process_params.push_back(
449  new LoadInst(get_pointer_element_type(out), out, "", false, bb_forbody));
450  }
451  row_process_params.push_back(agg_init_val);
452  row_process_params.push_back(pos);
453  row_process_params.push_back(frag_row_off_ptr);
454  row_process_params.push_back(row_count_ptr);
455  if (hoist_literals) {
456  CHECK(literals);
457  row_process_params.push_back(literals);
458  }
459  CallInst* row_process =
460  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
461  row_process->setCallingConv(CallingConv::C);
462  row_process->setTailCall(false);
463  Attributes row_process_pal;
464  row_process->setAttributes(row_process_pal);
465 
466  BinaryOperator* pos_inc =
467  BinaryOperator::CreateNSW(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
468  ICmpInst* loop_or_exit =
469  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
470  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
471 
472  // Block ._crit_edge
473  std::vector<Instruction*> result_vec_pre;
474  if (!is_estimate_query) {
475  for (size_t i = 0; i < aggr_col_count; ++i) {
476  auto result = new LoadInst(get_pointer_element_type(result_ptr_vec[i]),
477  result_ptr_vec[i],
478  ".pre.result",
479  false,
480  bb_crit_edge);
481  result->setAlignment(LLVM_ALIGN(8));
482  result_vec_pre.push_back(result);
483  }
484  }
485 
486  BranchInst::Create(bb_exit, bb_crit_edge);
487 
488  // Block .exit
500  if (!is_estimate_query) {
501  std::vector<PHINode*> result_vec;
502  for (int64_t i = aggr_col_count - 1; i >= 0; --i) {
503  auto result =
504  PHINode::Create(IntegerType::get(mod->getContext(), 64), 2, "", bb_exit);
505  result->addIncoming(result_vec_pre[i], bb_crit_edge);
506  result->addIncoming(agg_init_val_vec[i], bb_entry);
507  result_vec.insert(result_vec.begin(), result);
508  }
509 
510  for (size_t i = 0; i < aggr_col_count; ++i) {
511  auto col_idx = ConstantInt::get(i32_type, i);
512  if (gpu_smem_context.isSharedMemoryUsed()) {
513  auto target_addr = GetElementPtrInst::CreateInBounds(
514  smem_output_buffer->getType()->getPointerElementType(),
515  smem_output_buffer,
516  col_idx,
517  "",
518  bb_exit);
519  // TODO: generalize this once we want to support other types of aggregate
520  // functions besides COUNT.
521  auto agg_func = mod->getFunction("agg_sum_shared");
522  CHECK(agg_func);
523  CallInst::Create(
524  agg_func, std::vector<llvm::Value*>{target_addr, result_vec[i]}, "", bb_exit);
525  } else {
526  auto out_gep = GetElementPtrInst::CreateInBounds(
527  out->getType()->getPointerElementType(), out, col_idx, "", bb_exit);
528  auto col_buffer =
529  new LoadInst(get_pointer_element_type(out_gep), out_gep, "", false, bb_exit);
530  col_buffer->setAlignment(LLVM_ALIGN(8));
531  auto slot_idx = BinaryOperator::CreateAdd(
533  BinaryOperator::CreateMul(frag_idx, pos_step, "", bb_exit),
534  "",
535  bb_exit);
536  auto target_addr = GetElementPtrInst::CreateInBounds(
537  col_buffer->getType()->getPointerElementType(),
538  col_buffer,
539  slot_idx,
540  "",
541  bb_exit);
542  StoreInst* result_st = new StoreInst(result_vec[i], target_addr, false, bb_exit);
543  result_st->setAlignment(LLVM_ALIGN(8));
544  }
545  }
546  if (gpu_smem_context.isSharedMemoryUsed()) {
547  // final reduction of results from shared memory buffer back into global memory.
548  auto sync_thread_func = mod->getFunction("sync_threadblock");
549  CHECK(sync_thread_func);
550  CallInst::Create(sync_thread_func, std::vector<llvm::Value*>{}, "", bb_exit);
551  auto reduce_smem_to_gmem_func = mod->getFunction("write_back_non_grouped_agg");
552  CHECK(reduce_smem_to_gmem_func);
553  // each thread reduce the aggregate target corresponding to its own thread ID.
554  // If there are more targets than threads we do not currently use shared memory
555  // optimization. This can be relaxed if necessary
556  for (size_t i = 0; i < aggr_col_count; i++) {
557  auto out_gep =
558  GetElementPtrInst::CreateInBounds(out->getType()->getPointerElementType(),
559  out,
560  ConstantInt::get(i32_type, i),
561  "",
562  bb_exit);
563  auto gmem_output_buffer = new LoadInst(get_pointer_element_type(out_gep),
564  out_gep,
565  "gmem_output_buffer_" + std::to_string(i),
566  false,
567  bb_exit);
568  CallInst::Create(
569  reduce_smem_to_gmem_func,
570  std::vector<llvm::Value*>{
571  smem_output_buffer, gmem_output_buffer, ConstantInt::get(i32_type, i)},
572  "",
573  bb_exit);
574  }
575  }
576  }
577 
578  ReturnInst::Create(mod->getContext(), bb_exit);
579 
580  // Resolve Forward References
581  pos_inc_pre->replaceAllUsesWith(pos_inc);
582  delete pos_inc_pre;
583 
584  if (verifyFunction(*query_func_ptr)) {
585  LOG(FATAL) << "Generated invalid code. ";
586  }
587 
588  return {query_func_ptr, row_process};
589 }
#define LOG(tag)
Definition: Logger.h:285
AGG_TYPE agg_func(AGG_TYPE const lhs, AGG_TYPE const rhs)
#define LLVM_ALIGN(alignment)
std::string to_string(char const *&&v)
llvm::Function * group_buff_idx(llvm::Module *mod)
#define CHECK(condition)
Definition: Logger.h:291
llvm::Type * get_pointer_element_type(llvm::Value *value)
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)

+ Here is the call graph for this function: