OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ResultSetReductionCodegen.cpp File Reference
#include "ResultSetReductionCodegen.h"
#include "IRCodegenUtils.h"
#include "LoopControlFlow/JoinLoop.h"
#include "ResultSetReductionJIT.h"
#include "ResultSetReductionOps.h"
#include <llvm/IR/Instructions.h>
+ Include dependency graph for ResultSetReductionCodegen.cpp:

Go to the source code of this file.

Namespaces

 anonymous_namespace{ResultSetReductionCodegen.cpp}
 

Functions

llvm::Typellvm_type (const Type type, llvm::LLVMContext &ctx)
 
llvm::ICmpInst::Predicate anonymous_namespace{ResultSetReductionCodegen.cpp}::llvm_predicate (const ICmp::Predicate predicate)
 
llvm::BinaryOperator::BinaryOps anonymous_namespace{ResultSetReductionCodegen.cpp}::llvm_binary_op (const BinaryOperator::BinaryOp op)
 
llvm::Instruction::CastOps anonymous_namespace{ResultSetReductionCodegen.cpp}::llvm_cast_op (const Cast::CastOp op)
 
void anonymous_namespace{ResultSetReductionCodegen.cpp}::return_early (llvm::Value *cond, const ReductionCode &reduction_code, llvm::Function *func, llvm::Value *error_code)
 
llvm::Value * anonymous_namespace{ResultSetReductionCodegen.cpp}::mapped_value (const Value *val, const std::unordered_map< const Value *, llvm::Value * > &m)
 
llvm::Function * anonymous_namespace{ResultSetReductionCodegen.cpp}::mapped_function (const Function *function, const std::unordered_map< const Function *, llvm::Function * > &f)
 
std::vector< llvm::Value * > anonymous_namespace{ResultSetReductionCodegen.cpp}::llvm_args (const std::vector< const Value * > args, const std::unordered_map< const Value *, llvm::Value * > &m)
 
void anonymous_namespace{ResultSetReductionCodegen.cpp}::translate_for (const For *for_loop, Function *ir_reduce_loop, const ReductionCode &reduction_code, std::unordered_map< const Value *, llvm::Value * > &m, const std::unordered_map< const Function *, llvm::Function * > &f)
 
void anonymous_namespace{ResultSetReductionCodegen.cpp}::translate_body (const std::vector< std::unique_ptr< Instruction >> &body, const Function *function, llvm::Function *llvm_function, const ReductionCode &reduction_code, std::unordered_map< const Value *, llvm::Value * > &m, const std::unordered_map< const Function *, llvm::Function * > &f)
 
void anonymous_namespace{ResultSetReductionCodegen.cpp}::create_entry_block (llvm::Function *function, CgenState *cgen_state)
 
void translate_function (const Function *function, llvm::Function *llvm_function, const ReductionCode &reduction_code, const std::unordered_map< const Function *, llvm::Function * > &f)
 

Function Documentation

llvm::Type* llvm_type ( const Type  type,
llvm::LLVMContext &  ctx 
)

Definition at line 26 of file ResultSetReductionCodegen.cpp.

References Double, DoublePtr, logger::FATAL, Float, FloatPtr, get_fp_type(), get_int_type(), Int1, Int32, Int32Ptr, Int64, Int64Ptr, Int64PtrPtr, Int8, Int8Ptr, LOG, run_benchmark_import::type, UNREACHABLE, Void, and VoidPtr.

Referenced by anonymous_namespace{ResultSetReductionJIT.cpp}::create_llvm_function(), StubGenerator::generateStub(), and anonymous_namespace{ResultSetReductionCodegen.cpp}::translate_body().

26  {
27  switch (type) {
28  case Type::Int1: {
29  return get_int_type(1, ctx);
30  }
31  case Type::Int8: {
32  return get_int_type(8, ctx);
33  }
34  case Type::Int32: {
35  return get_int_type(32, ctx);
36  }
37  case Type::Int64: {
38  return get_int_type(64, ctx);
39  }
40  case Type::Float: {
41  return get_fp_type(32, ctx);
42  }
43  case Type::Double: {
44  return get_fp_type(64, ctx);
45  }
46  case Type::Void: {
47  return llvm::Type::getVoidTy(ctx);
48  }
49  case Type::Int8Ptr: {
50  return llvm::PointerType::get(get_int_type(8, ctx), 0);
51  }
52  case Type::Int32Ptr: {
53  return llvm::PointerType::get(get_int_type(32, ctx), 0);
54  }
55  case Type::Int64Ptr: {
56  return llvm::PointerType::get(get_int_type(64, ctx), 0);
57  }
58  case Type::FloatPtr: {
59  return llvm::Type::getFloatPtrTy(ctx);
60  }
61  case Type::DoublePtr: {
62  return llvm::Type::getDoublePtrTy(ctx);
63  }
64  case Type::VoidPtr: {
65  return llvm::PointerType::get(get_int_type(8, ctx), 0);
66  }
67  case Type::Int64PtrPtr: {
68  return llvm::PointerType::get(llvm::PointerType::get(get_int_type(64, ctx), 0), 0);
69  }
70  default: {
71  LOG(FATAL) << "Argument type not supported: " << static_cast<int>(type);
72  break;
73  }
74  }
75  UNREACHABLE();
76  return nullptr;
77 }
#define LOG(tag)
Definition: Logger.h:285
#define UNREACHABLE()
Definition: Logger.h:338
llvm::Type * get_fp_type(const int width, llvm::LLVMContext &context)
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

void translate_function ( const Function function,
llvm::Function *  llvm_function,
const ReductionCode reduction_code,
const std::unordered_map< const Function *, llvm::Function * > &  f 
)

Definition at line 366 of file ResultSetReductionCodegen.cpp.

References AUTOMATIC_IR_METADATA, ReductionCode::cgen_state, CHECK, anonymous_namespace{ResultSetReductionCodegen.cpp}::create_entry_block(), Double, f(), logger::FATAL, Float, Int32, Int64, Int8, LOG, anonymous_namespace{ResultSetReductionCodegen.cpp}::translate_body(), and verify_function_ir().

Referenced by ResultSetReductionJIT::codegen(), and GpuReductionHelperJIT::codegen().

369  {
370  auto cgen_state = reduction_code.cgen_state;
371  AUTOMATIC_IR_METADATA(cgen_state);
372  create_entry_block(llvm_function, cgen_state);
373  // Set the value mapping based on the input arguments.
374  std::unordered_map<const Value*, llvm::Value*> m;
375  auto llvm_arg_it = llvm_function->arg_begin();
376  for (size_t arg_idx = 0; arg_idx < function->arg_types().size(); ++arg_idx) {
377  llvm::Value* llvm_arg = &(*llvm_arg_it);
378  const auto it_ok = m.emplace(function->arg(arg_idx), llvm_arg);
379  CHECK(it_ok.second);
380  ++llvm_arg_it;
381  }
382  // Add mapping for the constants used by the function.
383  for (const auto& constant : function->constants()) {
384  llvm::Value* constant_llvm{nullptr};
385  switch (constant->type()) {
386  case Type::Int8: {
387  constant_llvm =
388  cgen_state->llInt<int8_t>(static_cast<ConstantInt*>(constant.get())->value());
389  break;
390  }
391  case Type::Int32: {
392  constant_llvm = cgen_state->llInt<int32_t>(
393  static_cast<ConstantInt*>(constant.get())->value());
394  break;
395  }
396  case Type::Int64: {
397  constant_llvm = cgen_state->llInt<int64_t>(
398  static_cast<ConstantInt*>(constant.get())->value());
399  break;
400  }
401  case Type::Float: {
402  constant_llvm = cgen_state->llFp(
403  static_cast<float>(static_cast<ConstantFP*>(constant.get())->value()));
404  break;
405  }
406  case Type::Double: {
407  constant_llvm =
408  cgen_state->llFp(static_cast<ConstantFP*>(constant.get())->value());
409  break;
410  }
411  default: {
412  LOG(FATAL) << "Constant type not supported: "
413  << static_cast<int>(constant->type());
414  }
415  }
416  CHECK(constant_llvm);
417  const auto it_ok = m.emplace(constant.get(), constant_llvm);
418  CHECK(it_ok.second);
419  }
420  translate_body(function->body(), function, llvm_function, reduction_code, m, f);
421  verify_function_ir(llvm_function);
422 }
CgenState * cgen_state
void create_entry_block(llvm::Function *function, CgenState *cgen_state)
#define LOG(tag)
Definition: Logger.h:285
void verify_function_ir(const llvm::Function *func)
#define AUTOMATIC_IR_METADATA(CGENSTATE)
void translate_body(const std::vector< std::unique_ptr< Instruction >> &body, const Function *function, llvm::Function *llvm_function, const ReductionCode &reduction_code, std::unordered_map< const Value *, llvm::Value * > &m, const std::unordered_map< const Function *, llvm::Function * > &f)
torch::Tensor f(torch::Tensor x, torch::Tensor W_target, torch::Tensor b_target)
#define CHECK(condition)
Definition: Logger.h:291

+ Here is the call graph for this function:

+ Here is the caller graph for this function: