602 using namespace llvm;
604 auto func_pos_start = pos_start<Attributes>(mod);
605 CHECK(func_pos_start);
606 auto func_pos_step = pos_step<Attributes>(mod);
607 CHECK(func_pos_step);
608 auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
609 CHECK(func_group_buff_idx);
610 auto func_row_process = row_process<Attributes>(mod, 0, hoist_literals);
611 CHECK(func_row_process);
613 ? mod->getFunction(
"init_shared_mem")
614 : mod->getFunction(
"init_shared_mem_nop");
615 CHECK(func_init_shared_mem);
617 auto func_write_back = mod->getFunction(
"write_back_nop");
618 CHECK(func_write_back);
620 auto i32_type = IntegerType::get(mod->getContext(), 32);
621 auto i64_type = IntegerType::get(mod->getContext(), 64);
622 auto pi8_type = PointerType::get(IntegerType::get(mod->getContext(), 8), 0);
623 auto pi32_type = PointerType::get(i32_type, 0);
624 auto pi64_type = PointerType::get(i64_type, 0);
625 auto ppi64_type = PointerType::get(pi64_type, 0);
626 auto ppi8_type = PointerType::get(pi8_type, 0);
628 std::vector<Type*> query_args;
629 query_args.push_back(ppi8_type);
630 if (hoist_literals) {
631 query_args.push_back(pi8_type);
633 query_args.push_back(pi64_type);
634 query_args.push_back(pi64_type);
635 query_args.push_back(pi32_type);
636 query_args.push_back(pi64_type);
638 query_args.push_back(ppi64_type);
639 query_args.push_back(i32_type);
640 query_args.push_back(pi64_type);
641 query_args.push_back(pi32_type);
642 query_args.push_back(pi32_type);
643 query_args.push_back(pi8_type);
645 FunctionType* query_func_type = FunctionType::get(
646 Type::getVoidTy(mod->getContext()),
650 std::string query_name{
"query_group_by_template"};
651 auto query_func_ptr = mod->getFunction(query_name);
652 CHECK(!query_func_ptr);
654 query_func_ptr = Function::Create(
656 GlobalValue::ExternalLinkage,
657 "query_group_by_template",
660 query_func_ptr->setCallingConv(CallingConv::C);
662 Attributes query_func_pal;
664 SmallVector<Attributes, 4> Attrs;
667 #if 14 <= LLVM_VERSION_MAJOR
668 AttrBuilder B(mod->getContext());
672 B.addAttribute(Attribute::ReadNone);
673 B.addAttribute(Attribute::NoCapture);
674 PAS = Attributes::get(mod->getContext(), 1U, B);
677 Attrs.push_back(PAS);
679 #if 14 <= LLVM_VERSION_MAJOR
680 AttrBuilder B(mod->getContext());
684 B.addAttribute(Attribute::ReadOnly);
685 B.addAttribute(Attribute::NoCapture);
686 PAS = Attributes::get(mod->getContext(), 2U, B);
689 Attrs.push_back(PAS);
691 #if 14 <= LLVM_VERSION_MAJOR
692 AttrBuilder B(mod->getContext());
696 B.addAttribute(Attribute::ReadNone);
697 B.addAttribute(Attribute::NoCapture);
698 PAS = Attributes::get(mod->getContext(), 3U, B);
701 Attrs.push_back(PAS);
703 #if 14 <= LLVM_VERSION_MAJOR
704 AttrBuilder B(mod->getContext());
708 B.addAttribute(Attribute::ReadOnly);
709 B.addAttribute(Attribute::NoCapture);
710 PAS = Attributes::get(mod->getContext(), 4U, B);
713 Attrs.push_back(PAS);
715 #if 14 <= LLVM_VERSION_MAJOR
716 AttrBuilder B(mod->getContext());
720 B.addAttribute(Attribute::UWTable);
721 PAS = Attributes::get(mod->getContext(), ~0U, B);
724 Attrs.push_back(PAS);
726 query_func_pal = Attributes::get(mod->getContext(), Attrs);
728 query_func_ptr->setAttributes(query_func_pal);
730 Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
731 Value* byte_stream = &*query_arg_it;
732 byte_stream->setName(
"byte_stream");
733 Value* literals{
nullptr};
734 if (hoist_literals) {
735 literals = &*(++query_arg_it);
737 literals->setName(
"literals");
739 Value* row_count_ptr = &*(++query_arg_it);
740 row_count_ptr->setName(
"row_count_ptr");
741 Value* frag_row_off_ptr = &*(++query_arg_it);
742 frag_row_off_ptr->setName(
"frag_row_off_ptr");
743 Value* max_matched_ptr = &*(++query_arg_it);
744 max_matched_ptr->setName(
"max_matched_ptr");
745 Value* agg_init_val = &*(++query_arg_it);
746 agg_init_val->setName(
"agg_init_val");
747 Value* group_by_buffers = &*(++query_arg_it);
748 group_by_buffers->setName(
"group_by_buffers");
749 Value* frag_idx = &*(++query_arg_it);
750 frag_idx->setName(
"frag_idx");
751 Value* join_hash_tables = &*(++query_arg_it);
752 join_hash_tables->setName(
"join_hash_tables");
753 Value* total_matched = &*(++query_arg_it);
754 total_matched->setName(
"total_matched");
755 Value* error_code = &*(++query_arg_it);
756 error_code->setName(
"error_code");
757 Value* row_func_mgr = &*(++query_arg_it);
758 row_func_mgr->setName(
"row_func_mgr");
760 auto bb_entry = BasicBlock::Create(mod->getContext(),
".entry", query_func_ptr, 0);
762 BasicBlock::Create(mod->getContext(),
".loop.preheader", query_func_ptr, 0);
763 auto bb_forbody = BasicBlock::Create(mod->getContext(),
".forbody", query_func_ptr, 0);
765 BasicBlock::Create(mod->getContext(),
"._crit_edge", query_func_ptr, 0);
766 auto bb_exit = BasicBlock::Create(mod->getContext(),
".exit", query_func_ptr, 0);
769 LoadInst* row_count =
new LoadInst(
772 row_count->setName(
"row_count");
774 LoadInst* max_matched =
new LoadInst(
778 auto crt_matched_ptr =
new AllocaInst(i32_type, 0,
"crt_matched", bb_entry);
779 auto old_total_matched_ptr =
new AllocaInst(i32_type, 0,
"old_total_matched", bb_entry);
780 CallInst*
pos_start = CallInst::Create(func_pos_start,
"", bb_entry);
781 pos_start->setCallingConv(CallingConv::C);
783 Attributes pos_start_pal;
786 CallInst*
pos_step = CallInst::Create(func_pos_step,
"", bb_entry);
787 pos_step->setCallingConv(CallingConv::C);
789 Attributes pos_step_pal;
790 pos_step->setAttributes(pos_step_pal);
792 CallInst* group_buff_idx_call = CallInst::Create(func_group_buff_idx,
"", bb_entry);
793 group_buff_idx_call->setCallingConv(CallingConv::C);
794 group_buff_idx_call->setTailCall(
true);
795 Attributes group_buff_idx_pal;
796 group_buff_idx_call->setAttributes(group_buff_idx_pal);
799 const PointerType* Ty = dyn_cast<PointerType>(group_by_buffers->getType());
802 Value* varlen_output_buffer{
nullptr};
806 auto varlen_output_buffer_gep = GetElementPtrInst::Create(
807 Ty->getPointerElementType(),
809 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 0),
812 varlen_output_buffer =
814 varlen_output_buffer_gep,
815 "varlen_output_buffer",
822 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 1),
823 "group_buff_idx_varlen_offset",
826 varlen_output_buffer =
827 ConstantPointerNull::get(Type::getInt64PtrTy(mod->getContext()));
829 CHECK(varlen_output_buffer);
831 CastInst* pos_start_i64 =
new SExtInst(
pos_start, i64_type,
"", bb_entry);
832 GetElementPtrInst* group_by_buffers_gep = GetElementPtrInst::Create(
833 Ty->getPointerElementType(), group_by_buffers,
group_buff_idx,
"", bb_entry);
835 group_by_buffers_gep,
839 col_buffer->setName(
"col_buffer");
842 llvm::ConstantInt* shared_mem_bytes_lv =
845 llvm::CallInst* result_buffer =
846 CallInst::Create(func_init_shared_mem,
847 std::vector<llvm::Value*>{col_buffer, shared_mem_bytes_lv},
851 ICmpInst* enter_or_not =
852 new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count,
"");
853 BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
856 CastInst* pos_step_i64 =
new SExtInst(
pos_step, i64_type,
"", bb_preheader);
857 BranchInst::Create(bb_forbody, bb_preheader);
861 PHINode* pos = PHINode::Create(i64_type, check_scan_limit ? 3 : 2,
"pos", bb_forbody);
863 std::vector<Value*> row_process_params;
864 row_process_params.push_back(result_buffer);
865 row_process_params.push_back(varlen_output_buffer);
866 row_process_params.push_back(crt_matched_ptr);
867 row_process_params.push_back(total_matched);
868 row_process_params.push_back(old_total_matched_ptr);
869 row_process_params.push_back(max_matched_ptr);
870 row_process_params.push_back(agg_init_val);
871 row_process_params.push_back(pos);
872 row_process_params.push_back(frag_row_off_ptr);
873 row_process_params.push_back(row_count_ptr);
874 if (hoist_literals) {
876 row_process_params.push_back(literals);
878 if (check_scan_limit) {
879 new StoreInst(ConstantInt::get(IntegerType::get(mod->getContext(), 32), 0),
884 CallInst::Create(func_row_process, row_process_params,
"", bb_forbody);
887 Attributes row_process_pal;
892 auto func_sync_warp_protected = mod->getFunction(
"sync_warp_protected");
893 CHECK(func_sync_warp_protected);
894 CallInst::Create(func_sync_warp_protected,
895 std::vector<llvm::Value*>{pos, row_count},
901 BinaryOperator::Create(Instruction::Add, pos, pos_step_i64,
"", bb_forbody);
902 ICmpInst* loop_or_exit =
903 new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count,
"");
904 if (check_scan_limit) {
910 auto filter_match = BasicBlock::Create(
911 mod->getContext(),
"filter_match", query_func_ptr, bb_crit_edge);
912 llvm::Value* new_total_matched =
914 old_total_matched_ptr,
919 BinaryOperator::CreateAdd(new_total_matched, crt_matched,
"", filter_match);
920 CHECK(new_total_matched);
921 ICmpInst* limit_not_reached =
new ICmpInst(*filter_match,
925 "limit_not_reached");
929 BinaryOperator::Create(
930 BinaryOperator::And, loop_or_exit, limit_not_reached,
"", filter_match),
932 auto filter_nomatch = BasicBlock::Create(
933 mod->getContext(),
"filter_nomatch", query_func_ptr, bb_crit_edge);
934 BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, filter_nomatch);
935 ICmpInst* crt_matched_nz =
new ICmpInst(
936 *bb_forbody, ICmpInst::ICMP_NE, crt_matched, ConstantInt::get(i32_type, 0),
"");
937 BranchInst::Create(filter_match, filter_nomatch, crt_matched_nz, bb_forbody);
938 pos->addIncoming(pos_start_i64, bb_preheader);
939 pos->addIncoming(pos_pre, filter_match);
940 pos->addIncoming(pos_pre, filter_nomatch);
942 pos->addIncoming(pos_start_i64, bb_preheader);
943 pos->addIncoming(pos_pre, bb_forbody);
944 BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
948 BranchInst::Create(bb_exit, bb_crit_edge);
951 CallInst::Create(func_write_back,
952 std::vector<Value*>{col_buffer, result_buffer, shared_mem_bytes_lv},
956 ReturnInst::Create(mod->getContext(), bb_exit);
959 pos_pre->replaceAllUsesWith(pos_inc);
962 if (verifyFunction(*query_func_ptr, &llvm::errs())) {
963 LOG(
FATAL) <<
"Generated invalid code. ";
llvm::Function * pos_start(llvm::Module *mod)
bool hasVarlenOutput() const
size_t getSharedMemorySize() const
#define LLVM_ALIGN(alignment)
bool isSharedMemoryUsed() const
llvm::Function * group_buff_idx(llvm::Module *mod)
llvm::Function * pos_step(llvm::Module *mod)
bool isWarpSyncRequired(const ExecutorDeviceType) const
llvm::Type * get_pointer_element_type(llvm::Value *value)
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)