OmniSciDB  a987f07e93
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CgenState.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CgenState.h"
18 #include "CodeGenerator.h"
20 
21 #include <llvm/IR/InstIterator.h>
22 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
23 #include <llvm/Transforms/Utils/Cloning.h>
24 
25 CgenState::CgenState(const size_t num_query_infos,
26  const bool contains_left_deep_outer_join,
27  Executor* executor)
28  : executor_id_(executor->getExecutorId())
29  , module_(nullptr)
30  , row_func_(nullptr)
31  , filter_func_(nullptr)
32  , current_func_(nullptr)
33  , row_func_bb_(nullptr)
34  , filter_func_bb_(nullptr)
35  , row_func_call_(nullptr)
36  , filter_func_call_(nullptr)
37  , context_(executor->getContext())
38  , ir_builder_(context_)
39  , contains_left_deep_outer_join_(contains_left_deep_outer_join)
40  , outer_join_match_found_per_level_(std::max(num_query_infos, size_t(1)) - 1)
41  , needs_error_check_(false)
42  , needs_geos_(false)
43  , query_func_(nullptr)
44  , query_func_entry_ir_builder_(context_){};
45 
46 CgenState::CgenState(const size_t num_query_infos,
47  const bool contains_left_deep_outer_join)
48  : CgenState(num_query_infos,
49  contains_left_deep_outer_join,
50  Executor::getExecutor(Executor::UNITARY_EXECUTOR_ID).get()) {}
51 
52 CgenState::CgenState(llvm::LLVMContext& context)
53  : executor_id_(Executor::INVALID_EXECUTOR_ID)
54  , module_(nullptr)
55  , row_func_(nullptr)
56  , context_(context)
57  , ir_builder_(context_)
58  , contains_left_deep_outer_join_(false)
59  , needs_error_check_(false)
60  , needs_geos_(false)
61  , query_func_(nullptr)
62  , query_func_entry_ir_builder_(context_){};
63 
64 llvm::ConstantInt* CgenState::inlineIntNull(const SQLTypeInfo& type_info) {
65  auto type = type_info.get_type();
66  if (type_info.is_string()) {
67  switch (type_info.get_compression()) {
68  case kENCODING_DICT:
69  return llInt(static_cast<int32_t>(inline_int_null_val(type_info)));
70  case kENCODING_NONE:
71  return llInt(int64_t(0));
72  default:
73  CHECK(false);
74  }
75  }
76  switch (type) {
77  case kBOOLEAN:
78  return llInt(static_cast<int8_t>(inline_int_null_val(type_info)));
79  case kTINYINT:
80  return llInt(static_cast<int8_t>(inline_int_null_val(type_info)));
81  case kSMALLINT:
82  return llInt(static_cast<int16_t>(inline_int_null_val(type_info)));
83  case kINT:
84  return llInt(static_cast<int32_t>(inline_int_null_val(type_info)));
85  case kBIGINT:
86  return llInt(static_cast<int64_t>(inline_int_null_val(type_info)));
87  case kTIME:
88  case kTIMESTAMP:
89  case kDATE:
90  case kINTERVAL_DAY_TIME:
92  return llInt(inline_int_null_val(type_info));
93  case kDECIMAL:
94  case kNUMERIC:
95  return llInt(inline_int_null_val(type_info));
96  case kARRAY:
97  return llInt(int64_t(0));
98  default:
99  abort();
100  }
101 }
102 
103 llvm::ConstantFP* CgenState::inlineFpNull(const SQLTypeInfo& type_info) {
104  CHECK(type_info.is_fp());
105  switch (type_info.get_type()) {
106  case kFLOAT:
107  return llFp(NULL_FLOAT);
108  case kDOUBLE:
109  return llFp(NULL_DOUBLE);
110  default:
111  abort();
112  }
113 }
114 
115 llvm::Constant* CgenState::inlineNull(const SQLTypeInfo& ti) {
116  return ti.is_fp() ? static_cast<llvm::Constant*>(inlineFpNull(ti))
117  : static_cast<llvm::Constant*>(inlineIntNull(ti));
118 }
119 
120 std::pair<llvm::ConstantInt*, llvm::ConstantInt*> CgenState::inlineIntMaxMin(
121  const size_t byte_width,
122  const bool is_signed) {
123  int64_t max_int{0}, min_int{0};
124  if (is_signed) {
125  std::tie(max_int, min_int) = inline_int_max_min(byte_width);
126  } else {
127  uint64_t max_uint{0}, min_uint{0};
128  std::tie(max_uint, min_uint) = inline_uint_max_min(byte_width);
129  max_int = static_cast<int64_t>(max_uint);
130  CHECK_EQ(uint64_t(0), min_uint);
131  }
132  switch (byte_width) {
133  case 1:
134  return std::make_pair(::ll_int(static_cast<int8_t>(max_int), context_),
135  ::ll_int(static_cast<int8_t>(min_int), context_));
136  case 2:
137  return std::make_pair(::ll_int(static_cast<int16_t>(max_int), context_),
138  ::ll_int(static_cast<int16_t>(min_int), context_));
139  case 4:
140  return std::make_pair(::ll_int(static_cast<int32_t>(max_int), context_),
141  ::ll_int(static_cast<int32_t>(min_int), context_));
142  case 8:
143  return std::make_pair(::ll_int(max_int, context_), ::ll_int(min_int, context_));
144  default:
145  abort();
146  }
147 }
148 
149 llvm::Value* CgenState::castToTypeIn(llvm::Value* val, const size_t dst_bits) {
150  auto src_bits = val->getType()->getScalarSizeInBits();
151  if (src_bits == dst_bits) {
152  return val;
153  }
154  if (val->getType()->isIntegerTy()) {
155  return ir_builder_.CreateIntCast(
156  val, get_int_type(dst_bits, context_), src_bits != 1);
157  }
158  // real (not dictionary-encoded) strings; store the pointer to the payload
159  if (val->getType()->isPointerTy()) {
160  return ir_builder_.CreatePointerCast(val, get_int_type(dst_bits, context_));
161  }
162 
163  CHECK(val->getType()->isFloatTy() || val->getType()->isDoubleTy());
164 
165  llvm::Type* dst_type = nullptr;
166  switch (dst_bits) {
167  case 64:
168  dst_type = llvm::Type::getDoubleTy(context_);
169  break;
170  case 32:
171  dst_type = llvm::Type::getFloatTy(context_);
172  break;
173  default:
174  CHECK(false);
175  }
176 
177  return ir_builder_.CreateFPCast(val, dst_type);
178 }
179 
180 void CgenState::maybeCloneFunctionRecursive(llvm::Function* fn) {
181  CHECK(fn);
182  if (!fn->isDeclaration()) {
183  return;
184  }
185 
186  // Get the implementation from the runtime module.
187  auto func_impl = getExecutor()->get_rt_module()->getFunction(fn->getName());
188  CHECK(func_impl) << fn->getName().str();
189 
190  if (func_impl->isDeclaration()) {
191  return;
192  }
193 
194  auto DestI = fn->arg_begin();
195  for (auto arg_it = func_impl->arg_begin(); arg_it != func_impl->arg_end(); ++arg_it) {
196  DestI->setName(arg_it->getName());
197  vmap_[&*arg_it] = &*DestI++;
198  }
199 
200  llvm::SmallVector<llvm::ReturnInst*, 8> Returns; // Ignore returns cloned.
201 #if LLVM_VERSION_MAJOR > 12
202  llvm::CloneFunctionInto(
203  fn, func_impl, vmap_, llvm::CloneFunctionChangeType::DifferentModule, Returns);
204 #else
205  llvm::CloneFunctionInto(fn, func_impl, vmap_, /*ModuleLevelChanges=*/true, Returns);
206 #endif
207 
208  for (auto it = llvm::inst_begin(fn), e = llvm::inst_end(fn); it != e; ++it) {
209  if (llvm::isa<llvm::CallInst>(*it)) {
210  auto& call = llvm::cast<llvm::CallInst>(*it);
211  maybeCloneFunctionRecursive(call.getCalledFunction());
212  }
213  }
214 }
215 
216 llvm::Value* CgenState::emitCall(const std::string& fname,
217  const std::vector<llvm::Value*>& args) {
218  // Get the function reference from the query module.
219  auto func = module_->getFunction(fname);
220  CHECK(func) << fname;
221  // If the function called isn't external, clone the implementation from the runtime
222  // module.
224 
225  return ir_builder_.CreateCall(func, args);
226 }
227 
228 llvm::Value* CgenState::emitEntryCall(const std::string& fname,
229  const std::vector<llvm::Value*>& args) {
230  // Get the function reference from the query module.
231  auto func = module_->getFunction(fname);
232  CHECK(func);
233  // If the function called isn't external, clone the implementation from the runtime
234  // module.
236 
237  return query_func_entry_ir_builder_.CreateCall(func, args);
238 }
239 
240 void CgenState::emitErrorCheck(llvm::Value* condition,
241  llvm::Value* errorCode,
242  std::string label) {
243  needs_error_check_ = true;
244  auto check_ok = llvm::BasicBlock::Create(context_, label + "_ok", current_func_);
245  auto check_fail = llvm::BasicBlock::Create(context_, label + "_fail", current_func_);
246  ir_builder_.CreateCondBr(condition, check_ok, check_fail);
247  ir_builder_.SetInsertPoint(check_fail);
248  ir_builder_.CreateRet(errorCode);
249  ir_builder_.SetInsertPoint(check_ok);
250 }
251 
252 namespace {
253 
254 // clang-format off
255 template <typename T>
256 llvm::Type* getTy(llvm::LLVMContext& ctx) { return getTy<std::remove_pointer_t<T>>(ctx)->getPointerTo(); }
257 // Commented out to avoid -Wunused-function warnings, but enable as needed.
258 // template<> llvm::Type* getTy<bool>(llvm::LLVMContext& ctx) { return llvm::Type::getInt1Ty(ctx); }
259 //template<> llvm::Type* getTy<int8_t>(llvm::LLVMContext& ctx) { return llvm::Type::getInt8Ty(ctx); }
260 // template<> llvm::Type* getTy<int16_t>(llvm::LLVMContext& ctx) { return llvm::Type::getInt16Ty(ctx); }
261 //template<> llvm::Type* getTy<int32_t>(llvm::LLVMContext& ctx) { return llvm::Type::getInt32Ty(ctx); }
262 // template<> llvm::Type* getTy<int64_t>(llvm::LLVMContext& ctx) { return llvm::Type::getInt64Ty(ctx); }
263 // template<> llvm::Type* getTy<float>(llvm::LLVMContext& ctx) { return llvm::Type::getFloatTy(ctx); }
264 template<> llvm::Type* getTy<double>(llvm::LLVMContext& ctx) { return llvm::Type::getDoubleTy(ctx); }
265 //template<> llvm::Type* getTy<void>(llvm::LLVMContext& ctx) { return llvm::Type::getVoidTy(ctx); }
266 // clang-format on
267 
269  GpuFunctionDefinition(char const* name) : name_(name) {}
270  char const* const name_;
271 
272  virtual ~GpuFunctionDefinition() = default;
273 
274  virtual llvm::FunctionCallee getFunction(llvm::Module* llvm_module,
275  llvm::LLVMContext& context) const = 0;
276 };
277 
278 // TYPES = return_type, arg0_type, arg1_type, arg2_type, ...
279 template <typename... TYPES>
280 struct GpuFunction final : public GpuFunctionDefinition {
281  GpuFunction(char const* name) : GpuFunctionDefinition(name) {}
282 
283  llvm::FunctionCallee getFunction(llvm::Module* llvm_module,
284  llvm::LLVMContext& context) const {
285  return llvm_module->getOrInsertFunction(name_, getTy<TYPES>(context)...);
286  }
287 };
288 
289 static const std::unordered_map<std::string, std::shared_ptr<GpuFunctionDefinition>>
291  {"asin", std::make_shared<GpuFunction<double, double>>("Asin")},
292  {"atanh", std::make_shared<GpuFunction<double, double>>("Atanh")},
293  {"atan", std::make_shared<GpuFunction<double, double>>("Atan")},
294  {"cosh", std::make_shared<GpuFunction<double, double>>("Cosh")},
295  {"cos", std::make_shared<GpuFunction<double, double>>("Cos")},
296  {"exp", std::make_shared<GpuFunction<double, double>>("Exp")},
297  {"log", std::make_shared<GpuFunction<double, double>>("ln")},
298  {"pow", std::make_shared<GpuFunction<double, double, double>>("power")},
299  {"sinh", std::make_shared<GpuFunction<double, double>>("Sinh")},
300  {"sin", std::make_shared<GpuFunction<double, double>>("Sin")},
301  {"sqrt", std::make_shared<GpuFunction<double, double>>("Sqrt")},
302  {"tan", std::make_shared<GpuFunction<double, double>>("Tan")}};
303 } // namespace
304 
305 std::vector<std::string> CgenState::gpuFunctionsToReplace(llvm::Function* fn) {
306  std::vector<std::string> ret;
307 
308  CHECK(fn);
309  CHECK(!fn->isDeclaration());
310 
311  for (auto& basic_block : *fn) {
312  auto& inst_list = basic_block.getInstList();
313  for (auto inst_itr = inst_list.begin(); inst_itr != inst_list.end(); ++inst_itr) {
314  if (auto call_inst = llvm::dyn_cast<llvm::CallInst>(inst_itr)) {
315  auto called_fcn = call_inst->getCalledFunction();
316  CHECK(called_fcn);
317 
318  if (gpu_replacement_functions.find(called_fcn->getName().str()) !=
320  ret.emplace_back(called_fcn->getName());
321  }
322  }
323  }
324  }
325  return ret;
326 }
327 
328 void CgenState::replaceFunctionForGpu(const std::string& fcn_to_replace,
329  llvm::Function* fn) {
330  CHECK(fn);
331  CHECK(!fn->isDeclaration());
332 
333  auto map_it = gpu_replacement_functions.find(fcn_to_replace);
334  if (map_it == gpu_replacement_functions.end()) {
335  throw QueryMustRunOnCpu("Codegen failed: Could not find replacement functon for " +
336  fcn_to_replace +
337  " to run on gpu. Query step must run in cpu mode.");
338  }
339  const auto& gpu_fcn_obj = map_it->second;
340  CHECK(gpu_fcn_obj);
341  VLOG(1) << "Replacing " << fcn_to_replace << " with " << gpu_fcn_obj->name_
342  << " for parent function " << fn->getName().str();
343 
344  for (auto& basic_block : *fn) {
345  auto& inst_list = basic_block.getInstList();
346  for (auto inst_itr = inst_list.begin(); inst_itr != inst_list.end(); ++inst_itr) {
347  if (auto call_inst = llvm::dyn_cast<llvm::CallInst>(inst_itr)) {
348  auto called_fcn = call_inst->getCalledFunction();
349  CHECK(called_fcn);
350 
351  if (called_fcn->getName() == fcn_to_replace) {
352  std::vector<llvm::Value*> args;
353  std::vector<llvm::Type*> arg_types;
354  for (auto& arg : call_inst->args()) {
355  arg_types.push_back(arg.get()->getType());
356  args.push_back(arg.get());
357  }
358  auto gpu_func = gpu_fcn_obj->getFunction(module_, context_);
359  CHECK(gpu_func);
360  auto gpu_func_type = gpu_func.getFunctionType();
361  CHECK(gpu_func_type);
362  CHECK_EQ(gpu_func_type->getReturnType(), called_fcn->getReturnType());
363  llvm::ReplaceInstWithInst(call_inst,
364  llvm::CallInst::Create(gpu_func, args, ""));
365  return;
366  }
367  }
368  }
369  }
370 }
371 
372 std::shared_ptr<Executor> CgenState::getExecutor() const {
375 }
376 
377 llvm::LLVMContext& CgenState::getExecutorContext() const {
378  return getExecutor()->getContext();
379 }
380 
381 void CgenState::set_module_shallow_copy(const std::unique_ptr<llvm::Module>& llvm_module,
382  bool always_clone) {
383  module_ =
384  llvm::CloneModule(*llvm_module, vmap_, [always_clone](const llvm::GlobalValue* gv) {
385  auto func = llvm::dyn_cast<llvm::Function>(gv);
386  if (!func) {
387  return true;
388  }
389  return (func->getLinkage() == llvm::GlobalValue::LinkageTypes::PrivateLinkage ||
390  func->getLinkage() == llvm::GlobalValue::LinkageTypes::InternalLinkage ||
391  (always_clone && CodeGenerator::alwaysCloneRuntimeFunction(func)));
392  }).release();
393 }
394 
395 
397  const std::string& fname,
398  llvm::Type* ret_type,
399  const std::vector<llvm::Value*> args,
400  const std::vector<llvm::Attribute::AttrKind>& fnattrs,
401  const bool has_struct_return) {
402  std::vector<llvm::Type*> arg_types;
403  for (const auto arg : args) {
404  CHECK(arg);
405  arg_types.push_back(arg->getType());
406  }
407  auto func_ty = llvm::FunctionType::get(ret_type, arg_types, false);
408  llvm::AttributeList attrs;
409  if (!fnattrs.empty()) {
410  std::vector<std::pair<unsigned, llvm::Attribute>> indexedAttrs;
411  indexedAttrs.reserve(fnattrs.size());
412  for (auto attr : fnattrs) {
413  indexedAttrs.emplace_back(llvm::AttributeList::FunctionIndex,
414  llvm::Attribute::get(context_, attr));
415  }
416  attrs = llvm::AttributeList::get(context_,
417  {&indexedAttrs.front(), indexedAttrs.size()});
418  }
419 
420  auto func_p = module_->getOrInsertFunction(fname, func_ty, attrs);
421  CHECK(func_p);
422  auto callee = func_p.getCallee();
423  llvm::Function* func{nullptr};
424  if (auto callee_cast = llvm::dyn_cast<llvm::ConstantExpr>(callee)) {
425  // Get or insert function automatically adds a ConstantExpr cast if the return type
426  // of the existing function does not match the supplied return type.
427  CHECK(callee_cast->isCast());
428  CHECK_EQ(callee_cast->getNumOperands(), size_t(1));
429  func = llvm::dyn_cast<llvm::Function>(callee_cast->getOperand(0));
430  } else {
431  func = llvm::dyn_cast<llvm::Function>(callee);
432  }
433  CHECK(func);
434  llvm::FunctionType* func_type = func_p.getFunctionType();
435  CHECK(func_type);
436  if (has_struct_return) {
437  const auto arg_ti = func_type->getParamType(0);
438  CHECK(arg_ti->isPointerTy() && arg_ti->getPointerElementType()->isStructTy());
439  auto attr_list = func->getAttributes();
440 #if 14 <= LLVM_VERSION_MAJOR
441  llvm::AttrBuilder arr_arg_builder(context_, attr_list.getParamAttrs(0));
442 #else
443  llvm::AttrBuilder arr_arg_builder(attr_list.getParamAttributes(0));
444 #endif
445  arr_arg_builder.addAttribute(llvm::Attribute::StructRet);
446  func->addParamAttrs(0, arr_arg_builder);
447  }
448  llvm::Value* result = ir_builder_.CreateCall(func_p, args);
449  // check the assumed type
450  CHECK_EQ(result->getType(), ret_type);
451  return result;
452 }
#define CHECK_EQ(x, y)
Definition: Logger.h:297
llvm::Value * castToTypeIn(llvm::Value *val, const size_t bit_width)
Definition: CgenState.cpp:149
llvm::Value * emitEntryCall(const std::string &fname, const std::vector< llvm::Value * > &args)
Definition: CgenState.cpp:228
llvm::FunctionCallee getFunction(llvm::Module *llvm_module, llvm::LLVMContext &context) const
Definition: CgenState.cpp:283
#define NULL_DOUBLE
Definition: sqltypes.h:64
llvm::LLVMContext & getExecutorContext() const
Definition: CgenState.cpp:377
std::shared_ptr< Executor > getExecutor() const
Definition: CgenState.cpp:372
#define NULL_FLOAT
void maybeCloneFunctionRecursive(llvm::Function *fn)
Definition: CgenState.cpp:180
bool is_fp() const
Definition: sqltypes.h:580
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
llvm::IRBuilder ir_builder_
Definition: CgenState.h:375
llvm::Type * getTy(llvm::LLVMContext &ctx)
Definition: CgenState.cpp:256
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:380
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
llvm::Type * getTy< double >(llvm::LLVMContext &ctx)
Definition: CgenState.cpp:264
static std::shared_ptr< Executor > getExecutor(const ExecutorId id, const std::string &debug_dir="", const std::string &debug_file="", const SystemParameters &system_parameters=SystemParameters())
Definition: Execute.cpp:477
static const std::unordered_map< std::string, std::shared_ptr< GpuFunctionDefinition > > gpu_replacement_functions
Definition: CgenState.cpp:290
llvm::Module * module_
Definition: CgenState.h:364
size_t executor_id_
Definition: CgenState.h:267
llvm::LLVMContext & context_
Definition: CgenState.h:373
llvm::Function * current_func_
Definition: CgenState.h:367
llvm::Value * emitExternalCall(const std::string &fname, llvm::Type *ret_type, const std::vector< llvm::Value * > args, const std::vector< llvm::Attribute::AttrKind > &fnattrs={}, const bool has_struct_return=false)
Definition: CgenState.cpp:396
llvm::ConstantInt * inlineIntNull(const SQLTypeInfo &)
Definition: CgenState.cpp:64
void replaceFunctionForGpu(const std::string &fcn_to_replace, llvm::Function *fn)
Definition: CgenState.cpp:328
bool needs_error_check_
Definition: CgenState.h:395
llvm::ConstantFP * llFp(const float v) const
Definition: CgenState.h:244
std::vector< std::string > gpuFunctionsToReplace(llvm::Function *fn)
Definition: CgenState.cpp:305
llvm::IRBuilder query_func_entry_ir_builder_
Definition: CgenState.h:399
llvm::Value * emitCall(const std::string &fname, const std::vector< llvm::Value * > &args)
Definition: CgenState.cpp:216
static const ExecutorId INVALID_EXECUTOR_ID
Definition: Execute.h:377
std::pair< uint64_t, uint64_t > inline_uint_max_min(const size_t byte_width)
llvm::Constant * inlineNull(const SQLTypeInfo &)
Definition: CgenState.cpp:115
void set_module_shallow_copy(const std::unique_ptr< llvm::Module > &module, bool always_clone=false)
Definition: CgenState.cpp:381
Definition: sqltypes.h:68
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:388
static bool alwaysCloneRuntimeFunction(const llvm::Function *func)
void emitErrorCheck(llvm::Value *condition, llvm::Value *errorCode, std::string label)
Definition: CgenState.cpp:240
llvm::ConstantInt * llInt(const T v) const
Definition: CgenState.h:240
bool g_enable_watchdog false
Definition: Execute.cpp:79
#define CHECK(condition)
Definition: Logger.h:289
llvm::ValueToValueMapTy vmap_
Definition: CgenState.h:374
int64_t inline_int_null_val(const SQL_TYPE_INFO &ti)
std::pair< int64_t, int64_t > inline_int_max_min(const size_t byte_width)
Definition: sqltypes.h:60
bool is_string() const
Definition: sqltypes.h:576
string name
Definition: setup.in.py:72
CgenState(const size_t num_query_infos, const bool contains_left_deep_outer_join, Executor *executor)
Definition: CgenState.cpp:25
std::pair< llvm::ConstantInt *, llvm::ConstantInt * > inlineIntMaxMin(const size_t byte_width, const bool is_signed)
Definition: CgenState.cpp:120
#define VLOG(n)
Definition: Logger.h:383
llvm::ConstantFP * inlineFpNull(const SQLTypeInfo &)
Definition: CgenState.cpp:103