OmniSciDB  72180abbfe
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ArrowResultSet.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <arrow/api.h>
20 #include <arrow/io/memory.h>
21 #include <arrow/ipc/api.h>
22 
24 #include "Shared/ArrowUtil.h"
25 
26 namespace {
27 
28 SQLTypeInfo type_from_arrow_field(const arrow::Field& field) {
29  switch (field.type()->id()) {
30  case arrow::Type::INT8:
31  return SQLTypeInfo(kTINYINT, !field.nullable());
32  case arrow::Type::INT16:
33  return SQLTypeInfo(kSMALLINT, !field.nullable());
34  case arrow::Type::INT32:
35  return SQLTypeInfo(kINT, !field.nullable());
36  case arrow::Type::INT64:
37  return SQLTypeInfo(kBIGINT, !field.nullable());
38  case arrow::Type::FLOAT:
39  return SQLTypeInfo(kFLOAT, !field.nullable());
40  case arrow::Type::DOUBLE:
41  return SQLTypeInfo(kDOUBLE, !field.nullable());
42  case arrow::Type::DICTIONARY:
43  return SQLTypeInfo(kTEXT, !field.nullable(), kENCODING_DICT);
44  case arrow::Type::TIMESTAMP: {
45  // TODO(Wamsi): go right fold expr in c++17
46  auto get_precision = [&field](auto type) { return field.type()->Equals(type); };
47  if (get_precision(arrow::timestamp(arrow::TimeUnit::SECOND))) {
48  return SQLTypeInfo(kTIMESTAMP, !field.nullable());
49  } else if (get_precision(arrow::timestamp(arrow::TimeUnit::MILLI))) {
50  return SQLTypeInfo(kTIMESTAMP, 3, 0, !field.nullable());
51  } else if (get_precision(arrow::timestamp(arrow::TimeUnit::MICRO))) {
52  return SQLTypeInfo(kTIMESTAMP, 6, 0, !field.nullable());
53  } else if (get_precision(arrow::timestamp(arrow::TimeUnit::NANO))) {
54  return SQLTypeInfo(kTIMESTAMP, 9, 0, !field.nullable());
55  } else {
56  UNREACHABLE();
57  }
58  }
59  case arrow::Type::DATE32:
60  return SQLTypeInfo(kDATE, !field.nullable(), kENCODING_DATE_IN_DAYS);
61  case arrow::Type::DATE64:
62  return SQLTypeInfo(kDATE, !field.nullable());
63  case arrow::Type::TIME32:
64  return SQLTypeInfo(kTIME, !field.nullable());
65  default:
66  CHECK(false);
67  }
68  CHECK(false);
69  return SQLTypeInfo();
70 }
71 
72 } // namespace
73 
74 ArrowResultSet::ArrowResultSet(const std::shared_ptr<ResultSet>& rows,
75  const std::vector<TargetMetaInfo>& targets_meta)
76  : rows_(rows), targets_meta_(targets_meta), crt_row_idx_(0) {
78  auto schema = record_batch_->schema();
79  for (int i = 0; i < schema->num_fields(); ++i) {
80  std::shared_ptr<arrow::Field> field = schema->field(i);
81  SQLTypeInfo type_info = type_from_arrow_field(*schema->field(i));
82  column_metainfo_.emplace_back(field->name(), type_info);
83  columns_.emplace_back(record_batch_->column(i));
84  }
85 }
86 
87 template <typename Type, typename ArrayType>
88 void ArrowResultSet::appendValue(std::vector<TargetValue>& row,
89  const arrow::Array& column,
90  const Type null_val,
91  const size_t idx) const {
92  const auto& col = static_cast<const ArrayType&>(column);
93  row.emplace_back(col.IsNull(idx) ? null_val : static_cast<Type>(col.Value(idx)));
94 }
95 
96 std::vector<TargetValue> ArrowResultSet::getRowAt(const size_t index) const {
97  if (index >= rowCount()) {
98  return {};
99  }
100 
101  CHECK_LT(index, rowCount());
102  std::vector<TargetValue> row;
103  for (int i = 0; i < record_batch_->num_columns(); ++i) {
104  const auto& column = *columns_[i];
105  const auto& column_typeinfo = getColType(i);
106  switch (column_typeinfo.get_type()) {
107  case kTINYINT: {
108  CHECK_EQ(arrow::Type::INT8, column.type_id());
109  appendValue<int64_t, arrow::Int8Array>(
110  row, column, inline_int_null_val(column_typeinfo), index);
111  break;
112  }
113  case kSMALLINT: {
114  CHECK_EQ(arrow::Type::INT16, column.type_id());
115  appendValue<int64_t, arrow::Int16Array>(
116  row, column, inline_int_null_val(column_typeinfo), index);
117  break;
118  }
119  case kINT: {
120  CHECK_EQ(arrow::Type::INT32, column.type_id());
121  appendValue<int64_t, arrow::Int32Array>(
122  row, column, inline_int_null_val(column_typeinfo), index);
123  break;
124  }
125  case kBIGINT: {
126  CHECK_EQ(arrow::Type::INT64, column.type_id());
127  appendValue<int64_t, arrow::Int64Array>(
128  row, column, inline_int_null_val(column_typeinfo), index);
129  break;
130  }
131  case kFLOAT: {
132  CHECK_EQ(arrow::Type::FLOAT, column.type_id());
133  appendValue<float, arrow::FloatArray>(
134  row, column, inline_fp_null_value<float>(), index);
135  break;
136  }
137  case kDOUBLE: {
138  CHECK_EQ(arrow::Type::DOUBLE, column.type_id());
139  appendValue<double, arrow::DoubleArray>(
140  row, column, inline_fp_null_value<double>(), index);
141  break;
142  }
143  case kTEXT: {
144  CHECK_EQ(kENCODING_DICT, column_typeinfo.get_compression());
145  CHECK_EQ(arrow::Type::DICTIONARY, column.type_id());
146  const auto& dict_column = static_cast<const arrow::DictionaryArray&>(column);
147  if (dict_column.IsNull(index)) {
148  row.emplace_back(NullableString(nullptr));
149  } else {
150  const auto& indices =
151  static_cast<const arrow::Int32Array&>(*dict_column.indices());
152  const auto& dictionary =
153  static_cast<const arrow::StringArray&>(*dict_column.dictionary());
154  row.emplace_back(dictionary.GetString(indices.Value(index)));
155  }
156  break;
157  }
158  case kTIMESTAMP: {
159  CHECK_EQ(arrow::Type::TIMESTAMP, column.type_id());
160  appendValue<int64_t, arrow::TimestampArray>(
161  row, column, inline_int_null_val(column_typeinfo), index);
162  break;
163  }
164  case kDATE: {
165  // TODO(wamsi): constexpr?
166  CHECK(arrow::Type::DATE32 == column.type_id() ||
167  arrow::Type::DATE64 == column.type_id());
168  column_typeinfo.is_date_in_days()
169  ? appendValue<int64_t, arrow::Date32Array>(
170  row, column, inline_int_null_val(column_typeinfo), index)
171  : appendValue<int64_t, arrow::Date64Array>(
172  row, column, inline_int_null_val(column_typeinfo), index);
173  break;
174  }
175  case kTIME: {
176  CHECK_EQ(arrow::Type::TIME32, column.type_id());
177  appendValue<int64_t, arrow::Time32Array>(
178  row, column, inline_int_null_val(column_typeinfo), index);
179  break;
180  }
181  default:
182  CHECK(false);
183  }
184  }
185  return row;
186 }
187 
188 std::vector<TargetValue> ArrowResultSet::getNextRow(const bool translate_strings,
189  const bool decimal_to_double) const {
190  if (crt_row_idx_ == rowCount()) {
191  return {};
192  }
194  auto row = getRowAt(crt_row_idx_);
195  ++crt_row_idx_;
196  return row;
197 }
198 
199 size_t ArrowResultSet::colCount() const {
200  return column_metainfo_.size();
201 }
202 
203 SQLTypeInfo ArrowResultSet::getColType(const size_t col_idx) const {
204  CHECK_LT(col_idx, column_metainfo_.size());
205  return column_metainfo_[col_idx].get_type_info();
206 }
207 
209  return !rowCount();
210 }
211 
212 size_t ArrowResultSet::rowCount() const {
213  return record_batch_->num_rows();
214 }
215 
217  std::vector<std::string> col_names;
218 
219  if (!targets_meta_.empty()) {
220  for (auto& meta : targets_meta_) {
221  col_names.push_back(meta.get_resname());
222  }
223  } else {
224  for (unsigned int i = 0; i < rows_->colCount(); i++) {
225  col_names.push_back("col_" + std::to_string(i));
226  }
227  }
228  const auto converter = ArrowResultSetConverter(rows_, col_names, -1);
229 
230  arrow::ipc::DictionaryMemo schema_memo;
231  const auto serialized_arrow_output = converter.getSerializedArrowOutput(&schema_memo);
232 
233  arrow::io::BufferReader schema_reader(serialized_arrow_output.schema);
234 
235  std::shared_ptr<arrow::Schema> schema;
236  ARROW_THROW_NOT_OK(arrow::ipc::ReadSchema(&schema_reader, &dictionary_memo_, &schema));
237  CHECK_EQ(schema_memo.num_fields(), dictionary_memo_.num_fields());
238 
239  // add the dictionaries from the serialized output to the newly created memo
240  const auto& serialized_id_to_dict = schema_memo.id_to_dictionary();
241  for (const auto& itr : serialized_id_to_dict) {
242  const auto& id = itr.first;
243  const auto& dict = itr.second;
244  CHECK(!dictionary_memo_.HasDictionary(id));
245  ARROW_THROW_NOT_OK(dictionary_memo_.AddDictionary(id, dict));
246  }
247 
248  arrow::io::BufferReader records_reader(serialized_arrow_output.records);
249 
250  ARROW_THROW_NOT_OK(arrow::ipc::ReadRecordBatch(
251  schema, &dictionary_memo_, &records_reader, &record_batch_));
252 
253  CHECK_EQ(schema->num_fields(), record_batch_->num_columns());
254 }
255 
256 std::unique_ptr<ArrowResultSet> result_set_arrow_loopback(
257  const ExecutionResult& results) {
258  // NOTE(wesm): About memory ownership
259 
260  // After calling ReadRecordBatch, the buffers inside arrow::RecordBatch now
261  // share ownership of the memory in serialized_arrow_output.records (zero
262  // copy). Not necessary to retain these buffers. Same is true of any
263  // dictionaries contained in serialized_arrow_output.schema; the arrays
264  // reference that memory (zero copy).
265  return std::make_unique<ArrowResultSet>(results.getRows(), results.getTargetsMeta());
266 }
267 
268 std::unique_ptr<ArrowResultSet> result_set_arrow_loopback(
269  const ExecutionResult* results,
270  const std::shared_ptr<ResultSet>& rows) {
271  return results ? std::make_unique<ArrowResultSet>(rows, results->getTargetsMeta())
272  : std::make_unique<ArrowResultSet>(rows);
273 }
#define CHECK_EQ(x, y)
Definition: Logger.h:205
#define ARROW_THROW_NOT_OK(s)
Definition: ArrowUtil.h:37
Definition: sqltypes.h:50
double decimal_to_double(const SQLTypeInfo &otype, int64_t oval)
std::shared_ptr< ResultSet > rows_
size_t rowCount() const
SQLTypeInfo getColType(const size_t col_idx) const
#define UNREACHABLE()
Definition: Logger.h:241
arrow::ipc::DictionaryMemo dictionary_memo_
SQLTypeInfo type_from_arrow_field(const arrow::Field &field)
std::string to_string(char const *&&v)
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:31
CHECK(cgen_state)
const std::vector< TargetMetaInfo > & getTargetsMeta() const
const std::shared_ptr< ResultSet > & getRows() const
std::vector< TargetValue > getNextRow(const bool translate_strings, const bool decimal_to_double) const
#define CHECK_LT(x, y)
Definition: Logger.h:207
Definition: sqltypes.h:53
Definition: sqltypes.h:54
size_t colCount() const
std::unique_ptr< ArrowResultSet > result_set_arrow_loopback(const ExecutionResult &results)
constexpr float inline_fp_null_value< float >()
std::vector< TargetMetaInfo > column_metainfo_
constexpr double inline_fp_null_value< double >()
boost::variant< std::string, void * > NullableString
Definition: TargetValue.h:155
bool definitelyHasNoRows() const
void appendValue(std::vector< TargetValue > &row, const arrow::Array &column, const Type null_val, const size_t idx) const
void resultSetArrowLoopback()
SQLTypes type
Definition: sqltypes.h:645
int64_t inline_int_null_val(const SQL_TYPE_INFO &ti)
Definition: sqltypes.h:46
std::shared_ptr< arrow::RecordBatch > record_batch_
ArrowResultSet(const std::shared_ptr< ResultSet > &rows, const std::vector< TargetMetaInfo > &targets_meta)
std::vector< TargetMetaInfo > targets_meta_
std::vector< std::shared_ptr< arrow::Array > > columns_