OmniSciDB  a987f07e93
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Distance.h
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
20 
21 namespace spatial_type {
22 
23 class Distance : public Codegen {
24  public:
25  Distance(const Analyzer::GeoOperator* geo_operator,
26  const Catalog_Namespace::Catalog* catalog)
27  : Codegen(geo_operator, catalog) {
28  CHECK_EQ(operator_->size(), size_t(2));
29  const auto& ti = operator_->get_type_info();
30  is_nullable_ = !ti.get_notnull();
31  }
32 
33  size_t size() const final { return 2; }
34 
35  SQLTypeInfo getNullType() const final { return SQLTypeInfo(kBOOLEAN); }
36 
37  std::tuple<std::vector<llvm::Value*>, llvm::Value*> codegenLoads(
38  const std::vector<llvm::Value*>& arg_lvs,
39  const std::vector<llvm::Value*>& pos_lvs,
40  CgenState* cgen_state) final {
41  CHECK_EQ(pos_lvs.size(), size());
42  std::string size_fn_name = "array_size";
43  if (is_nullable_) {
44  size_fn_name += "_nullable";
45  }
46 
47  auto& builder = cgen_state->ir_builder_;
48  llvm::Value* is_null = cgen_state->llBool(false);
49 
50  std::vector<llvm::Value*> operand_lvs;
51  size_t arg_lvs_index{0};
52  for (size_t i = 0; i < size(); i++) {
53  const auto operand = getOperand(i);
54  CHECK(operand);
55  const auto& operand_ti = operand->get_type_info();
56  CHECK(IS_GEO(operand_ti.get_type()));
57  const size_t num_physical_coord_lvs = operand_ti.get_physical_coord_cols();
58 
59  // iterate over column inputs
60  bool is_coords_lv{true};
61  if (dynamic_cast<const Analyzer::ColumnVar*>(operand)) {
62  for (size_t j = 0; j < num_physical_coord_lvs; j++) {
63  CHECK_LT(arg_lvs_index, arg_lvs.size());
64  auto lv = arg_lvs[arg_lvs_index++];
65  // TODO: fast fixlen array buff for coords
66  auto array_buff_lv =
67  cgen_state->emitExternalCall("array_buff",
68  llvm::Type::getInt8PtrTy(cgen_state->context_),
69  {lv, pos_lvs[i]});
70  auto const is_coords = (j == 0);
71  if (!is_coords) {
72  // cast additional columns to i32*
73  array_buff_lv = builder.CreateBitCast(
74  array_buff_lv, llvm::Type::getInt32PtrTy(cgen_state->context_));
75  }
76  operand_lvs.push_back(array_buff_lv);
77 
78  const auto ptr_type = llvm::dyn_cast_or_null<llvm::PointerType>(lv->getType());
79  CHECK(ptr_type);
80  const auto elem_type = ptr_type->getPointerElementType();
81  CHECK(elem_type);
82  auto const shift = log2_bytes(is_coords ? 1 : 4);
83  std::vector<llvm::Value*> array_sz_args{
84  lv, pos_lvs[i], cgen_state->llInt(shift)};
85  if (is_nullable_) { // TODO: should we do this for all arguments, or just
86  // coords?
87  array_sz_args.push_back(cgen_state->llInt(
88  static_cast<int32_t>(inline_int_null_value<int32_t>())));
89  }
90  operand_lvs.push_back(cgen_state->emitExternalCall(
91  size_fn_name, get_int_type(32, cgen_state->context_), array_sz_args));
92  llvm::Value* operand_is_null_lv{nullptr};
93  if (is_nullable_ && is_coords_lv) {
94  if (operand_ti.get_type() == kPOINT) {
95  operand_is_null_lv = cgen_state->emitExternalCall(
96  "point_coord_array_is_null",
97  llvm::Type::getInt1Ty(cgen_state->context_),
98  {lv, pos_lvs[i]});
99  } else {
100  operand_is_null_lv = builder.CreateICmpEQ(
101  operand_lvs.back(),
102  cgen_state->llInt(
103  static_cast<int32_t>(inline_int_null_value<int32_t>())));
104  }
105  is_null = builder.CreateOr(is_null, operand_is_null_lv);
106  }
107  is_coords_lv = false;
108  }
109  } else {
110  bool is_coords_lv{true};
111  for (size_t j = 0; j < num_physical_coord_lvs; j++) {
112  // ptr
113  CHECK_LT(arg_lvs_index, arg_lvs.size());
114  auto array_buff_lv = arg_lvs[arg_lvs_index++];
115  if (j == 0) {
116  // cast alloca to i8*
117  array_buff_lv = builder.CreateBitCast(
118  array_buff_lv, llvm::Type::getInt8PtrTy(cgen_state->context_));
119  } else {
120  // cast additional columns to i32*
121  array_buff_lv = builder.CreateBitCast(
122  array_buff_lv, llvm::Type::getInt32PtrTy(cgen_state->context_));
123  }
124  operand_lvs.push_back(array_buff_lv);
125  if (is_nullable_ && is_coords_lv) {
126  auto coords_array_type =
127  llvm::dyn_cast<llvm::PointerType>(operand_lvs.back()->getType());
128  CHECK(coords_array_type);
129  is_null = builder.CreateOr(
130  is_null,
131  builder.CreateICmpEQ(operand_lvs.back(),
132  llvm::ConstantPointerNull::get(coords_array_type)));
133  }
134  is_coords_lv = false;
135  CHECK_LT(arg_lvs_index, arg_lvs.size());
136  operand_lvs.push_back(arg_lvs[arg_lvs_index++]);
137  }
138  }
139  }
140  CHECK_EQ(arg_lvs_index, arg_lvs.size());
141 
142  // use the points array size argument for nullability
143  return std::make_tuple(operand_lvs, is_nullable_ ? is_null : nullptr);
144  }
145 
146  std::vector<llvm::Value*> codegen(const std::vector<llvm::Value*>& args,
147  CodeGenerator::NullCheckCodegen* nullcheck_codegen,
148  CgenState* cgen_state,
149  const CompilationOptions& co) final {
150  const auto& first_operand_ti = getOperand(0)->get_type_info();
151  const auto& second_operand_ti = getOperand(1)->get_type_info();
152 
153  const bool is_geodesic = first_operand_ti.get_subtype() == kGEOGRAPHY &&
154  first_operand_ti.get_output_srid() == 4326;
155 
156  if (is_geodesic && !((first_operand_ti.get_type() == kPOINT &&
157  second_operand_ti.get_type() == kPOINT) ||
158  (first_operand_ti.get_type() == kLINESTRING &&
159  second_operand_ti.get_type() == kPOINT) ||
160  (first_operand_ti.get_type() == kPOINT &&
161  second_operand_ti.get_type() == kLINESTRING))) {
162  throw std::runtime_error(getName() +
163  " currently doesn't accept non-POINT geographies");
164  }
165 
166  bool unsupported_args = false;
167  if (first_operand_ti.get_type() == kMULTILINESTRING) {
168  unsupported_args = (second_operand_ti.get_type() != kPOINT);
169  } else if (second_operand_ti.get_type() == kMULTILINESTRING) {
170  unsupported_args = (first_operand_ti.get_type() != kPOINT);
171  }
172  if (unsupported_args) {
173  throw std::runtime_error(getName() +
174  " currently doesn't support this argument combination");
175  }
176 
177  std::string func_name = getName() + suffix(first_operand_ti.get_type()) +
178  suffix(second_operand_ti.get_type());
179  if (is_geodesic) {
180  func_name += "_Geodesic";
181  }
182  auto& builder = cgen_state->ir_builder_;
183 
184  std::vector<llvm::Value*> operand_lvs;
185  for (size_t i = 0; i < args.size(); i += 2) {
186  operand_lvs.push_back(args[i]);
187  operand_lvs.push_back(
188  builder.CreateSExt(args[i + 1], llvm::Type::getInt64Ty(cgen_state->context_)));
189  }
190 
191  const auto& ret_ti = operator_->get_type_info();
192  // push back ic, isr, osr for now
193  operand_lvs.push_back(
194  cgen_state->llInt(Geospatial::get_compression_scheme(first_operand_ti))); // ic 1
195  operand_lvs.push_back(
196  cgen_state->llInt(first_operand_ti.get_input_srid())); // in srid 1
197  operand_lvs.push_back(cgen_state->llInt(
198  Geospatial::get_compression_scheme(second_operand_ti))); // ic 2
199  operand_lvs.push_back(
200  cgen_state->llInt(second_operand_ti.get_input_srid())); // in srid 2
201  const auto srid_override = operator_->getOutputSridOverride();
202  operand_lvs.push_back(
203  cgen_state->llInt(srid_override ? *srid_override : 0)); // out srid
204 
205  if (getName() == "ST_Distance" && first_operand_ti.get_subtype() != kGEOGRAPHY &&
206  (first_operand_ti.get_type() != kPOINT ||
207  second_operand_ti.get_type() != kPOINT)) {
208  operand_lvs.push_back(cgen_state->llFp(double(0.0)));
209  }
210 
211  CHECK(ret_ti.get_type() == kDOUBLE);
212  auto ret = cgen_state->emitExternalCall(
213  func_name, llvm::Type::getDoubleTy(cgen_state->context_), operand_lvs);
214  if (is_nullable_) {
215  CHECK(nullcheck_codegen);
216  ret = nullcheck_codegen->finalize(cgen_state->inlineFpNull(ret_ti), ret);
217  }
218  return {ret};
219  }
220 };
221 
222 } // namespace spatial_type
#define CHECK_EQ(x, y)
Definition: Logger.h:297
class for a per-database catalog. also includes metadata for the current database and the current use...
Definition: Catalog.h:132
int32_t get_compression_scheme(const SQLTypeInfo &ti)
Definition: Compression.cpp:23
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
std::string suffix(SQLTypes type)
Definition: Codegen.cpp:70
std::tuple< std::vector< llvm::Value * >, llvm::Value * > codegenLoads(const std::vector< llvm::Value * > &arg_lvs, const std::vector< llvm::Value * > &pos_lvs, CgenState *cgen_state) final
Definition: Distance.h:37
CONSTEXPR DEVICE bool is_null(const T &value)
std::vector< llvm::Value * > codegen(const std::vector< llvm::Value * > &args, CodeGenerator::NullCheckCodegen *nullcheck_codegen, CgenState *cgen_state, const CompilationOptions &co) final
Definition: Distance.h:146
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:83
#define CHECK_LT(x, y)
Definition: Logger.h:299
Distance(const Analyzer::GeoOperator *geo_operator, const Catalog_Namespace::Catalog *catalog)
Definition: Distance.h:25
const Analyzer::GeoOperator * operator_
Definition: Codegen.h:70
size_t size() const
Definition: Analyzer.cpp:3967
size_t size() const final
Definition: Distance.h:33
#define CHECK(condition)
Definition: Logger.h:289
virtual const Analyzer::Expr * getOperand(const size_t index)
Definition: Codegen.cpp:65
uint32_t log2_bytes(const uint32_t bytes)
Definition: Execute.h:176
#define IS_GEO(T)
Definition: sqltypes.h:298
SQLTypeInfo getNullType() const final
Definition: Distance.h:35