OmniSciDB  8fa3bf436f
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
CgenState.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CgenState.h"
19 
20 #include <llvm/IR/InstIterator.h>
21 #include <llvm/Transforms/Utils/Cloning.h>
22 
23 extern std::unique_ptr<llvm::Module> g_rt_module;
24 #ifdef ENABLE_GEOS
25 extern std::unique_ptr<llvm::Module> g_rt_geos_module;
26 #endif
27 
28 llvm::ConstantInt* CgenState::inlineIntNull(const SQLTypeInfo& type_info) {
29  auto type = type_info.get_type();
30  if (type_info.is_string()) {
31  switch (type_info.get_compression()) {
32  case kENCODING_DICT:
33  return llInt(static_cast<int32_t>(inline_int_null_val(type_info)));
34  case kENCODING_NONE:
35  return llInt(int64_t(0));
36  default:
37  CHECK(false);
38  }
39  }
40  switch (type) {
41  case kBOOLEAN:
42  return llInt(static_cast<int8_t>(inline_int_null_val(type_info)));
43  case kTINYINT:
44  return llInt(static_cast<int8_t>(inline_int_null_val(type_info)));
45  case kSMALLINT:
46  return llInt(static_cast<int16_t>(inline_int_null_val(type_info)));
47  case kINT:
48  return llInt(static_cast<int32_t>(inline_int_null_val(type_info)));
49  case kBIGINT:
50  case kTIME:
51  case kTIMESTAMP:
52  case kDATE:
53  case kINTERVAL_DAY_TIME:
55  return llInt(inline_int_null_val(type_info));
56  case kDECIMAL:
57  case kNUMERIC:
58  return llInt(inline_int_null_val(type_info));
59  case kARRAY:
60  return llInt(int64_t(0));
61  default:
62  abort();
63  }
64 }
65 
66 llvm::ConstantFP* CgenState::inlineFpNull(const SQLTypeInfo& type_info) {
67  CHECK(type_info.is_fp());
68  switch (type_info.get_type()) {
69  case kFLOAT:
70  return llFp(NULL_FLOAT);
71  case kDOUBLE:
72  return llFp(NULL_DOUBLE);
73  default:
74  abort();
75  }
76 }
77 
78 llvm::Constant* CgenState::inlineNull(const SQLTypeInfo& ti) {
79  return ti.is_fp() ? static_cast<llvm::Constant*>(inlineFpNull(ti))
80  : static_cast<llvm::Constant*>(inlineIntNull(ti));
81 }
82 
83 std::pair<llvm::ConstantInt*, llvm::ConstantInt*> CgenState::inlineIntMaxMin(
84  const size_t byte_width,
85  const bool is_signed) {
86  int64_t max_int{0}, min_int{0};
87  if (is_signed) {
88  std::tie(max_int, min_int) = inline_int_max_min(byte_width);
89  } else {
90  uint64_t max_uint{0}, min_uint{0};
91  std::tie(max_uint, min_uint) = inline_uint_max_min(byte_width);
92  max_int = static_cast<int64_t>(max_uint);
93  CHECK_EQ(uint64_t(0), min_uint);
94  }
95  switch (byte_width) {
96  case 1:
97  return std::make_pair(::ll_int(static_cast<int8_t>(max_int), context_),
98  ::ll_int(static_cast<int8_t>(min_int), context_));
99  case 2:
100  return std::make_pair(::ll_int(static_cast<int16_t>(max_int), context_),
101  ::ll_int(static_cast<int16_t>(min_int), context_));
102  case 4:
103  return std::make_pair(::ll_int(static_cast<int32_t>(max_int), context_),
104  ::ll_int(static_cast<int32_t>(min_int), context_));
105  case 8:
106  return std::make_pair(::ll_int(max_int, context_), ::ll_int(min_int, context_));
107  default:
108  abort();
109  }
110 }
111 
112 llvm::Value* CgenState::castToTypeIn(llvm::Value* val, const size_t dst_bits) {
113  auto src_bits = val->getType()->getScalarSizeInBits();
114  if (src_bits == dst_bits) {
115  return val;
116  }
117  if (val->getType()->isIntegerTy()) {
118  return ir_builder_.CreateIntCast(
119  val, get_int_type(dst_bits, context_), src_bits != 1);
120  }
121  // real (not dictionary-encoded) strings; store the pointer to the payload
122  if (val->getType()->isPointerTy()) {
123  return ir_builder_.CreatePointerCast(val, get_int_type(dst_bits, context_));
124  }
125 
126  CHECK(val->getType()->isFloatTy() || val->getType()->isDoubleTy());
127 
128  llvm::Type* dst_type = nullptr;
129  switch (dst_bits) {
130  case 64:
131  dst_type = llvm::Type::getDoubleTy(context_);
132  break;
133  case 32:
134  dst_type = llvm::Type::getFloatTy(context_);
135  break;
136  default:
137  CHECK(false);
138  }
139 
140  return ir_builder_.CreateFPCast(val, dst_type);
141 }
142 
143 void CgenState::maybeCloneFunctionRecursive(llvm::Function* fn) {
144  CHECK(fn);
145  if (!fn->isDeclaration()) {
146  return;
147  }
148 
149  // Get the implementation from the runtime module.
150  auto func_impl = g_rt_module->getFunction(fn->getName());
151  CHECK(func_impl) << fn->getName().str();
152 
153  if (func_impl->isDeclaration()) {
154  return;
155  }
156 
157  auto DestI = fn->arg_begin();
158  for (auto arg_it = func_impl->arg_begin(); arg_it != func_impl->arg_end(); ++arg_it) {
159  DestI->setName(arg_it->getName());
160  vmap_[&*arg_it] = &*DestI++;
161  }
162 
163  llvm::SmallVector<llvm::ReturnInst*, 8> Returns; // Ignore returns cloned.
164  llvm::CloneFunctionInto(fn, func_impl, vmap_, /*ModuleLevelChanges=*/true, Returns);
165 
166  for (auto it = llvm::inst_begin(fn), e = llvm::inst_end(fn); it != e; ++it) {
167  if (llvm::isa<llvm::CallInst>(*it)) {
168  auto& call = llvm::cast<llvm::CallInst>(*it);
169  maybeCloneFunctionRecursive(call.getCalledFunction());
170  }
171  }
172 }
173 
174 llvm::Value* CgenState::emitCall(const std::string& fname,
175  const std::vector<llvm::Value*>& args) {
176  // Get the function reference from the query module.
177  auto func = module_->getFunction(fname);
178  CHECK(func);
179  // If the function called isn't external, clone the implementation from the runtime
180  // module.
182 
183  return ir_builder_.CreateCall(func, args);
184 }
185 
186 void CgenState::emitErrorCheck(llvm::Value* condition,
187  llvm::Value* errorCode,
188  std::string label) {
189  needs_error_check_ = true;
190  auto check_ok = llvm::BasicBlock::Create(context_, label + "_ok", current_func_);
191  auto check_fail = llvm::BasicBlock::Create(context_, label + "_fail", current_func_);
192  ir_builder_.CreateCondBr(condition, check_ok, check_fail);
193  ir_builder_.SetInsertPoint(check_fail);
194  ir_builder_.CreateRet(errorCode);
195  ir_builder_.SetInsertPoint(check_ok);
196 }
#define CHECK_EQ(x, y)
Definition: Logger.h:211
llvm::Value * castToTypeIn(llvm::Value *val, const size_t bit_width)
Definition: CgenState.cpp:112
#define NULL_DOUBLE
Definition: sqltypes.h:48
#define NULL_FLOAT
void maybeCloneFunctionRecursive(llvm::Function *fn)
Definition: CgenState.cpp:143
bool is_fp() const
Definition: sqltypes.h:493
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
llvm::IRBuilder ir_builder_
Definition: CgenState.h:329
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:314
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
llvm::Module * module_
Definition: CgenState.h:318
llvm::LLVMContext & context_
Definition: CgenState.h:327
llvm::Function * current_func_
Definition: CgenState.h:321
llvm::ConstantInt * inlineIntNull(const SQLTypeInfo &)
Definition: CgenState.cpp:28
std::unique_ptr< llvm::Module > g_rt_module
bool needs_error_check_
Definition: CgenState.h:344
llvm::ConstantFP * llFp(const float v) const
Definition: CgenState.h:304
llvm::Value * emitCall(const std::string &fname, const std::vector< llvm::Value * > &args)
Definition: CgenState.cpp:174
std::pair< uint64_t, uint64_t > inline_uint_max_min(const size_t byte_width)
llvm::Constant * inlineNull(const SQLTypeInfo &)
Definition: CgenState.cpp:78
Definition: sqltypes.h:52
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:322
void emitErrorCheck(llvm::Value *condition, llvm::Value *errorCode, std::string label)
Definition: CgenState.cpp:186
llvm::ConstantInt * llInt(const T v) const
Definition: CgenState.h:300
#define CHECK(condition)
Definition: Logger.h:203
llvm::ValueToValueMapTy vmap_
Definition: CgenState.h:328
int64_t inline_int_null_val(const SQL_TYPE_INFO &ti)
std::pair< int64_t, int64_t > inline_int_max_min(const size_t byte_width)
Definition: sqltypes.h:44
bool is_string() const
Definition: sqltypes.h:489
std::pair< llvm::ConstantInt *, llvm::ConstantInt * > inlineIntMaxMin(const size_t byte_width, const bool is_signed)
Definition: CgenState.cpp:83
llvm::ConstantFP * inlineFpNull(const SQLTypeInfo &)
Definition: CgenState.cpp:66