OmniSciDB  04ee39c94c
ArrowResultSet.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ArrowResultSet.h"
18 #include "ArrowUtil.h"
20 
21 #include <arrow/api.h>
22 #include <arrow/io/memory.h>
23 #include <arrow/ipc/api.h>
24 
25 namespace {
26 
27 SQLTypeInfo type_from_arrow_field(const arrow::Field& field) {
28  switch (field.type()->id()) {
29  case arrow::Type::INT16:
30  return SQLTypeInfo(kSMALLINT, !field.nullable());
31  case arrow::Type::INT32:
32  return SQLTypeInfo(kINT, !field.nullable());
33  case arrow::Type::INT64:
34  return SQLTypeInfo(kBIGINT, !field.nullable());
35  case arrow::Type::FLOAT:
36  return SQLTypeInfo(kFLOAT, !field.nullable());
37  case arrow::Type::DOUBLE:
38  return SQLTypeInfo(kDOUBLE, !field.nullable());
39  case arrow::Type::DICTIONARY:
40  return SQLTypeInfo(kTEXT, !field.nullable(), kENCODING_DICT);
41  case arrow::Type::TIMESTAMP: {
42  // TODO(Wamsi): go right fold expr in c++17
43  auto get_precision = [&field](auto type) { return field.type()->Equals(type); };
44  if (get_precision(arrow::timestamp(arrow::TimeUnit::SECOND))) {
45  return SQLTypeInfo(kTIMESTAMP, !field.nullable());
46  } else if (get_precision(arrow::timestamp(arrow::TimeUnit::MILLI))) {
47  return SQLTypeInfo(kTIMESTAMP, 3, 0, !field.nullable());
48  } else if (get_precision(arrow::timestamp(arrow::TimeUnit::MICRO))) {
49  return SQLTypeInfo(kTIMESTAMP, 6, 0, !field.nullable());
50  } else if (get_precision(arrow::timestamp(arrow::TimeUnit::NANO))) {
51  return SQLTypeInfo(kTIMESTAMP, 9, 0, !field.nullable());
52  } else {
53  UNREACHABLE();
54  }
55  }
56  case arrow::Type::DATE32:
57  return SQLTypeInfo(kDATE, !field.nullable(), kENCODING_DATE_IN_DAYS);
58  case arrow::Type::DATE64:
59  return SQLTypeInfo(kDATE, !field.nullable());
60  case arrow::Type::TIME32:
61  return SQLTypeInfo(kTIME, !field.nullable());
62  default:
63  CHECK(false);
64  }
65  CHECK(false);
66  return SQLTypeInfo();
67 }
68 
69 } // namespace
70 
71 ArrowResultSet::ArrowResultSet(const std::shared_ptr<ResultSet>& rows,
72  const std::vector<TargetMetaInfo>& targets_meta)
73  : rows_(rows), targets_meta_(targets_meta), crt_row_idx_(0) {
75  auto schema = record_batch_->schema();
76  for (int i = 0; i < schema->num_fields(); ++i) {
77  std::shared_ptr<arrow::Field> field = schema->field(i);
78  SQLTypeInfo type_info = type_from_arrow_field(*schema->field(i));
79  column_metainfo_.emplace_back(field->name(), type_info);
80  columns_.emplace_back(record_batch_->column(i));
81  }
82 }
83 
84 template <typename Type, typename ArrayType>
85 void ArrowResultSet::appendValue(std::vector<TargetValue>& row,
86  const arrow::Array& column,
87  const Type null_val,
88  const size_t idx) const {
89  const auto& col = static_cast<const ArrayType&>(column);
90  row.emplace_back(col.IsNull(idx) ? null_val : static_cast<Type>(col.Value(idx)));
91 }
92 
93 std::vector<TargetValue> ArrowResultSet::getRowAt(const size_t index) const {
94  if (index >= rowCount()) {
95  return {};
96  }
97 
98  CHECK_LT(index, rowCount());
99  std::vector<TargetValue> row;
100  for (int i = 0; i < record_batch_->num_columns(); ++i) {
101  const auto& column = *columns_[i];
102  const auto& column_typeinfo = getColType(i);
103  switch (column_typeinfo.get_type()) {
104  case kSMALLINT: {
105  CHECK_EQ(arrow::Type::INT16, column.type_id());
106  appendValue<int64_t, arrow::Int16Array>(
107  row, column, inline_int_null_val(column_typeinfo), index);
108  break;
109  }
110  case kINT: {
111  CHECK_EQ(arrow::Type::INT32, column.type_id());
112  appendValue<int64_t, arrow::Int32Array>(
113  row, column, inline_int_null_val(column_typeinfo), index);
114  break;
115  }
116  case kBIGINT: {
117  CHECK_EQ(arrow::Type::INT64, column.type_id());
118  appendValue<int64_t, arrow::Int64Array>(
119  row, column, inline_int_null_val(column_typeinfo), index);
120  break;
121  }
122  case kFLOAT: {
123  CHECK_EQ(arrow::Type::FLOAT, column.type_id());
124  appendValue<float, arrow::FloatArray>(
125  row, column, inline_fp_null_value<float>(), index);
126  break;
127  }
128  case kDOUBLE: {
129  CHECK_EQ(arrow::Type::DOUBLE, column.type_id());
130  appendValue<double, arrow::DoubleArray>(
131  row, column, inline_fp_null_value<double>(), index);
132  break;
133  }
134  case kTEXT: {
135  CHECK_EQ(kENCODING_DICT, column_typeinfo.get_compression());
136  CHECK_EQ(arrow::Type::DICTIONARY, column.type_id());
137  const auto& dict_column = static_cast<const arrow::DictionaryArray&>(column);
138  if (dict_column.IsNull(index)) {
139  row.emplace_back(NullableString(nullptr));
140  } else {
141  const auto& indices =
142  static_cast<const arrow::Int32Array&>(*dict_column.indices());
143  const auto& dictionary =
144  static_cast<const arrow::StringArray&>(*dict_column.dictionary());
145  row.emplace_back(dictionary.GetString(indices.Value(index)));
146  }
147  break;
148  }
149  case kTIMESTAMP: {
150  CHECK_EQ(arrow::Type::TIMESTAMP, column.type_id());
151  appendValue<int64_t, arrow::TimestampArray>(
152  row, column, inline_int_null_val(column_typeinfo), index);
153  break;
154  }
155  case kDATE: {
156  // TODO(wamsi): constexpr?
157  CHECK(arrow::Type::DATE32 == column.type_id() ||
158  arrow::Type::DATE64 == column.type_id());
159  column_typeinfo.is_date_in_days()
160  ? appendValue<int64_t, arrow::Date32Array>(
161  row, column, inline_int_null_val(column_typeinfo), index)
162  : appendValue<int64_t, arrow::Date64Array>(
163  row, column, inline_int_null_val(column_typeinfo), index);
164  break;
165  }
166  case kTIME: {
167  CHECK_EQ(arrow::Type::TIME32, column.type_id());
168  appendValue<int64_t, arrow::Time32Array>(
169  row, column, inline_int_null_val(column_typeinfo), index);
170  break;
171  }
172  default:
173  CHECK(false);
174  }
175  }
176  return row;
177 }
178 
179 std::vector<TargetValue> ArrowResultSet::getNextRow(const bool translate_strings,
180  const bool decimal_to_double) const {
181  if (crt_row_idx_ == rowCount()) {
182  return {};
183  }
185  auto row = getRowAt(crt_row_idx_);
186  ++crt_row_idx_;
187  return row;
188 }
189 
190 size_t ArrowResultSet::colCount() const {
191  return column_metainfo_.size();
192 }
193 
194 SQLTypeInfo ArrowResultSet::getColType(const size_t col_idx) const {
195  CHECK_LT(col_idx, column_metainfo_.size());
196  return column_metainfo_[col_idx].get_type_info();
197 }
198 
200  return !rowCount();
201 }
202 
203 size_t ArrowResultSet::rowCount() const {
204  return record_batch_->num_rows();
205 }
206 
208  std::vector<std::string> col_names;
209 
210  if (!targets_meta_.empty()) {
211  for (auto& meta : targets_meta_) {
212  col_names.push_back(meta.get_resname());
213  }
214  } else {
215  for (unsigned int i = 0; i < rows_->colCount(); i++) {
216  col_names.push_back("col_" + std::to_string(i));
217  }
218  }
219  const auto converter = ArrowResultSetConverter(rows_, col_names, -1);
220  const auto serialized_arrow_output = converter.getSerializedArrowOutput();
221 
222  arrow::io::BufferReader schema_reader(serialized_arrow_output.schema);
223 
224  std::shared_ptr<arrow::Schema> schema;
225  ARROW_THROW_NOT_OK(arrow::ipc::ReadSchema(&schema_reader, &schema));
226 
227  arrow::io::BufferReader records_reader(serialized_arrow_output.records);
229  arrow::ipc::ReadRecordBatch(schema, &records_reader, &record_batch_));
230 
231  CHECK_EQ(schema->num_fields(), record_batch_->num_columns());
232 }
233 
234 std::unique_ptr<ArrowResultSet> result_set_arrow_loopback(
235  const ExecutionResult& results) {
236  // NOTE(wesm): About memory ownership
237 
238  // After calling ReadRecordBatch, the buffers inside arrow::RecordBatch now
239  // share ownership of the memory in serialized_arrow_output.records (zero
240  // copy). Not necessary to retain these buffers. Same is true of any
241  // dictionaries contained in serialized_arrow_output.schema; the arrays
242  // reference that memory (zero copy).
243  return std::make_unique<ArrowResultSet>(results.getRows(), results.getTargetsMeta());
244 }
245 
246 std::unique_ptr<ArrowResultSet> result_set_arrow_loopback(
247  const ExecutionResult* results,
248  const std::shared_ptr<ResultSet>& rows) {
249  return results ? std::make_unique<ArrowResultSet>(rows, results->getTargetsMeta())
250  : std::make_unique<ArrowResultSet>(rows);
251 }
#define CHECK_EQ(x, y)
Definition: Logger.h:195
#define ARROW_THROW_NOT_OK(s)
Definition: ArrowUtil.h:28
Definition: sqltypes.h:51
double decimal_to_double(const SQLTypeInfo &otype, int64_t oval)
std::shared_ptr< ResultSet > rows_
void appendValue(std::vector< TargetValue > &row, const arrow::Array &column, const Type null_val, const size_t idx) const
#define UNREACHABLE()
Definition: Logger.h:231
SQLTypeInfo type_from_arrow_field(const arrow::Field &field)
bool definitelyHasNoRows() const
std::string to_string(char const *&&v)
std::vector< TargetValue > getNextRow(const bool translate_strings, const bool decimal_to_double) const
SQLTypeInfo getColType(const size_t col_idx) const
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
std::vector< TargetValue > getRowAt(const size_t index) const
SQLTypeInfoCore< ArrayContextTypeSizer, ExecutorTypePackaging, DateTimeFacilities > SQLTypeInfo
Definition: sqltypes.h:823
#define CHECK_LT(x, y)
Definition: Logger.h:197
Definition: sqltypes.h:54
Definition: sqltypes.h:55
const std::shared_ptr< ResultSet > & getRows() const
std::unique_ptr< ArrowResultSet > result_set_arrow_loopback(const ExecutionResult &results)
constexpr float inline_fp_null_value< float >()
std::vector< TargetMetaInfo > column_metainfo_
constexpr double inline_fp_null_value< double >()
boost::variant< std::string, void * > NullableString
Definition: TargetValue.h:155
const std::vector< TargetMetaInfo > & getTargetsMeta() const
#define CHECK(condition)
Definition: Logger.h:187
void resultSetArrowLoopback()
int64_t inline_int_null_val(const SQL_TYPE_INFO &ti)
Definition: sqltypes.h:47
std::shared_ptr< arrow::RecordBatch > record_batch_
size_t rowCount() const
SQLTypes type
Definition: sqltypes.h:642
ArrowResultSet(const std::shared_ptr< ResultSet > &rows, const std::vector< TargetMetaInfo > &targets_meta)
std::vector< TargetMetaInfo > targets_meta_
std::vector< std::shared_ptr< arrow::Array > > columns_
size_t colCount() const