OmniSciDB  c0231cc57d
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CgenState.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CgenState.h"
18 #include "CodeGenerator.h"
20 
21 #include <llvm/IR/InstIterator.h>
22 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
23 #include <llvm/Transforms/Utils/Cloning.h>
24 
25 CgenState::CgenState(const size_t num_query_infos,
26  const bool contains_left_deep_outer_join,
27  Executor* executor)
28  : executor_id_(executor->getExecutorId())
29  , module_(nullptr)
30  , row_func_(nullptr)
31  , filter_func_(nullptr)
32  , current_func_(nullptr)
33  , row_func_bb_(nullptr)
34  , filter_func_bb_(nullptr)
35  , row_func_call_(nullptr)
36  , filter_func_call_(nullptr)
37  , context_(executor->getContext())
38  , ir_builder_(context_)
39  , contains_left_deep_outer_join_(contains_left_deep_outer_join)
40  , outer_join_match_found_per_level_(std::max(num_query_infos, size_t(1)) - 1)
41  , needs_error_check_(false)
42  , needs_geos_(false)
43  , query_func_(nullptr)
44  , query_func_entry_ir_builder_(context_){};
45 
46 CgenState::CgenState(const size_t num_query_infos,
47  const bool contains_left_deep_outer_join)
48  : CgenState(num_query_infos,
49  contains_left_deep_outer_join,
50  Executor::getExecutor(Executor::UNITARY_EXECUTOR_ID).get()) {}
51 
52 CgenState::CgenState(llvm::LLVMContext& context)
53  : executor_id_(Executor::INVALID_EXECUTOR_ID)
54  , module_(nullptr)
55  , row_func_(nullptr)
56  , context_(context)
57  , ir_builder_(context_)
58  , contains_left_deep_outer_join_(false)
59  , needs_error_check_(false)
60  , needs_geos_(false)
61  , query_func_(nullptr)
62  , query_func_entry_ir_builder_(context_){};
63 
64 llvm::ConstantInt* CgenState::inlineIntNull(const SQLTypeInfo& type_info) {
65  auto type = type_info.get_type();
66  if (type_info.is_string()) {
67  switch (type_info.get_compression()) {
68  case kENCODING_DICT:
69  return llInt(static_cast<int32_t>(inline_int_null_val(type_info)));
70  case kENCODING_NONE:
71  return llInt(int64_t(0));
72  default:
73  CHECK(false);
74  }
75  }
76  switch (type) {
77  case kBOOLEAN:
78  return llInt(static_cast<int8_t>(inline_int_null_val(type_info)));
79  case kTINYINT:
80  return llInt(static_cast<int8_t>(inline_int_null_val(type_info)));
81  case kSMALLINT:
82  return llInt(static_cast<int16_t>(inline_int_null_val(type_info)));
83  case kINT:
84  return llInt(static_cast<int32_t>(inline_int_null_val(type_info)));
85  case kBIGINT:
86  return llInt(static_cast<int64_t>(inline_int_null_val(type_info)));
87  case kTIME:
88  case kTIMESTAMP:
89  if (type_info.get_compression() == kENCODING_FIXED) {
90  return llInt(inline_fixed_encoding_null_val(type_info));
91  }
92  case kDATE:
93  case kINTERVAL_DAY_TIME:
95  return llInt(inline_int_null_val(type_info));
96  case kDECIMAL:
97  case kNUMERIC:
98  return llInt(inline_int_null_val(type_info));
99  case kARRAY:
100  return llInt(int64_t(0));
101  default:
102  abort();
103  }
104 }
105 
106 llvm::ConstantFP* CgenState::inlineFpNull(const SQLTypeInfo& type_info) {
107  CHECK(type_info.is_fp());
108  switch (type_info.get_type()) {
109  case kFLOAT:
110  return llFp(NULL_FLOAT);
111  case kDOUBLE:
112  return llFp(NULL_DOUBLE);
113  default:
114  abort();
115  }
116 }
117 
118 llvm::Constant* CgenState::inlineNull(const SQLTypeInfo& ti) {
119  return ti.is_fp() ? static_cast<llvm::Constant*>(inlineFpNull(ti))
120  : static_cast<llvm::Constant*>(inlineIntNull(ti));
121 }
122 
123 std::pair<llvm::ConstantInt*, llvm::ConstantInt*> CgenState::inlineIntMaxMin(
124  const size_t byte_width,
125  const bool is_signed) {
126  int64_t max_int{0}, min_int{0};
127  if (is_signed) {
128  std::tie(max_int, min_int) = inline_int_max_min(byte_width);
129  } else {
130  uint64_t max_uint{0}, min_uint{0};
131  std::tie(max_uint, min_uint) = inline_uint_max_min(byte_width);
132  max_int = static_cast<int64_t>(max_uint);
133  CHECK_EQ(uint64_t(0), min_uint);
134  }
135  switch (byte_width) {
136  case 1:
137  return std::make_pair(::ll_int(static_cast<int8_t>(max_int), context_),
138  ::ll_int(static_cast<int8_t>(min_int), context_));
139  case 2:
140  return std::make_pair(::ll_int(static_cast<int16_t>(max_int), context_),
141  ::ll_int(static_cast<int16_t>(min_int), context_));
142  case 4:
143  return std::make_pair(::ll_int(static_cast<int32_t>(max_int), context_),
144  ::ll_int(static_cast<int32_t>(min_int), context_));
145  case 8:
146  return std::make_pair(::ll_int(max_int, context_), ::ll_int(min_int, context_));
147  default:
148  abort();
149  }
150 }
151 
152 llvm::Value* CgenState::castToTypeIn(llvm::Value* val, const size_t dst_bits) {
153  auto src_bits = val->getType()->getScalarSizeInBits();
154  if (src_bits == dst_bits) {
155  return val;
156  }
157  if (val->getType()->isIntegerTy()) {
158  return ir_builder_.CreateIntCast(
159  val, get_int_type(dst_bits, context_), src_bits != 1);
160  }
161  // real (not dictionary-encoded) strings; store the pointer to the payload
162  if (val->getType()->isPointerTy()) {
163  return ir_builder_.CreatePointerCast(val, get_int_type(dst_bits, context_));
164  }
165 
166  CHECK(val->getType()->isFloatTy() || val->getType()->isDoubleTy());
167 
168  llvm::Type* dst_type = nullptr;
169  switch (dst_bits) {
170  case 64:
171  dst_type = llvm::Type::getDoubleTy(context_);
172  break;
173  case 32:
174  dst_type = llvm::Type::getFloatTy(context_);
175  break;
176  default:
177  CHECK(false);
178  }
179 
180  return ir_builder_.CreateFPCast(val, dst_type);
181 }
182 
183 void CgenState::maybeCloneFunctionRecursive(llvm::Function* fn) {
184  CHECK(fn);
185  if (!fn->isDeclaration()) {
186  return;
187  }
188 
189  // Get the implementation from the runtime module.
190  auto func_impl = getExecutor()->get_rt_module()->getFunction(fn->getName());
191  CHECK(func_impl) << fn->getName().str();
192 
193  if (func_impl->isDeclaration()) {
194  return;
195  }
196 
197  auto DestI = fn->arg_begin();
198  for (auto arg_it = func_impl->arg_begin(); arg_it != func_impl->arg_end(); ++arg_it) {
199  DestI->setName(arg_it->getName());
200  vmap_[&*arg_it] = &*DestI++;
201  }
202 
203  llvm::SmallVector<llvm::ReturnInst*, 8> Returns; // Ignore returns cloned.
204 #if LLVM_VERSION_MAJOR > 12
205  llvm::CloneFunctionInto(
206  fn, func_impl, vmap_, llvm::CloneFunctionChangeType::DifferentModule, Returns);
207 #else
208  llvm::CloneFunctionInto(fn, func_impl, vmap_, /*ModuleLevelChanges=*/true, Returns);
209 #endif
210 
211  for (auto it = llvm::inst_begin(fn), e = llvm::inst_end(fn); it != e; ++it) {
212  if (llvm::isa<llvm::CallInst>(*it)) {
213  auto& call = llvm::cast<llvm::CallInst>(*it);
214  maybeCloneFunctionRecursive(call.getCalledFunction());
215  }
216  }
217 }
218 
219 llvm::Value* CgenState::emitCall(const std::string& fname,
220  const std::vector<llvm::Value*>& args) {
221  // Get the function reference from the query module.
222  auto func = module_->getFunction(fname);
223  CHECK(func);
224  // If the function called isn't external, clone the implementation from the runtime
225  // module.
227 
228  return ir_builder_.CreateCall(func, args);
229 }
230 
231 llvm::Value* CgenState::emitEntryCall(const std::string& fname,
232  const std::vector<llvm::Value*>& args) {
233  // Get the function reference from the query module.
234  auto func = module_->getFunction(fname);
235  CHECK(func);
236  // If the function called isn't external, clone the implementation from the runtime
237  // module.
239 
240  return query_func_entry_ir_builder_.CreateCall(func, args);
241 }
242 
243 void CgenState::emitErrorCheck(llvm::Value* condition,
244  llvm::Value* errorCode,
245  std::string label) {
246  needs_error_check_ = true;
247  auto check_ok = llvm::BasicBlock::Create(context_, label + "_ok", current_func_);
248  auto check_fail = llvm::BasicBlock::Create(context_, label + "_fail", current_func_);
249  ir_builder_.CreateCondBr(condition, check_ok, check_fail);
250  ir_builder_.SetInsertPoint(check_fail);
251  ir_builder_.CreateRet(errorCode);
252  ir_builder_.SetInsertPoint(check_ok);
253 }
254 
255 namespace {
256 
257 // clang-format off
258 template <typename T>
259 llvm::Type* getTy(llvm::LLVMContext& ctx) { return getTy<std::remove_pointer_t<T>>(ctx)->getPointerTo(); }
260 // Commented out to avoid -Wunused-function warnings, but enable as needed.
261 // template<> llvm::Type* getTy<bool>(llvm::LLVMContext& ctx) { return llvm::Type::getInt1Ty(ctx); }
262 //template<> llvm::Type* getTy<int8_t>(llvm::LLVMContext& ctx) { return llvm::Type::getInt8Ty(ctx); }
263 // template<> llvm::Type* getTy<int16_t>(llvm::LLVMContext& ctx) { return llvm::Type::getInt16Ty(ctx); }
264 //template<> llvm::Type* getTy<int32_t>(llvm::LLVMContext& ctx) { return llvm::Type::getInt32Ty(ctx); }
265 // template<> llvm::Type* getTy<int64_t>(llvm::LLVMContext& ctx) { return llvm::Type::getInt64Ty(ctx); }
266 // template<> llvm::Type* getTy<float>(llvm::LLVMContext& ctx) { return llvm::Type::getFloatTy(ctx); }
267 template<> llvm::Type* getTy<double>(llvm::LLVMContext& ctx) { return llvm::Type::getDoubleTy(ctx); }
268 //template<> llvm::Type* getTy<void>(llvm::LLVMContext& ctx) { return llvm::Type::getVoidTy(ctx); }
269 // clang-format on
270 
272  GpuFunctionDefinition(char const* name) : name_(name) {}
273  char const* const name_;
274 
275  virtual ~GpuFunctionDefinition() = default;
276 
277  virtual llvm::FunctionCallee getFunction(llvm::Module* llvm_module,
278  llvm::LLVMContext& context) const = 0;
279 };
280 
281 // TYPES = return_type, arg0_type, arg1_type, arg2_type, ...
282 template <typename... TYPES>
283 struct GpuFunction final : public GpuFunctionDefinition {
284  GpuFunction(char const* name) : GpuFunctionDefinition(name) {}
285 
286  llvm::FunctionCallee getFunction(llvm::Module* llvm_module,
287  llvm::LLVMContext& context) const {
288  return llvm_module->getOrInsertFunction(name_, getTy<TYPES>(context)...);
289  }
290 };
291 
292 static const std::unordered_map<std::string, std::shared_ptr<GpuFunctionDefinition>>
294  {"asin", std::make_shared<GpuFunction<double, double>>("Asin")},
295  {"atanh", std::make_shared<GpuFunction<double, double>>("Atanh")},
296  {"atan", std::make_shared<GpuFunction<double, double>>("Atan")},
297  {"cosh", std::make_shared<GpuFunction<double, double>>("Cosh")},
298  {"cos", std::make_shared<GpuFunction<double, double>>("Cos")},
299  {"exp", std::make_shared<GpuFunction<double, double>>("Exp")},
300  {"log", std::make_shared<GpuFunction<double, double>>("ln")},
301  {"pow", std::make_shared<GpuFunction<double, double, double>>("power")},
302  {"sinh", std::make_shared<GpuFunction<double, double>>("Sinh")},
303  {"sin", std::make_shared<GpuFunction<double, double>>("Sin")},
304  {"sqrt", std::make_shared<GpuFunction<double, double>>("Sqrt")},
305  {"tan", std::make_shared<GpuFunction<double, double>>("Tan")}};
306 } // namespace
307 
308 std::vector<std::string> CgenState::gpuFunctionsToReplace(llvm::Function* fn) {
309  std::vector<std::string> ret;
310 
311  CHECK(fn);
312  CHECK(!fn->isDeclaration());
313 
314  for (auto& basic_block : *fn) {
315  auto& inst_list = basic_block.getInstList();
316  for (auto inst_itr = inst_list.begin(); inst_itr != inst_list.end(); ++inst_itr) {
317  if (auto call_inst = llvm::dyn_cast<llvm::CallInst>(inst_itr)) {
318  auto called_fcn = call_inst->getCalledFunction();
319  CHECK(called_fcn);
320 
321  if (gpu_replacement_functions.find(called_fcn->getName().str()) !=
323  ret.emplace_back(called_fcn->getName());
324  }
325  }
326  }
327  }
328  return ret;
329 }
330 
331 void CgenState::replaceFunctionForGpu(const std::string& fcn_to_replace,
332  llvm::Function* fn) {
333  CHECK(fn);
334  CHECK(!fn->isDeclaration());
335 
336  auto map_it = gpu_replacement_functions.find(fcn_to_replace);
337  if (map_it == gpu_replacement_functions.end()) {
338  throw QueryMustRunOnCpu("Codegen failed: Could not find replacement functon for " +
339  fcn_to_replace +
340  " to run on gpu. Query step must run in cpu mode.");
341  }
342  const auto& gpu_fcn_obj = map_it->second;
343  CHECK(gpu_fcn_obj);
344  VLOG(1) << "Replacing " << fcn_to_replace << " with " << gpu_fcn_obj->name_
345  << " for parent function " << fn->getName().str();
346 
347  for (auto& basic_block : *fn) {
348  auto& inst_list = basic_block.getInstList();
349  for (auto inst_itr = inst_list.begin(); inst_itr != inst_list.end(); ++inst_itr) {
350  if (auto call_inst = llvm::dyn_cast<llvm::CallInst>(inst_itr)) {
351  auto called_fcn = call_inst->getCalledFunction();
352  CHECK(called_fcn);
353 
354  if (called_fcn->getName() == fcn_to_replace) {
355  std::vector<llvm::Value*> args;
356  std::vector<llvm::Type*> arg_types;
357  for (auto& arg : call_inst->args()) {
358  arg_types.push_back(arg.get()->getType());
359  args.push_back(arg.get());
360  }
361  auto gpu_func = gpu_fcn_obj->getFunction(module_, context_);
362  CHECK(gpu_func);
363  auto gpu_func_type = gpu_func.getFunctionType();
364  CHECK(gpu_func_type);
365  CHECK_EQ(gpu_func_type->getReturnType(), called_fcn->getReturnType());
366  llvm::ReplaceInstWithInst(call_inst,
367  llvm::CallInst::Create(gpu_func, args, ""));
368  return;
369  }
370  }
371  }
372  }
373 }
374 
375 std::shared_ptr<Executor> CgenState::getExecutor() const {
378 }
379 
380 llvm::LLVMContext& CgenState::getExecutorContext() const {
381  return getExecutor()->getContext();
382 }
383 
384 void CgenState::set_module_shallow_copy(const std::unique_ptr<llvm::Module>& llvm_module,
385  bool always_clone) {
386  module_ =
387  llvm::CloneModule(*llvm_module, vmap_, [always_clone](const llvm::GlobalValue* gv) {
388  auto func = llvm::dyn_cast<llvm::Function>(gv);
389  if (!func) {
390  return true;
391  }
392  return (func->getLinkage() == llvm::GlobalValue::LinkageTypes::PrivateLinkage ||
393  func->getLinkage() == llvm::GlobalValue::LinkageTypes::InternalLinkage ||
394  (always_clone && CodeGenerator::alwaysCloneRuntimeFunction(func)));
395  }).release();
396 }
#define CHECK_EQ(x, y)
Definition: Logger.h:230
llvm::Value * castToTypeIn(llvm::Value *val, const size_t bit_width)
Definition: CgenState.cpp:152
llvm::Value * emitEntryCall(const std::string &fname, const std::vector< llvm::Value * > &args)
Definition: CgenState.cpp:231
llvm::FunctionCallee getFunction(llvm::Module *llvm_module, llvm::LLVMContext &context) const
Definition: CgenState.cpp:286
#define NULL_DOUBLE
Definition: sqltypes.h:63
llvm::LLVMContext & getExecutorContext() const
Definition: CgenState.cpp:380
std::shared_ptr< Executor > getExecutor() const
Definition: CgenState.cpp:375
#define NULL_FLOAT
void maybeCloneFunctionRecursive(llvm::Function *fn)
Definition: CgenState.cpp:183
bool is_fp() const
Definition: sqltypes.h:604
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
llvm::IRBuilder ir_builder_
Definition: CgenState.h:441
llvm::Type * getTy(llvm::LLVMContext &ctx)
Definition: CgenState.cpp:259
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:404
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
llvm::Type * getTy< double >(llvm::LLVMContext &ctx)
Definition: CgenState.cpp:267
static std::shared_ptr< Executor > getExecutor(const ExecutorId id, const std::string &debug_dir="", const std::string &debug_file="", const SystemParameters &system_parameters=SystemParameters())
Definition: Execute.cpp:477
static const std::unordered_map< std::string, std::shared_ptr< GpuFunctionDefinition > > gpu_replacement_functions
Definition: CgenState.cpp:293
llvm::Module * module_
Definition: CgenState.h:430
size_t executor_id_
Definition: CgenState.h:333
llvm::LLVMContext & context_
Definition: CgenState.h:439
llvm::Function * current_func_
Definition: CgenState.h:433
llvm::ConstantInt * inlineIntNull(const SQLTypeInfo &)
Definition: CgenState.cpp:64
void replaceFunctionForGpu(const std::string &fcn_to_replace, llvm::Function *fn)
Definition: CgenState.cpp:331
bool needs_error_check_
Definition: CgenState.h:461
llvm::ConstantFP * llFp(const float v) const
Definition: CgenState.h:310
std::vector< std::string > gpuFunctionsToReplace(llvm::Function *fn)
Definition: CgenState.cpp:308
llvm::IRBuilder query_func_entry_ir_builder_
Definition: CgenState.h:465
llvm::Value * emitCall(const std::string &fname, const std::vector< llvm::Value * > &args)
Definition: CgenState.cpp:219
static const ExecutorId INVALID_EXECUTOR_ID
Definition: Execute.h:377
std::pair< uint64_t, uint64_t > inline_uint_max_min(const size_t byte_width)
llvm::Constant * inlineNull(const SQLTypeInfo &)
Definition: CgenState.cpp:118
void set_module_shallow_copy(const std::unique_ptr< llvm::Module > &module, bool always_clone=false)
Definition: CgenState.cpp:384
Definition: sqltypes.h:67
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:412
static bool alwaysCloneRuntimeFunction(const llvm::Function *func)
void emitErrorCheck(llvm::Value *condition, llvm::Value *errorCode, std::string label)
Definition: CgenState.cpp:243
llvm::ConstantInt * llInt(const T v) const
Definition: CgenState.h:306
bool g_enable_watchdog false
Definition: Execute.cpp:79
#define CHECK(condition)
Definition: Logger.h:222
llvm::ValueToValueMapTy vmap_
Definition: CgenState.h:440
int64_t inline_int_null_val(const SQL_TYPE_INFO &ti)
int64_t inline_fixed_encoding_null_val(const SQL_TYPE_INFO &ti)
std::pair< int64_t, int64_t > inline_int_max_min(const size_t byte_width)
Definition: sqltypes.h:59
bool is_string() const
Definition: sqltypes.h:600
string name
Definition: setup.in.py:72
CgenState(const size_t num_query_infos, const bool contains_left_deep_outer_join, Executor *executor)
Definition: CgenState.cpp:25
std::pair< llvm::ConstantInt *, llvm::ConstantInt * > inlineIntMaxMin(const size_t byte_width, const bool is_signed)
Definition: CgenState.cpp:123
#define VLOG(n)
Definition: Logger.h:316
llvm::ConstantFP * inlineFpNull(const SQLTypeInfo &)
Definition: CgenState.cpp:106