21 #include <llvm/IR/Constants.h>
22 #include <llvm/IR/IRBuilder.h>
23 #include <llvm/IR/Instructions.h>
24 #include <llvm/IR/Verifier.h>
34 auto type = value->getType();
38 return pointer_type->getPointerElementType();
41 template <
class Attributes>
45 std::vector<Type*> func_args;
46 FunctionType* func_type = FunctionType::get(
47 IntegerType::get(mod->getContext(), 32),
51 auto func_ptr = mod->getFunction(name);
53 func_ptr = Function::Create(
55 GlobalValue::ExternalLinkage,
58 func_ptr->setCallingConv(CallingConv::C);
63 SmallVector<Attributes, 4> Attrs;
66 #if 14 <= LLVM_VERSION_MAJOR
67 AttrBuilder B(mod->getContext());
71 PAS = Attributes::get(mod->getContext(), ~0U, B);
75 func_pal = Attributes::get(mod->getContext(), Attrs);
77 func_ptr->setAttributes(func_pal);
82 template <
class Attributes>
84 return default_func_builder<Attributes>(mod,
"pos_start");
87 template <
class Attributes>
89 return default_func_builder<Attributes>(mod,
"group_buff_idx");
92 template <
class Attributes>
96 std::vector<Type*> func_args;
97 FunctionType* func_type = FunctionType::get(
98 IntegerType::get(mod->getContext(), 32),
102 auto func_ptr = mod->getFunction(
"pos_step");
104 func_ptr = Function::Create(
106 GlobalValue::ExternalLinkage,
109 func_ptr->setCallingConv(CallingConv::C);
114 SmallVector<Attributes, 4> Attrs;
117 #if 14 <= LLVM_VERSION_MAJOR
118 AttrBuilder B(mod->getContext());
122 PAS = Attributes::get(mod->getContext(), ~0U, B);
125 Attrs.push_back(PAS);
126 func_pal = Attributes::get(mod->getContext(), Attrs);
128 func_ptr->setAttributes(func_pal);
133 template <
class Attributes>
135 const size_t aggr_col_count,
136 const bool hoist_literals) {
137 using namespace llvm;
139 std::vector<Type*> func_args;
140 auto i8_type = IntegerType::get(mod->getContext(), 8);
141 auto i32_type = IntegerType::get(mod->getContext(), 32);
142 auto i64_type = IntegerType::get(mod->getContext(), 64);
143 auto pi32_type = PointerType::get(i32_type, 0);
144 auto pi64_type = PointerType::get(i64_type, 0);
146 if (aggr_col_count) {
147 for (
size_t i = 0; i < aggr_col_count; ++i) {
148 func_args.push_back(pi64_type);
151 func_args.push_back(pi64_type);
152 func_args.push_back(pi64_type);
153 func_args.push_back(pi32_type);
154 func_args.push_back(pi32_type);
155 func_args.push_back(pi32_type);
156 func_args.push_back(pi32_type);
159 func_args.push_back(pi64_type);
161 func_args.push_back(i64_type);
162 func_args.push_back(pi64_type);
163 func_args.push_back(pi64_type);
164 if (hoist_literals) {
165 func_args.push_back(PointerType::get(i8_type, 0));
167 FunctionType* func_type = FunctionType::get(
172 std::string func_name{
"row_process"};
173 auto func_ptr = mod->getFunction(func_name);
176 func_ptr = Function::Create(
178 GlobalValue::ExternalLinkage,
181 func_ptr->setCallingConv(CallingConv::C);
185 SmallVector<Attributes, 4> Attrs;
188 #if 14 <= LLVM_VERSION_MAJOR
189 AttrBuilder B(mod->getContext());
193 PAS = Attributes::get(mod->getContext(), ~0U, B);
196 Attrs.push_back(PAS);
197 func_pal = Attributes::get(mod->getContext(), Attrs);
199 func_ptr->setAttributes(func_pal);
207 template <
class Attributes>
210 const size_t aggr_col_count,
211 const bool hoist_literals,
212 const bool is_estimate_query,
214 using namespace llvm;
216 auto func_pos_start = pos_start<Attributes>(mod);
217 CHECK(func_pos_start);
218 auto func_pos_step = pos_step<Attributes>(mod);
219 CHECK(func_pos_step);
220 auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
221 CHECK(func_group_buff_idx);
222 auto func_row_process = row_process<Attributes>(
223 mod, is_estimate_query ? 1 : aggr_col_count, hoist_literals);
224 CHECK(func_row_process);
226 auto i8_type = IntegerType::get(mod->getContext(), 8);
227 auto i32_type = IntegerType::get(mod->getContext(), 32);
228 auto i64_type = IntegerType::get(mod->getContext(), 64);
229 auto pi8_type = PointerType::get(i8_type, 0);
230 auto ppi8_type = PointerType::get(pi8_type, 0);
231 auto pi32_type = PointerType::get(i32_type, 0);
232 auto pi64_type = PointerType::get(i64_type, 0);
233 auto ppi64_type = PointerType::get(pi64_type, 0);
235 std::vector<Type*> query_args;
236 query_args.push_back(ppi8_type);
237 if (hoist_literals) {
238 query_args.push_back(pi8_type);
240 query_args.push_back(pi64_type);
241 query_args.push_back(pi64_type);
242 query_args.push_back(pi32_type);
244 query_args.push_back(pi64_type);
245 query_args.push_back(ppi64_type);
246 query_args.push_back(i32_type);
247 query_args.push_back(pi64_type);
248 query_args.push_back(pi32_type);
249 query_args.push_back(pi32_type);
250 query_args.push_back(pi8_type);
252 FunctionType* query_func_type = FunctionType::get(
253 Type::getVoidTy(mod->getContext()),
257 std::string query_template_name{
"query_template"};
258 auto query_func_ptr = mod->getFunction(query_template_name);
259 CHECK(!query_func_ptr);
261 query_func_ptr = Function::Create(
263 GlobalValue::ExternalLinkage,
266 query_func_ptr->setCallingConv(CallingConv::C);
268 Attributes query_func_pal;
270 SmallVector<Attributes, 4> Attrs;
273 #if 14 <= LLVM_VERSION_MAJOR
274 AttrBuilder B(mod->getContext());
278 B.addAttribute(Attribute::NoCapture);
279 PAS = Attributes::get(mod->getContext(), 1U, B);
282 Attrs.push_back(PAS);
284 #if 14 <= LLVM_VERSION_MAJOR
285 AttrBuilder B(mod->getContext());
289 B.addAttribute(Attribute::NoCapture);
290 PAS = Attributes::get(mod->getContext(), 2U, B);
293 Attrs.push_back(PAS);
296 #if 14 <= LLVM_VERSION_MAJOR
297 AttrBuilder B(mod->getContext());
301 B.addAttribute(Attribute::NoCapture);
302 Attrs.push_back(Attributes::get(mod->getContext(), 3U, B));
306 #if 14 <= LLVM_VERSION_MAJOR
307 AttrBuilder B(mod->getContext());
311 B.addAttribute(Attribute::NoCapture);
312 Attrs.push_back(Attributes::get(mod->getContext(), 4U, B));
315 Attrs.push_back(PAS);
317 query_func_pal = Attributes::get(mod->getContext(), Attrs);
319 query_func_ptr->setAttributes(query_func_pal);
321 Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
322 Value* byte_stream = &*query_arg_it;
323 byte_stream->setName(
"byte_stream");
324 Value* literals{
nullptr};
325 if (hoist_literals) {
326 literals = &*(++query_arg_it);
327 literals->setName(
"literals");
329 Value* row_count_ptr = &*(++query_arg_it);
330 row_count_ptr->setName(
"row_count_ptr");
331 Value* frag_row_off_ptr = &*(++query_arg_it);
332 frag_row_off_ptr->setName(
"frag_row_off_ptr");
333 Value* max_matched_ptr = &*(++query_arg_it);
334 max_matched_ptr->setName(
"max_matched_ptr");
335 Value* agg_init_val = &*(++query_arg_it);
336 agg_init_val->setName(
"agg_init_val");
337 Value* out = &*(++query_arg_it);
339 Value* frag_idx = &*(++query_arg_it);
340 frag_idx->setName(
"frag_idx");
341 Value* join_hash_tables = &*(++query_arg_it);
342 join_hash_tables->setName(
"join_hash_tables");
343 Value* total_matched = &*(++query_arg_it);
344 total_matched->setName(
"total_matched");
345 Value* error_code = &*(++query_arg_it);
346 error_code->setName(
"error_code");
347 Value* row_func_mgr = &*(++query_arg_it);
348 row_func_mgr->setName(
"row_func_mgr");
350 auto bb_entry = BasicBlock::Create(mod->getContext(),
".entry", query_func_ptr, 0);
352 BasicBlock::Create(mod->getContext(),
".loop.preheader", query_func_ptr, 0);
353 auto bb_forbody = BasicBlock::Create(mod->getContext(),
".for.body", query_func_ptr, 0);
355 BasicBlock::Create(mod->getContext(),
"._crit_edge", query_func_ptr, 0);
356 auto bb_exit = BasicBlock::Create(mod->getContext(),
".exit", query_func_ptr, 0);
359 std::vector<Value*> result_ptr_vec;
360 llvm::CallInst* smem_output_buffer{
nullptr};
361 if (!is_estimate_query) {
362 for (
size_t i = 0; i < aggr_col_count; ++i) {
363 auto result_ptr =
new AllocaInst(i64_type, 0,
"result", bb_entry);
365 result_ptr_vec.push_back(result_ptr);
368 auto init_smem_func = mod->getFunction(
"init_shared_mem");
369 CHECK(init_smem_func);
372 smem_output_buffer = CallInst::Create(
374 std::vector<llvm::Value*>{
376 llvm::ConstantInt::get(i32_type, aggr_col_count *
sizeof(int64_t))},
388 row_count->setName(
"row_count");
389 std::vector<Value*> agg_init_val_vec;
390 if (!is_estimate_query) {
391 for (
size_t i = 0; i < aggr_col_count; ++i) {
392 auto idx_lv = ConstantInt::get(i32_type, i);
393 auto agg_init_gep = GetElementPtrInst::CreateInBounds(
394 agg_init_val->getType()->getPointerElementType(),
399 auto agg_init_val =
new LoadInst(
402 agg_init_val_vec.push_back(agg_init_val);
403 auto init_val_st =
new StoreInst(agg_init_val, result_ptr_vec[i],
false, bb_entry);
408 CallInst*
pos_start = CallInst::Create(func_pos_start,
"pos_start", bb_entry);
409 pos_start->setCallingConv(CallingConv::C);
411 Attributes pos_start_pal;
414 CallInst*
pos_step = CallInst::Create(func_pos_step,
"pos_step", bb_entry);
415 pos_step->setCallingConv(CallingConv::C);
417 Attributes pos_step_pal;
418 pos_step->setAttributes(pos_step_pal);
421 if (!is_estimate_query) {
422 group_buff_idx = CallInst::Create(func_group_buff_idx,
"group_buff_idx", bb_entry);
425 Attributes group_buff_idx_pal;
429 CastInst* pos_start_i64 =
new SExtInst(
pos_start, i64_type,
"", bb_entry);
430 ICmpInst* enter_or_not =
431 new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count,
"");
432 BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
435 CastInst* pos_step_i64 =
new SExtInst(
pos_step, i64_type,
"", bb_preheader);
436 BranchInst::Create(bb_forbody, bb_preheader);
440 PHINode* pos = PHINode::Create(i64_type, 2,
"pos", bb_forbody);
441 pos->addIncoming(pos_start_i64, bb_preheader);
442 pos->addIncoming(pos_inc_pre, bb_forbody);
444 std::vector<Value*> row_process_params;
445 row_process_params.insert(
446 row_process_params.end(), result_ptr_vec.begin(), result_ptr_vec.end());
447 if (is_estimate_query) {
448 row_process_params.push_back(
451 row_process_params.push_back(agg_init_val);
452 row_process_params.push_back(pos);
453 row_process_params.push_back(frag_row_off_ptr);
454 row_process_params.push_back(row_count_ptr);
455 if (hoist_literals) {
457 row_process_params.push_back(literals);
460 CallInst::Create(func_row_process, row_process_params,
"", bb_forbody);
463 Attributes row_process_pal;
467 BinaryOperator::CreateNSW(Instruction::Add, pos, pos_step_i64,
"", bb_forbody);
468 ICmpInst* loop_or_exit =
469 new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count,
"");
470 BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
473 std::vector<Instruction*> result_vec_pre;
474 if (!is_estimate_query) {
475 for (
size_t i = 0; i < aggr_col_count; ++i) {
482 result_vec_pre.push_back(
result);
486 BranchInst::Create(bb_exit, bb_crit_edge);
500 if (!is_estimate_query) {
501 std::vector<PHINode*> result_vec;
502 for (int64_t i = aggr_col_count - 1; i >= 0; --i) {
504 PHINode::Create(IntegerType::get(mod->getContext(), 64), 2,
"", bb_exit);
505 result->addIncoming(result_vec_pre[i], bb_crit_edge);
506 result->addIncoming(agg_init_val_vec[i], bb_entry);
507 result_vec.insert(result_vec.begin(),
result);
510 for (
size_t i = 0; i < aggr_col_count; ++i) {
511 auto col_idx = ConstantInt::get(i32_type, i);
513 auto target_addr = GetElementPtrInst::CreateInBounds(
514 smem_output_buffer->getType()->getPointerElementType(),
521 auto agg_func = mod->getFunction(
"agg_sum_shared");
524 agg_func, std::vector<llvm::Value*>{target_addr, result_vec[i]},
"", bb_exit);
526 auto out_gep = GetElementPtrInst::CreateInBounds(
527 out->getType()->getPointerElementType(), out, col_idx,
"", bb_exit);
531 auto slot_idx = BinaryOperator::CreateAdd(
533 BinaryOperator::CreateMul(frag_idx,
pos_step,
"", bb_exit),
536 auto target_addr = GetElementPtrInst::CreateInBounds(
537 col_buffer->getType()->getPointerElementType(),
542 StoreInst* result_st =
new StoreInst(result_vec[i], target_addr,
false, bb_exit);
548 auto sync_thread_func = mod->getFunction(
"sync_threadblock");
549 CHECK(sync_thread_func);
550 CallInst::Create(sync_thread_func, std::vector<llvm::Value*>{},
"", bb_exit);
551 auto reduce_smem_to_gmem_func = mod->getFunction(
"write_back_non_grouped_agg");
552 CHECK(reduce_smem_to_gmem_func);
556 for (
size_t i = 0; i < aggr_col_count; i++) {
558 GetElementPtrInst::CreateInBounds(out->getType()->getPointerElementType(),
560 ConstantInt::get(i32_type, i),
569 reduce_smem_to_gmem_func,
570 std::vector<llvm::Value*>{
571 smem_output_buffer, gmem_output_buffer, ConstantInt::get(i32_type, i)},
578 ReturnInst::Create(mod->getContext(), bb_exit);
581 pos_inc_pre->replaceAllUsesWith(pos_inc);
584 if (verifyFunction(*query_func_ptr)) {
585 LOG(
FATAL) <<
"Generated invalid code. ";
591 template <
class Attributes>
594 const bool hoist_literals,
597 const bool check_scan_limit,
602 using namespace llvm;
604 auto func_pos_start = pos_start<Attributes>(mod);
605 CHECK(func_pos_start);
606 auto func_pos_step = pos_step<Attributes>(mod);
607 CHECK(func_pos_step);
608 auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
609 CHECK(func_group_buff_idx);
610 auto func_row_process = row_process<Attributes>(mod, 0, hoist_literals);
611 CHECK(func_row_process);
613 ? mod->getFunction(
"init_shared_mem")
614 : mod->getFunction(
"init_shared_mem_nop");
615 CHECK(func_init_shared_mem);
617 auto func_write_back = mod->getFunction(
"write_back_nop");
618 CHECK(func_write_back);
620 auto i32_type = IntegerType::get(mod->getContext(), 32);
621 auto i64_type = IntegerType::get(mod->getContext(), 64);
622 auto pi8_type = PointerType::get(IntegerType::get(mod->getContext(), 8), 0);
623 auto pi32_type = PointerType::get(i32_type, 0);
624 auto pi64_type = PointerType::get(i64_type, 0);
625 auto ppi64_type = PointerType::get(pi64_type, 0);
626 auto ppi8_type = PointerType::get(pi8_type, 0);
628 std::vector<Type*> query_args;
629 query_args.push_back(ppi8_type);
630 if (hoist_literals) {
631 query_args.push_back(pi8_type);
633 query_args.push_back(pi64_type);
634 query_args.push_back(pi64_type);
635 query_args.push_back(pi32_type);
636 query_args.push_back(pi64_type);
638 query_args.push_back(ppi64_type);
639 query_args.push_back(i32_type);
640 query_args.push_back(pi64_type);
641 query_args.push_back(pi32_type);
642 query_args.push_back(pi32_type);
643 query_args.push_back(pi8_type);
645 FunctionType* query_func_type = FunctionType::get(
646 Type::getVoidTy(mod->getContext()),
650 std::string query_name{
"query_group_by_template"};
651 auto query_func_ptr = mod->getFunction(query_name);
652 CHECK(!query_func_ptr);
654 query_func_ptr = Function::Create(
656 GlobalValue::ExternalLinkage,
657 "query_group_by_template",
660 query_func_ptr->setCallingConv(CallingConv::C);
662 Attributes query_func_pal;
664 SmallVector<Attributes, 4> Attrs;
667 #if 14 <= LLVM_VERSION_MAJOR
668 AttrBuilder B(mod->getContext());
672 B.addAttribute(Attribute::ReadNone);
673 B.addAttribute(Attribute::NoCapture);
674 PAS = Attributes::get(mod->getContext(), 1U, B);
677 Attrs.push_back(PAS);
679 #if 14 <= LLVM_VERSION_MAJOR
680 AttrBuilder B(mod->getContext());
684 B.addAttribute(Attribute::ReadOnly);
685 B.addAttribute(Attribute::NoCapture);
686 PAS = Attributes::get(mod->getContext(), 2U, B);
689 Attrs.push_back(PAS);
691 #if 14 <= LLVM_VERSION_MAJOR
692 AttrBuilder B(mod->getContext());
696 B.addAttribute(Attribute::ReadNone);
697 B.addAttribute(Attribute::NoCapture);
698 PAS = Attributes::get(mod->getContext(), 3U, B);
701 Attrs.push_back(PAS);
703 #if 14 <= LLVM_VERSION_MAJOR
704 AttrBuilder B(mod->getContext());
708 B.addAttribute(Attribute::ReadOnly);
709 B.addAttribute(Attribute::NoCapture);
710 PAS = Attributes::get(mod->getContext(), 4U, B);
713 Attrs.push_back(PAS);
715 #if 14 <= LLVM_VERSION_MAJOR
716 AttrBuilder B(mod->getContext());
720 B.addAttribute(Attribute::UWTable);
721 PAS = Attributes::get(mod->getContext(), ~0U, B);
724 Attrs.push_back(PAS);
726 query_func_pal = Attributes::get(mod->getContext(), Attrs);
728 query_func_ptr->setAttributes(query_func_pal);
730 Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
731 Value* byte_stream = &*query_arg_it;
732 byte_stream->setName(
"byte_stream");
733 Value* literals{
nullptr};
734 if (hoist_literals) {
735 literals = &*(++query_arg_it);
737 literals->setName(
"literals");
739 Value* row_count_ptr = &*(++query_arg_it);
740 row_count_ptr->setName(
"row_count_ptr");
741 Value* frag_row_off_ptr = &*(++query_arg_it);
742 frag_row_off_ptr->setName(
"frag_row_off_ptr");
743 Value* max_matched_ptr = &*(++query_arg_it);
744 max_matched_ptr->setName(
"max_matched_ptr");
745 Value* agg_init_val = &*(++query_arg_it);
746 agg_init_val->setName(
"agg_init_val");
747 Value* group_by_buffers = &*(++query_arg_it);
748 group_by_buffers->setName(
"group_by_buffers");
749 Value* frag_idx = &*(++query_arg_it);
750 frag_idx->setName(
"frag_idx");
751 Value* join_hash_tables = &*(++query_arg_it);
752 join_hash_tables->setName(
"join_hash_tables");
753 Value* total_matched = &*(++query_arg_it);
754 total_matched->setName(
"total_matched");
755 Value* error_code = &*(++query_arg_it);
756 error_code->setName(
"error_code");
757 Value* row_func_mgr = &*(++query_arg_it);
758 row_func_mgr->setName(
"row_func_mgr");
760 auto bb_entry = BasicBlock::Create(mod->getContext(),
".entry", query_func_ptr, 0);
762 BasicBlock::Create(mod->getContext(),
".loop.preheader", query_func_ptr, 0);
763 auto bb_forbody = BasicBlock::Create(mod->getContext(),
".forbody", query_func_ptr, 0);
765 BasicBlock::Create(mod->getContext(),
"._crit_edge", query_func_ptr, 0);
766 auto bb_exit = BasicBlock::Create(mod->getContext(),
".exit", query_func_ptr, 0);
769 LoadInst* row_count =
new LoadInst(
772 row_count->setName(
"row_count");
774 LoadInst* max_matched =
new LoadInst(
778 auto crt_matched_ptr =
new AllocaInst(i32_type, 0,
"crt_matched", bb_entry);
779 auto old_total_matched_ptr =
new AllocaInst(i32_type, 0,
"old_total_matched", bb_entry);
780 CallInst*
pos_start = CallInst::Create(func_pos_start,
"", bb_entry);
781 pos_start->setCallingConv(CallingConv::C);
783 Attributes pos_start_pal;
786 CallInst*
pos_step = CallInst::Create(func_pos_step,
"", bb_entry);
787 pos_step->setCallingConv(CallingConv::C);
789 Attributes pos_step_pal;
790 pos_step->setAttributes(pos_step_pal);
792 CallInst* group_buff_idx_call = CallInst::Create(func_group_buff_idx,
"", bb_entry);
793 group_buff_idx_call->setCallingConv(CallingConv::C);
794 group_buff_idx_call->setTailCall(
true);
795 Attributes group_buff_idx_pal;
796 group_buff_idx_call->setAttributes(group_buff_idx_pal);
799 const PointerType* Ty = dyn_cast<PointerType>(group_by_buffers->getType());
802 Value* varlen_output_buffer{
nullptr};
806 auto varlen_output_buffer_gep = GetElementPtrInst::Create(
807 Ty->getPointerElementType(),
809 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 0),
812 varlen_output_buffer =
814 varlen_output_buffer_gep,
815 "varlen_output_buffer",
822 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 1),
823 "group_buff_idx_varlen_offset",
826 varlen_output_buffer =
827 ConstantPointerNull::get(Type::getInt64PtrTy(mod->getContext()));
829 CHECK(varlen_output_buffer);
831 CastInst* pos_start_i64 =
new SExtInst(
pos_start, i64_type,
"", bb_entry);
832 GetElementPtrInst* group_by_buffers_gep = GetElementPtrInst::Create(
833 Ty->getPointerElementType(), group_by_buffers,
group_buff_idx,
"", bb_entry);
835 group_by_buffers_gep,
839 col_buffer->setName(
"col_buffer");
842 llvm::ConstantInt* shared_mem_bytes_lv =
845 llvm::CallInst* result_buffer =
846 CallInst::Create(func_init_shared_mem,
847 std::vector<llvm::Value*>{col_buffer, shared_mem_bytes_lv},
851 ICmpInst* enter_or_not =
852 new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count,
"");
853 BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
856 CastInst* pos_step_i64 =
new SExtInst(
pos_step, i64_type,
"", bb_preheader);
857 BranchInst::Create(bb_forbody, bb_preheader);
861 PHINode* pos = PHINode::Create(i64_type, check_scan_limit ? 3 : 2,
"pos", bb_forbody);
863 std::vector<Value*> row_process_params;
864 row_process_params.push_back(result_buffer);
865 row_process_params.push_back(varlen_output_buffer);
866 row_process_params.push_back(crt_matched_ptr);
867 row_process_params.push_back(total_matched);
868 row_process_params.push_back(old_total_matched_ptr);
869 row_process_params.push_back(max_matched_ptr);
870 row_process_params.push_back(agg_init_val);
871 row_process_params.push_back(pos);
872 row_process_params.push_back(frag_row_off_ptr);
873 row_process_params.push_back(row_count_ptr);
874 if (hoist_literals) {
876 row_process_params.push_back(literals);
878 if (check_scan_limit) {
879 new StoreInst(ConstantInt::get(IntegerType::get(mod->getContext(), 32), 0),
884 CallInst::Create(func_row_process, row_process_params,
"", bb_forbody);
887 Attributes row_process_pal;
892 auto func_sync_warp_protected = mod->getFunction(
"sync_warp_protected");
893 CHECK(func_sync_warp_protected);
894 CallInst::Create(func_sync_warp_protected,
895 std::vector<llvm::Value*>{pos, row_count},
901 BinaryOperator::Create(Instruction::Add, pos, pos_step_i64,
"", bb_forbody);
902 ICmpInst* loop_or_exit =
903 new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count,
"");
904 if (check_scan_limit) {
910 auto filter_match = BasicBlock::Create(
911 mod->getContext(),
"filter_match", query_func_ptr, bb_crit_edge);
912 llvm::Value* new_total_matched =
914 old_total_matched_ptr,
919 BinaryOperator::CreateAdd(new_total_matched, crt_matched,
"", filter_match);
920 CHECK(new_total_matched);
921 ICmpInst* limit_not_reached =
new ICmpInst(*filter_match,
925 "limit_not_reached");
929 BinaryOperator::Create(
930 BinaryOperator::And, loop_or_exit, limit_not_reached,
"", filter_match),
932 auto filter_nomatch = BasicBlock::Create(
933 mod->getContext(),
"filter_nomatch", query_func_ptr, bb_crit_edge);
934 BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, filter_nomatch);
935 ICmpInst* crt_matched_nz =
new ICmpInst(
936 *bb_forbody, ICmpInst::ICMP_NE, crt_matched, ConstantInt::get(i32_type, 0),
"");
937 BranchInst::Create(filter_match, filter_nomatch, crt_matched_nz, bb_forbody);
938 pos->addIncoming(pos_start_i64, bb_preheader);
939 pos->addIncoming(pos_pre, filter_match);
940 pos->addIncoming(pos_pre, filter_nomatch);
942 pos->addIncoming(pos_start_i64, bb_preheader);
943 pos->addIncoming(pos_pre, bb_forbody);
944 BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
948 BranchInst::Create(bb_exit, bb_crit_edge);
951 CallInst::Create(func_write_back,
952 std::vector<Value*>{col_buffer, result_buffer, shared_mem_bytes_lv},
956 ReturnInst::Create(mod->getContext(), bb_exit);
959 pos_pre->replaceAllUsesWith(pos_inc);
962 if (verifyFunction(*query_func_ptr, &llvm::errs())) {
963 LOG(
FATAL) <<
"Generated invalid code. ";
971 const size_t aggr_col_count,
972 const bool hoist_literals,
973 const bool is_estimate_query,
975 return query_template_impl<llvm::AttributeList>(
976 mod, aggr_col_count, hoist_literals, is_estimate_query, gpu_smem_context);
980 const bool hoist_literals,
983 const bool check_scan_limit,
985 return query_group_by_template_impl<llvm::AttributeList>(mod,
std::tuple< llvm::Function *, llvm::CallInst * > query_group_by_template_impl(llvm::Module *mod, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
llvm::Function * pos_start(llvm::Module *mod)
bool hasVarlenOutput() const
size_t getSharedMemorySize() const
std::tuple< llvm::Function *, llvm::CallInst * > query_template(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
AGG_TYPE agg_func(AGG_TYPE const lhs, AGG_TYPE const rhs)
Type pointer_type(const Type pointee)
#define LLVM_ALIGN(alignment)
bool isSharedMemoryUsed() const
std::tuple< llvm::Function *, llvm::CallInst * > query_group_by_template(llvm::Module *mod, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
llvm::Function * group_buff_idx(llvm::Module *mod)
llvm::Function * pos_step(llvm::Module *mod)
llvm::Function * default_func_builder(llvm::Module *mod, const std::string &name)
bool isWarpSyncRequired(const ExecutorDeviceType) const
std::tuple< llvm::Function *, llvm::CallInst * > query_template_impl(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
llvm::Type * get_pointer_element_type(llvm::Value *value)
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)