OmniSciDB  21ac014ffc
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CgenState.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CgenState.h"
19 
20 #include <llvm/IR/InstIterator.h>
21 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
22 #include <llvm/Transforms/Utils/Cloning.h>
23 
24 extern std::unique_ptr<llvm::Module> g_rt_module;
25 #ifdef ENABLE_GEOS
26 extern std::unique_ptr<llvm::Module> g_rt_geos_module;
27 #endif
28 
29 llvm::ConstantInt* CgenState::inlineIntNull(const SQLTypeInfo& type_info) {
30  auto type = type_info.get_type();
31  if (type_info.is_string()) {
32  switch (type_info.get_compression()) {
33  case kENCODING_DICT:
34  return llInt(static_cast<int32_t>(inline_int_null_val(type_info)));
35  case kENCODING_NONE:
36  return llInt(int64_t(0));
37  default:
38  CHECK(false);
39  }
40  }
41  switch (type) {
42  case kBOOLEAN:
43  return llInt(static_cast<int8_t>(inline_int_null_val(type_info)));
44  case kTINYINT:
45  return llInt(static_cast<int8_t>(inline_int_null_val(type_info)));
46  case kSMALLINT:
47  return llInt(static_cast<int16_t>(inline_int_null_val(type_info)));
48  case kINT:
49  return llInt(static_cast<int32_t>(inline_int_null_val(type_info)));
50  case kBIGINT:
51  case kTIME:
52  case kTIMESTAMP:
53  case kDATE:
54  case kINTERVAL_DAY_TIME:
56  return llInt(inline_int_null_val(type_info));
57  case kDECIMAL:
58  case kNUMERIC:
59  return llInt(inline_int_null_val(type_info));
60  case kARRAY:
61  return llInt(int64_t(0));
62  default:
63  abort();
64  }
65 }
66 
67 llvm::ConstantFP* CgenState::inlineFpNull(const SQLTypeInfo& type_info) {
68  CHECK(type_info.is_fp());
69  switch (type_info.get_type()) {
70  case kFLOAT:
71  return llFp(NULL_FLOAT);
72  case kDOUBLE:
73  return llFp(NULL_DOUBLE);
74  default:
75  abort();
76  }
77 }
78 
79 llvm::Constant* CgenState::inlineNull(const SQLTypeInfo& ti) {
80  return ti.is_fp() ? static_cast<llvm::Constant*>(inlineFpNull(ti))
81  : static_cast<llvm::Constant*>(inlineIntNull(ti));
82 }
83 
84 std::pair<llvm::ConstantInt*, llvm::ConstantInt*> CgenState::inlineIntMaxMin(
85  const size_t byte_width,
86  const bool is_signed) {
87  int64_t max_int{0}, min_int{0};
88  if (is_signed) {
89  std::tie(max_int, min_int) = inline_int_max_min(byte_width);
90  } else {
91  uint64_t max_uint{0}, min_uint{0};
92  std::tie(max_uint, min_uint) = inline_uint_max_min(byte_width);
93  max_int = static_cast<int64_t>(max_uint);
94  CHECK_EQ(uint64_t(0), min_uint);
95  }
96  switch (byte_width) {
97  case 1:
98  return std::make_pair(::ll_int(static_cast<int8_t>(max_int), context_),
99  ::ll_int(static_cast<int8_t>(min_int), context_));
100  case 2:
101  return std::make_pair(::ll_int(static_cast<int16_t>(max_int), context_),
102  ::ll_int(static_cast<int16_t>(min_int), context_));
103  case 4:
104  return std::make_pair(::ll_int(static_cast<int32_t>(max_int), context_),
105  ::ll_int(static_cast<int32_t>(min_int), context_));
106  case 8:
107  return std::make_pair(::ll_int(max_int, context_), ::ll_int(min_int, context_));
108  default:
109  abort();
110  }
111 }
112 
113 llvm::Value* CgenState::castToTypeIn(llvm::Value* val, const size_t dst_bits) {
114  auto src_bits = val->getType()->getScalarSizeInBits();
115  if (src_bits == dst_bits) {
116  return val;
117  }
118  if (val->getType()->isIntegerTy()) {
119  return ir_builder_.CreateIntCast(
120  val, get_int_type(dst_bits, context_), src_bits != 1);
121  }
122  // real (not dictionary-encoded) strings; store the pointer to the payload
123  if (val->getType()->isPointerTy()) {
124  return ir_builder_.CreatePointerCast(val, get_int_type(dst_bits, context_));
125  }
126 
127  CHECK(val->getType()->isFloatTy() || val->getType()->isDoubleTy());
128 
129  llvm::Type* dst_type = nullptr;
130  switch (dst_bits) {
131  case 64:
132  dst_type = llvm::Type::getDoubleTy(context_);
133  break;
134  case 32:
135  dst_type = llvm::Type::getFloatTy(context_);
136  break;
137  default:
138  CHECK(false);
139  }
140 
141  return ir_builder_.CreateFPCast(val, dst_type);
142 }
143 
144 void CgenState::maybeCloneFunctionRecursive(llvm::Function* fn) {
145  CHECK(fn);
146  if (!fn->isDeclaration()) {
147  return;
148  }
149 
150  // Get the implementation from the runtime module.
151  auto func_impl = g_rt_module->getFunction(fn->getName());
152  CHECK(func_impl) << fn->getName().str();
153 
154  if (func_impl->isDeclaration()) {
155  return;
156  }
157 
158  auto DestI = fn->arg_begin();
159  for (auto arg_it = func_impl->arg_begin(); arg_it != func_impl->arg_end(); ++arg_it) {
160  DestI->setName(arg_it->getName());
161  vmap_[&*arg_it] = &*DestI++;
162  }
163 
164  llvm::SmallVector<llvm::ReturnInst*, 8> Returns; // Ignore returns cloned.
165  llvm::CloneFunctionInto(fn, func_impl, vmap_, /*ModuleLevelChanges=*/true, Returns);
166 
167  for (auto it = llvm::inst_begin(fn), e = llvm::inst_end(fn); it != e; ++it) {
168  if (llvm::isa<llvm::CallInst>(*it)) {
169  auto& call = llvm::cast<llvm::CallInst>(*it);
170  maybeCloneFunctionRecursive(call.getCalledFunction());
171  }
172  }
173 }
174 
175 llvm::Value* CgenState::emitCall(const std::string& fname,
176  const std::vector<llvm::Value*>& args) {
177  // Get the function reference from the query module.
178  auto func = module_->getFunction(fname);
179  CHECK(func);
180  // If the function called isn't external, clone the implementation from the runtime
181  // module.
183 
184  return ir_builder_.CreateCall(func, args);
185 }
186 
187 void CgenState::emitErrorCheck(llvm::Value* condition,
188  llvm::Value* errorCode,
189  std::string label) {
190  needs_error_check_ = true;
191  auto check_ok = llvm::BasicBlock::Create(context_, label + "_ok", current_func_);
192  auto check_fail = llvm::BasicBlock::Create(context_, label + "_fail", current_func_);
193  ir_builder_.CreateCondBr(condition, check_ok, check_fail);
194  ir_builder_.SetInsertPoint(check_fail);
195  ir_builder_.CreateRet(errorCode);
196  ir_builder_.SetInsertPoint(check_ok);
197 }
198 
199 namespace {
200 
202  GpuFunctionDefinition(const std::string& name) : name(name) {}
203  std::string name;
204 
206 
207  virtual llvm::FunctionCallee getFunction(llvm::Module* module,
208  llvm::LLVMContext& context) const = 0;
209 };
210 
213 
214  llvm::FunctionCallee getFunction(llvm::Module* module,
215  llvm::LLVMContext& context) const final {
216  return module->getOrInsertFunction(name,
217  /*ret_type=*/llvm::Type::getDoubleTy(context),
218  /*args=*/llvm::Type::getDoubleTy(context),
219  llvm::Type::getDoubleTy(context));
220  }
221 };
222 
225 
226  llvm::FunctionCallee getFunction(llvm::Module* module,
227  llvm::LLVMContext& context) const final {
228  return module->getOrInsertFunction(name,
229  /*ret_type=*/llvm::Type::getDoubleTy(context),
230  /*args=*/llvm::Type::getDoubleTy(context));
231  }
232 };
233 
236 
237  llvm::FunctionCallee getFunction(llvm::Module* module,
238  llvm::LLVMContext& context) const final {
239  return module->getOrInsertFunction(name,
240  /*ret_type=*/llvm::Type::getDoubleTy(context),
241  /*args=*/llvm::Type::getDoubleTy(context));
242  }
243 };
244 
247 
248  llvm::FunctionCallee getFunction(llvm::Module* module,
249  llvm::LLVMContext& context) const final {
250  return module->getOrInsertFunction(name,
251  /*ret_type=*/llvm::Type::getDoubleTy(context),
252  /*args=*/llvm::Type::getDoubleTy(context));
253  }
254 };
255 
258 
259  llvm::FunctionCallee getFunction(llvm::Module* module,
260  llvm::LLVMContext& context) const final {
261  return module->getOrInsertFunction(name,
262  /*ret_type=*/llvm::Type::getDoubleTy(context),
263  /*args=*/llvm::Type::getDoubleTy(context));
264  }
265 };
266 
267 static const std::unordered_map<std::string, std::shared_ptr<GpuFunctionDefinition>>
268  gpu_replacement_functions{{"pow", std::make_shared<GpuPowerFunction>()},
269  {"atan", std::make_shared<GpuAtanFunction>()},
270  {"log", std::make_shared<GpuLogFunction>()},
271  {"tan", std::make_shared<GpuTanFunction>()},
272  {"exp", std::make_shared<GpuExpFunction>()}};
273 
274 } // namespace
275 
276 std::vector<std::string> CgenState::gpuFunctionsToReplace(llvm::Function* fn) {
277  std::vector<std::string> ret;
278 
279  CHECK(fn);
280  CHECK(!fn->isDeclaration());
281 
282  for (auto& basic_block : *fn) {
283  auto& inst_list = basic_block.getInstList();
284  for (auto inst_itr = inst_list.begin(); inst_itr != inst_list.end(); ++inst_itr) {
285  if (auto call_inst = llvm::dyn_cast<llvm::CallInst>(inst_itr)) {
286  auto called_fcn = call_inst->getCalledFunction();
287  CHECK(called_fcn);
288 
289  if (gpu_replacement_functions.find(called_fcn->getName().str()) !=
291  ret.emplace_back(called_fcn->getName());
292  }
293  }
294  }
295  }
296  return ret;
297 }
298 
299 void CgenState::replaceFunctionForGpu(const std::string& fcn_to_replace,
300  llvm::Function* fn) {
301  CHECK(fn);
302  CHECK(!fn->isDeclaration());
303 
304  auto map_it = gpu_replacement_functions.find(fcn_to_replace);
305  if (map_it == gpu_replacement_functions.end()) {
306  throw QueryMustRunOnCpu("Codegen failed: Could not find replacement functon for " +
307  fcn_to_replace +
308  " to run on gpu. Query must run in cpu mode.");
309  }
310  const auto& gpu_fcn_obj = map_it->second;
311  CHECK(gpu_fcn_obj);
312  const auto& gpu_fcn_name = gpu_fcn_obj->name;
313  VLOG(1) << "Replacing " << fcn_to_replace << " with " << gpu_fcn_name
314  << " for parent function " << fn->getName().str();
315 
316  for (auto& basic_block : *fn) {
317  auto& inst_list = basic_block.getInstList();
318  for (auto inst_itr = inst_list.begin(); inst_itr != inst_list.end(); ++inst_itr) {
319  if (auto call_inst = llvm::dyn_cast<llvm::CallInst>(inst_itr)) {
320  auto called_fcn = call_inst->getCalledFunction();
321  CHECK(called_fcn);
322 
323  if (called_fcn->getName() == fcn_to_replace) {
324  std::vector<llvm::Value*> args;
325  std::vector<llvm::Type*> arg_types;
326  for (auto& arg : call_inst->args()) {
327  arg_types.push_back(arg.get()->getType());
328  args.push_back(arg.get());
329  }
330  auto gpu_func = gpu_fcn_obj->getFunction(module_, context_);
331  CHECK(gpu_func);
332  auto gpu_func_type = gpu_func.getFunctionType();
333  CHECK(gpu_func_type);
334  CHECK_EQ(gpu_func_type->getReturnType(), called_fcn->getReturnType());
335  llvm::ReplaceInstWithInst(call_inst,
336  llvm::CallInst::Create(gpu_func, args, ""));
337  return;
338  }
339  }
340  }
341  }
342 }
#define CHECK_EQ(x, y)
Definition: Logger.h:214
llvm::Value * castToTypeIn(llvm::Value *val, const size_t bit_width)
Definition: CgenState.cpp:113
#define NULL_DOUBLE
Definition: sqltypes.h:48
#define NULL_FLOAT
void maybeCloneFunctionRecursive(llvm::Function *fn)
Definition: CgenState.cpp:144
bool is_fp() const
Definition: sqltypes.h:502
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
llvm::IRBuilder ir_builder_
Definition: CgenState.h:340
string name
Definition: setup.in.py:72
llvm::FunctionCallee getFunction(llvm::Module *module, llvm::LLVMContext &context) const final
Definition: CgenState.cpp:214
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:323
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
static const std::unordered_map< std::string, std::shared_ptr< GpuFunctionDefinition > > gpu_replacement_functions
Definition: CgenState.cpp:268
llvm::Module * module_
Definition: CgenState.h:329
llvm::LLVMContext & context_
Definition: CgenState.h:338
llvm::Function * current_func_
Definition: CgenState.h:332
llvm::ConstantInt * inlineIntNull(const SQLTypeInfo &)
Definition: CgenState.cpp:29
std::unique_ptr< llvm::Module > g_rt_module
llvm::FunctionCallee getFunction(llvm::Module *module, llvm::LLVMContext &context) const final
Definition: CgenState.cpp:226
void replaceFunctionForGpu(const std::string &fcn_to_replace, llvm::Function *fn)
Definition: CgenState.cpp:299
bool needs_error_check_
Definition: CgenState.h:358
llvm::ConstantFP * llFp(const float v) const
Definition: CgenState.h:311
std::vector< std::string > gpuFunctionsToReplace(llvm::Function *fn)
Definition: CgenState.cpp:276
llvm::Value * emitCall(const std::string &fname, const std::vector< llvm::Value * > &args)
Definition: CgenState.cpp:175
std::pair< uint64_t, uint64_t > inline_uint_max_min(const size_t byte_width)
llvm::Constant * inlineNull(const SQLTypeInfo &)
Definition: CgenState.cpp:79
llvm::FunctionCallee getFunction(llvm::Module *module, llvm::LLVMContext &context) const final
Definition: CgenState.cpp:259
Definition: sqltypes.h:52
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:331
void emitErrorCheck(llvm::Value *condition, llvm::Value *errorCode, std::string label)
Definition: CgenState.cpp:187
llvm::FunctionCallee getFunction(llvm::Module *module, llvm::LLVMContext &context) const final
Definition: CgenState.cpp:237
llvm::ConstantInt * llInt(const T v) const
Definition: CgenState.h:307
#define CHECK(condition)
Definition: Logger.h:206
llvm::ValueToValueMapTy vmap_
Definition: CgenState.h:339
int64_t inline_int_null_val(const SQL_TYPE_INFO &ti)
std::pair< int64_t, int64_t > inline_int_max_min(const size_t byte_width)
Definition: sqltypes.h:44
bool is_string() const
Definition: sqltypes.h:498
std::pair< llvm::ConstantInt *, llvm::ConstantInt * > inlineIntMaxMin(const size_t byte_width, const bool is_signed)
Definition: CgenState.cpp:84
llvm::FunctionCallee getFunction(llvm::Module *module, llvm::LLVMContext &context) const final
Definition: CgenState.cpp:248
#define VLOG(n)
Definition: Logger.h:300
llvm::ConstantFP * inlineFpNull(const SQLTypeInfo &)
Definition: CgenState.cpp:67