OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CaseIR.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CodeGenerator.h"
18 #include "Execute.h"
19 
20 std::vector<llvm::Value*> CodeGenerator::codegen(const Analyzer::CaseExpr* case_expr,
21  const CompilationOptions& co) {
23  const auto case_ti = case_expr->get_type_info();
24  llvm::Type* case_llvm_type = nullptr;
25  bool is_real_str = false;
26  if (case_ti.is_integer() || case_ti.is_time() || case_ti.is_decimal()) {
27  case_llvm_type = get_int_type(get_bit_width(case_ti), cgen_state_->context_);
28  } else if (case_ti.is_fp()) {
29  case_llvm_type = case_ti.get_type() == kFLOAT
30  ? llvm::Type::getFloatTy(cgen_state_->context_)
31  : llvm::Type::getDoubleTy(cgen_state_->context_);
32  } else if (case_ti.is_string()) {
33  if (case_ti.get_compression() == kENCODING_DICT) {
34  case_llvm_type =
35  get_int_type(8 * case_ti.get_logical_size(), cgen_state_->context_);
36  } else {
37  is_real_str = true;
38  case_llvm_type = createStringViewStructType();
39  }
40  } else if (case_ti.is_boolean()) {
41  case_llvm_type = get_int_type(8 * case_ti.get_logical_size(), cgen_state_->context_);
42  } else if (case_ti.is_geometry()) {
43  throw std::runtime_error(
44  "Geospatial column projections are currently not supported in conditional "
45  "expressions.");
46  } else if (case_ti.is_array()) {
47  throw std::runtime_error(
48  "Array column projections are currently not supported in conditional "
49  "expressions.");
50  }
51  CHECK(case_llvm_type);
52  const auto& else_ti = case_expr->get_else_expr()->get_type_info();
53  CHECK_EQ(else_ti.get_type(), case_ti.get_type());
54  llvm::Value* case_val = codegenCase(case_expr, case_llvm_type, is_real_str, co);
55  std::vector<llvm::Value*> ret_vals{case_val};
56  if (is_real_str) {
57  ret_vals.push_back(cgen_state_->ir_builder_.CreateExtractValue(case_val, 0));
58  ret_vals.push_back(cgen_state_->ir_builder_.CreateExtractValue(case_val, 1));
59  ret_vals.back() = cgen_state_->ir_builder_.CreateTrunc(
60  ret_vals.back(), llvm::Type::getInt32Ty(cgen_state_->context_));
61  }
62  return ret_vals;
63 }
64 
65 llvm::Value* CodeGenerator::codegenCase(const Analyzer::CaseExpr* case_expr,
66  llvm::Type* case_llvm_type,
67  const bool is_real_str,
68  const CompilationOptions& co) {
70  // Here the linear control flow will diverge and expressions cached during the
71  // code branch code generation (currently just column decoding) are not going
72  // to be available once we're done generating the case. Take a snapshot of
73  // the cache with FetchCacheAnchor and restore it once we're done with CASE.
75  const auto& expr_pair_list = case_expr->get_expr_pair_list();
76  std::vector<llvm::Value*> then_lvs;
77  std::vector<llvm::BasicBlock*> then_bbs;
78  const auto end_bb = llvm::BasicBlock::Create(
80  for (const auto& expr_pair : expr_pair_list) {
82  const auto when_lv = toBool(codegen(expr_pair.first.get(), true, co).front());
83  const auto cmp_bb = cgen_state_->ir_builder_.GetInsertBlock();
84  const auto then_bb = llvm::BasicBlock::Create(cgen_state_->context_,
85  "then_case",
87  /*insert_before=*/end_bb);
88  cgen_state_->ir_builder_.SetInsertPoint(then_bb);
89  auto then_bb_lvs = codegen(expr_pair.second.get(), true, co);
90  if (is_real_str) {
91  if (then_bb_lvs.size() == 3) {
92  then_lvs.push_back(
93  cgen_state_->emitCall("string_pack", {then_bb_lvs[1], then_bb_lvs[2]}));
94  } else {
95  then_lvs.push_back(then_bb_lvs.front());
96  }
97  } else {
98  CHECK_EQ(size_t(1), then_bb_lvs.size());
99  then_lvs.push_back(then_bb_lvs.front());
100  }
101  then_bbs.push_back(cgen_state_->ir_builder_.GetInsertBlock());
102  cgen_state_->ir_builder_.CreateBr(end_bb);
103  const auto when_bb = llvm::BasicBlock::Create(
105  cgen_state_->ir_builder_.SetInsertPoint(cmp_bb);
106  cgen_state_->ir_builder_.CreateCondBr(when_lv, then_bb, when_bb);
107  cgen_state_->ir_builder_.SetInsertPoint(when_bb);
108  }
109  const auto else_expr = case_expr->get_else_expr();
110  CHECK(else_expr);
111  auto else_lvs = codegen(else_expr, true, co);
112  llvm::Value* else_lv{nullptr};
113  if (else_lvs.size() == 3) {
114  else_lv = cgen_state_->emitCall("string_pack", {else_lvs[1], else_lvs[2]});
115  } else {
116  else_lv = else_lvs.front();
117  }
118  CHECK(else_lv);
119  auto else_bb = cgen_state_->ir_builder_.GetInsertBlock();
120  cgen_state_->ir_builder_.CreateBr(end_bb);
121  cgen_state_->ir_builder_.SetInsertPoint(end_bb);
122  auto then_phi =
123  cgen_state_->ir_builder_.CreatePHI(case_llvm_type, expr_pair_list.size() + 1);
124  CHECK_EQ(then_bbs.size(), then_lvs.size());
125  for (size_t i = 0; i < then_bbs.size(); ++i) {
126  then_phi->addIncoming(then_lvs[i], then_bbs[i]);
127  }
128  then_phi->addIncoming(else_lv, else_bb);
129  return then_phi;
130 }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
const Expr * get_else_expr() const
Definition: Analyzer.h:1387
CgenState * cgen_state_
llvm::IRBuilder ir_builder_
Definition: CgenState.h:384
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
size_t get_bit_width(const SQLTypeInfo &ti)
llvm::LLVMContext & context_
Definition: CgenState.h:382
llvm::Function * current_func_
Definition: CgenState.h:376
#define AUTOMATIC_IR_METADATA(CGENSTATE)
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:79
llvm::Value * emitCall(const std::string &fname, const std::vector< llvm::Value * > &args)
Definition: CgenState.cpp:217
std::vector< llvm::Value * > codegen(const Analyzer::Expr *, const bool fetch_columns, const CompilationOptions &)
Definition: IRCodegen.cpp:30
llvm::StructType * createStringViewStructType()
llvm::Value * codegenCase(const Analyzer::CaseExpr *, llvm::Type *case_llvm_type, const bool is_real_str, const CompilationOptions &)
Definition: CaseIR.cpp:65
llvm::Value * toBool(llvm::Value *)
Definition: LogicalIR.cpp:343
#define CHECK(condition)
Definition: Logger.h:291
const std::list< std::pair< std::shared_ptr< Analyzer::Expr >, std::shared_ptr< Analyzer::Expr > > > & get_expr_pair_list() const
Definition: Analyzer.h:1384