OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ResultSetReductionCodegen.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include "IRCodegenUtils.h"
21 #include "ResultSetReductionJIT.h"
22 #include "ResultSetReductionOps.h"
23 
24 #include <llvm/IR/Instructions.h>
25 
26 llvm::Type* llvm_type(const Type type, llvm::LLVMContext& ctx) {
27  switch (type) {
28  case Type::Int1: {
29  return get_int_type(1, ctx);
30  }
31  case Type::Int8: {
32  return get_int_type(8, ctx);
33  }
34  case Type::Int32: {
35  return get_int_type(32, ctx);
36  }
37  case Type::Int64: {
38  return get_int_type(64, ctx);
39  }
40  case Type::Float: {
41  return get_fp_type(32, ctx);
42  }
43  case Type::Double: {
44  return get_fp_type(64, ctx);
45  }
46  case Type::Void: {
47  return llvm::Type::getVoidTy(ctx);
48  }
49  case Type::Int8Ptr: {
50  return llvm::PointerType::get(get_int_type(8, ctx), 0);
51  }
52  case Type::Int32Ptr: {
53  return llvm::PointerType::get(get_int_type(32, ctx), 0);
54  }
55  case Type::Int64Ptr: {
56  return llvm::PointerType::get(get_int_type(64, ctx), 0);
57  }
58  case Type::FloatPtr: {
59  return llvm::Type::getFloatPtrTy(ctx);
60  }
61  case Type::DoublePtr: {
62  return llvm::Type::getDoublePtrTy(ctx);
63  }
64  case Type::VoidPtr: {
65  return llvm::PointerType::get(get_int_type(8, ctx), 0);
66  }
67  case Type::Int64PtrPtr: {
68  return llvm::PointerType::get(llvm::PointerType::get(get_int_type(64, ctx), 0), 0);
69  }
70  default: {
71  LOG(FATAL) << "Argument type not supported: " << static_cast<int>(type);
72  break;
73  }
74  }
75  UNREACHABLE();
76  return nullptr;
77 }
78 
79 namespace {
80 
81 // Convert an IR predicate to the corresponding LLVM one.
82 llvm::ICmpInst::Predicate llvm_predicate(const ICmp::Predicate predicate) {
83  switch (predicate) {
84  case ICmp::Predicate::EQ: {
85  return llvm::ICmpInst::ICMP_EQ;
86  }
87  case ICmp::Predicate::NE: {
88  return llvm::ICmpInst::ICMP_NE;
89  }
90  default: {
91  LOG(FATAL) << "Invalid predicate: " << static_cast<int>(predicate);
92  }
93  }
94  UNREACHABLE();
95  return llvm::ICmpInst::ICMP_EQ;
96 }
97 
98 // Convert an IR binary operator type to the corresponding LLVM one.
99 llvm::BinaryOperator::BinaryOps llvm_binary_op(const BinaryOperator::BinaryOp op) {
100  switch (op) {
102  return llvm::Instruction::Add;
103  }
105  return llvm::Instruction::Mul;
106  }
107  default: {
108  LOG(FATAL) << "Invalid binary operator: " << static_cast<int>(op);
109  }
110  }
111  UNREACHABLE();
112  return llvm::Instruction::Add;
113 }
114 
115 // Convert an IR cast operator type to the corresponding LLVM one.
116 llvm::Instruction::CastOps llvm_cast_op(const Cast::CastOp op) {
117  switch (op) {
118  case Cast::CastOp::Trunc: {
119  return llvm::Instruction::Trunc;
120  }
121  case Cast::CastOp::SExt: {
122  return llvm::Instruction::SExt;
123  }
124  case Cast::CastOp::BitCast: {
125  return llvm::Instruction::BitCast;
126  }
127  default: {
128  LOG(FATAL) << "Invalid cast operator: " << static_cast<int>(op);
129  }
130  }
131  UNREACHABLE();
132  return llvm::Instruction::SExt;
133 }
134 
135 // Emit an early return from a function when the provided 'cond' is true, which the caller
136 // code can use when entries are empty or the watchdog is triggered. For functions which
137 // return void, the specified error code is ignored. For functions which return an
138 // integer, the error code is returned.
139 void return_early(llvm::Value* cond,
140  const ReductionCode& reduction_code,
141  llvm::Function* func,
142  llvm::Value* error_code) {
143  auto cgen_state = reduction_code.cgen_state;
144  AUTOMATIC_IR_METADATA(cgen_state);
145  auto& ctx = cgen_state->context_;
146  const auto early_return = llvm::BasicBlock::Create(ctx, ".early_return", func, 0);
147  const auto do_reduction = llvm::BasicBlock::Create(ctx, ".do_reduction", func, 0);
148  cgen_state->ir_builder_.CreateCondBr(cond, early_return, do_reduction);
149  cgen_state->ir_builder_.SetInsertPoint(early_return);
150 
151  if (func->getReturnType()->isVoidTy()) {
152  cgen_state->ir_builder_.CreateRetVoid();
153  } else {
154  CHECK(error_code);
155  cgen_state->ir_builder_.CreateRet(error_code);
156  }
157 
158  cgen_state->ir_builder_.SetInsertPoint(do_reduction);
159 }
160 
161 // Returns the corresponding LLVM value for the given IR value.
162 llvm::Value* mapped_value(const Value* val,
163  const std::unordered_map<const Value*, llvm::Value*>& m) {
164  if (val) {
165  const auto it = m.find(val);
166  CHECK(it != m.end());
167  return it->second;
168  } else {
169  return nullptr;
170  }
171 }
172 
173 // Returns the corresponding LLVM function for the given IR function.
174 llvm::Function* mapped_function(
175  const Function* function,
176  const std::unordered_map<const Function*, llvm::Function*>& f) {
177  const auto it = f.find(function);
178  CHECK(it != f.end()) << function->name() << " not found.";
179  return it->second;
180 }
181 
182 // Given a list of IR values and the mapping, return the list of corresponding LLVM IR
183 // values.
184 std::vector<llvm::Value*> llvm_args(
185  const std::vector<const Value*> args,
186  const std::unordered_map<const Value*, llvm::Value*>& m) {
187  std::vector<llvm::Value*> llvm_args;
189  args.begin(), args.end(), std::back_inserter(llvm_args), [&m](const Value* value) {
190  return mapped_value(value, m);
191  });
192  return llvm_args;
193 }
194 
195 void translate_for(const For* for_loop,
196  Function* ir_reduce_loop,
197  const ReductionCode& reduction_code,
198  std::unordered_map<const Value*, llvm::Value*>& m,
199  const std::unordered_map<const Function*, llvm::Function*>& f);
200 
201 // Translate a list of instructions to LLVM IR.
202 void translate_body(const std::vector<std::unique_ptr<Instruction>>& body,
203  const Function* function,
204  llvm::Function* llvm_function,
205  const ReductionCode& reduction_code,
206  std::unordered_map<const Value*, llvm::Value*>& m,
207  const std::unordered_map<const Function*, llvm::Function*>& f) {
208  auto cgen_state = reduction_code.cgen_state;
209  AUTOMATIC_IR_METADATA(cgen_state);
210  auto& ctx = cgen_state->context_;
211  for (const auto& instr : body) {
212  const auto instr_ptr = instr.get();
213  llvm::Value* translated{nullptr};
214  if (auto gep = dynamic_cast<const GetElementPtr*>(instr_ptr)) {
215  auto* base = mapped_value(gep->base(), m);
216  translated = cgen_state->ir_builder_.CreateGEP(
217  base->getType()->getScalarType()->getPointerElementType(),
218  base,
219  mapped_value(gep->index(), m),
220  gep->label());
221  } else if (auto load = dynamic_cast<const Load*>(instr_ptr)) {
222  auto* value = mapped_value(load->source(), m);
223  translated = cgen_state->ir_builder_.CreateLoad(
224  value->getType()->getPointerElementType(), value, load->label());
225  } else if (auto icmp = dynamic_cast<const ICmp*>(instr_ptr)) {
226  translated = cgen_state->ir_builder_.CreateICmp(llvm_predicate(icmp->predicate()),
227  mapped_value(icmp->lhs(), m),
228  mapped_value(icmp->rhs(), m),
229  icmp->label());
230  } else if (auto binary_operator = dynamic_cast<const BinaryOperator*>(instr_ptr)) {
231  translated =
232  cgen_state->ir_builder_.CreateBinOp(llvm_binary_op(binary_operator->op()),
233  mapped_value(binary_operator->lhs(), m),
234  mapped_value(binary_operator->rhs(), m),
235  binary_operator->label());
236  } else if (auto cast = dynamic_cast<const Cast*>(instr_ptr)) {
237  translated = cgen_state->ir_builder_.CreateCast(llvm_cast_op(cast->op()),
238  mapped_value(cast->source(), m),
239  llvm_type(cast->type(), ctx),
240  cast->label());
241  } else if (auto ret = dynamic_cast<const Ret*>(instr_ptr)) {
242  if (ret->value()) {
243  cgen_state->ir_builder_.CreateRet(mapped_value(ret->value(), m));
244  } else {
245  cgen_state->ir_builder_.CreateRetVoid();
246  }
247  } else if (auto call = dynamic_cast<const Call*>(instr_ptr)) {
248  std::vector<llvm::Value*> llvm_args;
249  const auto args = call->arguments();
250  std::transform(args.begin(),
251  args.end(),
252  std::back_inserter(llvm_args),
253  [&m](const Value* value) { return mapped_value(value, m); });
254  if (call->callee()) {
255  translated = cgen_state->ir_builder_.CreateCall(
256  mapped_function(call->callee(), f), llvm_args, call->label());
257  } else {
258  translated = cgen_state->emitCall(call->callee_name(), llvm_args);
259  }
260  } else if (auto external_call = dynamic_cast<const ExternalCall*>(instr_ptr)) {
261  translated = cgen_state->emitExternalCall(external_call->callee_name(),
262  llvm_type(external_call->type(), ctx),
263  llvm_args(external_call->arguments(), m));
264  } else if (auto alloca = dynamic_cast<const Alloca*>(instr_ptr)) {
265  translated = cgen_state->ir_builder_.CreateAlloca(
266  llvm_type(pointee_type(alloca->type()), ctx),
267  mapped_value(alloca->array_size(), m),
268  alloca->label());
269  } else if (auto memcpy = dynamic_cast<const MemCpy*>(instr_ptr)) {
270  cgen_state->ir_builder_.CreateMemCpy(mapped_value(memcpy->dest(), m),
271  LLVM_MAYBE_ALIGN(0),
272  mapped_value(memcpy->source(), m),
273  LLVM_MAYBE_ALIGN(0),
274  mapped_value(memcpy->size(), m));
275  } else if (auto ret_early = dynamic_cast<const ReturnEarly*>(instr_ptr)) {
276  return_early(mapped_value(ret_early->cond(), m),
277  reduction_code,
278  llvm_function,
279  mapped_value(ret_early->error_code(), m));
280  } else if (auto for_loop = dynamic_cast<const For*>(instr_ptr)) {
281  translate_for(for_loop, reduction_code.ir_reduce_loop.get(), reduction_code, m, f);
282  } else {
283  LOG(FATAL) << "Instruction not supported yet";
284  }
285  if (translated) {
286  const auto it_ok = m.emplace(instr_ptr, translated);
287  CHECK(it_ok.second);
288  }
289  }
290 }
291 
292 // Translate a loop to LLVM IR, using existing loop construction facilities.
293 void translate_for(const For* for_loop,
294  Function* ir_reduce_loop,
295  const ReductionCode& reduction_code,
296  std::unordered_map<const Value*, llvm::Value*>& m,
297  const std::unordered_map<const Function*, llvm::Function*>& f) {
298  auto cgen_state = reduction_code.cgen_state;
299  AUTOMATIC_IR_METADATA(cgen_state);
300  const auto bb_entry = cgen_state->ir_builder_.GetInsertBlock();
301  auto& ctx = cgen_state->context_;
302  const auto i64_type = get_int_type(64, cgen_state->context_);
303  const auto end_index = mapped_value(for_loop->end(), m);
304  const auto start_index = mapped_value(for_loop->start(), m);
305  // The start and end indices are absolute. Subtract the start index from the iterator.
306  const auto iteration_count =
307  cgen_state->ir_builder_.CreateSub(end_index, start_index, "iteration_count");
308  const auto upper_bound = cgen_state->ir_builder_.CreateSExt(iteration_count, i64_type);
309  const auto bb_exit =
310  llvm::BasicBlock::Create(ctx, ".exit", mapped_function(ir_reduce_loop, f));
311  JoinLoop join_loop(
314  [upper_bound](const std::vector<llvm::Value*>& v) {
315  JoinLoopDomain domain{{0}};
316  domain.upper_bound = upper_bound;
317  return domain;
318  },
319  nullptr,
320  nullptr,
321  nullptr,
322  nullptr,
323  false,
324  "reduction_loop");
325  const auto bb_loop_body = JoinLoop::codegen(
326  {join_loop},
327  [cgen_state, for_loop, ir_reduce_loop, &f, &m, &reduction_code](
328  const std::vector<llvm::Value*>& iterators) {
329  const auto loop_body_bb = llvm::BasicBlock::Create(
330  cgen_state->context_,
331  ".loop_body",
332  cgen_state->ir_builder_.GetInsertBlock()->getParent());
333  cgen_state->ir_builder_.SetInsertPoint(loop_body_bb);
334  // Make the iterator the same type as start and end indices (32-bit integer).
335  const auto loop_iter =
336  cgen_state->ir_builder_.CreateTrunc(iterators.back(),
337  get_int_type(32, cgen_state->context_),
338  "relative_entry_idx");
339  m.emplace(for_loop->iter(), loop_iter);
340  translate_body(for_loop->body(),
341  ir_reduce_loop,
342  mapped_function(ir_reduce_loop, f),
343  reduction_code,
344  m,
345  f);
346  return loop_body_bb;
347  },
348  nullptr,
349  bb_exit,
350  cgen_state);
351  cgen_state->ir_builder_.SetInsertPoint(bb_entry);
352  cgen_state->ir_builder_.CreateBr(bb_loop_body);
353  cgen_state->ir_builder_.SetInsertPoint(bb_exit);
354 }
355 
356 // Create the entry basic block into an initially empty function.
357 void create_entry_block(llvm::Function* function, CgenState* cgen_state) {
358  AUTOMATIC_IR_METADATA(cgen_state);
359  const auto bb_entry =
360  llvm::BasicBlock::Create(cgen_state->context_, ".entry", function, 0);
361  cgen_state->ir_builder_.SetInsertPoint(bb_entry);
362 }
363 
364 } // namespace
365 
366 void translate_function(const Function* function,
367  llvm::Function* llvm_function,
368  const ReductionCode& reduction_code,
369  const std::unordered_map<const Function*, llvm::Function*>& f) {
370  auto cgen_state = reduction_code.cgen_state;
371  AUTOMATIC_IR_METADATA(cgen_state);
372  create_entry_block(llvm_function, cgen_state);
373  // Set the value mapping based on the input arguments.
374  std::unordered_map<const Value*, llvm::Value*> m;
375  auto llvm_arg_it = llvm_function->arg_begin();
376  for (size_t arg_idx = 0; arg_idx < function->arg_types().size(); ++arg_idx) {
377  llvm::Value* llvm_arg = &(*llvm_arg_it);
378  const auto it_ok = m.emplace(function->arg(arg_idx), llvm_arg);
379  CHECK(it_ok.second);
380  ++llvm_arg_it;
381  }
382  // Add mapping for the constants used by the function.
383  for (const auto& constant : function->constants()) {
384  llvm::Value* constant_llvm{nullptr};
385  switch (constant->type()) {
386  case Type::Int8: {
387  constant_llvm =
388  cgen_state->llInt<int8_t>(static_cast<ConstantInt*>(constant.get())->value());
389  break;
390  }
391  case Type::Int32: {
392  constant_llvm = cgen_state->llInt<int32_t>(
393  static_cast<ConstantInt*>(constant.get())->value());
394  break;
395  }
396  case Type::Int64: {
397  constant_llvm = cgen_state->llInt<int64_t>(
398  static_cast<ConstantInt*>(constant.get())->value());
399  break;
400  }
401  case Type::Float: {
402  constant_llvm = cgen_state->llFp(
403  static_cast<float>(static_cast<ConstantFP*>(constant.get())->value()));
404  break;
405  }
406  case Type::Double: {
407  constant_llvm =
408  cgen_state->llFp(static_cast<ConstantFP*>(constant.get())->value());
409  break;
410  }
411  default: {
412  LOG(FATAL) << "Constant type not supported: "
413  << static_cast<int>(constant->type());
414  }
415  }
416  CHECK(constant_llvm);
417  const auto it_ok = m.emplace(constant.get(), constant_llvm);
418  CHECK(it_ok.second);
419  }
420  translate_body(function->body(), function, llvm_function, reduction_code, m, f);
421  verify_function_ir(llvm_function);
422 }
DEVICE auto upper_bound(ARGS &&...args)
Definition: gpu_enabled.h:123
CgenState * cgen_state
void create_entry_block(llvm::Function *function, CgenState *cgen_state)
void translate_for(const For *for_loop, Function *ir_reduce_loop, const ReductionCode &reduction_code, std::unordered_map< const Value *, llvm::Value * > &m, const std::unordered_map< const Function *, llvm::Function * > &f)
std::unique_ptr< Function > ir_reduce_loop
void load(Archive &ar, ExplainedQueryHint &query_hint, const unsigned int version)
#define LOG(tag)
Definition: Logger.h:285
llvm::IRBuilder ir_builder_
Definition: CgenState.h:384
#define UNREACHABLE()
Definition: Logger.h:338
std::vector< llvm::Value * > llvm_args(const std::vector< const Value * > args, const std::unordered_map< const Value *, llvm::Value * > &m)
llvm::Type * get_fp_type(const int width, llvm::LLVMContext &context)
llvm::ICmpInst::Predicate llvm_predicate(const ICmp::Predicate predicate)
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
void translate_function(const Function *function, llvm::Function *llvm_function, const ReductionCode &reduction_code, const std::unordered_map< const Function *, llvm::Function * > &f)
const Value * end() const
const Value * start() const
void verify_function_ir(const llvm::Function *func)
llvm::LLVMContext & context_
Definition: CgenState.h:382
static llvm::BasicBlock * codegen(const std::vector< JoinLoop > &join_loops, const std::function< llvm::BasicBlock *(const std::vector< llvm::Value * > &)> &body_codegen, llvm::Value *outer_iter, llvm::BasicBlock *exit_bb, CgenState *cgen_state)
Definition: JoinLoop.cpp:50
Type pointee_type(const Type pointer)
llvm::Instruction::CastOps llvm_cast_op(const Cast::CastOp op)
OUTPUT transform(INPUT const &input, FUNC const &func)
Definition: misc.h:320
#define AUTOMATIC_IR_METADATA(CGENSTATE)
llvm::Type * llvm_type(const Type type, llvm::LLVMContext &ctx)
void return_early(llvm::Value *cond, const ReductionCode &reduction_code, llvm::Function *func, llvm::Value *error_code)
llvm::Value * upper_bound
Definition: JoinLoop.h:45
void translate_body(const std::vector< std::unique_ptr< Instruction >> &body, const Function *function, llvm::Function *llvm_function, const ReductionCode &reduction_code, std::unordered_map< const Value *, llvm::Value * > &m, const std::unordered_map< const Function *, llvm::Function * > &f)
#define LLVM_MAYBE_ALIGN(alignment)
torch::Tensor f(torch::Tensor x, torch::Tensor W_target, torch::Tensor b_target)
def error_code
Definition: report.py:244
#define CHECK(condition)
Definition: Logger.h:291
llvm::Function * mapped_function(const Function *function, const std::unordered_map< const Function *, llvm::Function * > &f)
llvm::Value * mapped_value(const Value *val, const std::unordered_map< const Value *, llvm::Value * > &m)
llvm::BinaryOperator::BinaryOps llvm_binary_op(const BinaryOperator::BinaryOp op)