24 #include <llvm/IR/Instructions.h>
47 return llvm::Type::getVoidTy(ctx);
59 return llvm::Type::getFloatPtrTy(ctx);
62 return llvm::Type::getDoublePtrTy(ctx);
68 return llvm::PointerType::get(llvm::PointerType::get(
get_int_type(64, ctx), 0), 0);
71 LOG(
FATAL) <<
"Argument type not supported: " <<
static_cast<int>(
type);
85 return llvm::ICmpInst::ICMP_EQ;
88 return llvm::ICmpInst::ICMP_NE;
91 LOG(
FATAL) <<
"Invalid predicate: " <<
static_cast<int>(predicate);
95 return llvm::ICmpInst::ICMP_EQ;
102 return llvm::Instruction::Add;
105 return llvm::Instruction::Mul;
108 LOG(
FATAL) <<
"Invalid binary operator: " <<
static_cast<int>(op);
112 return llvm::Instruction::Add;
119 return llvm::Instruction::Trunc;
122 return llvm::Instruction::SExt;
125 return llvm::Instruction::BitCast;
128 LOG(
FATAL) <<
"Invalid cast operator: " <<
static_cast<int>(op);
132 return llvm::Instruction::SExt;
141 llvm::Function* func,
142 llvm::Value* error_code) {
143 auto cgen_state = reduction_code.
cgen_state.get();
145 auto& ctx = cgen_state->context_;
146 const auto early_return = llvm::BasicBlock::Create(ctx,
".early_return", func, 0);
147 const auto do_reduction = llvm::BasicBlock::Create(ctx,
".do_reduction", func, 0);
148 cgen_state->ir_builder_.CreateCondBr(cond, early_return, do_reduction);
149 cgen_state->ir_builder_.SetInsertPoint(early_return);
151 if (func->getReturnType()->isVoidTy()) {
152 cgen_state->ir_builder_.CreateRetVoid();
155 cgen_state->ir_builder_.CreateRet(error_code);
158 cgen_state->ir_builder_.SetInsertPoint(do_reduction);
163 const std::unordered_map<const Value*, llvm::Value*>& m) {
165 const auto it = m.find(val);
166 CHECK(it != m.end());
175 const Function*
function,
176 const std::unordered_map<const Function*, llvm::Function*>& f) {
177 const auto it = f.find(
function);
178 CHECK(it != f.end()) << function->name() <<
" not found.";
185 const std::vector<const Value*>
args,
186 const std::unordered_map<const Value*, llvm::Value*>& m) {
189 args.begin(), args.end(), std::back_inserter(llvm_args), [&m](
const Value* value) {
196 Function* ir_reduce_loop,
198 std::unordered_map<const Value*, llvm::Value*>& m,
199 const std::unordered_map<const Function*, llvm::Function*>& f);
203 const Function*
function,
204 llvm::Function* llvm_function,
206 std::unordered_map<const Value*, llvm::Value*>& m,
207 const std::unordered_map<const Function*, llvm::Function*>& f) {
208 auto cgen_state = reduction_code.
cgen_state.get();
210 auto& ctx = cgen_state->context_;
211 for (
const auto& instr : body) {
212 const auto instr_ptr = instr.get();
213 llvm::Value* translated{
nullptr};
214 if (
auto gep = dynamic_cast<const GetElementPtr*>(instr_ptr)) {
215 translated = cgen_state->ir_builder_.CreateGEP(
217 }
else if (
auto load = dynamic_cast<const Load*>(instr_ptr)) {
218 translated = cgen_state->ir_builder_.CreateLoad(
mapped_value(load->source(), m),
220 }
else if (
auto icmp = dynamic_cast<const ICmp*>(instr_ptr)) {
221 translated = cgen_state->ir_builder_.CreateICmp(
llvm_predicate(icmp->predicate()),
225 }
else if (
auto binary_operator = dynamic_cast<const BinaryOperator*>(instr_ptr)) {
227 cgen_state->ir_builder_.CreateBinOp(
llvm_binary_op(binary_operator->op()),
230 binary_operator->label());
231 }
else if (
auto cast = dynamic_cast<const Cast*>(instr_ptr)) {
232 translated = cgen_state->ir_builder_.CreateCast(
llvm_cast_op(cast->op()),
236 }
else if (
auto ret = dynamic_cast<const Ret*>(instr_ptr)) {
238 cgen_state->ir_builder_.CreateRet(
mapped_value(ret->value(), m));
240 cgen_state->ir_builder_.CreateRetVoid();
242 }
else if (
auto call = dynamic_cast<const Call*>(instr_ptr)) {
244 const auto args = call->arguments();
245 std::transform(
args.begin(),
247 std::back_inserter(llvm_args),
249 if (call->callee()) {
250 translated = cgen_state->ir_builder_.CreateCall(
253 translated = cgen_state->emitCall(call->callee_name(),
llvm_args);
255 }
else if (
auto external_call = dynamic_cast<const ExternalCall*>(instr_ptr)) {
256 translated = cgen_state->emitExternalCall(external_call->callee_name(),
258 llvm_args(external_call->arguments(), m));
259 }
else if (
auto alloca = dynamic_cast<const Alloca*>(instr_ptr)) {
260 translated = cgen_state->ir_builder_.CreateAlloca(
264 }
else if (
auto memcpy = dynamic_cast<const MemCpy*>(instr_ptr)) {
265 cgen_state->ir_builder_.CreateMemCpy(
mapped_value(memcpy->dest(), m),
270 }
else if (
auto ret_early = dynamic_cast<const ReturnEarly*>(instr_ptr)) {
275 }
else if (
auto for_loop = dynamic_cast<const For*>(instr_ptr)) {
278 LOG(
FATAL) <<
"Instruction not supported yet";
281 const auto it_ok = m.emplace(instr_ptr, translated);
289 Function* ir_reduce_loop,
291 std::unordered_map<const Value*, llvm::Value*>& m,
292 const std::unordered_map<const Function*, llvm::Function*>& f) {
293 auto cgen_state = reduction_code.
cgen_state.get();
295 const auto bb_entry = cgen_state->ir_builder_.GetInsertBlock();
296 auto& ctx = cgen_state->context_;
297 const auto i64_type =
get_int_type(64, cgen_state->context_);
301 const auto iteration_count =
302 cgen_state->ir_builder_.CreateSub(end_index, start_index,
"iteration_count");
303 const auto upper_bound = cgen_state->ir_builder_.CreateSExt(iteration_count, i64_type);
305 llvm::BasicBlock::Create(ctx,
".exit",
mapped_function(ir_reduce_loop, f));
309 [
upper_bound](
const std::vector<llvm::Value*>& v) {
320 [cgen_state, for_loop, ir_reduce_loop, &f, &m, &reduction_code](
321 const std::vector<llvm::Value*>& iterators) {
322 const auto loop_body_bb = llvm::BasicBlock::Create(
323 cgen_state->context_,
325 cgen_state->ir_builder_.GetInsertBlock()->getParent());
326 cgen_state->ir_builder_.SetInsertPoint(loop_body_bb);
328 const auto loop_iter =
329 cgen_state->ir_builder_.CreateTrunc(iterators.back(),
331 "relative_entry_idx");
332 m.emplace(for_loop->iter(), loop_iter);
344 cgen_state->ir_builder_.SetInsertPoint(bb_entry);
345 cgen_state->ir_builder_.CreateBr(bb_loop_body);
346 cgen_state->ir_builder_.SetInsertPoint(bb_exit);
352 const auto bb_entry =
353 llvm::BasicBlock::Create(cgen_state->
context_,
".entry",
function, 0);
360 llvm::Function* llvm_function,
362 const std::unordered_map<const Function*, llvm::Function*>& f) {
363 auto cgen_state = reduction_code.
cgen_state.get();
367 std::unordered_map<const Value*, llvm::Value*> m;
368 auto llvm_arg_it = llvm_function->arg_begin();
369 for (
size_t arg_idx = 0; arg_idx <
function->arg_types().size(); ++arg_idx) {
370 llvm::Value* llvm_arg = &(*llvm_arg_it);
371 const auto it_ok = m.emplace(function->arg(arg_idx), llvm_arg);
376 for (
const auto& constant : function->constants()) {
377 llvm::Value* constant_llvm{
nullptr};
378 switch (constant->type()) {
381 cgen_state->llInt<int8_t>(
static_cast<ConstantInt*
>(constant.get())->value());
385 constant_llvm = cgen_state->llInt<int32_t>(
386 static_cast<ConstantInt*
>(constant.get())->value());
390 constant_llvm = cgen_state->llInt<int64_t>(
391 static_cast<ConstantInt*
>(constant.get())->value());
395 constant_llvm = cgen_state->llFp(
396 static_cast<float>(static_cast<ConstantFP*>(constant.get())->value()));
401 cgen_state->llFp(static_cast<ConstantFP*>(constant.get())->value());
405 LOG(
FATAL) <<
"Constant type not supported: "
406 <<
static_cast<int>(constant->type());
409 CHECK(constant_llvm);
410 const auto it_ok = m.emplace(constant.get(), constant_llvm);
413 translate_body(function->body(),
function, llvm_function, reduction_code, m, f);
DEVICE auto upper_bound(ARGS &&...args)
std::unique_ptr< CgenState > cgen_state
void create_entry_block(llvm::Function *function, CgenState *cgen_state)
void translate_for(const For *for_loop, Function *ir_reduce_loop, const ReductionCode &reduction_code, std::unordered_map< const Value *, llvm::Value * > &m, const std::unordered_map< const Function *, llvm::Function * > &f)
std::unique_ptr< Function > ir_reduce_loop
llvm::IRBuilder ir_builder_
std::vector< llvm::Value * > llvm_args(const std::vector< const Value * > args, const std::unordered_map< const Value *, llvm::Value * > &m)
llvm::Type * get_fp_type(const int width, llvm::LLVMContext &context)
llvm::ICmpInst::Predicate llvm_predicate(const ICmp::Predicate predicate)
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
void translate_function(const Function *function, llvm::Function *llvm_function, const ReductionCode &reduction_code, const std::unordered_map< const Function *, llvm::Function * > &f)
const Value * end() const
const Value * start() const
void verify_function_ir(const llvm::Function *func)
llvm::LLVMContext & context_
static llvm::BasicBlock * codegen(const std::vector< JoinLoop > &join_loops, const std::function< llvm::BasicBlock *(const std::vector< llvm::Value * > &)> &body_codegen, llvm::Value *outer_iter, llvm::BasicBlock *exit_bb, CgenState *cgen_state)
Type pointee_type(const Type pointer)
llvm::Instruction::CastOps llvm_cast_op(const Cast::CastOp op)
llvm::Type * llvm_type(const Type type, llvm::LLVMContext &ctx)
void return_early(llvm::Value *cond, const ReductionCode &reduction_code, llvm::Function *func, llvm::Value *error_code)
llvm::Value * upper_bound
void translate_body(const std::vector< std::unique_ptr< Instruction >> &body, const Function *function, llvm::Function *llvm_function, const ReductionCode &reduction_code, std::unordered_map< const Value *, llvm::Value * > &m, const std::unordered_map< const Function *, llvm::Function * > &f)
#define LLVM_MAYBE_ALIGN(alignment)
llvm::Function * mapped_function(const Function *function, const std::unordered_map< const Function *, llvm::Function * > &f)
llvm::Value * mapped_value(const Value *val, const std::unordered_map< const Value *, llvm::Value * > &m)
llvm::BinaryOperator::BinaryOps llvm_binary_op(const BinaryOperator::BinaryOp op)