OmniSciDB  c0231cc57d
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ArrowResultSet.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <arrow/api.h>
20 #include <arrow/io/memory.h>
21 #include <arrow/ipc/api.h>
22 
24 #include "Shared/ArrowUtil.h"
25 
26 namespace {
27 
28 SQLTypeInfo type_from_arrow_field(const arrow::Field& field) {
29  switch (field.type()->id()) {
30  case arrow::Type::INT8:
31  return SQLTypeInfo(kTINYINT, !field.nullable());
32  case arrow::Type::INT16:
33  return SQLTypeInfo(kSMALLINT, !field.nullable());
34  case arrow::Type::INT32:
35  return SQLTypeInfo(kINT, !field.nullable());
36  case arrow::Type::INT64:
37  return SQLTypeInfo(kBIGINT, !field.nullable());
38  case arrow::Type::FLOAT:
39  return SQLTypeInfo(kFLOAT, !field.nullable());
40  case arrow::Type::DOUBLE:
41  return SQLTypeInfo(kDOUBLE, !field.nullable());
42  case arrow::Type::DICTIONARY:
43  return SQLTypeInfo(kTEXT, !field.nullable(), kENCODING_DICT);
44  case arrow::Type::TIMESTAMP: {
45  // TODO(Wamsi): go right fold expr in c++17
46  auto get_precision = [&field](auto type) { return field.type()->Equals(type); };
47  if (get_precision(arrow::timestamp(arrow::TimeUnit::SECOND))) {
48  return SQLTypeInfo(kTIMESTAMP, !field.nullable());
49  } else if (get_precision(arrow::timestamp(arrow::TimeUnit::MILLI))) {
50  return SQLTypeInfo(kTIMESTAMP, 3, 0, !field.nullable());
51  } else if (get_precision(arrow::timestamp(arrow::TimeUnit::MICRO))) {
52  return SQLTypeInfo(kTIMESTAMP, 6, 0, !field.nullable());
53  } else if (get_precision(arrow::timestamp(arrow::TimeUnit::NANO))) {
54  return SQLTypeInfo(kTIMESTAMP, 9, 0, !field.nullable());
55  } else {
56  UNREACHABLE();
57  }
58  }
59  case arrow::Type::DATE32:
60  return SQLTypeInfo(kDATE, !field.nullable(), kENCODING_DATE_IN_DAYS);
61  case arrow::Type::DATE64:
62  return SQLTypeInfo(kDATE, !field.nullable());
63  case arrow::Type::TIME32:
64  return SQLTypeInfo(kTIME, !field.nullable());
65  default:
66  CHECK(false);
67  }
68  CHECK(false);
69  return SQLTypeInfo();
70 }
71 
72 } // namespace
73 
74 ArrowResultSet::ArrowResultSet(const std::shared_ptr<ResultSet>& rows,
75  const std::vector<TargetMetaInfo>& targets_meta,
76  const ExecutorDeviceType device_type)
77  : rows_(rows), targets_meta_(targets_meta), crt_row_idx_(0) {
78  resultSetArrowLoopback(device_type);
79  auto schema = record_batch_->schema();
80  for (int i = 0; i < schema->num_fields(); ++i) {
81  std::shared_ptr<arrow::Field> field = schema->field(i);
82  SQLTypeInfo type_info = type_from_arrow_field(*schema->field(i));
83  column_metainfo_.emplace_back(field->name(), type_info);
84  columns_.emplace_back(record_batch_->column(i));
85  }
86 }
87 
89  const std::shared_ptr<ResultSet>& rows,
90  const std::vector<TargetMetaInfo>& targets_meta,
91  const ExecutorDeviceType device_type,
92  const size_t min_result_size_for_bulk_dictionary_fetch,
93  const double max_dictionary_to_result_size_ratio_for_bulk_dictionary_fetch)
94  : rows_(rows), targets_meta_(targets_meta), crt_row_idx_(0) {
95  resultSetArrowLoopback(device_type,
96  min_result_size_for_bulk_dictionary_fetch,
97  max_dictionary_to_result_size_ratio_for_bulk_dictionary_fetch);
98  auto schema = record_batch_->schema();
99  for (int i = 0; i < schema->num_fields(); ++i) {
100  std::shared_ptr<arrow::Field> field = schema->field(i);
101  SQLTypeInfo type_info = type_from_arrow_field(*schema->field(i));
102  column_metainfo_.emplace_back(field->name(), type_info);
103  columns_.emplace_back(record_batch_->column(i));
104  }
105 }
106 
107 template <typename Type, typename ArrayType>
108 void ArrowResultSet::appendValue(std::vector<TargetValue>& row,
109  const arrow::Array& column,
110  const Type null_val,
111  const size_t idx) const {
112  const auto& col = static_cast<const ArrayType&>(column);
113  row.emplace_back(col.IsNull(idx) ? null_val : static_cast<Type>(col.Value(idx)));
114 }
115 
116 std::vector<std::string> ArrowResultSet::getDictionaryStrings(
117  const size_t col_idx) const {
118  if (col_idx >= colCount()) {
119  throw std::runtime_error("ArrowResultSet::getDictionaryStrings: col_idx is invalid.");
120  }
121  const auto& column_typeinfo = getColType(col_idx);
122  if (column_typeinfo.get_type() != kTEXT) {
123  throw std::runtime_error(
124  "ArrowResultSet::getDictionaryStrings: col_idx does not refer to column of type "
125  "TEXT.");
126  }
127  CHECK_EQ(kENCODING_DICT, column_typeinfo.get_compression());
128  const auto& column = *columns_[col_idx];
129  CHECK_EQ(arrow::Type::DICTIONARY, column.type_id());
130  const auto& dict_column = static_cast<const arrow::DictionaryArray&>(column);
131  const auto& dictionary =
132  static_cast<const arrow::StringArray&>(*dict_column.dictionary());
133  const size_t dictionary_size = dictionary.length();
134  std::vector<std::string> dictionary_strings;
135  dictionary_strings.reserve(dictionary_size);
136  for (size_t d = 0; d < dictionary_size; ++d) {
137  dictionary_strings.emplace_back(dictionary.GetString(d));
138  }
139  return dictionary_strings;
140 }
141 
142 std::vector<TargetValue> ArrowResultSet::getRowAt(const size_t index) const {
143  if (index >= rowCount()) {
144  return {};
145  }
146 
147  CHECK_LT(index, rowCount());
148  std::vector<TargetValue> row;
149  for (int i = 0; i < record_batch_->num_columns(); ++i) {
150  const auto& column = *columns_[i];
151  const auto& column_typeinfo = getColType(i);
152  switch (column_typeinfo.get_type()) {
153  case kTINYINT: {
154  CHECK_EQ(arrow::Type::INT8, column.type_id());
155  appendValue<int64_t, arrow::Int8Array>(
156  row, column, inline_int_null_val(column_typeinfo), index);
157  break;
158  }
159  case kSMALLINT: {
160  CHECK_EQ(arrow::Type::INT16, column.type_id());
161  appendValue<int64_t, arrow::Int16Array>(
162  row, column, inline_int_null_val(column_typeinfo), index);
163  break;
164  }
165  case kINT: {
166  CHECK_EQ(arrow::Type::INT32, column.type_id());
167  appendValue<int64_t, arrow::Int32Array>(
168  row, column, inline_int_null_val(column_typeinfo), index);
169  break;
170  }
171  case kBIGINT: {
172  CHECK_EQ(arrow::Type::INT64, column.type_id());
173  appendValue<int64_t, arrow::Int64Array>(
174  row, column, inline_int_null_val(column_typeinfo), index);
175  break;
176  }
177  case kFLOAT: {
178  CHECK_EQ(arrow::Type::FLOAT, column.type_id());
179  appendValue<float, arrow::FloatArray>(
180  row, column, inline_fp_null_value<float>(), index);
181  break;
182  }
183  case kDOUBLE: {
184  CHECK_EQ(arrow::Type::DOUBLE, column.type_id());
185  appendValue<double, arrow::DoubleArray>(
186  row, column, inline_fp_null_value<double>(), index);
187  break;
188  }
189  case kTEXT: {
190  CHECK_EQ(kENCODING_DICT, column_typeinfo.get_compression());
191  CHECK_EQ(arrow::Type::DICTIONARY, column.type_id());
192  const auto& dict_column = static_cast<const arrow::DictionaryArray&>(column);
193  if (dict_column.IsNull(index)) {
194  row.emplace_back(NullableString(nullptr));
195  } else {
196  const auto& indices =
197  static_cast<const arrow::Int32Array&>(*dict_column.indices());
198  const auto& dictionary =
199  static_cast<const arrow::StringArray&>(*dict_column.dictionary());
200  row.emplace_back(dictionary.GetString(indices.Value(index)));
201  }
202  break;
203  }
204  case kTIMESTAMP: {
205  CHECK_EQ(arrow::Type::TIMESTAMP, column.type_id());
206  appendValue<int64_t, arrow::TimestampArray>(
207  row, column, inline_int_null_val(column_typeinfo), index);
208  break;
209  }
210  case kDATE: {
211  // TODO(wamsi): constexpr?
212  CHECK(arrow::Type::DATE32 == column.type_id() ||
213  arrow::Type::DATE64 == column.type_id());
214  column_typeinfo.is_date_in_days()
215  ? appendValue<int64_t, arrow::Date32Array>(
216  row, column, inline_int_null_val(column_typeinfo), index)
217  : appendValue<int64_t, arrow::Date64Array>(
218  row, column, inline_int_null_val(column_typeinfo), index);
219  break;
220  }
221  case kTIME: {
222  CHECK_EQ(arrow::Type::TIME32, column.type_id());
223  appendValue<int64_t, arrow::Time32Array>(
224  row, column, inline_int_null_val(column_typeinfo), index);
225  break;
226  }
227  default:
228  CHECK(false);
229  }
230  }
231  return row;
232 }
233 
234 std::vector<TargetValue> ArrowResultSet::getNextRow(const bool translate_strings,
235  const bool decimal_to_double) const {
236  if (crt_row_idx_ == rowCount()) {
237  return {};
238  }
240  auto row = getRowAt(crt_row_idx_);
241  ++crt_row_idx_;
242  return row;
243 }
244 
245 size_t ArrowResultSet::colCount() const {
246  return column_metainfo_.size();
247 }
248 
249 SQLTypeInfo ArrowResultSet::getColType(const size_t col_idx) const {
250  CHECK_LT(col_idx, column_metainfo_.size());
251  return column_metainfo_[col_idx].get_type_info();
252 }
253 
255  return !rowCount();
256 }
257 
258 size_t ArrowResultSet::rowCount() const {
259  return record_batch_->num_rows();
260 }
261 
262 // Function is for parity with ResultSet interface
263 // and associated tests
265  return rowCount();
266 }
267 
268 // Function is for parity with ResultSet interface
269 // and associated tests
271  return rowCount() == static_cast<size_t>(0);
272 }
273 
276  device_type,
279  default_max_dictionary_to_result_size_ratio_for_bulk_dictionary_fetch);
280 }
281 
283  const ExecutorDeviceType device_type,
284  const size_t min_result_size_for_bulk_dictionary_fetch,
285  const double max_dictionary_to_result_size_ratio_for_bulk_dictionary_fetch) {
286  std::vector<std::string> col_names;
287 
288  if (!targets_meta_.empty()) {
289  for (auto& meta : targets_meta_) {
290  col_names.push_back(meta.get_resname());
291  }
292  } else {
293  for (unsigned int i = 0; i < rows_->colCount(); i++) {
294  col_names.push_back("col_" + std::to_string(i));
295  }
296  }
297 
298  // We convert the given rows to arrow, which gets serialized
299  // into a buffer by Arrow Wire.
300  auto converter = ArrowResultSetConverter(
301  rows_,
302  col_names,
303  -1,
304  min_result_size_for_bulk_dictionary_fetch,
305  max_dictionary_to_result_size_ratio_for_bulk_dictionary_fetch);
306  converter.transport_method_ = ArrowTransport::WIRE;
307  converter.device_type_ = device_type;
308 
309  // Lifetime of the result buffer is that of ArrowResultSet
310  results_ = std::make_shared<ArrowResult>(converter.getArrowResult());
311 
312  // Create a reader for reading back serialized
313  arrow::io::BufferReader reader(
314  reinterpret_cast<const uint8_t*>(results_->df_buffer.data()), results_->df_size);
315 
316  ARROW_ASSIGN_OR_THROW(auto batch_reader,
317  arrow::ipc::RecordBatchStreamReader::Open(&reader));
318 
319  ARROW_THROW_NOT_OK(batch_reader->ReadNext(&record_batch_));
320 
321  // Collect dictionaries from the record batch into the dictionary memo.
323  arrow::ipc::internal::CollectDictionaries(*record_batch_, &dictionary_memo_));
324 
325  CHECK_EQ(record_batch_->schema()->num_fields(), record_batch_->num_columns());
326 }
327 
328 std::unique_ptr<ArrowResultSet> result_set_arrow_loopback(
329  const ExecutionResult& results) {
330  // NOTE(wesm): About memory ownership
331 
332  // After calling ReadRecordBatch, the buffers inside arrow::RecordBatch now
333  // share ownership of the memory in serialized_arrow_output.records (zero
334  // copy). Not necessary to retain these buffers. Same is true of any
335  // dictionaries contained in serialized_arrow_output.schema; the arrays
336  // reference that memory (zero copy).
337  return std::make_unique<ArrowResultSet>(results.getRows(), results.getTargetsMeta());
338 }
339 
340 std::unique_ptr<ArrowResultSet> result_set_arrow_loopback(
341  const ExecutionResult* results,
342  const std::shared_ptr<ResultSet>& rows,
343  const ExecutorDeviceType device_type) {
344  return results ? std::make_unique<ArrowResultSet>(
345  rows, results->getTargetsMeta(), device_type)
346  : std::make_unique<ArrowResultSet>(rows, device_type);
347 }
348 
349 std::unique_ptr<ArrowResultSet> result_set_arrow_loopback(
350  const ExecutionResult* results,
351  const std::shared_ptr<ResultSet>& rows,
352  const ExecutorDeviceType device_type,
353  const size_t min_result_size_for_bulk_dictionary_fetch,
354  const double max_dictionary_to_result_size_ratio_for_bulk_dictionary_fetch) {
355  std::vector<TargetMetaInfo> dummy_targets_meta;
356  return results ? std::make_unique<ArrowResultSet>(
357  rows,
358  results->getTargetsMeta(),
359  device_type,
360  min_result_size_for_bulk_dictionary_fetch,
361  max_dictionary_to_result_size_ratio_for_bulk_dictionary_fetch)
362  : std::make_unique<ArrowResultSet>(
363  rows,
364  dummy_targets_meta,
365  device_type,
366  min_result_size_for_bulk_dictionary_fetch,
367  max_dictionary_to_result_size_ratio_for_bulk_dictionary_fetch);
368 }
#define CHECK_EQ(x, y)
Definition: Logger.h:230
#define ARROW_THROW_NOT_OK(s)
Definition: ArrowUtil.h:36
Definition: sqltypes.h:63
double decimal_to_double(const SQLTypeInfo &otype, int64_t oval)
ExecutorDeviceType
std::shared_ptr< ArrowResult > results_
std::shared_ptr< ResultSet > rows_
size_t rowCount() const
SQLTypeInfo getColType(const size_t col_idx) const
#define ARROW_ASSIGN_OR_THROW(lhs, rexpr)
Definition: ArrowUtil.h:60
#define UNREACHABLE()
Definition: Logger.h:266
arrow::ipc::DictionaryMemo dictionary_memo_
SQLTypeInfo type_from_arrow_field(const arrow::Field &field)
std::string to_string(char const *&&v)
ArrowResultSet(const std::shared_ptr< ResultSet > &rows, const std::vector< TargetMetaInfo > &targets_meta, const ExecutorDeviceType device_type=ExecutorDeviceType::CPU)
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
size_t entryCount() const
const std::vector< TargetMetaInfo > & getTargetsMeta() const
const std::shared_ptr< ResultSet > & getRows() const
std::vector< TargetValue > getNextRow(const bool translate_strings, const bool decimal_to_double) const
#define CHECK_LT(x, y)
Definition: Logger.h:232
Definition: sqltypes.h:66
Definition: sqltypes.h:67
size_t colCount() const
std::unique_ptr< ArrowResultSet > result_set_arrow_loopback(const ExecutionResult &results)
constexpr float inline_fp_null_value< float >()
std::vector< TargetMetaInfo > column_metainfo_
constexpr double inline_fp_null_value< double >()
boost::variant< std::string, void * > NullableString
Definition: TargetValue.h:179
bool definitelyHasNoRows() const
static constexpr size_t default_min_result_size_for_bulk_dictionary_fetch
void appendValue(std::vector< TargetValue > &row, const arrow::Array &column, const Type null_val, const size_t idx) const
#define CHECK(condition)
Definition: Logger.h:222
SQLTypes type
Definition: sqltypes.h:1023
int64_t inline_int_null_val(const SQL_TYPE_INFO &ti)
bool isEmpty() const
Definition: sqltypes.h:59
void resultSetArrowLoopback(const ExecutorDeviceType device_type=ExecutorDeviceType::CPU)
std::vector< TargetValue > getRowAt(const size_t index) const
std::shared_ptr< arrow::RecordBatch > record_batch_
std::vector< TargetMetaInfo > targets_meta_
std::vector< std::shared_ptr< arrow::Array > > columns_