OmniSciDB  06b3bd477c
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ResultSetReductionInterpreter.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2019 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
19 
20 thread_local size_t g_value_id;
21 
22 namespace {
23 
24 // Extract types of the given values.
25 std::vector<Type> get_value_types(const std::vector<const Value*>& values) {
26  std::vector<Type> value_types;
27  value_types.reserve(value_types.size());
28  std::transform(values.begin(),
29  values.end(),
30  std::back_inserter(value_types),
31  [](const Value* value) { return value->type(); });
32  return value_types;
33 }
34 
35 // For an alloca buffer, return the element size.
36 size_t get_element_size(const Type element_type) {
37  switch (element_type) {
38  case Type::Int8Ptr: {
39  return sizeof(int8_t);
40  }
41  case Type::Int64PtrPtr: {
42  return sizeof(int64_t*);
43  }
44  default: {
45  LOG(FATAL) << "Base pointer type not supported: " << static_cast<int>(element_type);
46  break;
47  }
48  }
49  return 0;
50 }
51 
52 } // namespace
53 
54 // Implements execution for all the operators. Caller is responsible for stopping
55 // evaluation when the return value is set.
57  public:
58  ReductionInterpreterImpl(const std::vector<ReductionInterpreter::EvalValue>& vars)
59  : vars_(vars) {}
60 
61  std::optional<ReductionInterpreter::EvalValue> ret() const { return ret_; }
62 
63  public:
64  static void runGetElementPtr(const Instruction* instruction,
65  ReductionInterpreterImpl* interpreter) {
66  CHECK(!interpreter->ret_) << "Function has already returned";
67  const auto gep = static_cast<const GetElementPtr*>(instruction);
68  const auto element_size = get_element_size(gep->base()->type());
69  const auto base = interpreter->vars_[gep->base()->id()];
70  const auto index = interpreter->vars_[gep->index()->id()];
71  auto result_ptr =
72  reinterpret_cast<const int8_t*>(base.ptr) + index.int_val * element_size;
73  interpreter->setVar(gep, ReductionInterpreter::EvalValue{.ptr = result_ptr});
74  }
75 
76  static void runLoad(const Instruction* instruction,
77  ReductionInterpreterImpl* interpreter) {
78  CHECK(!interpreter->ret_) << "Function has already returned";
79  const auto load = static_cast<const Load*>(instruction);
80  const auto source_type = load->source()->type();
81  CHECK(is_pointer_type(source_type));
82  const auto source = interpreter->vars_[load->source()->id()];
83  switch (source_type) {
84  case Type::Int8Ptr: {
85  const auto int_val = *reinterpret_cast<const int8_t*>(source.ptr);
86  interpreter->setVar(load, ReductionInterpreter::EvalValue{.int_val = int_val});
87  break;
88  }
89  case Type::Int32Ptr: {
90  const auto int_val = *reinterpret_cast<const int32_t*>(source.ptr);
91  interpreter->setVar(load, ReductionInterpreter::EvalValue{.int_val = int_val});
92  break;
93  }
94  case Type::Int64Ptr: {
95  const auto int_val = *reinterpret_cast<const int64_t*>(source.ptr);
96  interpreter->setVar(load, ReductionInterpreter::EvalValue{.int_val = int_val});
97  break;
98  }
99  case Type::FloatPtr: {
100  const auto float_val = *reinterpret_cast<const float*>(source.ptr);
101  interpreter->setVar(load,
103  break;
104  }
105  case Type::DoublePtr: {
106  const auto double_val = *reinterpret_cast<const double*>(source.ptr);
107  interpreter->setVar(load,
109  break;
110  }
111  case Type::Int64PtrPtr: {
112  const auto int_ptr_val = *reinterpret_cast<const int64_t* const*>(source.ptr);
113  interpreter->setVar(load, ReductionInterpreter::EvalValue{.ptr = int_ptr_val});
114  break;
115  }
116  default: {
117  LOG(FATAL) << "Source pointer type not supported: "
118  << static_cast<int>(source_type);
119  }
120  }
121  }
122 
123  static void runICmp(const Instruction* instruction,
124  ReductionInterpreterImpl* interpreter) {
125  CHECK(!interpreter->ret_) << "Function has already returned";
126  const auto icmp = static_cast<const ICmp*>(instruction);
127  CHECK(is_integer_type(icmp->lhs()->type()));
128  CHECK(is_integer_type(icmp->rhs()->type()));
129  const auto lhs = interpreter->vars_[icmp->lhs()->id()];
130  const auto rhs = interpreter->vars_[icmp->rhs()->id()];
131  bool result = false;
132  switch (icmp->predicate()) {
133  case ICmp::Predicate::EQ: {
134  result = lhs.int_val == rhs.int_val;
135  break;
136  }
137  case ICmp::Predicate::NE: {
138  result = lhs.int_val != rhs.int_val;
139  break;
140  }
141  default: {
142  LOG(FATAL) << "Predicate not supported: " << static_cast<int>(icmp->predicate());
143  }
144  }
145  interpreter->setVar(icmp, ReductionInterpreter::EvalValue{.int_val = result});
146  }
147 
148  static void runBinaryOperator(const Instruction* instruction,
149  ReductionInterpreterImpl* interpreter) {
150  CHECK(!interpreter->ret_) << "Function has already returned";
151  const auto binary_operator = static_cast<const BinaryOperator*>(instruction);
152  CHECK(is_integer_type(binary_operator->type()));
153  const auto lhs = interpreter->vars_[binary_operator->lhs()->id()];
154  const auto rhs = interpreter->vars_[binary_operator->rhs()->id()];
155  int64_t result = 0;
156  switch (binary_operator->op()) {
158  result = lhs.int_val + rhs.int_val;
159  break;
160  }
162  result = lhs.int_val * rhs.int_val;
163  break;
164  }
165  default: {
166  LOG(FATAL) << "Binary operator not supported: "
167  << static_cast<int>(binary_operator->op());
168  }
169  }
170  interpreter->setVar(binary_operator,
172  }
173 
174  static void runCast(const Instruction* instruction,
175  ReductionInterpreterImpl* interpreter) {
176  CHECK(!interpreter->ret_) << "Function has already returned";
177  const auto cast = static_cast<const Cast*>(instruction);
178  const auto source = interpreter->vars_[cast->source()->id()];
179  // Given that evaluated values store all values as int64_t or void*, Trunc and SExt
180  // are no-op. The information about the type is already part of the destination.
181  switch (cast->op()) {
182  case Cast::CastOp::Trunc:
183  case Cast::CastOp::SExt: {
184  CHECK(is_integer_type(cast->source()->type()));
185  interpreter->setVar(cast,
186  ReductionInterpreter::EvalValue{.int_val = source.int_val});
187  break;
188  }
189  case Cast::CastOp::BitCast: {
190  CHECK(is_pointer_type(cast->source()->type()));
191  interpreter->setVar(cast, ReductionInterpreter::EvalValue{.ptr = source.ptr});
192  break;
193  }
194  default: {
195  LOG(FATAL) << "Cast operator not supported: " << static_cast<int>(cast->op());
196  }
197  }
198  }
199 
200  static void runRet(const Instruction* instruction,
201  ReductionInterpreterImpl* interpreter) {
202  CHECK(!interpreter->ret_) << "Function has already returned";
203  const auto ret = static_cast<const Ret*>(instruction);
204  if (ret->type() == Type::Void) {
205  // Even if the returned type is void, the return value still needs to be set to
206  // something to inform the caller that it should stop evaluating.
207  interpreter->ret_ = ReductionInterpreter::EvalValue{};
208  } else {
209  interpreter->ret_ = interpreter->vars_[ret->value()->id()];
210  }
211  }
212 
213  static void runCall(const Instruction* instruction,
214  ReductionInterpreterImpl* interpreter) {
215  CHECK(!interpreter->ret_) << "Function has already returned";
216  const auto call = static_cast<const Call*>(instruction);
217  if (call->callee()) {
218  // Call one of the functions generated to implement reduction.
219  const auto inputs = getCallInputs(call, interpreter);
220  auto ret = ReductionInterpreter::run(call->callee(), inputs);
221  if (call->type() != Type::Void) {
222  // Assign the returned value.
223  interpreter->setVar(call, ret);
224  }
225  } else {
226  // Call an internal runtime function.
227  const auto func_ptr = bindStub(call);
228  const auto inputs = getCallInputs(call, interpreter);
230  func_ptr(&ret, &inputs);
231  if (call->type() != Type::Void) {
232  // Assign the returned value.
233  interpreter->setVar(call, ret);
234  }
235  }
236  return;
237  }
238 
239  static void runExternalCall(const Instruction* instruction,
240  ReductionInterpreterImpl* interpreter) {
241  CHECK(!interpreter->ret_) << "Function has already returned";
242  const auto external_call = static_cast<const ExternalCall*>(instruction);
243  const auto& arguments = external_call->arguments();
244  const auto argument_types = get_value_types(arguments);
245  const auto func_ptr = bindStub(external_call);
246  const auto inputs = getCallInputs(external_call, interpreter);
248  func_ptr(&output, &inputs);
249  interpreter->setVar(external_call, output);
250  }
251 
252  static void runAlloca(const Instruction* instruction,
253  ReductionInterpreterImpl* interpreter) {
254  CHECK(!interpreter->ret_) << "Function has already returned";
255  const auto alloca = static_cast<const Alloca*>(instruction);
256  const auto element_size = get_element_size(alloca->type());
257  CHECK(is_integer_type(alloca->array_size()->type()));
258  const auto array_size = interpreter->vars_[alloca->array_size()->id()];
259  interpreter->alloca_buffers_.emplace_back(element_size * array_size.int_val);
260  interpreter->setVar(alloca,
262  .mutable_ptr = interpreter->alloca_buffers_.back().data()});
263  }
264 
265  static void runMemCpy(const Instruction* instruction,
266  ReductionInterpreterImpl* interpreter) {
267  CHECK(!interpreter->ret_) << "Function has already returned";
268  const auto memcpy = static_cast<const MemCpy*>(instruction);
269  CHECK(is_pointer_type(memcpy->dest()->type()));
270  CHECK(is_pointer_type(memcpy->source()->type()));
271  CHECK(is_integer_type(memcpy->size()->type()));
272  const auto dest = interpreter->vars_[memcpy->dest()->id()];
273  const auto source = interpreter->vars_[memcpy->source()->id()];
274  const auto size = interpreter->vars_[memcpy->size()->id()];
275  ::memcpy(dest.mutable_ptr, source.ptr, size.int_val);
276  }
277 
278  static void runReturnEarly(const Instruction* instruction,
279  ReductionInterpreterImpl* interpreter) {
280  CHECK(!interpreter->ret_) << "Function has already returned";
281  const auto ret_early = static_cast<const ReturnEarly*>(instruction);
282  CHECK(ret_early->cond()->type() == Type::Int1);
283  const auto cond = interpreter->vars_[ret_early->cond()->id()];
284 
285  auto error_code = ret_early->error_code();
286 
287  if (cond.int_val) {
288  auto rc = interpreter->vars_[error_code->id()].int_val;
289  interpreter->ret_ = ReductionInterpreter::EvalValue{.int_val = rc};
290  }
291  }
292 
293  static void runFor(const Instruction* instruction,
294  ReductionInterpreterImpl* interpreter) {
295  CHECK(!interpreter->ret_) << "Function has already returned";
296  const size_t saved_alloca_count = interpreter->alloca_buffers_.size();
297  const auto for_loop = static_cast<const For*>(instruction);
298  CHECK(is_integer_type(for_loop->start()->type()));
299  CHECK(is_integer_type(for_loop->end()->type()));
300  const auto start = interpreter->vars_[for_loop->start()->id()];
301  const auto end = interpreter->vars_[for_loop->end()->id()];
302  for (int64_t i = start.int_val; i < end.int_val; ++i) {
303  // The start and end indices are absolute, but the iteration happens from 0.
304  // Subtract the start index before setting the iterator.
305  interpreter->vars_[for_loop->iter()->id()] = {.int_val = i - start.int_val};
306  auto ret = ReductionInterpreter::run(for_loop->body(), interpreter->vars_);
307  if (ret) {
308  interpreter->ret_ = *ret;
309  break;
310  }
311  }
312  // Pop all the alloca buffers allocated by the code in the loop.
313  interpreter->alloca_buffers_.resize(saved_alloca_count);
314  }
315 
316  private:
317  // Set the variable based on its id.
318  void setVar(const Value* var, ReductionInterpreter::EvalValue value) {
319  vars_[var->id()] = value;
320  }
321 
322  // Seed the parameters of the callee.
323  template <class Call>
324  static std::vector<ReductionInterpreter::EvalValue> getCallInputs(
325  const Call* call,
326  const ReductionInterpreterImpl* interpreter) {
327  std::vector<ReductionInterpreter::EvalValue> inputs;
328  inputs.reserve(interpreter->vars_.size());
329  for (const auto argument : call->arguments()) {
330  inputs.push_back(interpreter->vars_[argument->id()]);
331  }
332  return inputs;
333  }
334 
335  // Bind and cache a stub call.
336  template <class Call>
337  static StubGenerator::Stub bindStub(const Call* call) {
338  const auto func_ptr =
339  call->cached_callee()
340  ? reinterpret_cast<StubGenerator::Stub>(call->cached_callee())
342  get_value_types(call->arguments()),
343  call->type(),
344  call->external());
345  CHECK(func_ptr);
346  call->set_cached_callee(reinterpret_cast<void*>(func_ptr));
347  return func_ptr;
348  }
349 
350  // Holds the evaluated values.
351  std::vector<ReductionInterpreter::EvalValue> vars_;
352  // Holds buffers allocated by the alloca instruction.
353  std::vector<std::vector<int8_t>> alloca_buffers_;
354  // Holds the value returned by the function.
355  std::optional<ReductionInterpreter::EvalValue> ret_ = std::nullopt;
356 };
357 
360 }
361 
363  ReductionInterpreterImpl::runLoad(this, interpreter);
364 }
365 
367  ReductionInterpreterImpl::runICmp(this, interpreter);
368 }
369 
372 }
373 
375  ReductionInterpreterImpl::runCast(this, interpreter);
376 }
377 
378 void Ret::run(ReductionInterpreterImpl* interpreter) {
379  ReductionInterpreterImpl::runRet(this, interpreter);
380 }
381 
383  ReductionInterpreterImpl::runCall(this, interpreter);
384 }
385 
388 }
389 
391  ReductionInterpreterImpl::runAlloca(this, interpreter);
392 }
393 
395  ReductionInterpreterImpl::runMemCpy(this, interpreter);
396 }
397 
399  ReductionInterpreterImpl::runReturnEarly(this, interpreter);
400 }
401 
402 void For::run(ReductionInterpreterImpl* interpreter) {
403  ReductionInterpreterImpl::runFor(this, interpreter);
404 }
405 
406 namespace {
407 
408 // Create an evaluated constant.
410  switch (constant->type()) {
411  case Type::Int8:
412  case Type::Int32:
413  case Type::Int64: {
414  return {.int_val = static_cast<const ConstantInt*>(constant)->value()};
415  }
416  case Type::Float: {
417  return {.float_val =
418  static_cast<float>(static_cast<const ConstantFP*>(constant)->value())};
419  }
420  case Type::Double: {
421  return {.double_val = static_cast<const ConstantFP*>(constant)->value()};
422  }
423  default: {
424  LOG(FATAL) << "Constant type not supported: " << static_cast<int>(constant->type());
425  break;
426  }
427  }
428  return {};
429 }
430 
431 } // namespace
432 
434  const Function* function,
435  const std::vector<ReductionInterpreter::EvalValue>& inputs) {
436  const auto last_id = function->body().back()->id();
437  const auto& arg_types = function->arg_types();
438  std::vector<ReductionInterpreter::EvalValue> vars(last_id + 1);
439  // Add the arguments to the variable map.
440  for (size_t i = 0; i < arg_types.size(); ++i) {
441  vars[function->arg(i)->id()] = inputs[i];
442  }
443  // Add constants to the variable map.
444  for (const auto& constant : function->constants()) {
445  vars[constant->id()] = eval_constant(constant.get());
446  }
447  const auto maybe_ret = run(function->body(), vars);
448  CHECK(maybe_ret);
449  return *maybe_ret;
450 }
451 
452 std::optional<ReductionInterpreter::EvalValue> ReductionInterpreter::run(
453  const std::vector<std::unique_ptr<Instruction>>& body,
454  const std::vector<ReductionInterpreter::EvalValue>& vars) {
455  ReductionInterpreterImpl interp_impl(vars);
456  for (const auto& instr : body) {
457  instr->run(&interp_impl);
458  const auto ret = interp_impl.ret();
459  if (ret) {
460  return *ret;
461  }
462  }
463  return interp_impl.ret();
464 }
static void runReturnEarly(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
ReductionInterpreterImpl(const std::vector< ReductionInterpreter::EvalValue > &vars)
bool external() const
size_t id() const
static void runFor(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
void setVar(const Value *var, ReductionInterpreter::EvalValue value)
#define LOG(tag)
Definition: Logger.h:188
std::vector< ReductionInterpreter::EvalValue > vars_
void run(ReductionInterpreterImpl *interpreter) override
Type type() const
void run(ReductionInterpreterImpl *interpreter) override
void run(ReductionInterpreterImpl *interpreter) override
static void runBinaryOperator(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static void runCall(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static std::vector< ReductionInterpreter::EvalValue > getCallInputs(const Call *call, const ReductionInterpreterImpl *interpreter)
thread_local size_t g_value_id
static void runLoad(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
const Value * source() const
static EvalValue run(const Function *function, const std::vector< EvalValue > &inputs)
CHECK(cgen_state)
void run(ReductionInterpreterImpl *interpreter) override
void run(ReductionInterpreterImpl *interpreter) override
const int8_t const int64_t const uint64_t const int32_t const int64_t int64_t uint32_t const int64_t int32_t * error_code
static Stub generateStub(const std::string &name, const std::vector< Type > &arg_types, const Type ret_type, const bool is_external)
const std::string & callee_name() const
ReductionInterpreter::EvalValue eval_constant(const Constant *constant)
const std::vector< const Value * > & arguments() const
std::vector< Type > get_value_types(const std::vector< const Value * > &values)
void run(ReductionInterpreterImpl *interpreter) override
void * cached_callee() const
std::optional< ReductionInterpreter::EvalValue > ret() const
ReductionInterpreter::EvalValue(*)(void *output_handle, const void *inputs_handle) Stub
static void runExternalCall(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static void runICmp(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
std::optional< ReductionInterpreter::EvalValue > ret_
static void runAlloca(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
std::vector< std::vector< int8_t > > alloca_buffers_
void run(ReductionInterpreterImpl *interpreter) override
void run(ReductionInterpreterImpl *interpreter) override
bool is_pointer_type(const Type type)
void run(ReductionInterpreterImpl *interpreter) override
static void runCast(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static void runMemCpy(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
void run(ReductionInterpreterImpl *interpreter) override
static void runGetElementPtr(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
static void runRet(const Instruction *instruction, ReductionInterpreterImpl *interpreter)
bool is_integer_type(const Type type)
void run(ReductionInterpreterImpl *interpreter) override
void set_cached_callee(void *cached_callee) const
void run(ReductionInterpreterImpl *interpreter) override
static StubGenerator::Stub bindStub(const Call *call)
const std::vector< const Value * > & arguments() const