OmniSciDB  91042dcc5b
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Transform.h
Go to the documentation of this file.
1 /*
2  * Copyright 2021 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
20 namespace spatial_type {
21 
22 // ST_Transform
23 class Transform : public Codegen {
24  public:
25  Transform(const Analyzer::GeoOperator* geo_operator,
26  const Catalog_Namespace::Catalog* catalog)
27  : Codegen(geo_operator, catalog)
29  dynamic_cast<const Analyzer::GeoTransformOperator*>(geo_operator)) {
30  CHECK_EQ(operator_->size(), size_t(1)); // geo input expr
32  const auto& ti = geo_operator->get_type_info();
33  if (ti.get_notnull()) {
34  is_nullable_ = false;
35  } else {
36  is_nullable_ = true;
37  }
38  }
39 
40  size_t size() const override { return 1; }
41 
42  SQLTypeInfo getNullType() const override { return SQLTypeInfo(kBOOLEAN); }
43 
44  inline static bool isUtm(unsigned const srid) {
45  return (32601 <= srid && srid <= 32660) || (32701 <= srid && srid <= 32760);
46  }
47 
48  std::tuple<std::vector<llvm::Value*>, llvm::Value*> codegenLoads(
49  const std::vector<llvm::Value*>& arg_lvs,
50  const std::vector<llvm::Value*>& pos_lvs,
51  CgenState* cgen_state) override {
52  CHECK_EQ(pos_lvs.size(), size());
53  const auto geo_operand = getOperand(0);
54  const auto& operand_ti = geo_operand->get_type_info();
55  CHECK(operand_ti.is_geometry() && operand_ti.get_type() == kPOINT);
56 
57  if (dynamic_cast<const Analyzer::ColumnVar*>(geo_operand)) {
58  CHECK_EQ(arg_lvs.size(), size_t(1)); // col_byte_stream
60  arg_lvs.front(), pos_lvs.front(), operand_ti, cgen_state);
61  return std::make_tuple(std::vector<llvm::Value*>{arr_load_lvs.buffer},
62  arr_load_lvs.is_null);
63  } else if (dynamic_cast<const Analyzer::GeoConstant*>(geo_operand)) {
64  CHECK_EQ(arg_lvs.size(), size_t(2)); // ptr, size
65 
66  // nulls not supported, and likely compressed, so require a new buffer for the
67  // transformation
69  return std::make_tuple(std::vector<llvm::Value*>{arg_lvs.front()}, nullptr);
70  } else {
71  CHECK(arg_lvs.size() == size_t(1) ||
72  arg_lvs.size() == size_t(2)); // ptr or ptr, size
73  // coming from a temporary, can modify the memory pointer directly
75  auto& builder = cgen_state->ir_builder_;
76 
77  const auto is_null = builder.CreateICmp(
78  llvm::CmpInst::ICMP_EQ,
79  arg_lvs.front(),
80  llvm::ConstantPointerNull::get( // TODO: check ptr address space
81  operand_ti.get_compression() == kENCODING_GEOINT
82  ? llvm::Type::getInt32PtrTy(cgen_state->context_)
83  : llvm::Type::getDoublePtrTy(cgen_state->context_)));
84  return std::make_tuple(std::vector<llvm::Value*>{arg_lvs.front()}, is_null);
85  }
86  UNREACHABLE();
87  return std::make_tuple(std::vector<llvm::Value*>{}, nullptr);
88  }
89 
90  std::vector<llvm::Value*> codegen(const std::vector<llvm::Value*>& args,
91  CodeGenerator::NullCheckCodegen* nullcheck_codegen,
92  CgenState* cgen_state,
93  const CompilationOptions& co) override {
94  CHECK_EQ(args.size(), size_t(1));
95 
96  const auto geo_operand = getOperand(0);
97  const auto& operand_ti = geo_operand->get_type_info();
98  auto& builder = cgen_state->ir_builder_;
99 
100  llvm::Value* arr_buff_ptr = args.front();
101  if (operand_ti.get_compression() == kENCODING_GEOINT) {
102  // decompress
103  auto new_arr_ptr =
104  builder.CreateAlloca(llvm::Type::getDoubleTy(cgen_state->context_),
105  cgen_state->llInt(int32_t(2)),
106  getName() + "_Array");
107  auto compressed_arr_ptr = builder.CreateBitCast(
108  arr_buff_ptr, llvm::Type::getInt32PtrTy(cgen_state->context_));
109  // x coord
110  auto* gep = builder.CreateGEP(
111  compressed_arr_ptr->getType()->getScalarType()->getPointerElementType(),
112  compressed_arr_ptr,
113  cgen_state->llInt(0));
114  auto x_coord_lv = cgen_state->emitExternalCall(
115  "decompress_x_coord_geoint",
116  llvm::Type::getDoubleTy(cgen_state->context_),
117  {builder.CreateLoad(
118  gep->getType()->getPointerElementType(), gep, "compressed_x_coord")});
119  builder.CreateStore(
120  x_coord_lv,
121  builder.CreateGEP(
122  new_arr_ptr->getType()->getScalarType()->getPointerElementType(),
123  new_arr_ptr,
124  cgen_state->llInt(0)));
125  gep = builder.CreateGEP(
126  compressed_arr_ptr->getType()->getScalarType()->getPointerElementType(),
127  compressed_arr_ptr,
128  cgen_state->llInt(1));
129  auto y_coord_lv = cgen_state->emitExternalCall(
130  "decompress_y_coord_geoint",
131  llvm::Type::getDoubleTy(cgen_state->context_),
132  {builder.CreateLoad(
133  gep->getType()->getPointerElementType(), gep, "compressed_y_coord")});
134  builder.CreateStore(
135  y_coord_lv,
136  builder.CreateGEP(
137  new_arr_ptr->getType()->getScalarType()->getPointerElementType(),
138  new_arr_ptr,
139  cgen_state->llInt(1)));
140  arr_buff_ptr = new_arr_ptr;
141  } else if (!can_transform_in_place_) {
142  auto new_arr_ptr =
143  builder.CreateAlloca(llvm::Type::getDoubleTy(cgen_state->context_),
144  cgen_state->llInt(int32_t(2)),
145  getName() + "_Array");
146  const auto arr_buff_ptr_cast = builder.CreateBitCast(
147  arr_buff_ptr, llvm::Type::getDoublePtrTy(cgen_state->context_));
148 
149  auto* gep = builder.CreateGEP(
150  arr_buff_ptr_cast->getType()->getScalarType()->getPointerElementType(),
151  arr_buff_ptr_cast,
152  cgen_state->llInt(0));
153  builder.CreateStore(
154  builder.CreateLoad(gep->getType()->getPointerElementType(), gep),
155  builder.CreateGEP(
156  new_arr_ptr->getType()->getScalarType()->getPointerElementType(),
157  new_arr_ptr,
158  cgen_state->llInt(0)));
159  gep = builder.CreateGEP(
160  arr_buff_ptr_cast->getType()->getScalarType()->getPointerElementType(),
161  arr_buff_ptr_cast,
162  cgen_state->llInt(1));
163  builder.CreateStore(
164  builder.CreateLoad(gep->getType()->getPointerElementType(), gep),
165  builder.CreateGEP(
166  new_arr_ptr->getType()->getScalarType()->getPointerElementType(),
167  new_arr_ptr,
168  cgen_state->llInt(1)));
169  arr_buff_ptr = new_arr_ptr;
170  }
171  CHECK(arr_buff_ptr->getType() == llvm::Type::getDoublePtrTy(cgen_state->context_));
172 
173  auto const srid_in = static_cast<unsigned>(transform_operator_->getInputSRID());
174  auto const srid_out = static_cast<unsigned>(transform_operator_->getOutputSRID());
175  if (srid_in == srid_out) {
176  // noop
177  return {args.front()};
178  }
179 
180  // transform in place
181  std::string transform_function_prefix{""};
182  std::vector<llvm::Value*> transform_args;
183 
184  if (srid_out == 900913) {
185  if (srid_in == 4326) {
186  transform_function_prefix = "transform_4326_900913_";
187  } else if (isUtm(srid_in)) {
188  transform_function_prefix = "transform_utm_900913_";
189  transform_args.push_back(cgen_state->llInt(srid_in));
190  } else {
191  throw std::runtime_error("Unsupported input SRID " + std::to_string(srid_in) +
192  " for output SRID " + std::to_string(srid_out));
193  }
194  } else if (srid_out == 4326) {
195  if (srid_in == 900913) {
196  transform_function_prefix = "transform_900913_4326_";
197  } else if (isUtm(srid_in)) {
198  transform_function_prefix = "transform_utm_4326_";
199  transform_args.push_back(cgen_state->llInt(srid_in));
200  } else {
201  throw std::runtime_error("Unsupported input SRID " + std::to_string(srid_in) +
202  " for output SRID " + std::to_string(srid_out));
203  }
204  } else if (isUtm(srid_out)) {
205  if (srid_in == 4326) {
206  transform_function_prefix = "transform_4326_utm_";
207  } else if (srid_in == 900913) {
208  transform_function_prefix = "transform_900913_utm_";
209  } else {
210  throw std::runtime_error("Unsupported input SRID " + std::to_string(srid_in) +
211  " for output SRID " + std::to_string(srid_out));
212  }
213  transform_args.push_back(cgen_state->llInt(srid_out));
214  } else {
215  throw std::runtime_error("Unsupported output SRID for ST_Transform: " +
216  std::to_string(srid_out));
217  }
218  CHECK(!transform_function_prefix.empty());
219 
220  auto x_coord_ptr_lv = builder.CreateGEP(
221  arr_buff_ptr->getType()->getScalarType()->getPointerElementType(),
222  arr_buff_ptr,
223  cgen_state->llInt(0),
224  "x_coord_ptr");
225  transform_args.push_back(builder.CreateLoad(
226  x_coord_ptr_lv->getType()->getPointerElementType(), x_coord_ptr_lv, "x_coord"));
227  auto y_coord_ptr_lv = builder.CreateGEP(
228  arr_buff_ptr->getType()->getScalarType()->getPointerElementType(),
229  arr_buff_ptr,
230  cgen_state->llInt(1),
231  "y_coord_ptr");
232  transform_args.push_back(builder.CreateLoad(
233  y_coord_ptr_lv->getType()->getPointerElementType(), y_coord_ptr_lv, "y_coord"));
235  auto fn_x = cgen_state->module_->getFunction(transform_function_prefix + 'x');
236  CHECK(fn_x);
237  cgen_state->maybeCloneFunctionRecursive(fn_x);
238  CHECK(!fn_x->isDeclaration());
239 
240  auto gpu_functions_to_replace = cgen_state->gpuFunctionsToReplace(fn_x);
241  for (const auto& fcn_name : gpu_functions_to_replace) {
242  cgen_state->replaceFunctionForGpu(fcn_name, fn_x);
243  }
244  verify_function_ir(fn_x);
245  auto transform_call = builder.CreateCall(fn_x, transform_args);
246  builder.CreateStore(transform_call, x_coord_ptr_lv);
247 
248  auto fn_y = cgen_state->module_->getFunction(transform_function_prefix + 'y');
249  CHECK(fn_y);
250  cgen_state->maybeCloneFunctionRecursive(fn_y);
251  CHECK(!fn_y->isDeclaration());
252 
253  gpu_functions_to_replace = cgen_state->gpuFunctionsToReplace(fn_y);
254  for (const auto& fcn_name : gpu_functions_to_replace) {
255  cgen_state->replaceFunctionForGpu(fcn_name, fn_y);
256  }
257  verify_function_ir(fn_y);
258  transform_call = builder.CreateCall(fn_y, transform_args);
259  builder.CreateStore(transform_call, y_coord_ptr_lv);
260  } else {
261  builder.CreateStore(
262  cgen_state->emitCall(transform_function_prefix + 'x', transform_args),
263  x_coord_ptr_lv);
264  builder.CreateStore(
265  cgen_state->emitCall(transform_function_prefix + 'y', transform_args),
266  y_coord_ptr_lv);
267  }
268  auto ret = arr_buff_ptr;
269  const auto& geo_ti = transform_operator_->get_type_info();
270 
271  if (is_nullable_) {
272  CHECK(nullcheck_codegen);
273  ret = nullcheck_codegen->finalize(
274  llvm::ConstantPointerNull::get(
275  geo_ti.get_compression() == kENCODING_GEOINT
276  ? llvm::PointerType::get(llvm::Type::getInt32Ty(cgen_state->context_),
277  0)
278  : llvm::PointerType::get(llvm::Type::getDoubleTy(cgen_state->context_),
279  0)),
280  ret);
281  }
282  return {ret,
283  cgen_state->llInt(static_cast<int32_t>(
284  geo_ti.get_compression() == kENCODING_GEOINT ? 8 : 16))};
285  }
286 
287  private:
290 };
291 
292 } // namespace spatial_type
std::vector< llvm::Value * > codegen(const std::vector< llvm::Value * > &args, CodeGenerator::NullCheckCodegen *nullcheck_codegen, CgenState *cgen_state, const CompilationOptions &co) override
Definition: Transform.h:90
#define CHECK_EQ(x, y)
Definition: Logger.h:219
class for a per-database catalog. also includes metadata for the current database and the current use...
Definition: Catalog.h:114
int32_t getInputSRID() const
Definition: Analyzer.h:1965
#define const
void maybeCloneFunctionRecursive(llvm::Function *fn)
Definition: CgenState.cpp:179
llvm::Value * emitExternalCall(const std::string &fname, llvm::Type *ret_type, const std::vector< llvm::Value * > args, const std::vector< llvm::Attribute::AttrKind > &fnattrs={}, const bool has_struct_return=false)
Definition: CgenState.h:210
llvm::IRBuilder ir_builder_
Definition: CgenState.h:354
Transform(const Analyzer::GeoOperator *geo_operator, const Catalog_Namespace::Catalog *catalog)
Definition: Transform.h:25
#define UNREACHABLE()
Definition: Logger.h:255
std::string to_string(char const *&&v)
llvm::Module * module_
Definition: CgenState.h:343
void verify_function_ir(const llvm::Function *func)
llvm::LLVMContext & context_
Definition: CgenState.h:352
CONSTEXPR DEVICE bool is_null(const T &value)
void replaceFunctionForGpu(const std::string &fcn_to_replace, llvm::Function *fn)
Definition: CgenState.cpp:315
std::vector< std::string > gpuFunctionsToReplace(llvm::Function *fn)
Definition: CgenState.cpp:292
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:77
llvm::Value * emitCall(const std::string &fname, const std::vector< llvm::Value * > &args)
Definition: CgenState.cpp:215
int32_t getOutputSRID() const
Definition: Analyzer.h:1967
ExecutorDeviceType device_type
size_t size() const override
Definition: Transform.h:40
SQLTypeInfo getNullType() const override
Definition: Transform.h:42
static ArrayLoadCodegen codegenGeoArrayLoadAndNullcheck(llvm::Value *byte_stream, llvm::Value *pos, const SQLTypeInfo &ti, CgenState *cgen_state)
Definition: GeoIR.cpp:23
static bool isUtm(unsigned const srid)
Definition: Transform.h:44
const Analyzer::GeoOperator * operator_
Definition: Codegen.h:70
size_t size() const
Definition: Analyzer.cpp:3644
std::tuple< std::vector< llvm::Value * >, llvm::Value * > codegenLoads(const std::vector< llvm::Value * > &arg_lvs, const std::vector< llvm::Value * > &pos_lvs, CgenState *cgen_state) override
Definition: Transform.h:48
llvm::ConstantInt * llInt(const T v) const
Definition: CgenState.h:289
const Analyzer::GeoTransformOperator * transform_operator_
Definition: Transform.h:288
llvm::Value * finalize(llvm::Value *null_lv, llvm::Value *notnull_lv)
Definition: IRCodegen.cpp:1452
#define CHECK(condition)
Definition: Logger.h:211
virtual const Analyzer::Expr * getOperand(const size_t index)
Definition: Codegen.cpp:63
std::string getName() const
Definition: Codegen.h:39