26 std::vector<Type> value_types;
27 value_types.reserve(value_types.size());
28 std::transform(values.begin(),
30 std::back_inserter(value_types),
31 [](
const Value* value) {
return value->type(); });
37 switch (element_type) {
39 return sizeof(int8_t);
42 return sizeof(int64_t*);
45 LOG(
FATAL) <<
"Base pointer type not supported: " <<
static_cast<int>(element_type);
61 std::optional<ReductionInterpreter::EvalValue>
ret()
const {
return ret_; }
66 CHECK(!interpreter->
ret_) <<
"Function has already returned";
67 const auto gep =
static_cast<const GetElementPtr*
>(instruction);
69 const auto base = interpreter->
vars_[gep->base()->id()];
70 const auto index = interpreter->
vars_[gep->index()->id()];
72 reinterpret_cast<const int8_t*
>(base.ptr) + index.int_val * element_size;
78 CHECK(!interpreter->
ret_) <<
"Function has already returned";
79 const auto load =
static_cast<const Load*
>(instruction);
80 const auto source_type = load->
source()->
type();
82 const auto source = interpreter->
vars_[load->source()->id()];
83 switch (source_type) {
85 const auto int_val = *
reinterpret_cast<const int8_t*
>(source.ptr);
90 const auto int_val = *
reinterpret_cast<const int32_t*
>(source.ptr);
95 const auto int_val = *
reinterpret_cast<const int64_t*
>(source.ptr);
100 const auto float_val = *
reinterpret_cast<const float*
>(source.ptr);
106 const auto double_val = *
reinterpret_cast<const double*
>(source.ptr);
112 const auto int_ptr_val = *
reinterpret_cast<const int64_t* const*
>(source.ptr);
117 LOG(
FATAL) <<
"Source pointer type not supported: "
118 <<
static_cast<int>(source_type);
125 CHECK(!interpreter->
ret_) <<
"Function has already returned";
126 const auto icmp =
static_cast<const ICmp*
>(instruction);
129 const auto lhs = interpreter->
vars_[icmp->lhs()->id()];
130 const auto rhs = interpreter->
vars_[icmp->rhs()->id()];
132 switch (icmp->predicate()) {
134 result = lhs.int_val == rhs.int_val;
138 result = lhs.int_val != rhs.int_val;
142 LOG(
FATAL) <<
"Predicate not supported: " <<
static_cast<int>(icmp->predicate());
150 CHECK(!interpreter->
ret_) <<
"Function has already returned";
151 const auto binary_operator =
static_cast<const BinaryOperator*
>(instruction);
153 const auto lhs = interpreter->
vars_[binary_operator->lhs()->id()];
154 const auto rhs = interpreter->
vars_[binary_operator->rhs()->id()];
156 switch (binary_operator->op()) {
158 result = lhs.int_val + rhs.int_val;
162 result = lhs.int_val * rhs.int_val;
166 LOG(
FATAL) <<
"Binary operator not supported: "
167 <<
static_cast<int>(binary_operator->op());
170 interpreter->
setVar(binary_operator,
176 CHECK(!interpreter->
ret_) <<
"Function has already returned";
177 const auto cast =
static_cast<const Cast*
>(instruction);
178 const auto source = interpreter->
vars_[cast->source()->id()];
181 switch (cast->op()) {
195 LOG(
FATAL) <<
"Cast operator not supported: " <<
static_cast<int>(cast->op());
202 CHECK(!interpreter->
ret_) <<
"Function has already returned";
203 const auto ret =
static_cast<const Ret*
>(instruction);
209 interpreter->
ret_ = interpreter->
vars_[ret->value()->id()];
215 CHECK(!interpreter->
ret_) <<
"Function has already returned";
216 const auto call =
static_cast<const Call*
>(instruction);
217 if (call->callee()) {
219 const auto inputs = getCallInputs(call, interpreter);
223 interpreter->
setVar(call, ret);
227 const auto func_ptr = bindStub(call);
228 const auto inputs = getCallInputs(call, interpreter);
230 func_ptr(&ret, &inputs);
233 interpreter->
setVar(call, ret);
241 CHECK(!interpreter->
ret_) <<
"Function has already returned";
242 const auto external_call =
static_cast<const ExternalCall*
>(instruction);
243 const auto& arguments = external_call->
arguments();
245 const auto func_ptr = bindStub(external_call);
246 const auto inputs = getCallInputs(external_call, interpreter);
248 func_ptr(&output, &inputs);
249 interpreter->
setVar(external_call, output);
254 CHECK(!interpreter->
ret_) <<
"Function has already returned";
255 const auto alloca =
static_cast<const Alloca*
>(instruction);
258 const auto array_size = interpreter->
vars_[alloca->array_size()->id()];
259 interpreter->
alloca_buffers_.emplace_back(element_size * array_size.int_val);
260 interpreter->
setVar(alloca,
267 CHECK(!interpreter->
ret_) <<
"Function has already returned";
268 const auto memcpy =
static_cast<const MemCpy*
>(instruction);
272 const auto dest = interpreter->
vars_[memcpy->dest()->id()];
273 const auto source = interpreter->
vars_[memcpy->source()->id()];
274 const auto size = interpreter->
vars_[memcpy->size()->id()];
275 ::memcpy(
dest.mutable_ptr, source.ptr, size.int_val);
280 CHECK(!interpreter->
ret_) <<
"Function has already returned";
281 const auto ret_early =
static_cast<const ReturnEarly*
>(instruction);
283 const auto cond = interpreter->
vars_[ret_early->cond()->id()];
285 auto error_code = ret_early->error_code();
288 auto rc = interpreter->
vars_[error_code->id()].int_val;
295 CHECK(!interpreter->
ret_) <<
"Function has already returned";
297 const auto for_loop =
static_cast<const For*
>(instruction);
300 const auto start = interpreter->
vars_[for_loop->start()->id()];
301 const auto end = interpreter->
vars_[for_loop->end()->id()];
302 for (int64_t i = start.int_val; i < end.int_val; ++i) {
305 interpreter->
vars_[for_loop->iter()->id()] = {.int_val = i - start.int_val};
308 interpreter->
ret_ = *ret;
319 vars_[var->
id()] = value;
323 template <
class Call>
327 std::vector<ReductionInterpreter::EvalValue> inputs;
328 inputs.reserve(interpreter->
vars_.size());
329 for (
const auto argument : call->
arguments()) {
330 inputs.push_back(interpreter->
vars_[argument->id()]);
336 template <
class Call>
338 const auto func_ptr =
351 std::vector<ReductionInterpreter::EvalValue>
vars_;
355 std::optional<ReductionInterpreter::EvalValue> ret_ = std::nullopt;
410 switch (constant->
type()) {
414 return {.int_val =
static_cast<const ConstantInt*
>(constant)->value()};
418 static_cast<float>(
static_cast<const ConstantFP*
>(constant)->value())};
421 return {.double_val =
static_cast<const ConstantFP*
>(constant)->value()};
424 LOG(
FATAL) <<
"Constant type not supported: " <<
static_cast<int>(constant->
type());
434 const Function*
function,
435 const std::vector<ReductionInterpreter::EvalValue>& inputs) {
436 const auto last_id =
function->body().back()->id();
437 const auto& arg_types =
function->arg_types();
438 std::vector<ReductionInterpreter::EvalValue> vars(last_id + 1);
440 for (
size_t i = 0; i < arg_types.size(); ++i) {
441 vars[
function->arg(i)->id()] = inputs[i];
444 for (
const auto& constant : function->constants()) {
447 const auto maybe_ret =
run(function->body(), vars);
453 const std::vector<std::unique_ptr<Instruction>>& body,
454 const std::vector<ReductionInterpreter::EvalValue>& vars) {
456 for (
const auto& instr : body) {
457 instr->run(&interp_impl);
458 const auto ret = interp_impl.ret();
463 return interp_impl.ret();
static void runReturnEarly(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
ReductionInterpreterImpl(const std::vector< ReductionInterpreter::EvalValue > &vars)
static void runFor(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
void setVar(const Value *var, ReductionInterpreter::EvalValue value)
std::vector< ReductionInterpreter::EvalValue > vars_
void run(ReductionInterpreterImpl *interpreter) override
size_t get_element_size(const Type element_type)
void run(ReductionInterpreterImpl *interpreter) override
void run(ReductionInterpreterImpl *interpreter) override
static void runBinaryOperator(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static void runCall(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static std::vector< ReductionInterpreter::EvalValue > getCallInputs(const Call *call, const ReductionInterpreterImpl *interpreter)
thread_local size_t g_value_id
static void runLoad(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
const Value * source() const
static EvalValue run(const Function *function, const std::vector< EvalValue > &inputs)
void run(ReductionInterpreterImpl *interpreter) override
void run(ReductionInterpreterImpl *interpreter) override
static Stub generateStub(const std::string &name, const std::vector< Type > &arg_types, const Type ret_type, const bool is_external)
const std::string & callee_name() const
ReductionInterpreter::EvalValue eval_constant(const Constant *constant)
const std::vector< const Value * > & arguments() const
std::vector< Type > get_value_types(const std::vector< const Value * > &values)
void run(ReductionInterpreterImpl *interpreter) override
void * cached_callee() const
std::optional< ReductionInterpreter::EvalValue > ret() const
ReductionInterpreter::EvalValue(*)(void *output_handle, const void *inputs_handle) Stub
static void runExternalCall(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static void runICmp(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
std::optional< ReductionInterpreter::EvalValue > ret_
static void runAlloca(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
std::vector< std::vector< int8_t > > alloca_buffers_
void run(ReductionInterpreterImpl *interpreter) override
void run(ReductionInterpreterImpl *interpreter) override
bool is_pointer_type(const Type type)
void run(ReductionInterpreterImpl *interpreter) override
static void runCast(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static void runMemCpy(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
void run(ReductionInterpreterImpl *interpreter) override
static void runGetElementPtr(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static void runRet(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
bool is_integer_type(const Type type)
void run(ReductionInterpreterImpl *interpreter) override
void set_cached_callee(void *cached_callee) const
void run(ReductionInterpreterImpl *interpreter) override
static StubGenerator::Stub bindStub(const Call *call)
const std::vector< const Value * > & arguments() const