21 #include <llvm/IR/Constants.h>
22 #include <llvm/IR/IRBuilder.h>
23 #include <llvm/IR/Instructions.h>
24 #include <llvm/IR/Verifier.h>
34 auto type = value->getType();
38 return pointer_type->getPointerElementType();
41 template <
class Attributes>
45 std::vector<Type*> func_args;
46 FunctionType* func_type = FunctionType::get(
47 IntegerType::get(mod->getContext(), 32),
51 auto func_ptr = mod->getFunction(name);
53 func_ptr = Function::Create(
55 GlobalValue::ExternalLinkage,
58 func_ptr->setCallingConv(CallingConv::C);
63 SmallVector<Attributes, 4> Attrs;
66 #if 14 <= LLVM_VERSION_MAJOR
67 AttrBuilder B(mod->getContext());
71 PAS = Attributes::get(mod->getContext(), ~0U, B);
75 func_pal = Attributes::get(mod->getContext(), Attrs);
77 func_ptr->setAttributes(func_pal);
82 template <
class Attributes>
84 return default_func_builder<Attributes>(mod,
"pos_start");
87 template <
class Attributes>
89 return default_func_builder<Attributes>(mod,
"group_buff_idx");
92 template <
class Attributes>
96 std::vector<Type*> func_args;
97 FunctionType* func_type = FunctionType::get(
98 IntegerType::get(mod->getContext(), 32),
102 auto func_ptr = mod->getFunction(
"pos_step");
104 func_ptr = Function::Create(
106 GlobalValue::ExternalLinkage,
109 func_ptr->setCallingConv(CallingConv::C);
114 SmallVector<Attributes, 4> Attrs;
117 #if 14 <= LLVM_VERSION_MAJOR
118 AttrBuilder B(mod->getContext());
122 PAS = Attributes::get(mod->getContext(), ~0U, B);
125 Attrs.push_back(PAS);
126 func_pal = Attributes::get(mod->getContext(), Attrs);
128 func_ptr->setAttributes(func_pal);
133 template <
class Attributes>
135 const size_t aggr_col_count,
136 const bool hoist_literals) {
137 using namespace llvm;
139 std::vector<Type*> func_args;
140 auto i8_type = IntegerType::get(mod->getContext(), 8);
141 auto i32_type = IntegerType::get(mod->getContext(), 32);
142 auto i64_type = IntegerType::get(mod->getContext(), 64);
143 auto pi32_type = PointerType::get(i32_type, 0);
144 auto pi64_type = PointerType::get(i64_type, 0);
146 if (aggr_col_count) {
147 for (
size_t i = 0; i < aggr_col_count; ++i) {
148 func_args.push_back(pi64_type);
151 func_args.push_back(pi64_type);
152 func_args.push_back(pi64_type);
153 func_args.push_back(pi32_type);
154 func_args.push_back(pi32_type);
155 func_args.push_back(pi32_type);
156 func_args.push_back(pi32_type);
159 func_args.push_back(pi64_type);
161 func_args.push_back(i64_type);
162 func_args.push_back(pi64_type);
163 func_args.push_back(pi64_type);
164 if (hoist_literals) {
165 func_args.push_back(PointerType::get(i8_type, 0));
167 FunctionType* func_type = FunctionType::get(
172 std::string func_name{
"row_process"};
173 auto func_ptr = mod->getFunction(func_name);
176 func_ptr = Function::Create(
178 GlobalValue::ExternalLinkage,
181 func_ptr->setCallingConv(CallingConv::C);
185 SmallVector<Attributes, 4> Attrs;
188 #if 14 <= LLVM_VERSION_MAJOR
189 AttrBuilder B(mod->getContext());
193 PAS = Attributes::get(mod->getContext(), ~0U, B);
196 Attrs.push_back(PAS);
197 func_pal = Attributes::get(mod->getContext(), Attrs);
199 func_ptr->setAttributes(func_pal);
207 template <
class Attributes>
210 const size_t aggr_col_count,
211 const bool hoist_literals,
212 const bool is_estimate_query,
214 using namespace llvm;
216 auto func_pos_start = pos_start<Attributes>(mod);
217 CHECK(func_pos_start);
218 auto func_pos_step = pos_step<Attributes>(mod);
219 CHECK(func_pos_step);
220 auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
221 CHECK(func_group_buff_idx);
222 auto func_row_process = row_process<Attributes>(
223 mod, is_estimate_query ? 1 : aggr_col_count, hoist_literals);
224 CHECK(func_row_process);
226 auto i8_type = IntegerType::get(mod->getContext(), 8);
227 auto i32_type = IntegerType::get(mod->getContext(), 32);
228 auto i64_type = IntegerType::get(mod->getContext(), 64);
229 auto pi8_type = PointerType::get(i8_type, 0);
230 auto ppi8_type = PointerType::get(pi8_type, 0);
231 auto pi32_type = PointerType::get(i32_type, 0);
232 auto pi64_type = PointerType::get(i64_type, 0);
233 auto ppi64_type = PointerType::get(pi64_type, 0);
235 std::vector<Type*> query_args;
236 query_args.push_back(ppi8_type);
237 if (hoist_literals) {
238 query_args.push_back(pi8_type);
240 query_args.push_back(pi64_type);
241 query_args.push_back(pi64_type);
242 query_args.push_back(pi32_type);
244 query_args.push_back(pi64_type);
245 query_args.push_back(ppi64_type);
246 query_args.push_back(i32_type);
247 query_args.push_back(pi64_type);
248 query_args.push_back(pi32_type);
249 query_args.push_back(pi32_type);
251 FunctionType* query_func_type = FunctionType::get(
252 Type::getVoidTy(mod->getContext()),
256 std::string query_template_name{
"query_template"};
257 auto query_func_ptr = mod->getFunction(query_template_name);
258 CHECK(!query_func_ptr);
260 query_func_ptr = Function::Create(
262 GlobalValue::ExternalLinkage,
265 query_func_ptr->setCallingConv(CallingConv::C);
267 Attributes query_func_pal;
269 SmallVector<Attributes, 4> Attrs;
272 #if 14 <= LLVM_VERSION_MAJOR
273 AttrBuilder B(mod->getContext());
277 B.addAttribute(Attribute::NoCapture);
278 PAS = Attributes::get(mod->getContext(), 1U, B);
281 Attrs.push_back(PAS);
283 #if 14 <= LLVM_VERSION_MAJOR
284 AttrBuilder B(mod->getContext());
288 B.addAttribute(Attribute::NoCapture);
289 PAS = Attributes::get(mod->getContext(), 2U, B);
292 Attrs.push_back(PAS);
295 #if 14 <= LLVM_VERSION_MAJOR
296 AttrBuilder B(mod->getContext());
300 B.addAttribute(Attribute::NoCapture);
301 Attrs.push_back(Attributes::get(mod->getContext(), 3U, B));
305 #if 14 <= LLVM_VERSION_MAJOR
306 AttrBuilder B(mod->getContext());
310 B.addAttribute(Attribute::NoCapture);
311 Attrs.push_back(Attributes::get(mod->getContext(), 4U, B));
314 Attrs.push_back(PAS);
316 query_func_pal = Attributes::get(mod->getContext(), Attrs);
318 query_func_ptr->setAttributes(query_func_pal);
320 Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
321 Value* byte_stream = &*query_arg_it;
322 byte_stream->setName(
"byte_stream");
323 Value* literals{
nullptr};
324 if (hoist_literals) {
325 literals = &*(++query_arg_it);
326 literals->setName(
"literals");
328 Value* row_count_ptr = &*(++query_arg_it);
329 row_count_ptr->setName(
"row_count_ptr");
330 Value* frag_row_off_ptr = &*(++query_arg_it);
331 frag_row_off_ptr->setName(
"frag_row_off_ptr");
332 Value* max_matched_ptr = &*(++query_arg_it);
333 max_matched_ptr->setName(
"max_matched_ptr");
334 Value* agg_init_val = &*(++query_arg_it);
335 agg_init_val->setName(
"agg_init_val");
336 Value* out = &*(++query_arg_it);
338 Value* frag_idx = &*(++query_arg_it);
339 frag_idx->setName(
"frag_idx");
340 Value* join_hash_tables = &*(++query_arg_it);
341 join_hash_tables->setName(
"join_hash_tables");
342 Value* total_matched = &*(++query_arg_it);
343 total_matched->setName(
"total_matched");
344 Value* error_code = &*(++query_arg_it);
345 error_code->setName(
"error_code");
347 auto bb_entry = BasicBlock::Create(mod->getContext(),
".entry", query_func_ptr, 0);
349 BasicBlock::Create(mod->getContext(),
".loop.preheader", query_func_ptr, 0);
350 auto bb_forbody = BasicBlock::Create(mod->getContext(),
".for.body", query_func_ptr, 0);
352 BasicBlock::Create(mod->getContext(),
"._crit_edge", query_func_ptr, 0);
353 auto bb_exit = BasicBlock::Create(mod->getContext(),
".exit", query_func_ptr, 0);
356 std::vector<Value*> result_ptr_vec;
357 llvm::CallInst* smem_output_buffer{
nullptr};
358 if (!is_estimate_query) {
359 for (
size_t i = 0; i < aggr_col_count; ++i) {
360 auto result_ptr =
new AllocaInst(i64_type, 0,
"result", bb_entry);
362 result_ptr_vec.push_back(result_ptr);
365 auto init_smem_func = mod->getFunction(
"init_shared_mem");
366 CHECK(init_smem_func);
369 smem_output_buffer = CallInst::Create(
371 std::vector<llvm::Value*>{
373 llvm::ConstantInt::get(i32_type, aggr_col_count *
sizeof(int64_t))},
385 row_count->setName(
"row_count");
386 std::vector<Value*> agg_init_val_vec;
387 if (!is_estimate_query) {
388 for (
size_t i = 0; i < aggr_col_count; ++i) {
389 auto idx_lv = ConstantInt::get(i32_type, i);
390 auto agg_init_gep = GetElementPtrInst::CreateInBounds(
391 agg_init_val->getType()->getPointerElementType(),
396 auto agg_init_val =
new LoadInst(
399 agg_init_val_vec.push_back(agg_init_val);
400 auto init_val_st =
new StoreInst(agg_init_val, result_ptr_vec[i],
false, bb_entry);
405 CallInst*
pos_start = CallInst::Create(func_pos_start,
"pos_start", bb_entry);
406 pos_start->setCallingConv(CallingConv::C);
408 Attributes pos_start_pal;
411 CallInst*
pos_step = CallInst::Create(func_pos_step,
"pos_step", bb_entry);
412 pos_step->setCallingConv(CallingConv::C);
414 Attributes pos_step_pal;
415 pos_step->setAttributes(pos_step_pal);
418 if (!is_estimate_query) {
419 group_buff_idx = CallInst::Create(func_group_buff_idx,
"group_buff_idx", bb_entry);
422 Attributes group_buff_idx_pal;
426 CastInst* pos_start_i64 =
new SExtInst(
pos_start, i64_type,
"", bb_entry);
427 ICmpInst* enter_or_not =
428 new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count,
"");
429 BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
432 CastInst* pos_step_i64 =
new SExtInst(
pos_step, i64_type,
"", bb_preheader);
433 BranchInst::Create(bb_forbody, bb_preheader);
437 PHINode* pos = PHINode::Create(i64_type, 2,
"pos", bb_forbody);
438 pos->addIncoming(pos_start_i64, bb_preheader);
439 pos->addIncoming(pos_inc_pre, bb_forbody);
441 std::vector<Value*> row_process_params;
442 row_process_params.insert(
443 row_process_params.end(), result_ptr_vec.begin(), result_ptr_vec.end());
444 if (is_estimate_query) {
445 row_process_params.push_back(
448 row_process_params.push_back(agg_init_val);
449 row_process_params.push_back(pos);
450 row_process_params.push_back(frag_row_off_ptr);
451 row_process_params.push_back(row_count_ptr);
452 if (hoist_literals) {
454 row_process_params.push_back(literals);
457 CallInst::Create(func_row_process, row_process_params,
"", bb_forbody);
460 Attributes row_process_pal;
464 BinaryOperator::CreateNSW(Instruction::Add, pos, pos_step_i64,
"", bb_forbody);
465 ICmpInst* loop_or_exit =
466 new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count,
"");
467 BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
470 std::vector<Instruction*> result_vec_pre;
471 if (!is_estimate_query) {
472 for (
size_t i = 0; i < aggr_col_count; ++i) {
479 result_vec_pre.push_back(
result);
483 BranchInst::Create(bb_exit, bb_crit_edge);
497 if (!is_estimate_query) {
498 std::vector<PHINode*> result_vec;
499 for (int64_t i = aggr_col_count - 1; i >= 0; --i) {
501 PHINode::Create(IntegerType::get(mod->getContext(), 64), 2,
"", bb_exit);
502 result->addIncoming(result_vec_pre[i], bb_crit_edge);
503 result->addIncoming(agg_init_val_vec[i], bb_entry);
504 result_vec.insert(result_vec.begin(),
result);
507 for (
size_t i = 0; i < aggr_col_count; ++i) {
508 auto col_idx = ConstantInt::get(i32_type, i);
510 auto target_addr = GetElementPtrInst::CreateInBounds(
511 smem_output_buffer->getType()->getPointerElementType(),
518 auto agg_func = mod->getFunction(
"agg_sum_shared");
521 agg_func, std::vector<llvm::Value*>{target_addr, result_vec[i]},
"", bb_exit);
523 auto out_gep = GetElementPtrInst::CreateInBounds(
524 out->getType()->getPointerElementType(), out, col_idx,
"", bb_exit);
528 auto slot_idx = BinaryOperator::CreateAdd(
530 BinaryOperator::CreateMul(frag_idx,
pos_step,
"", bb_exit),
533 auto target_addr = GetElementPtrInst::CreateInBounds(
534 col_buffer->getType()->getPointerElementType(),
539 StoreInst* result_st =
new StoreInst(result_vec[i], target_addr,
false, bb_exit);
545 auto sync_thread_func = mod->getFunction(
"sync_threadblock");
546 CHECK(sync_thread_func);
547 CallInst::Create(sync_thread_func, std::vector<llvm::Value*>{},
"", bb_exit);
548 auto reduce_smem_to_gmem_func = mod->getFunction(
"write_back_non_grouped_agg");
549 CHECK(reduce_smem_to_gmem_func);
553 for (
size_t i = 0; i < aggr_col_count; i++) {
555 GetElementPtrInst::CreateInBounds(out->getType()->getPointerElementType(),
557 ConstantInt::get(i32_type, i),
566 reduce_smem_to_gmem_func,
567 std::vector<llvm::Value*>{
568 smem_output_buffer, gmem_output_buffer, ConstantInt::get(i32_type, i)},
575 ReturnInst::Create(mod->getContext(), bb_exit);
578 pos_inc_pre->replaceAllUsesWith(pos_inc);
581 if (verifyFunction(*query_func_ptr)) {
582 LOG(
FATAL) <<
"Generated invalid code. ";
588 template <
class Attributes>
591 const bool hoist_literals,
594 const bool check_scan_limit,
599 using namespace llvm;
601 auto func_pos_start = pos_start<Attributes>(mod);
602 CHECK(func_pos_start);
603 auto func_pos_step = pos_step<Attributes>(mod);
604 CHECK(func_pos_step);
605 auto func_group_buff_idx = group_buff_idx<Attributes>(mod);
606 CHECK(func_group_buff_idx);
607 auto func_row_process = row_process<Attributes>(mod, 0, hoist_literals);
608 CHECK(func_row_process);
610 ? mod->getFunction(
"init_shared_mem")
611 : mod->getFunction(
"init_shared_mem_nop");
612 CHECK(func_init_shared_mem);
614 auto func_write_back = mod->getFunction(
"write_back_nop");
615 CHECK(func_write_back);
617 auto i32_type = IntegerType::get(mod->getContext(), 32);
618 auto i64_type = IntegerType::get(mod->getContext(), 64);
619 auto pi8_type = PointerType::get(IntegerType::get(mod->getContext(), 8), 0);
620 auto pi32_type = PointerType::get(i32_type, 0);
621 auto pi64_type = PointerType::get(i64_type, 0);
622 auto ppi64_type = PointerType::get(pi64_type, 0);
623 auto ppi8_type = PointerType::get(pi8_type, 0);
625 std::vector<Type*> query_args;
626 query_args.push_back(ppi8_type);
627 if (hoist_literals) {
628 query_args.push_back(pi8_type);
630 query_args.push_back(pi64_type);
631 query_args.push_back(pi64_type);
632 query_args.push_back(pi32_type);
633 query_args.push_back(pi64_type);
635 query_args.push_back(ppi64_type);
636 query_args.push_back(i32_type);
637 query_args.push_back(pi64_type);
638 query_args.push_back(pi32_type);
639 query_args.push_back(pi32_type);
641 FunctionType* query_func_type = FunctionType::get(
642 Type::getVoidTy(mod->getContext()),
646 std::string query_name{
"query_group_by_template"};
647 auto query_func_ptr = mod->getFunction(query_name);
648 CHECK(!query_func_ptr);
650 query_func_ptr = Function::Create(
652 GlobalValue::ExternalLinkage,
653 "query_group_by_template",
656 query_func_ptr->setCallingConv(CallingConv::C);
658 Attributes query_func_pal;
660 SmallVector<Attributes, 4> Attrs;
663 #if 14 <= LLVM_VERSION_MAJOR
664 AttrBuilder B(mod->getContext());
668 B.addAttribute(Attribute::ReadNone);
669 B.addAttribute(Attribute::NoCapture);
670 PAS = Attributes::get(mod->getContext(), 1U, B);
673 Attrs.push_back(PAS);
675 #if 14 <= LLVM_VERSION_MAJOR
676 AttrBuilder B(mod->getContext());
680 B.addAttribute(Attribute::ReadOnly);
681 B.addAttribute(Attribute::NoCapture);
682 PAS = Attributes::get(mod->getContext(), 2U, B);
685 Attrs.push_back(PAS);
687 #if 14 <= LLVM_VERSION_MAJOR
688 AttrBuilder B(mod->getContext());
692 B.addAttribute(Attribute::ReadNone);
693 B.addAttribute(Attribute::NoCapture);
694 PAS = Attributes::get(mod->getContext(), 3U, B);
697 Attrs.push_back(PAS);
699 #if 14 <= LLVM_VERSION_MAJOR
700 AttrBuilder B(mod->getContext());
704 B.addAttribute(Attribute::ReadOnly);
705 B.addAttribute(Attribute::NoCapture);
706 PAS = Attributes::get(mod->getContext(), 4U, B);
709 Attrs.push_back(PAS);
711 #if 14 <= LLVM_VERSION_MAJOR
712 AttrBuilder B(mod->getContext());
716 B.addAttribute(Attribute::UWTable);
717 PAS = Attributes::get(mod->getContext(), ~0U, B);
720 Attrs.push_back(PAS);
722 query_func_pal = Attributes::get(mod->getContext(), Attrs);
724 query_func_ptr->setAttributes(query_func_pal);
726 Function::arg_iterator query_arg_it = query_func_ptr->arg_begin();
727 Value* byte_stream = &*query_arg_it;
728 byte_stream->setName(
"byte_stream");
729 Value* literals{
nullptr};
730 if (hoist_literals) {
731 literals = &*(++query_arg_it);
733 literals->setName(
"literals");
735 Value* row_count_ptr = &*(++query_arg_it);
736 row_count_ptr->setName(
"row_count_ptr");
737 Value* frag_row_off_ptr = &*(++query_arg_it);
738 frag_row_off_ptr->setName(
"frag_row_off_ptr");
739 Value* max_matched_ptr = &*(++query_arg_it);
740 max_matched_ptr->setName(
"max_matched_ptr");
741 Value* agg_init_val = &*(++query_arg_it);
742 agg_init_val->setName(
"agg_init_val");
743 Value* group_by_buffers = &*(++query_arg_it);
744 group_by_buffers->setName(
"group_by_buffers");
745 Value* frag_idx = &*(++query_arg_it);
746 frag_idx->setName(
"frag_idx");
747 Value* join_hash_tables = &*(++query_arg_it);
748 join_hash_tables->setName(
"join_hash_tables");
749 Value* total_matched = &*(++query_arg_it);
750 total_matched->setName(
"total_matched");
751 Value* error_code = &*(++query_arg_it);
752 error_code->setName(
"error_code");
754 auto bb_entry = BasicBlock::Create(mod->getContext(),
".entry", query_func_ptr, 0);
756 BasicBlock::Create(mod->getContext(),
".loop.preheader", query_func_ptr, 0);
757 auto bb_forbody = BasicBlock::Create(mod->getContext(),
".forbody", query_func_ptr, 0);
759 BasicBlock::Create(mod->getContext(),
"._crit_edge", query_func_ptr, 0);
760 auto bb_exit = BasicBlock::Create(mod->getContext(),
".exit", query_func_ptr, 0);
763 LoadInst* row_count =
new LoadInst(
766 row_count->setName(
"row_count");
768 LoadInst* max_matched =
new LoadInst(
772 auto crt_matched_ptr =
new AllocaInst(i32_type, 0,
"crt_matched", bb_entry);
773 auto old_total_matched_ptr =
new AllocaInst(i32_type, 0,
"old_total_matched", bb_entry);
774 CallInst*
pos_start = CallInst::Create(func_pos_start,
"", bb_entry);
775 pos_start->setCallingConv(CallingConv::C);
777 Attributes pos_start_pal;
780 CallInst*
pos_step = CallInst::Create(func_pos_step,
"", bb_entry);
781 pos_step->setCallingConv(CallingConv::C);
783 Attributes pos_step_pal;
784 pos_step->setAttributes(pos_step_pal);
786 CallInst* group_buff_idx_call = CallInst::Create(func_group_buff_idx,
"", bb_entry);
787 group_buff_idx_call->setCallingConv(CallingConv::C);
788 group_buff_idx_call->setTailCall(
true);
789 Attributes group_buff_idx_pal;
790 group_buff_idx_call->setAttributes(group_buff_idx_pal);
793 const PointerType* Ty = dyn_cast<PointerType>(group_by_buffers->getType());
796 Value* varlen_output_buffer{
nullptr};
800 auto varlen_output_buffer_gep = GetElementPtrInst::Create(
801 Ty->getPointerElementType(),
803 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 0),
806 varlen_output_buffer =
808 varlen_output_buffer_gep,
809 "varlen_output_buffer",
816 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 1),
817 "group_buff_idx_varlen_offset",
820 varlen_output_buffer =
821 ConstantPointerNull::get(Type::getInt64PtrTy(mod->getContext()));
823 CHECK(varlen_output_buffer);
825 CastInst* pos_start_i64 =
new SExtInst(
pos_start, i64_type,
"", bb_entry);
826 GetElementPtrInst* group_by_buffers_gep = GetElementPtrInst::Create(
827 Ty->getPointerElementType(), group_by_buffers,
group_buff_idx,
"", bb_entry);
829 group_by_buffers_gep,
833 col_buffer->setName(
"col_buffer");
836 llvm::ConstantInt* shared_mem_bytes_lv =
839 llvm::CallInst* result_buffer =
840 CallInst::Create(func_init_shared_mem,
841 std::vector<llvm::Value*>{col_buffer, shared_mem_bytes_lv},
845 ICmpInst* enter_or_not =
846 new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count,
"");
847 BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
850 CastInst* pos_step_i64 =
new SExtInst(
pos_step, i64_type,
"", bb_preheader);
851 BranchInst::Create(bb_forbody, bb_preheader);
855 PHINode* pos = PHINode::Create(i64_type, check_scan_limit ? 3 : 2,
"pos", bb_forbody);
857 std::vector<Value*> row_process_params;
858 row_process_params.push_back(result_buffer);
859 row_process_params.push_back(varlen_output_buffer);
860 row_process_params.push_back(crt_matched_ptr);
861 row_process_params.push_back(total_matched);
862 row_process_params.push_back(old_total_matched_ptr);
863 row_process_params.push_back(max_matched_ptr);
864 row_process_params.push_back(agg_init_val);
865 row_process_params.push_back(pos);
866 row_process_params.push_back(frag_row_off_ptr);
867 row_process_params.push_back(row_count_ptr);
868 if (hoist_literals) {
870 row_process_params.push_back(literals);
872 if (check_scan_limit) {
873 new StoreInst(ConstantInt::get(IntegerType::get(mod->getContext(), 32), 0),
878 CallInst::Create(func_row_process, row_process_params,
"", bb_forbody);
881 Attributes row_process_pal;
886 auto func_sync_warp_protected = mod->getFunction(
"sync_warp_protected");
887 CHECK(func_sync_warp_protected);
888 CallInst::Create(func_sync_warp_protected,
889 std::vector<llvm::Value*>{pos, row_count},
895 BinaryOperator::Create(Instruction::Add, pos, pos_step_i64,
"", bb_forbody);
896 ICmpInst* loop_or_exit =
897 new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count,
"");
898 if (check_scan_limit) {
904 auto filter_match = BasicBlock::Create(
905 mod->getContext(),
"filter_match", query_func_ptr, bb_crit_edge);
906 llvm::Value* new_total_matched =
908 old_total_matched_ptr,
913 BinaryOperator::CreateAdd(new_total_matched, crt_matched,
"", filter_match);
914 CHECK(new_total_matched);
915 ICmpInst* limit_not_reached =
new ICmpInst(*filter_match,
919 "limit_not_reached");
923 BinaryOperator::Create(
924 BinaryOperator::And, loop_or_exit, limit_not_reached,
"", filter_match),
926 auto filter_nomatch = BasicBlock::Create(
927 mod->getContext(),
"filter_nomatch", query_func_ptr, bb_crit_edge);
928 BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, filter_nomatch);
929 ICmpInst* crt_matched_nz =
new ICmpInst(
930 *bb_forbody, ICmpInst::ICMP_NE, crt_matched, ConstantInt::get(i32_type, 0),
"");
931 BranchInst::Create(filter_match, filter_nomatch, crt_matched_nz, bb_forbody);
932 pos->addIncoming(pos_start_i64, bb_preheader);
933 pos->addIncoming(pos_pre, filter_match);
934 pos->addIncoming(pos_pre, filter_nomatch);
936 pos->addIncoming(pos_start_i64, bb_preheader);
937 pos->addIncoming(pos_pre, bb_forbody);
938 BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
942 BranchInst::Create(bb_exit, bb_crit_edge);
945 CallInst::Create(func_write_back,
946 std::vector<Value*>{col_buffer, result_buffer, shared_mem_bytes_lv},
950 ReturnInst::Create(mod->getContext(), bb_exit);
953 pos_pre->replaceAllUsesWith(pos_inc);
956 if (verifyFunction(*query_func_ptr, &llvm::errs())) {
957 LOG(
FATAL) <<
"Generated invalid code. ";
965 const size_t aggr_col_count,
966 const bool hoist_literals,
967 const bool is_estimate_query,
969 return query_template_impl<llvm::AttributeList>(
970 mod, aggr_col_count, hoist_literals, is_estimate_query, gpu_smem_context);
974 const bool hoist_literals,
977 const bool check_scan_limit,
979 return query_group_by_template_impl<llvm::AttributeList>(mod,
std::tuple< llvm::Function *, llvm::CallInst * > query_group_by_template_impl(llvm::Module *mod, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
llvm::Function * pos_start(llvm::Module *mod)
bool hasVarlenOutput() const
size_t getSharedMemorySize() const
std::tuple< llvm::Function *, llvm::CallInst * > query_template(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
Type pointer_type(const Type pointee)
#define LLVM_ALIGN(alignment)
bool isSharedMemoryUsed() const
std::tuple< llvm::Function *, llvm::CallInst * > query_group_by_template(llvm::Module *mod, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
llvm::Function * group_buff_idx(llvm::Module *mod)
llvm::Function * pos_step(llvm::Module *mod)
llvm::Function * default_func_builder(llvm::Module *mod, const std::string &name)
bool isWarpSyncRequired(const ExecutorDeviceType) const
std::tuple< llvm::Function *, llvm::CallInst * > query_template_impl(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
llvm::Type * get_pointer_element_type(llvm::Value *value)
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)