OmniSciDB  b28c0d5765
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Transform.h
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
19 namespace spatial_type {
20 
21 // ST_Transform
22 class Transform : public Codegen {
23  public:
24  Transform(const Analyzer::GeoOperator* geo_operator,
25  const Catalog_Namespace::Catalog* catalog)
26  : Codegen(geo_operator, catalog)
28  dynamic_cast<const Analyzer::GeoTransformOperator*>(geo_operator)) {
29  CHECK_EQ(operator_->size(), size_t(1)); // geo input expr
31  const auto& ti = geo_operator->get_type_info();
32  if (ti.get_notnull()) {
33  is_nullable_ = false;
34  } else {
35  is_nullable_ = true;
36  }
37  }
38 
39  size_t size() const override { return 1; }
40 
41  SQLTypeInfo getNullType() const override { return SQLTypeInfo(kBOOLEAN); }
42 
43  inline static bool isUtm(unsigned const srid) {
44  return (32601 <= srid && srid <= 32660) || (32701 <= srid && srid <= 32760);
45  }
46 
47  std::tuple<std::vector<llvm::Value*>, llvm::Value*> codegenLoads(
48  const std::vector<llvm::Value*>& arg_lvs,
49  const std::vector<llvm::Value*>& pos_lvs,
50  CgenState* cgen_state) override {
51  CHECK_EQ(pos_lvs.size(), size());
52  const auto geo_operand = getOperand(0);
53  const auto& operand_ti = geo_operand->get_type_info();
54  CHECK(operand_ti.is_geometry() && operand_ti.get_type() == kPOINT);
55 
56  if (dynamic_cast<const Analyzer::ColumnVar*>(geo_operand)) {
57  CHECK_EQ(arg_lvs.size(), size_t(1)); // col_byte_stream
59  arg_lvs.front(), pos_lvs.front(), operand_ti, cgen_state);
60  return std::make_tuple(std::vector<llvm::Value*>{arr_load_lvs.buffer},
61  arr_load_lvs.is_null);
62  } else if (dynamic_cast<const Analyzer::GeoConstant*>(geo_operand)) {
63  CHECK_EQ(arg_lvs.size(), size_t(2)); // ptr, size
64 
65  // nulls not supported, and likely compressed, so require a new buffer for the
66  // transformation
68  return std::make_tuple(std::vector<llvm::Value*>{arg_lvs.front()}, nullptr);
69  } else {
70  CHECK(arg_lvs.size() == size_t(1) ||
71  arg_lvs.size() == size_t(2)); // ptr or ptr, size
72  // coming from a temporary, can modify the memory pointer directly
74  auto& builder = cgen_state->ir_builder_;
75 
76  const auto is_null = builder.CreateICmp(
77  llvm::CmpInst::ICMP_EQ,
78  arg_lvs.front(),
79  llvm::ConstantPointerNull::get( // TODO: check ptr address space
80  operand_ti.get_compression() == kENCODING_GEOINT
81  ? llvm::Type::getInt32PtrTy(cgen_state->context_)
82  : llvm::Type::getDoublePtrTy(cgen_state->context_)));
83  return std::make_tuple(std::vector<llvm::Value*>{arg_lvs.front()}, is_null);
84  }
85  UNREACHABLE();
86  return std::make_tuple(std::vector<llvm::Value*>{}, nullptr);
87  }
88 
89  std::vector<llvm::Value*> codegen(const std::vector<llvm::Value*>& args,
90  CodeGenerator::NullCheckCodegen* nullcheck_codegen,
91  CgenState* cgen_state,
92  const CompilationOptions& co) override {
93  CHECK_EQ(args.size(), size_t(1));
94 
95  const auto geo_operand = getOperand(0);
96  const auto& operand_ti = geo_operand->get_type_info();
97  auto& builder = cgen_state->ir_builder_;
98 
99  llvm::Value* arr_buff_ptr = args.front();
100  if (operand_ti.get_compression() == kENCODING_GEOINT) {
101  // decompress
102  auto new_arr_ptr =
103  builder.CreateAlloca(llvm::Type::getDoubleTy(cgen_state->context_),
104  cgen_state->llInt(int32_t(2)),
105  getName() + "_Array");
106  auto compressed_arr_ptr = builder.CreateBitCast(
107  arr_buff_ptr, llvm::Type::getInt32PtrTy(cgen_state->context_));
108  // x coord
109  auto* gep = builder.CreateGEP(
110  compressed_arr_ptr->getType()->getScalarType()->getPointerElementType(),
111  compressed_arr_ptr,
112  cgen_state->llInt(0));
113  auto x_coord_lv = cgen_state->emitExternalCall(
114  "decompress_x_coord_geoint",
115  llvm::Type::getDoubleTy(cgen_state->context_),
116  {builder.CreateLoad(
117  gep->getType()->getPointerElementType(), gep, "compressed_x_coord")});
118  builder.CreateStore(
119  x_coord_lv,
120  builder.CreateGEP(
121  new_arr_ptr->getType()->getScalarType()->getPointerElementType(),
122  new_arr_ptr,
123  cgen_state->llInt(0)));
124  gep = builder.CreateGEP(
125  compressed_arr_ptr->getType()->getScalarType()->getPointerElementType(),
126  compressed_arr_ptr,
127  cgen_state->llInt(1));
128  auto y_coord_lv = cgen_state->emitExternalCall(
129  "decompress_y_coord_geoint",
130  llvm::Type::getDoubleTy(cgen_state->context_),
131  {builder.CreateLoad(
132  gep->getType()->getPointerElementType(), gep, "compressed_y_coord")});
133  builder.CreateStore(
134  y_coord_lv,
135  builder.CreateGEP(
136  new_arr_ptr->getType()->getScalarType()->getPointerElementType(),
137  new_arr_ptr,
138  cgen_state->llInt(1)));
139  arr_buff_ptr = new_arr_ptr;
140  } else if (!can_transform_in_place_) {
141  auto new_arr_ptr =
142  builder.CreateAlloca(llvm::Type::getDoubleTy(cgen_state->context_),
143  cgen_state->llInt(int32_t(2)),
144  getName() + "_Array");
145  const auto arr_buff_ptr_cast = builder.CreateBitCast(
146  arr_buff_ptr, llvm::Type::getDoublePtrTy(cgen_state->context_));
147 
148  auto* gep = builder.CreateGEP(
149  arr_buff_ptr_cast->getType()->getScalarType()->getPointerElementType(),
150  arr_buff_ptr_cast,
151  cgen_state->llInt(0));
152  builder.CreateStore(
153  builder.CreateLoad(gep->getType()->getPointerElementType(), gep),
154  builder.CreateGEP(
155  new_arr_ptr->getType()->getScalarType()->getPointerElementType(),
156  new_arr_ptr,
157  cgen_state->llInt(0)));
158  gep = builder.CreateGEP(
159  arr_buff_ptr_cast->getType()->getScalarType()->getPointerElementType(),
160  arr_buff_ptr_cast,
161  cgen_state->llInt(1));
162  builder.CreateStore(
163  builder.CreateLoad(gep->getType()->getPointerElementType(), gep),
164  builder.CreateGEP(
165  new_arr_ptr->getType()->getScalarType()->getPointerElementType(),
166  new_arr_ptr,
167  cgen_state->llInt(1)));
168  arr_buff_ptr = new_arr_ptr;
169  }
170  CHECK(arr_buff_ptr->getType() == llvm::Type::getDoublePtrTy(cgen_state->context_));
171 
172  auto const srid_in = static_cast<unsigned>(transform_operator_->getInputSRID());
173  auto const srid_out = static_cast<unsigned>(transform_operator_->getOutputSRID());
174  if (srid_in == srid_out) {
175  // noop
176  return {args.front()};
177  }
178 
179  // transform in place
180  std::string transform_function_prefix{""};
181  std::vector<llvm::Value*> transform_args;
182 
183  if (srid_out == 900913) {
184  if (srid_in == 4326) {
185  transform_function_prefix = "transform_4326_900913_";
186  } else if (isUtm(srid_in)) {
187  transform_function_prefix = "transform_utm_900913_";
188  transform_args.push_back(cgen_state->llInt(srid_in));
189  } else {
190  throw std::runtime_error("Unsupported input SRID " + std::to_string(srid_in) +
191  " for output SRID " + std::to_string(srid_out));
192  }
193  } else if (srid_out == 4326) {
194  if (srid_in == 900913) {
195  transform_function_prefix = "transform_900913_4326_";
196  } else if (isUtm(srid_in)) {
197  transform_function_prefix = "transform_utm_4326_";
198  transform_args.push_back(cgen_state->llInt(srid_in));
199  } else {
200  throw std::runtime_error("Unsupported input SRID " + std::to_string(srid_in) +
201  " for output SRID " + std::to_string(srid_out));
202  }
203  } else if (isUtm(srid_out)) {
204  if (srid_in == 4326) {
205  transform_function_prefix = "transform_4326_utm_";
206  } else if (srid_in == 900913) {
207  transform_function_prefix = "transform_900913_utm_";
208  } else {
209  throw std::runtime_error("Unsupported input SRID " + std::to_string(srid_in) +
210  " for output SRID " + std::to_string(srid_out));
211  }
212  transform_args.push_back(cgen_state->llInt(srid_out));
213  } else {
214  throw std::runtime_error("Unsupported output SRID for ST_Transform: " +
215  std::to_string(srid_out));
216  }
217  CHECK(!transform_function_prefix.empty());
218 
219  auto x_coord_ptr_lv = builder.CreateGEP(
220  arr_buff_ptr->getType()->getScalarType()->getPointerElementType(),
221  arr_buff_ptr,
222  cgen_state->llInt(0),
223  "x_coord_ptr");
224  transform_args.push_back(builder.CreateLoad(
225  x_coord_ptr_lv->getType()->getPointerElementType(), x_coord_ptr_lv, "x_coord"));
226  auto y_coord_ptr_lv = builder.CreateGEP(
227  arr_buff_ptr->getType()->getScalarType()->getPointerElementType(),
228  arr_buff_ptr,
229  cgen_state->llInt(1),
230  "y_coord_ptr");
231  transform_args.push_back(builder.CreateLoad(
232  y_coord_ptr_lv->getType()->getPointerElementType(), y_coord_ptr_lv, "y_coord"));
234  auto fn_x = cgen_state->module_->getFunction(transform_function_prefix + 'x');
235  CHECK(fn_x);
236  cgen_state->maybeCloneFunctionRecursive(fn_x);
237  CHECK(!fn_x->isDeclaration());
238 
239  auto gpu_functions_to_replace = cgen_state->gpuFunctionsToReplace(fn_x);
240  for (const auto& fcn_name : gpu_functions_to_replace) {
241  cgen_state->replaceFunctionForGpu(fcn_name, fn_x);
242  }
243  verify_function_ir(fn_x);
244  auto transform_call = builder.CreateCall(fn_x, transform_args);
245  builder.CreateStore(transform_call, x_coord_ptr_lv);
246 
247  auto fn_y = cgen_state->module_->getFunction(transform_function_prefix + 'y');
248  CHECK(fn_y);
249  cgen_state->maybeCloneFunctionRecursive(fn_y);
250  CHECK(!fn_y->isDeclaration());
251 
252  gpu_functions_to_replace = cgen_state->gpuFunctionsToReplace(fn_y);
253  for (const auto& fcn_name : gpu_functions_to_replace) {
254  cgen_state->replaceFunctionForGpu(fcn_name, fn_y);
255  }
256  verify_function_ir(fn_y);
257  transform_call = builder.CreateCall(fn_y, transform_args);
258  builder.CreateStore(transform_call, y_coord_ptr_lv);
259  } else {
260  builder.CreateStore(
261  cgen_state->emitCall(transform_function_prefix + 'x', transform_args),
262  x_coord_ptr_lv);
263  builder.CreateStore(
264  cgen_state->emitCall(transform_function_prefix + 'y', transform_args),
265  y_coord_ptr_lv);
266  }
267  auto ret = arr_buff_ptr;
268  const auto& geo_ti = transform_operator_->get_type_info();
269 
270  if (is_nullable_) {
271  CHECK(nullcheck_codegen);
272  ret = nullcheck_codegen->finalize(
273  llvm::ConstantPointerNull::get(
274  geo_ti.get_compression() == kENCODING_GEOINT
275  ? llvm::PointerType::get(llvm::Type::getInt32Ty(cgen_state->context_),
276  0)
277  : llvm::PointerType::get(llvm::Type::getDoubleTy(cgen_state->context_),
278  0)),
279  ret);
280  }
281  return {ret,
282  cgen_state->llInt(static_cast<int32_t>(
283  geo_ti.get_compression() == kENCODING_GEOINT ? 8 : 16))};
284  }
285 
286  private:
289 };
290 
291 } // namespace spatial_type
std::vector< llvm::Value * > codegen(const std::vector< llvm::Value * > &args, CodeGenerator::NullCheckCodegen *nullcheck_codegen, CgenState *cgen_state, const CompilationOptions &co) override
Definition: Transform.h:89
#define CHECK_EQ(x, y)
Definition: Logger.h:230
class for a per-database catalog. also includes metadata for the current database and the current use...
Definition: Catalog.h:132
int32_t getInputSRID() const
Definition: Analyzer.h:2971
void maybeCloneFunctionRecursive(llvm::Function *fn)
Definition: CgenState.cpp:180
llvm::IRBuilder ir_builder_
Definition: CgenState.h:375
Transform(const Analyzer::GeoOperator *geo_operator, const Catalog_Namespace::Catalog *catalog)
Definition: Transform.h:24
#define UNREACHABLE()
Definition: Logger.h:266
std::string to_string(char const *&&v)
llvm::Module * module_
Definition: CgenState.h:364
void verify_function_ir(const llvm::Function *func)
llvm::LLVMContext & context_
Definition: CgenState.h:373
CONSTEXPR DEVICE bool is_null(const T &value)
llvm::Value * emitExternalCall(const std::string &fname, llvm::Type *ret_type, const std::vector< llvm::Value * > args, const std::vector< llvm::Attribute::AttrKind > &fnattrs={}, const bool has_struct_return=false)
Definition: CgenState.cpp:396
void replaceFunctionForGpu(const std::string &fcn_to_replace, llvm::Function *fn)
Definition: CgenState.cpp:328
std::vector< std::string > gpuFunctionsToReplace(llvm::Function *fn)
Definition: CgenState.cpp:305
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:82
llvm::Value * emitCall(const std::string &fname, const std::vector< llvm::Value * > &args)
Definition: CgenState.cpp:216
int32_t getOutputSRID() const
Definition: Analyzer.h:2973
ExecutorDeviceType device_type
size_t size() const override
Definition: Transform.h:39
SQLTypeInfo getNullType() const override
Definition: Transform.h:41
static ArrayLoadCodegen codegenGeoArrayLoadAndNullcheck(llvm::Value *byte_stream, llvm::Value *pos, const SQLTypeInfo &ti, CgenState *cgen_state)
Definition: GeoIR.cpp:23
static bool isUtm(unsigned const srid)
Definition: Transform.h:43
const Analyzer::GeoOperator * operator_
Definition: Codegen.h:70
size_t size() const
Definition: Analyzer.cpp:3950
std::tuple< std::vector< llvm::Value * >, llvm::Value * > codegenLoads(const std::vector< llvm::Value * > &arg_lvs, const std::vector< llvm::Value * > &pos_lvs, CgenState *cgen_state) override
Definition: Transform.h:47
llvm::ConstantInt * llInt(const T v) const
Definition: CgenState.h:240
const Analyzer::GeoTransformOperator * transform_operator_
Definition: Transform.h:287
llvm::Value * finalize(llvm::Value *null_lv, llvm::Value *notnull_lv)
Definition: IRCodegen.cpp:1459
#define CHECK(condition)
Definition: Logger.h:222
virtual const Analyzer::Expr * getOperand(const size_t index)
Definition: Codegen.cpp:65
std::string getName() const
Definition: Codegen.h:39