OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Centroid.h
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
20 
21 namespace spatial_type {
22 
23 class Centroid : public Codegen {
24  public:
25  Centroid(const Analyzer::GeoOperator* geo_operator) : Codegen(geo_operator) {
26  CHECK_EQ(operator_->size(), size_t(1));
27  const auto& ti = operator_->get_type_info();
28  is_nullable_ = !ti.get_notnull();
29  }
30 
31  size_t size() const final { return 1; }
32 
33  SQLTypeInfo getNullType() const final { return SQLTypeInfo(kINT); }
34 
35  std::tuple<std::vector<llvm::Value*>, llvm::Value*> codegenLoads(
36  const std::vector<llvm::Value*>& arg_lvs,
37  const std::vector<llvm::Value*>& pos_lvs,
38  CgenState* cgen_state) final {
39  CHECK_EQ(pos_lvs.size(), size());
40  const auto operand = getOperand(0);
41  CHECK(operand);
42  const auto& operand_ti = operand->get_type_info();
43 
44  std::string size_fn_name = "array_size";
45  if (is_nullable_) {
46  size_fn_name += "_nullable";
47  }
48 
49  auto& builder = cgen_state->ir_builder_;
50 
51  std::vector<llvm::Value*> operand_lvs;
52  // iterate over column inputs
53  if (dynamic_cast<const Analyzer::ColumnVar*>(operand)) {
54  for (size_t i = 0; i < arg_lvs.size(); i++) {
55  auto lv = arg_lvs[i];
56  auto array_buff_lv =
57  cgen_state->emitExternalCall("array_buff",
58  llvm::Type::getInt8PtrTy(cgen_state->context_),
59  {lv, pos_lvs.front()});
60  auto const is_coords = (i == 0);
61  if (!is_coords) {
62  array_buff_lv = builder.CreateBitCast(
63  array_buff_lv, llvm::Type::getInt32PtrTy(cgen_state->context_));
64  }
65  operand_lvs.push_back(array_buff_lv);
66  const auto ptr_type = llvm::dyn_cast_or_null<llvm::PointerType>(lv->getType());
67  CHECK(ptr_type);
68  const auto elem_type = ptr_type->getPointerElementType();
69  CHECK(elem_type);
70  auto const shift = log2_bytes(is_coords ? 1 : 4);
71  std::vector<llvm::Value*> array_sz_args{
72  lv, pos_lvs.front(), cgen_state->llInt(shift)};
73  if (is_nullable_) { // TODO: should we do this for all arguments, or just points?
74  array_sz_args.push_back(
75  cgen_state->llInt(static_cast<int32_t>(inline_int_null_value<int32_t>())));
76  }
77  operand_lvs.push_back(cgen_state->emitExternalCall(
78  size_fn_name, get_int_type(32, cgen_state->context_), array_sz_args));
79  }
80  } else {
81  for (size_t i = 0; i < arg_lvs.size(); i++) {
82  auto arg_lv = arg_lvs[i];
83  if (i > 0 && arg_lv->getType()->isPointerTy()) {
84  arg_lv = builder.CreateBitCast(arg_lv,
85  llvm::Type::getInt32PtrTy(cgen_state->context_));
86  }
87  operand_lvs.push_back(arg_lv);
88  }
89  }
90  CHECK_EQ(operand_lvs.size(),
91  size_t(2 * operand_ti.get_physical_coord_cols())); // array ptr and size
92 
93  // note that this block is the only one that differs from Area/Perimeter
94  // use the points array size argument for nullability
95  llvm::Value* null_check_operand_lv{nullptr};
96  if (is_nullable_) {
97  null_check_operand_lv = operand_lvs[1];
98  if (null_check_operand_lv->getType() !=
99  llvm::Type::getInt32Ty(cgen_state->context_)) {
100  CHECK(null_check_operand_lv->getType() ==
101  llvm::Type::getInt64Ty(cgen_state->context_));
102  // Geos functions come out 64-bit, cast down to 32 for now
103 
104  null_check_operand_lv = builder.CreateTrunc(
105  null_check_operand_lv, llvm::Type::getInt32Ty(cgen_state->context_));
106  }
107  }
108 
109  return std::make_tuple(operand_lvs, null_check_operand_lv);
110  }
111 
112  std::vector<llvm::Value*> codegen(const std::vector<llvm::Value*>& args,
113  CodeGenerator::NullCheckCodegen* nullcheck_codegen,
114  CgenState* cgen_state,
115  const CompilationOptions& co) final {
116  std::string func_name = "ST_Centroid";
117  const auto& ret_ti = operator_->get_type_info();
118  CHECK(ret_ti.is_geometry() && ret_ti.get_type() == kPOINT);
119  const auto& operand_ti = getOperand(0)->get_type_info();
120 
121  auto& builder = cgen_state->ir_builder_;
122 
123  // Allocate local storage for centroid point
124  auto elem_ty = llvm::Type::getDoubleTy(cgen_state->context_);
125  llvm::ArrayType* arr_type = llvm::ArrayType::get(elem_ty, 2);
126  auto pt_local_storage_lv =
127  builder.CreateAlloca(arr_type, nullptr, func_name + "_Local_Storage");
128 
129  llvm::Value* pt_compressed_local_storage_lv{NULL};
130  // Allocate local storage for compressed centroid point
131  if (ret_ti.get_compression() == kENCODING_GEOINT) {
132  auto elem_ty = llvm::Type::getInt32Ty(cgen_state->context_);
133  llvm::ArrayType* arr_type = llvm::ArrayType::get(elem_ty, 2);
134  pt_compressed_local_storage_lv = builder.CreateAlloca(
135  arr_type, nullptr, func_name + "_Compressed_Local_Storage");
136  }
137 
138  func_name += spatial_type::suffix(operand_ti.get_type());
139 
140  auto operand_lvs = args;
141 
142  // push back ic, isr, osr for now
143  operand_lvs.push_back(
144  cgen_state->llInt(Geospatial::get_compression_scheme(operand_ti))); // ic
145  operand_lvs.push_back(cgen_state->llInt(operand_ti.get_input_srid())); // in srid
146  auto output_srid = operand_ti.get_output_srid();
147  if (const auto srid_override = operator_->getOutputSridOverride()) {
148  output_srid = *srid_override;
149  }
150  operand_lvs.push_back(cgen_state->llInt(output_srid)); // out srid
151 
152  auto idx_lv = cgen_state->llInt(0);
153  auto pt_local_storage_gep = llvm::GetElementPtrInst::CreateInBounds(
154  pt_local_storage_lv->getType()->getScalarType()->getPointerElementType(),
155  pt_local_storage_lv,
156  {idx_lv, idx_lv},
157  "",
158  builder.GetInsertBlock());
159  // Pass local storage to centroid function
160  operand_lvs.push_back(pt_local_storage_gep);
161  CHECK(ret_ti.get_type() == kPOINT);
162  cgen_state->emitExternalCall(
163  func_name, llvm::Type::getVoidTy(cgen_state->context_), operand_lvs);
164 
165  llvm::Value* ret_coords = pt_local_storage_lv;
166  if (ret_ti.get_compression() == kENCODING_GEOINT) {
167  // Compress centroid point if requested
168  // Take values out of local storage, compress, store in compressed local storage
169 
170  auto x_ptr = builder.CreateGEP(
171  pt_local_storage_lv->getType()->getScalarType()->getPointerElementType(),
172  pt_local_storage_lv,
173  {cgen_state->llInt(0), cgen_state->llInt(0)},
174  "x_ptr");
175  auto x_lv = builder.CreateLoad(x_ptr->getType()->getPointerElementType(), x_ptr);
176  auto compressed_x_lv =
177  cgen_state->emitExternalCall("compress_x_coord_geoint",
178  llvm::Type::getInt32Ty(cgen_state->context_),
179  {x_lv});
180  auto compressed_x_ptr =
181  builder.CreateGEP(pt_compressed_local_storage_lv->getType()
182  ->getScalarType()
183  ->getPointerElementType(),
184  pt_compressed_local_storage_lv,
185  {cgen_state->llInt(0), cgen_state->llInt(0)},
186  "compressed_x_ptr");
187  builder.CreateStore(compressed_x_lv, compressed_x_ptr);
188 
189  auto y_ptr = builder.CreateGEP(
190  pt_local_storage_lv->getType()->getScalarType()->getPointerElementType(),
191  pt_local_storage_lv,
192  {cgen_state->llInt(0), cgen_state->llInt(1)},
193  "y_ptr");
194  auto y_lv = builder.CreateLoad(y_ptr->getType()->getPointerElementType(), y_ptr);
195  auto compressed_y_lv =
196  cgen_state->emitExternalCall("compress_y_coord_geoint",
197  llvm::Type::getInt32Ty(cgen_state->context_),
198  {y_lv});
199  auto compressed_y_ptr =
200  builder.CreateGEP(pt_compressed_local_storage_lv->getType()
201  ->getScalarType()
202  ->getPointerElementType(),
203  pt_compressed_local_storage_lv,
204  {cgen_state->llInt(0), cgen_state->llInt(1)},
205  "compressed_y_ptr");
206  builder.CreateStore(compressed_y_lv, compressed_y_ptr);
207 
208  ret_coords = pt_compressed_local_storage_lv;
209  } else {
210  CHECK(ret_ti.get_compression() == kENCODING_NONE);
211  }
212 
213  auto ret_ty = ret_ti.get_compression() == kENCODING_GEOINT
214  ? llvm::Type::getInt32PtrTy(cgen_state->context_)
215  : llvm::Type::getDoublePtrTy(cgen_state->context_);
216  ret_coords = builder.CreateBitCast(ret_coords, ret_ty);
217 
218  if (is_nullable_) {
219  CHECK(nullcheck_codegen);
220  ret_coords = nullcheck_codegen->finalize(
221  llvm::ConstantPointerNull::get(
222  ret_ti.get_compression() == kENCODING_GEOINT
223  ? llvm::PointerType::get(llvm::Type::getInt32Ty(cgen_state->context_),
224  0)
225  : llvm::PointerType::get(llvm::Type::getDoubleTy(cgen_state->context_),
226  0)),
227  ret_coords);
228  }
229 
230  return {ret_coords,
231  cgen_state->llInt(ret_ti.get_compression() == kENCODING_GEOINT ? 8 : 16)};
232  }
233 };
234 
235 } // namespace spatial_type
#define CHECK_EQ(x, y)
Definition: Logger.h:301
int32_t get_compression_scheme(const SQLTypeInfo &ti)
Definition: Compression.cpp:23
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
std::vector< llvm::Value * > codegen(const std::vector< llvm::Value * > &args, CodeGenerator::NullCheckCodegen *nullcheck_codegen, CgenState *cgen_state, const CompilationOptions &co) final
Definition: Centroid.h:112
std::string suffix(SQLTypes type)
Definition: Codegen.cpp:69
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:79
Centroid(const Analyzer::GeoOperator *geo_operator)
Definition: Centroid.h:25
std::tuple< std::vector< llvm::Value * >, llvm::Value * > codegenLoads(const std::vector< llvm::Value * > &arg_lvs, const std::vector< llvm::Value * > &pos_lvs, CgenState *cgen_state) final
Definition: Centroid.h:35
const Analyzer::GeoOperator * operator_
Definition: Codegen.h:67
size_t size() const
Definition: Analyzer.cpp:4182
size_t size() const final
Definition: Centroid.h:31
#define CHECK(condition)
Definition: Logger.h:291
virtual const Analyzer::Expr * getOperand(const size_t index)
Definition: Codegen.cpp:64
uint32_t log2_bytes(const uint32_t bytes)
Definition: Execute.h:198
Definition: sqltypes.h:72
SQLTypeInfo getNullType() const final
Definition: Centroid.h:33