OmniSciDB  471d68cefb
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ArrowResultSet.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <arrow/api.h>
20 #include <arrow/io/memory.h>
21 #include <arrow/ipc/api.h>
22 
24 #include "Shared/ArrowUtil.h"
25 
26 namespace {
27 
28 SQLTypeInfo type_from_arrow_field(const arrow::Field& field) {
29  switch (field.type()->id()) {
30  case arrow::Type::INT8:
31  return SQLTypeInfo(kTINYINT, !field.nullable());
32  case arrow::Type::INT16:
33  return SQLTypeInfo(kSMALLINT, !field.nullable());
34  case arrow::Type::INT32:
35  return SQLTypeInfo(kINT, !field.nullable());
36  case arrow::Type::INT64:
37  return SQLTypeInfo(kBIGINT, !field.nullable());
38  case arrow::Type::FLOAT:
39  return SQLTypeInfo(kFLOAT, !field.nullable());
41  return SQLTypeInfo(kDOUBLE, !field.nullable());
43  return SQLTypeInfo(kTEXT, !field.nullable(), kENCODING_DICT);
45  // TODO(Wamsi): go right fold expr in c++17
46  auto get_precision = [&field](auto type) { return field.type()->Equals(type); };
47  if (get_precision(arrow::timestamp(arrow::TimeUnit::SECOND))) {
48  return SQLTypeInfo(kTIMESTAMP, !field.nullable());
49  } else if (get_precision(arrow::timestamp(arrow::TimeUnit::MILLI))) {
50  return SQLTypeInfo(kTIMESTAMP, 3, 0, !field.nullable());
51  } else if (get_precision(arrow::timestamp(arrow::TimeUnit::MICRO))) {
52  return SQLTypeInfo(kTIMESTAMP, 6, 0, !field.nullable());
53  } else if (get_precision(arrow::timestamp(arrow::TimeUnit::NANO))) {
54  return SQLTypeInfo(kTIMESTAMP, 9, 0, !field.nullable());
55  } else {
56  UNREACHABLE();
57  }
58  }
59  case arrow::Type::DATE32:
60  return SQLTypeInfo(kDATE, !field.nullable(), kENCODING_DATE_IN_DAYS);
61  case arrow::Type::DATE64:
62  return SQLTypeInfo(kDATE, !field.nullable());
63  case arrow::Type::TIME32:
64  return SQLTypeInfo(kTIME, !field.nullable());
65  default:
66  CHECK(false);
67  }
68  CHECK(false);
69  return SQLTypeInfo();
70 }
71 
72 } // namespace
73 
74 ArrowResultSet::ArrowResultSet(const std::shared_ptr<ResultSet>& rows,
75  const std::vector<TargetMetaInfo>& targets_meta,
76  const ExecutorDeviceType device_type)
77  : rows_(rows), targets_meta_(targets_meta), crt_row_idx_(0) {
78  resultSetArrowLoopback(device_type);
79  auto schema = record_batch_->schema();
80  for (int i = 0; i < schema->num_fields(); ++i) {
81  std::shared_ptr<arrow::Field> field = schema->field(i);
82  SQLTypeInfo type_info = type_from_arrow_field(*schema->field(i));
83  column_metainfo_.emplace_back(field->name(), type_info);
84  columns_.emplace_back(record_batch_->column(i));
85  }
86 }
87 
88 template <typename Type, typename ArrayType>
89 void ArrowResultSet::appendValue(std::vector<TargetValue>& row,
90  const arrow::Array& column,
91  const Type null_val,
92  const size_t idx) const {
93  const auto& col = static_cast<const ArrayType&>(column);
94  row.emplace_back(col.IsNull(idx) ? null_val : static_cast<Type>(col.Value(idx)));
95 }
96 
97 std::vector<TargetValue> ArrowResultSet::getRowAt(const size_t index) const {
98  if (index >= rowCount()) {
99  return {};
100  }
101 
102  CHECK_LT(index, rowCount());
103  std::vector<TargetValue> row;
104  for (int i = 0; i < record_batch_->num_columns(); ++i) {
105  const auto& column = *columns_[i];
106  const auto& column_typeinfo = getColType(i);
107  switch (column_typeinfo.get_type()) {
108  case kTINYINT: {
109  CHECK_EQ(arrow::Type::INT8, column.type_id());
110  appendValue<int64_t, arrow::Int8Array>(
111  row, column, inline_int_null_val(column_typeinfo), index);
112  break;
113  }
114  case kSMALLINT: {
115  CHECK_EQ(arrow::Type::INT16, column.type_id());
116  appendValue<int64_t, arrow::Int16Array>(
117  row, column, inline_int_null_val(column_typeinfo), index);
118  break;
119  }
120  case kINT: {
121  CHECK_EQ(arrow::Type::INT32, column.type_id());
122  appendValue<int64_t, arrow::Int32Array>(
123  row, column, inline_int_null_val(column_typeinfo), index);
124  break;
125  }
126  case kBIGINT: {
127  CHECK_EQ(arrow::Type::INT64, column.type_id());
128  appendValue<int64_t, arrow::Int64Array>(
129  row, column, inline_int_null_val(column_typeinfo), index);
130  break;
131  }
132  case kFLOAT: {
133  CHECK_EQ(arrow::Type::FLOAT, column.type_id());
134  appendValue<float, arrow::FloatArray>(
135  row, column, inline_fp_null_value<float>(), index);
136  break;
137  }
138  case kDOUBLE: {
139  CHECK_EQ(arrow::Type::DOUBLE, column.type_id());
140  appendValue<double, arrow::DoubleArray>(
141  row, column, inline_fp_null_value<double>(), index);
142  break;
143  }
144  case kTEXT: {
145  CHECK_EQ(kENCODING_DICT, column_typeinfo.get_compression());
146  CHECK_EQ(arrow::Type::DICTIONARY, column.type_id());
147  const auto& dict_column = static_cast<const arrow::DictionaryArray&>(column);
148  if (dict_column.IsNull(index)) {
149  row.emplace_back(NullableString(nullptr));
150  } else {
151  const auto& indices =
152  static_cast<const arrow::Int32Array&>(*dict_column.indices());
153  const auto& dictionary =
154  static_cast<const arrow::StringArray&>(*dict_column.dictionary());
155  row.emplace_back(dictionary.GetString(indices.Value(index)));
156  }
157  break;
158  }
159  case kTIMESTAMP: {
160  CHECK_EQ(arrow::Type::TIMESTAMP, column.type_id());
161  appendValue<int64_t, arrow::TimestampArray>(
162  row, column, inline_int_null_val(column_typeinfo), index);
163  break;
164  }
165  case kDATE: {
166  // TODO(wamsi): constexpr?
167  CHECK(arrow::Type::DATE32 == column.type_id() ||
168  arrow::Type::DATE64 == column.type_id());
169  column_typeinfo.is_date_in_days()
170  ? appendValue<int64_t, arrow::Date32Array>(
171  row, column, inline_int_null_val(column_typeinfo), index)
172  : appendValue<int64_t, arrow::Date64Array>(
173  row, column, inline_int_null_val(column_typeinfo), index);
174  break;
175  }
176  case kTIME: {
177  CHECK_EQ(arrow::Type::TIME32, column.type_id());
178  appendValue<int64_t, arrow::Time32Array>(
179  row, column, inline_int_null_val(column_typeinfo), index);
180  break;
181  }
182  default:
183  CHECK(false);
184  }
185  }
186  return row;
187 }
188 
189 std::vector<TargetValue> ArrowResultSet::getNextRow(const bool translate_strings,
190  const bool decimal_to_double) const {
191  if (crt_row_idx_ == rowCount()) {
192  return {};
193  }
195  auto row = getRowAt(crt_row_idx_);
196  ++crt_row_idx_;
197  return row;
198 }
199 
200 size_t ArrowResultSet::colCount() const {
201  return column_metainfo_.size();
202 }
203 
204 SQLTypeInfo ArrowResultSet::getColType(const size_t col_idx) const {
205  CHECK_LT(col_idx, column_metainfo_.size());
206  return column_metainfo_[col_idx].get_type_info();
207 }
208 
210  return !rowCount();
211 }
212 
213 size_t ArrowResultSet::rowCount() const {
214  return record_batch_->num_rows();
215 }
216 
218  std::vector<std::string> col_names;
219 
220  if (!targets_meta_.empty()) {
221  for (auto& meta : targets_meta_) {
222  col_names.push_back(meta.get_resname());
223  }
224  } else {
225  for (unsigned int i = 0; i < rows_->colCount(); i++) {
226  col_names.push_back("col_" + std::to_string(i));
227  }
228  }
229 
230  // We convert the given rows to arrow, which gets serialized
231  // into a buffer by Arrow Wire.
232  auto converter = ArrowResultSetConverter(rows_, col_names, -1);
233  converter.transport_method_ = ArrowTransport::WIRE;
234  converter.device_type_ = device_type;
235 
236  // Lifetime of the result buffer is that of ArrowResultSet
237  results_ = std::make_shared<ArrowResult>(converter.getArrowResult());
238 
239  // Create a reader for reading back serialized
240  arrow::io::BufferReader reader(
241  reinterpret_cast<const uint8_t*>(results_->df_buffer.data()), results_->df_size);
242 
243  ARROW_ASSIGN_OR_THROW(auto batch_reader,
244  arrow::ipc::RecordBatchStreamReader::Open(&reader));
245 
246  ARROW_THROW_NOT_OK(batch_reader->ReadNext(&record_batch_));
247 
248  // Collect dictionaries from the record batch into the dictionary memo.
250  arrow::ipc::internal::CollectDictionaries(*record_batch_, &dictionary_memo_));
251 
252  CHECK_EQ(record_batch_->schema()->num_fields(), record_batch_->num_columns());
253 }
254 
255 std::unique_ptr<ArrowResultSet> result_set_arrow_loopback(
256  const ExecutionResult& results) {
257  // NOTE(wesm): About memory ownership
258 
259  // After calling ReadRecordBatch, the buffers inside arrow::RecordBatch now
260  // share ownership of the memory in serialized_arrow_output.records (zero
261  // copy). Not necessary to retain these buffers. Same is true of any
262  // dictionaries contained in serialized_arrow_output.schema; the arrays
263  // reference that memory (zero copy).
264  return std::make_unique<ArrowResultSet>(results.getRows(), results.getTargetsMeta());
265 }
266 
267 std::unique_ptr<ArrowResultSet> result_set_arrow_loopback(
268  const ExecutionResult* results,
269  const std::shared_ptr<ResultSet>& rows,
270  const ExecutorDeviceType device_type) {
271  return results ? std::make_unique<ArrowResultSet>(
272  rows, results->getTargetsMeta(), device_type)
273  : std::make_unique<ArrowResultSet>(rows, device_type);
274 }
#define CHECK_EQ(x, y)
Definition: Logger.h:217
#define ARROW_THROW_NOT_OK(s)
Definition: ArrowUtil.h:36
Definition: sqltypes.h:49
double decimal_to_double(const SQLTypeInfo &otype, int64_t oval)
ExecutorDeviceType
std::shared_ptr< ArrowResult > results_
std::shared_ptr< ResultSet > rows_
size_t rowCount() const
SQLTypeInfo getColType(const size_t col_idx) const
#define ARROW_ASSIGN_OR_THROW(lhs, rexpr)
Definition: ArrowUtil.h:60
#define DOUBLE
#define UNREACHABLE()
Definition: Logger.h:253
#define DICTIONARY
arrow::ipc::DictionaryMemo dictionary_memo_
SQLTypeInfo type_from_arrow_field(const arrow::Field &field)
std::string to_string(char const *&&v)
ArrowResultSet(const std::shared_ptr< ResultSet > &rows, const std::vector< TargetMetaInfo > &targets_meta, const ExecutorDeviceType device_type=ExecutorDeviceType::CPU)
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
const std::vector< TargetMetaInfo > & getTargetsMeta() const
const std::shared_ptr< ResultSet > & getRows() const
std::vector< TargetValue > getNextRow(const bool translate_strings, const bool decimal_to_double) const
#define CHECK_LT(x, y)
Definition: Logger.h:219
Definition: sqltypes.h:52
Definition: sqltypes.h:53
size_t colCount() const
std::unique_ptr< ArrowResultSet > result_set_arrow_loopback(const ExecutionResult &results)
constexpr float inline_fp_null_value< float >()
#define TIMESTAMP
std::vector< TargetMetaInfo > column_metainfo_
constexpr double inline_fp_null_value< double >()
boost::variant< std::string, void * > NullableString
Definition: TargetValue.h:155
bool definitelyHasNoRows() const
void appendValue(std::vector< TargetValue > &row, const arrow::Array &column, const Type null_val, const size_t idx) const
#define CHECK(condition)
Definition: Logger.h:209
int64_t inline_int_null_val(const SQL_TYPE_INFO &ti)
Definition: sqltypes.h:45
void resultSetArrowLoopback(const ExecutorDeviceType device_type=ExecutorDeviceType::CPU)
std::shared_ptr< arrow::RecordBatch > record_batch_
#define FLOAT
std::vector< TargetMetaInfo > targets_meta_
std::vector< std::shared_ptr< arrow::Array > > columns_