OmniSciDB  c0231cc57d
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ResultSetReductionInterpreter.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
19 
20 thread_local size_t g_value_id;
21 
22 namespace {
23 
24 // Extract types of the given values.
25 std::vector<Type> get_value_types(const std::vector<const Value*>& values) {
26  std::vector<Type> value_types;
27  value_types.reserve(value_types.size());
28  std::transform(values.begin(),
29  values.end(),
30  std::back_inserter(value_types),
31  [](const Value* value) { return value->type(); });
32  return value_types;
33 }
34 
35 // For an alloca buffer, return the element size.
36 size_t get_element_size(const Type element_type) {
37  switch (element_type) {
38  case Type::Int8Ptr: {
39  return sizeof(int8_t);
40  }
41  case Type::Int64PtrPtr: {
42  return sizeof(int64_t*);
43  }
44  default: {
45  LOG(FATAL) << "Base pointer type not supported: " << static_cast<int>(element_type);
46  break;
47  }
48  }
49  return 0;
50 }
51 
52 } // namespace
53 
54 // Implements execution for all the operators. Caller is responsible for stopping
55 // evaluation when the return value is set.
57  public:
58  ReductionInterpreterImpl(const size_t executor_id,
59  const std::vector<ReductionInterpreter::EvalValue>& vars)
60  : executor_id_(executor_id), vars_(vars) {}
61 
62  std::optional<ReductionInterpreter::EvalValue> ret() const { return ret_; }
63 
64  public:
65  size_t getExecutorId() const { return executor_id_; }
66  static void runGetElementPtr(const Instruction* instruction,
67  ReductionInterpreterImpl* interpreter) {
68  CHECK(!interpreter->ret_) << "Function has already returned";
69  const auto gep = static_cast<const GetElementPtr*>(instruction);
70  const auto element_size = get_element_size(gep->base()->type());
71  const auto base = interpreter->vars_[gep->base()->id()];
72  const auto index = interpreter->vars_[gep->index()->id()];
73  auto result_ptr =
74  reinterpret_cast<const int8_t*>(base.ptr) + index.int_val * element_size;
75  interpreter->setVar(gep, ReductionInterpreter::MakeEvalValue(result_ptr));
76  }
77 
78  static void runLoad(const Instruction* instruction,
79  ReductionInterpreterImpl* interpreter) {
80  CHECK(!interpreter->ret_) << "Function has already returned";
81  const auto load = static_cast<const Load*>(instruction);
82  const auto source_type = load->source()->type();
83  CHECK(is_pointer_type(source_type));
84  const auto source = interpreter->vars_[load->source()->id()];
85  switch (source_type) {
86  case Type::Int8Ptr: {
87  const auto int_val = *reinterpret_cast<const int8_t*>(source.ptr);
88  interpreter->setVar(load, ReductionInterpreter::MakeEvalValue(int_val));
89  break;
90  }
91  case Type::Int32Ptr: {
92  const auto int_val = *reinterpret_cast<const int32_t*>(source.ptr);
93  interpreter->setVar(load, ReductionInterpreter::MakeEvalValue(int_val));
94  break;
95  }
96  case Type::Int64Ptr: {
97  const auto int_val = *reinterpret_cast<const int64_t*>(source.ptr);
98  interpreter->setVar(load, ReductionInterpreter::MakeEvalValue(int_val));
99  break;
100  }
101  case Type::FloatPtr: {
102  const auto float_val = *reinterpret_cast<const float*>(source.ptr);
103  interpreter->setVar(load, ReductionInterpreter::MakeEvalValue(float_val));
104  break;
105  }
106  case Type::DoublePtr: {
107  const auto double_val = *reinterpret_cast<const double*>(source.ptr);
108  interpreter->setVar(load, ReductionInterpreter::MakeEvalValue(double_val));
109  break;
110  }
111  case Type::Int64PtrPtr: {
112  const auto int_ptr_val = *reinterpret_cast<const int64_t* const*>(source.ptr);
113  interpreter->setVar(load, ReductionInterpreter::MakeEvalValue(int_ptr_val));
114  break;
115  }
116  default: {
117  LOG(FATAL) << "Source pointer type not supported: "
118  << static_cast<int>(source_type);
119  }
120  }
121  }
122 
123  static void runICmp(const Instruction* instruction,
124  ReductionInterpreterImpl* interpreter) {
125  CHECK(!interpreter->ret_) << "Function has already returned";
126  const auto icmp = static_cast<const ICmp*>(instruction);
127  CHECK(is_integer_type(icmp->lhs()->type()));
128  CHECK(is_integer_type(icmp->rhs()->type()));
129  const auto lhs = interpreter->vars_[icmp->lhs()->id()];
130  const auto rhs = interpreter->vars_[icmp->rhs()->id()];
131  bool result = false;
132  switch (icmp->predicate()) {
133  case ICmp::Predicate::EQ: {
134  result = lhs.int_val == rhs.int_val;
135  break;
136  }
137  case ICmp::Predicate::NE: {
138  result = lhs.int_val != rhs.int_val;
139  break;
140  }
141  default: {
142  LOG(FATAL) << "Predicate not supported: " << static_cast<int>(icmp->predicate());
143  }
144  }
145  interpreter->setVar(icmp, ReductionInterpreter::MakeEvalValue(result));
146  }
147 
148  static void runBinaryOperator(const Instruction* instruction,
149  ReductionInterpreterImpl* interpreter) {
150  CHECK(!interpreter->ret_) << "Function has already returned";
151  const auto binary_operator = static_cast<const BinaryOperator*>(instruction);
152  CHECK(is_integer_type(binary_operator->type()));
153  const auto lhs = interpreter->vars_[binary_operator->lhs()->id()];
154  const auto rhs = interpreter->vars_[binary_operator->rhs()->id()];
155  int64_t result = 0;
156  switch (binary_operator->op()) {
158  result = lhs.int_val + rhs.int_val;
159  break;
160  }
162  result = lhs.int_val * rhs.int_val;
163  break;
164  }
165  default: {
166  LOG(FATAL) << "Binary operator not supported: "
167  << static_cast<int>(binary_operator->op());
168  }
169  }
170  interpreter->setVar(binary_operator, ReductionInterpreter::MakeEvalValue(result));
171  }
172 
173  static void runCast(const Instruction* instruction,
174  ReductionInterpreterImpl* interpreter) {
175  CHECK(!interpreter->ret_) << "Function has already returned";
176  const auto cast = static_cast<const Cast*>(instruction);
177  const auto source = interpreter->vars_[cast->source()->id()];
178  // Given that evaluated values store all values as int64_t or void*, Trunc and SExt
179  // are no-op. The information about the type is already part of the destination.
180  switch (cast->op()) {
181  case Cast::CastOp::Trunc:
182  case Cast::CastOp::SExt: {
183  CHECK(is_integer_type(cast->source()->type()));
184  interpreter->setVar(cast, ReductionInterpreter::MakeEvalValue(source.int_val));
185  break;
186  }
187  case Cast::CastOp::BitCast: {
188  CHECK(is_pointer_type(cast->source()->type()));
189  interpreter->setVar(cast, ReductionInterpreter::MakeEvalValue(source.ptr));
190  break;
191  }
192  default: {
193  LOG(FATAL) << "Cast operator not supported: " << static_cast<int>(cast->op());
194  }
195  }
196  }
197 
198  static void runRet(const Instruction* instruction,
199  ReductionInterpreterImpl* interpreter) {
200  CHECK(!interpreter->ret_) << "Function has already returned";
201  const auto ret = static_cast<const Ret*>(instruction);
202  if (ret->type() == Type::Void) {
203  // Even if the returned type is void, the return value still needs to be set to
204  // something to inform the caller that it should stop evaluating.
205  interpreter->ret_ = ReductionInterpreter::EvalValue{};
206  } else {
207  interpreter->ret_ = interpreter->vars_[ret->value()->id()];
208  }
209  }
210 
211  static void runCall(const Instruction* instruction,
212  ReductionInterpreterImpl* interpreter) {
213  auto executor_id = interpreter->getExecutorId();
214  CHECK(!interpreter->ret_) << "Function has already returned";
215  const auto call = static_cast<const Call*>(instruction);
216  if (call->callee()) {
217  // Call one of the functions generated to implement reduction.
218  const auto inputs = getCallInputs(call, interpreter);
219  auto ret = ReductionInterpreter::run(executor_id, call->callee(), inputs);
220  if (call->type() != Type::Void) {
221  // Assign the returned value.
222  interpreter->setVar(call, ret);
223  }
224  } else {
225  // Call an internal runtime function.
226  const auto func_ptr = bindStub(executor_id, call);
227  const auto inputs = getCallInputs(call, interpreter);
229  func_ptr(&ret, &inputs);
230  if (call->type() != Type::Void) {
231  // Assign the returned value.
232  interpreter->setVar(call, ret);
233  }
234  }
235  return;
236  }
237 
238  static void runExternalCall(const Instruction* instruction,
239  ReductionInterpreterImpl* interpreter) {
240  auto executor_id = interpreter->getExecutorId();
241  CHECK(!interpreter->ret_) << "Function has already returned";
242  const auto external_call = static_cast<const ExternalCall*>(instruction);
243  const auto& arguments = external_call->arguments();
244  const auto argument_types = get_value_types(arguments);
245  const auto func_ptr = bindStub(executor_id, external_call);
246  const auto inputs = getCallInputs(external_call, interpreter);
248  func_ptr(&output, &inputs);
249  interpreter->setVar(external_call, output);
250  }
251 
252  static void runAlloca(const Instruction* instruction,
253  ReductionInterpreterImpl* interpreter) {
254  CHECK(!interpreter->ret_) << "Function has already returned";
255  const auto alloca = static_cast<const Alloca*>(instruction);
256  const auto element_size = get_element_size(alloca->type());
257  CHECK(is_integer_type(alloca->array_size()->type()));
258  const auto array_size = interpreter->vars_[alloca->array_size()->id()];
259  interpreter->alloca_buffers_.emplace_back(element_size * array_size.int_val);
261  eval_value.mutable_ptr = interpreter->alloca_buffers_.back().data();
262  interpreter->setVar(alloca, eval_value);
263  }
264 
265  static void runMemCpy(const Instruction* instruction,
266  ReductionInterpreterImpl* interpreter) {
267  CHECK(!interpreter->ret_) << "Function has already returned";
268  const auto memcpy = static_cast<const MemCpy*>(instruction);
269  CHECK(is_pointer_type(memcpy->dest()->type()));
270  CHECK(is_pointer_type(memcpy->source()->type()));
271  CHECK(is_integer_type(memcpy->size()->type()));
272  const auto dest = interpreter->vars_[memcpy->dest()->id()];
273  const auto source = interpreter->vars_[memcpy->source()->id()];
274  const auto size = interpreter->vars_[memcpy->size()->id()];
275  ::memcpy(dest.mutable_ptr, source.ptr, size.int_val);
276  }
277 
278  static void runReturnEarly(const Instruction* instruction,
279  ReductionInterpreterImpl* interpreter) {
280  CHECK(!interpreter->ret_) << "Function has already returned";
281  const auto ret_early = static_cast<const ReturnEarly*>(instruction);
282  CHECK(ret_early->cond()->type() == Type::Int1);
283  const auto cond = interpreter->vars_[ret_early->cond()->id()];
284 
285  auto error_code = ret_early->error_code();
286 
287  if (cond.int_val) {
288  auto rc = interpreter->vars_[error_code->id()].int_val;
289  interpreter->ret_ = ReductionInterpreter::MakeEvalValue(rc);
290  }
291  }
292 
293  static void runFor(const Instruction* instruction,
294  ReductionInterpreterImpl* interpreter) {
295  auto executor_id = interpreter->getExecutorId();
296  CHECK(!interpreter->ret_) << "Function has already returned";
297  const size_t saved_alloca_count = interpreter->alloca_buffers_.size();
298  const auto for_loop = static_cast<const For*>(instruction);
299  CHECK(is_integer_type(for_loop->start()->type()));
300  CHECK(is_integer_type(for_loop->end()->type()));
301  const auto start = interpreter->vars_[for_loop->start()->id()];
302  const auto end = interpreter->vars_[for_loop->end()->id()];
303  for (int64_t i = start.int_val; i < end.int_val; ++i) {
304  // The start and end indices are absolute, but the iteration happens from 0.
305  // Subtract the start index before setting the iterator.
306  interpreter->vars_[for_loop->iter()->id()].int_val = i - start.int_val;
307  auto ret =
308  ReductionInterpreter::run(executor_id, for_loop->body(), interpreter->vars_);
309  if (ret) {
310  interpreter->ret_ = *ret;
311  break;
312  }
313  }
314  // Pop all the alloca buffers allocated by the code in the loop.
315  interpreter->alloca_buffers_.resize(saved_alloca_count);
316  }
317 
318  private:
319  // Set the variable based on its id.
320  void setVar(const Value* var, ReductionInterpreter::EvalValue value) {
321  vars_[var->id()] = value;
322  }
323 
324  // Seed the parameters of the callee.
325  template <class Call>
326  static std::vector<ReductionInterpreter::EvalValue> getCallInputs(
327  const Call* call,
328  const ReductionInterpreterImpl* interpreter) {
329  std::vector<ReductionInterpreter::EvalValue> inputs;
330  inputs.reserve(interpreter->vars_.size());
331  for (const auto argument : call->arguments()) {
332  inputs.push_back(interpreter->vars_[argument->id()]);
333  }
334  return inputs;
335  }
336 
337  // Bind and cache a stub call.
338  template <class Call>
339  static StubGenerator::Stub bindStub(const size_t executor_id, const Call* call) {
340  const auto func_ptr =
341  call->cached_callee()
342  ? reinterpret_cast<StubGenerator::Stub>(call->cached_callee())
343  : StubGenerator::generateStub(executor_id,
344  call->callee_name(),
345  get_value_types(call->arguments()),
346  call->type(),
347  call->external());
348  CHECK(func_ptr);
349  call->set_cached_callee(reinterpret_cast<void*>(func_ptr));
350  return func_ptr;
351  }
352 
353  // Holds executor id
354  size_t executor_id_;
355  // Holds the evaluated values.
356  std::vector<ReductionInterpreter::EvalValue> vars_;
357  // Holds buffers allocated by the alloca instruction.
358  std::vector<std::vector<int8_t>> alloca_buffers_;
359  // Holds the value returned by the function.
360  std::optional<ReductionInterpreter::EvalValue> ret_ = std::nullopt;
361 };
362 
365 }
366 
368  ReductionInterpreterImpl::runLoad(this, interpreter);
369 }
370 
372  ReductionInterpreterImpl::runICmp(this, interpreter);
373 }
374 
377 }
378 
380  ReductionInterpreterImpl::runCast(this, interpreter);
381 }
382 
383 void Ret::run(ReductionInterpreterImpl* interpreter) {
384  ReductionInterpreterImpl::runRet(this, interpreter);
385 }
386 
388  ReductionInterpreterImpl::runCall(this, interpreter);
389 }
390 
393 }
394 
396  ReductionInterpreterImpl::runAlloca(this, interpreter);
397 }
398 
400  ReductionInterpreterImpl::runMemCpy(this, interpreter);
401 }
402 
404  ReductionInterpreterImpl::runReturnEarly(this, interpreter);
405 }
406 
407 void For::run(ReductionInterpreterImpl* interpreter) {
408  ReductionInterpreterImpl::runFor(this, interpreter);
409 }
410 
411 namespace {
412 
413 // Create an evaluated constant.
415  switch (constant->type()) {
416  case Type::Int8:
417  case Type::Int32:
418  case Type::Int64: {
420  static_cast<const ConstantInt*>(constant)->value());
421  }
422  case Type::Float: {
424  static_cast<float>(static_cast<const ConstantFP*>(constant)->value()));
425  }
426  case Type::Double: {
428  static_cast<const ConstantFP*>(constant)->value());
429  }
430  default: {
431  LOG(FATAL) << "Constant type not supported: " << static_cast<int>(constant->type());
432  break;
433  }
434  }
435  return {};
436 }
437 
438 } // namespace
439 
441  const size_t executor_id,
442  const Function* function,
443  const std::vector<ReductionInterpreter::EvalValue>& inputs) {
444  const auto last_id = function->body().back()->id();
445  const auto& arg_types = function->arg_types();
446  std::vector<ReductionInterpreter::EvalValue> vars(last_id + 1);
447  // Add the arguments to the variable map.
448  for (size_t i = 0; i < arg_types.size(); ++i) {
449  vars[function->arg(i)->id()] = inputs[i];
450  }
451  // Add constants to the variable map.
452  for (const auto& constant : function->constants()) {
453  vars[constant->id()] = eval_constant(constant.get());
454  }
455  const auto maybe_ret = run(executor_id, function->body(), vars);
456  CHECK(maybe_ret);
457  return *maybe_ret;
458 }
459 
460 std::optional<ReductionInterpreter::EvalValue> ReductionInterpreter::run(
461  const size_t executor_id,
462  const std::vector<std::unique_ptr<Instruction>>& body,
463  const std::vector<ReductionInterpreter::EvalValue>& vars) {
464  ReductionInterpreterImpl interp_impl(executor_id, vars);
465  for (const auto& instr : body) {
466  instr->run(&interp_impl);
467  const auto ret = interp_impl.ret();
468  if (ret) {
469  return *ret;
470  }
471  }
472  return interp_impl.ret();
473 }
static void runReturnEarly(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
bool external() const
size_t id() const
static void runFor(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
ReductionInterpreterImpl(const size_t executor_id, const std::vector< ReductionInterpreter::EvalValue > &vars)
void setVar(const Value *var, ReductionInterpreter::EvalValue value)
#define LOG(tag)
Definition: Logger.h:216
std::vector< ReductionInterpreter::EvalValue > vars_
void run(ReductionInterpreterImpl *interpreter) override
Type type() const
void run(ReductionInterpreterImpl *interpreter) override
void run(ReductionInterpreterImpl *interpreter) override
static StubGenerator::Stub bindStub(const size_t executor_id, const Call *call)
static void runBinaryOperator(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static void runCall(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static std::vector< ReductionInterpreter::EvalValue > getCallInputs(const Call *call, const ReductionInterpreterImpl *interpreter)
thread_local size_t g_value_id
static void runLoad(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static Stub generateStub(const size_t executor_id, const std::string &name, const std::vector< Type > &arg_types, const Type ret_type, const bool is_external)
const Value * source() const
void run(ReductionInterpreterImpl *interpreter) override
void run(ReductionInterpreterImpl *interpreter) override
static EvalValue MakeEvalValue(const T &val)
OUTPUT transform(INPUT const &input, FUNC const &func)
Definition: misc.h:296
const std::string & callee_name() const
ReductionInterpreter::EvalValue eval_constant(const Constant *constant)
const std::vector< const Value * > & arguments() const
std::vector< Type > get_value_types(const std::vector< const Value * > &values)
void run(ReductionInterpreterImpl *interpreter) override
void * cached_callee() const
std::optional< ReductionInterpreter::EvalValue > ret() const
ReductionInterpreter::EvalValue(*)(void *output_handle, const void *inputs_handle) Stub
static void runExternalCall(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static void runICmp(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
std::optional< ReductionInterpreter::EvalValue > ret_
#define CHECK(condition)
Definition: Logger.h:222
static void runAlloca(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
std::vector< std::vector< int8_t > > alloca_buffers_
void run(ReductionInterpreterImpl *interpreter) override
static EvalValue run(const size_t execution_id, const Function *function, const std::vector< EvalValue > &inputs)
void run(ReductionInterpreterImpl *interpreter) override
bool is_pointer_type(const Type type)
void run(ReductionInterpreterImpl *interpreter) override
static void runCast(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static void runMemCpy(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
void run(ReductionInterpreterImpl *interpreter) override
static void runGetElementPtr(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static void runRet(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
bool is_integer_type(const Type type)
void run(ReductionInterpreterImpl *interpreter) override
void set_cached_callee(void *cached_callee) const
void run(ReductionInterpreterImpl *interpreter) override
const std::vector< const Value * > & arguments() const