OmniSciDB  6686921089
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CgenState.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CgenState.h"
19 
20 #include <llvm/IR/InstIterator.h>
21 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
22 #include <llvm/Transforms/Utils/Cloning.h>
23 
24 extern std::unique_ptr<llvm::Module> g_rt_module;
25 #ifdef ENABLE_GEOS
26 extern std::unique_ptr<llvm::Module> g_rt_geos_module;
27 #endif
28 
29 llvm::ConstantInt* CgenState::inlineIntNull(const SQLTypeInfo& type_info) {
30  auto type = type_info.get_type();
31  if (type_info.is_string()) {
32  switch (type_info.get_compression()) {
33  case kENCODING_DICT:
34  return llInt(static_cast<int32_t>(inline_int_null_val(type_info)));
35  case kENCODING_NONE:
36  return llInt(int64_t(0));
37  default:
38  CHECK(false);
39  }
40  }
41  switch (type) {
42  case kBOOLEAN:
43  return llInt(static_cast<int8_t>(inline_int_null_val(type_info)));
44  case kTINYINT:
45  return llInt(static_cast<int8_t>(inline_int_null_val(type_info)));
46  case kSMALLINT:
47  return llInt(static_cast<int16_t>(inline_int_null_val(type_info)));
48  case kINT:
49  return llInt(static_cast<int32_t>(inline_int_null_val(type_info)));
50  case kBIGINT:
51  case kTIME:
52  case kTIMESTAMP:
53  case kDATE:
54  case kINTERVAL_DAY_TIME:
56  return llInt(inline_int_null_val(type_info));
57  case kDECIMAL:
58  case kNUMERIC:
59  return llInt(inline_int_null_val(type_info));
60  case kARRAY:
61  return llInt(int64_t(0));
62  default:
63  abort();
64  }
65 }
66 
67 llvm::ConstantFP* CgenState::inlineFpNull(const SQLTypeInfo& type_info) {
68  CHECK(type_info.is_fp());
69  switch (type_info.get_type()) {
70  case kFLOAT:
71  return llFp(NULL_FLOAT);
72  case kDOUBLE:
73  return llFp(NULL_DOUBLE);
74  default:
75  abort();
76  }
77 }
78 
79 llvm::Constant* CgenState::inlineNull(const SQLTypeInfo& ti) {
80  return ti.is_fp() ? static_cast<llvm::Constant*>(inlineFpNull(ti))
81  : static_cast<llvm::Constant*>(inlineIntNull(ti));
82 }
83 
84 std::pair<llvm::ConstantInt*, llvm::ConstantInt*> CgenState::inlineIntMaxMin(
85  const size_t byte_width,
86  const bool is_signed) {
87  int64_t max_int{0}, min_int{0};
88  if (is_signed) {
89  std::tie(max_int, min_int) = inline_int_max_min(byte_width);
90  } else {
91  uint64_t max_uint{0}, min_uint{0};
92  std::tie(max_uint, min_uint) = inline_uint_max_min(byte_width);
93  max_int = static_cast<int64_t>(max_uint);
94  CHECK_EQ(uint64_t(0), min_uint);
95  }
96  switch (byte_width) {
97  case 1:
98  return std::make_pair(::ll_int(static_cast<int8_t>(max_int), context_),
99  ::ll_int(static_cast<int8_t>(min_int), context_));
100  case 2:
101  return std::make_pair(::ll_int(static_cast<int16_t>(max_int), context_),
102  ::ll_int(static_cast<int16_t>(min_int), context_));
103  case 4:
104  return std::make_pair(::ll_int(static_cast<int32_t>(max_int), context_),
105  ::ll_int(static_cast<int32_t>(min_int), context_));
106  case 8:
107  return std::make_pair(::ll_int(max_int, context_), ::ll_int(min_int, context_));
108  default:
109  abort();
110  }
111 }
112 
113 llvm::Value* CgenState::castToTypeIn(llvm::Value* val, const size_t dst_bits) {
114  auto src_bits = val->getType()->getScalarSizeInBits();
115  if (src_bits == dst_bits) {
116  return val;
117  }
118  if (val->getType()->isIntegerTy()) {
119  return ir_builder_.CreateIntCast(
120  val, get_int_type(dst_bits, context_), src_bits != 1);
121  }
122  // real (not dictionary-encoded) strings; store the pointer to the payload
123  if (val->getType()->isPointerTy()) {
124  return ir_builder_.CreatePointerCast(val, get_int_type(dst_bits, context_));
125  }
126 
127  CHECK(val->getType()->isFloatTy() || val->getType()->isDoubleTy());
128 
129  llvm::Type* dst_type = nullptr;
130  switch (dst_bits) {
131  case 64:
132  dst_type = llvm::Type::getDoubleTy(context_);
133  break;
134  case 32:
135  dst_type = llvm::Type::getFloatTy(context_);
136  break;
137  default:
138  CHECK(false);
139  }
140 
141  return ir_builder_.CreateFPCast(val, dst_type);
142 }
143 
144 void CgenState::maybeCloneFunctionRecursive(llvm::Function* fn) {
145  CHECK(fn);
146  if (!fn->isDeclaration()) {
147  return;
148  }
149 
150  // Get the implementation from the runtime module.
151  auto func_impl = g_rt_module->getFunction(fn->getName());
152  CHECK(func_impl) << fn->getName().str();
153 
154  if (func_impl->isDeclaration()) {
155  return;
156  }
157 
158  auto DestI = fn->arg_begin();
159  for (auto arg_it = func_impl->arg_begin(); arg_it != func_impl->arg_end(); ++arg_it) {
160  DestI->setName(arg_it->getName());
161  vmap_[&*arg_it] = &*DestI++;
162  }
163 
164  llvm::SmallVector<llvm::ReturnInst*, 8> Returns; // Ignore returns cloned.
165  llvm::CloneFunctionInto(fn, func_impl, vmap_, /*ModuleLevelChanges=*/true, Returns);
166 
167  for (auto it = llvm::inst_begin(fn), e = llvm::inst_end(fn); it != e; ++it) {
168  if (llvm::isa<llvm::CallInst>(*it)) {
169  auto& call = llvm::cast<llvm::CallInst>(*it);
170  maybeCloneFunctionRecursive(call.getCalledFunction());
171  }
172  }
173 }
174 
175 llvm::Value* CgenState::emitCall(const std::string& fname,
176  const std::vector<llvm::Value*>& args) {
177  // Get the function reference from the query module.
178  auto func = module_->getFunction(fname);
179  CHECK(func);
180  // If the function called isn't external, clone the implementation from the runtime
181  // module.
183 
184  return ir_builder_.CreateCall(func, args);
185 }
186 
187 void CgenState::emitErrorCheck(llvm::Value* condition,
188  llvm::Value* errorCode,
189  std::string label) {
190  needs_error_check_ = true;
191  auto check_ok = llvm::BasicBlock::Create(context_, label + "_ok", current_func_);
192  auto check_fail = llvm::BasicBlock::Create(context_, label + "_fail", current_func_);
193  ir_builder_.CreateCondBr(condition, check_ok, check_fail);
194  ir_builder_.SetInsertPoint(check_fail);
195  ir_builder_.CreateRet(errorCode);
196  ir_builder_.SetInsertPoint(check_ok);
197 }
198 
199 namespace {
200 
201 // clang-format off
202 template <typename T>
203 llvm::Type* getTy(llvm::LLVMContext& ctx) { return getTy<std::remove_pointer_t<T>>(ctx)->getPointerTo(); }
204 // Commented out to avoid -Wunused-function warnings, but enable as needed.
205 // template<> llvm::Type* getTy<bool>(llvm::LLVMContext& ctx) { return llvm::Type::getInt1Ty(ctx); }
206 //template<> llvm::Type* getTy<int8_t>(llvm::LLVMContext& ctx) { return llvm::Type::getInt8Ty(ctx); }
207 // template<> llvm::Type* getTy<int16_t>(llvm::LLVMContext& ctx) { return llvm::Type::getInt16Ty(ctx); }
208 //template<> llvm::Type* getTy<int32_t>(llvm::LLVMContext& ctx) { return llvm::Type::getInt32Ty(ctx); }
209 // template<> llvm::Type* getTy<int64_t>(llvm::LLVMContext& ctx) { return llvm::Type::getInt64Ty(ctx); }
210 // template<> llvm::Type* getTy<float>(llvm::LLVMContext& ctx) { return llvm::Type::getFloatTy(ctx); }
211 template<> llvm::Type* getTy<double>(llvm::LLVMContext& ctx) { return llvm::Type::getDoubleTy(ctx); }
212 //template<> llvm::Type* getTy<void>(llvm::LLVMContext& ctx) { return llvm::Type::getVoidTy(ctx); }
213 // clang-format on
214 
216  GpuFunctionDefinition(char const* name) : name_(name) {}
217  char const* const name_;
218 
219  virtual ~GpuFunctionDefinition() = default;
220 
221  virtual llvm::FunctionCallee getFunction(llvm::Module* module,
222  llvm::LLVMContext& context) const = 0;
223 };
224 
225 // TYPES = return_type, arg0_type, arg1_type, arg2_type, ...
226 template <typename... TYPES>
227 struct GpuFunction final : public GpuFunctionDefinition {
228  GpuFunction(char const* name) : GpuFunctionDefinition(name) {}
229 
230  llvm::FunctionCallee getFunction(llvm::Module* module,
231  llvm::LLVMContext& context) const {
232  return module->getOrInsertFunction(name_, getTy<TYPES>(context)...);
233  }
234 };
235 
236 static const std::unordered_map<std::string, std::shared_ptr<GpuFunctionDefinition>>
238  {"asin", std::make_shared<GpuFunction<double, double>>("Asin")},
239  {"atanh", std::make_shared<GpuFunction<double, double>>("Atanh")},
240  {"atan", std::make_shared<GpuFunction<double, double>>("Atan")},
241  {"cosh", std::make_shared<GpuFunction<double, double>>("Cosh")},
242  {"cos", std::make_shared<GpuFunction<double, double>>("Cos")},
243  {"exp", std::make_shared<GpuFunction<double, double>>("Exp")},
244  {"log", std::make_shared<GpuFunction<double, double>>("ln")},
245  {"pow", std::make_shared<GpuFunction<double, double, double>>("power")},
246  {"sinh", std::make_shared<GpuFunction<double, double>>("Sinh")},
247  {"sin", std::make_shared<GpuFunction<double, double>>("Sin")},
248  {"sqrt", std::make_shared<GpuFunction<double, double>>("Sqrt")},
249  {"tan", std::make_shared<GpuFunction<double, double>>("Tan")}};
250 } // namespace
251 
252 std::vector<std::string> CgenState::gpuFunctionsToReplace(llvm::Function* fn) {
253  std::vector<std::string> ret;
254 
255  CHECK(fn);
256  CHECK(!fn->isDeclaration());
257 
258  for (auto& basic_block : *fn) {
259  auto& inst_list = basic_block.getInstList();
260  for (auto inst_itr = inst_list.begin(); inst_itr != inst_list.end(); ++inst_itr) {
261  if (auto call_inst = llvm::dyn_cast<llvm::CallInst>(inst_itr)) {
262  auto called_fcn = call_inst->getCalledFunction();
263  CHECK(called_fcn);
264 
265  if (gpu_replacement_functions.find(called_fcn->getName().str()) !=
267  ret.emplace_back(called_fcn->getName());
268  }
269  }
270  }
271  }
272  return ret;
273 }
274 
275 void CgenState::replaceFunctionForGpu(const std::string& fcn_to_replace,
276  llvm::Function* fn) {
277  CHECK(fn);
278  CHECK(!fn->isDeclaration());
279 
280  auto map_it = gpu_replacement_functions.find(fcn_to_replace);
281  if (map_it == gpu_replacement_functions.end()) {
282  throw QueryMustRunOnCpu("Codegen failed: Could not find replacement functon for " +
283  fcn_to_replace +
284  " to run on gpu. Query step must run in cpu mode.");
285  }
286  const auto& gpu_fcn_obj = map_it->second;
287  CHECK(gpu_fcn_obj);
288  VLOG(1) << "Replacing " << fcn_to_replace << " with " << gpu_fcn_obj->name_
289  << " for parent function " << fn->getName().str();
290 
291  for (auto& basic_block : *fn) {
292  auto& inst_list = basic_block.getInstList();
293  for (auto inst_itr = inst_list.begin(); inst_itr != inst_list.end(); ++inst_itr) {
294  if (auto call_inst = llvm::dyn_cast<llvm::CallInst>(inst_itr)) {
295  auto called_fcn = call_inst->getCalledFunction();
296  CHECK(called_fcn);
297 
298  if (called_fcn->getName() == fcn_to_replace) {
299  std::vector<llvm::Value*> args;
300  std::vector<llvm::Type*> arg_types;
301  for (auto& arg : call_inst->args()) {
302  arg_types.push_back(arg.get()->getType());
303  args.push_back(arg.get());
304  }
305  auto gpu_func = gpu_fcn_obj->getFunction(module_, context_);
306  CHECK(gpu_func);
307  auto gpu_func_type = gpu_func.getFunctionType();
308  CHECK(gpu_func_type);
309  CHECK_EQ(gpu_func_type->getReturnType(), called_fcn->getReturnType());
310  llvm::ReplaceInstWithInst(call_inst,
311  llvm::CallInst::Create(gpu_func, args, ""));
312  return;
313  }
314  }
315  }
316  }
317 }
#define CHECK_EQ(x, y)
Definition: Logger.h:217
llvm::Value * castToTypeIn(llvm::Value *val, const size_t bit_width)
Definition: CgenState.cpp:113
#define NULL_DOUBLE
Definition: sqltypes.h:49
#define NULL_FLOAT
void maybeCloneFunctionRecursive(llvm::Function *fn)
Definition: CgenState.cpp:144
bool is_fp() const
Definition: sqltypes.h:513
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
llvm::IRBuilder ir_builder_
Definition: CgenState.h:340
string name
Definition: setup.in.py:72
llvm::Type * getTy(llvm::LLVMContext &ctx)
Definition: CgenState.cpp:203
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:329
llvm::FunctionCallee getFunction(llvm::Module *module, llvm::LLVMContext &context) const
Definition: CgenState.cpp:230
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
llvm::Type * getTy< double >(llvm::LLVMContext &ctx)
Definition: CgenState.cpp:211
static const std::unordered_map< std::string, std::shared_ptr< GpuFunctionDefinition > > gpu_replacement_functions
Definition: CgenState.cpp:237
llvm::Module * module_
Definition: CgenState.h:329
llvm::LLVMContext & context_
Definition: CgenState.h:338
llvm::Function * current_func_
Definition: CgenState.h:332
llvm::ConstantInt * inlineIntNull(const SQLTypeInfo &)
Definition: CgenState.cpp:29
std::unique_ptr< llvm::Module > g_rt_module
void replaceFunctionForGpu(const std::string &fcn_to_replace, llvm::Function *fn)
Definition: CgenState.cpp:275
bool needs_error_check_
Definition: CgenState.h:358
llvm::ConstantFP * llFp(const float v) const
Definition: CgenState.h:311
std::vector< std::string > gpuFunctionsToReplace(llvm::Function *fn)
Definition: CgenState.cpp:252
llvm::Value * emitCall(const std::string &fname, const std::vector< llvm::Value * > &args)
Definition: CgenState.cpp:175
std::pair< uint64_t, uint64_t > inline_uint_max_min(const size_t byte_width)
llvm::Constant * inlineNull(const SQLTypeInfo &)
Definition: CgenState.cpp:79
Definition: sqltypes.h:53
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:337
void emitErrorCheck(llvm::Value *condition, llvm::Value *errorCode, std::string label)
Definition: CgenState.cpp:187
llvm::ConstantInt * llInt(const T v) const
Definition: CgenState.h:307
#define CHECK(condition)
Definition: Logger.h:209
llvm::ValueToValueMapTy vmap_
Definition: CgenState.h:339
int64_t inline_int_null_val(const SQL_TYPE_INFO &ti)
std::pair< int64_t, int64_t > inline_int_max_min(const size_t byte_width)
Definition: sqltypes.h:45
bool is_string() const
Definition: sqltypes.h:509
std::pair< llvm::ConstantInt *, llvm::ConstantInt * > inlineIntMaxMin(const size_t byte_width, const bool is_signed)
Definition: CgenState.cpp:84
#define VLOG(n)
Definition: Logger.h:303
llvm::ConstantFP * inlineFpNull(const SQLTypeInfo &)
Definition: CgenState.cpp:67