OmniSciDB  a5dc49c757
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TableFunctionsStats.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef __CUDACC__
18 
19 #include "TableFunctionsStats.hpp"
20 
21 template <typename T>
23  const T* data,
24  const int64_t num_rows,
25  const StatsRequestPredicate& predicate) {
26  // const int64_t num_rows = col.size();
27  const size_t max_thread_count = std::thread::hardware_concurrency();
28  const size_t max_inputs_per_thread = 20000;
29  const size_t num_threads = std::min(
30  max_thread_count, ((num_rows + max_inputs_per_thread - 1) / max_inputs_per_thread));
31 
32  std::vector<T> local_col_mins(num_threads, std::numeric_limits<T>::max());
33  std::vector<T> local_col_maxes(num_threads, std::numeric_limits<T>::lowest());
34  std::vector<double> local_col_sums(num_threads, 0.);
35  std::vector<int64_t> local_col_non_null_or_filtered_counts(num_threads, 0L);
36  tbb::task_arena limited_arena(num_threads);
37  limited_arena.execute([&] {
39  tbb::blocked_range<int64_t>(0, num_rows),
40  [&](const tbb::blocked_range<int64_t>& r) {
41  const int64_t start_idx = r.begin();
42  const int64_t end_idx = r.end();
43  T local_col_min = std::numeric_limits<T>::max();
44  T local_col_max = std::numeric_limits<T>::lowest();
45  double local_col_sum = 0.;
46  int64_t local_col_non_null_or_filtered_count = 0;
47  for (int64_t r = start_idx; r < end_idx; ++r) {
48  const T val = data[r];
49  if constexpr (std::is_same_v<T, float> || std::is_same_v<T, double>) {
50  if (std::isnan(val) || std::isinf(val)) {
51  continue;
52  }
53  }
54  if (val == inline_null_value<T>()) {
55  continue;
56  }
57  if (!predicate(val)) {
58  continue;
59  }
60  if (val < local_col_min) {
61  local_col_min = val;
62  }
63  if (val > local_col_max) {
64  local_col_max = val;
65  }
66  local_col_sum += data[r];
67  local_col_non_null_or_filtered_count++;
68  }
69  size_t thread_idx = tbb::this_task_arena::current_thread_index();
70  if (local_col_min < local_col_mins[thread_idx]) {
71  local_col_mins[thread_idx] = local_col_min;
72  }
73  if (local_col_max > local_col_maxes[thread_idx]) {
74  local_col_maxes[thread_idx] = local_col_max;
75  }
76  local_col_sums[thread_idx] += local_col_sum;
77  local_col_non_null_or_filtered_counts[thread_idx] +=
78  local_col_non_null_or_filtered_count;
79  });
80  });
81 
82  ColumnStats<T> column_stats;
83  // Use separate double col_sum instead of column_stats.sum to avoid fp imprecision if T
84  // is float
85  double col_sum = 0.0;
86  column_stats.total_count = num_rows;
87 
88  for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) {
89  if (local_col_mins[thread_idx] < column_stats.min) {
90  column_stats.min = local_col_mins[thread_idx];
91  }
92  if (local_col_maxes[thread_idx] > column_stats.max) {
93  column_stats.max = local_col_maxes[thread_idx];
94  }
95  col_sum += local_col_sums[thread_idx];
96  column_stats.non_null_or_filtered_count +=
97  local_col_non_null_or_filtered_counts[thread_idx];
98  }
99 
100  if (column_stats.non_null_or_filtered_count > 0) {
101  column_stats.sum = col_sum;
102  column_stats.mean = col_sum / column_stats.non_null_or_filtered_count;
103  }
104  return column_stats;
105 }
106 
108  const int8_t* data,
109  const int64_t num_rows,
110  const StatsRequestPredicate& predicate);
112  const int16_t* data,
113  const int64_t num_rows,
114  const StatsRequestPredicate& predicate);
116  const int32_t* data,
117  const int64_t num_rows,
118  const StatsRequestPredicate& predicate);
120  const int64_t* data,
121  const int64_t num_rows,
122  const StatsRequestPredicate& predicate);
124  const float* data,
125  const int64_t num_rows,
126  const StatsRequestPredicate& predicate);
128  const double* data,
129  const int64_t num_rows,
130  const StatsRequestPredicate& predicate);
131 
132 template <typename T>
134  const Column<T>& col,
135  const StatsRequestPredicate& predicate) {
136  return get_column_stats(col.getPtr(), col.size(), predicate);
137 }
138 
140  const Column<int8_t>& col,
141  const StatsRequestPredicate& predicate);
143  const Column<int16_t>& col,
144  const StatsRequestPredicate& predicate);
146  const Column<int32_t>& col,
147  const StatsRequestPredicate& predicate);
149  const Column<int64_t>& col,
150  const StatsRequestPredicate& predicate);
152  const Column<float>& col,
153  const StatsRequestPredicate& predicate);
155  const Column<double>& col,
156  const StatsRequestPredicate& predicate);
157 
159  if (str == "COUNT") {
161  }
162  if (str == "MIN") {
164  }
165  if (str == "MAX") {
167  }
168  if (str == "SUM") {
170  }
171  if (str == "AVG") {
173  }
174  throw std::runtime_error("Invalid StatsRequestAggType: " + str);
175 }
176 
178  const std::string& str) {
179  if (str == "NONE") {
181  }
182  if (str == "LT" || str == "<") {
184  }
185  if (str == "GT" || str == ">") {
187  }
188  throw std::runtime_error("Invalid StatsRequestPredicateOp: " + str);
189 }
190 
191 std::string replace_substrings(const std::string& str,
192  const std::string& pattern_str,
193  const std::string& replacement_str) {
194  std::string replaced_str(str);
195 
196  size_t search_start_index = 0;
197  const auto pattern_str_len = pattern_str.size();
198  const auto replacement_str_len = replacement_str.size();
199 
200  while (true) {
201  search_start_index = replaced_str.find(pattern_str, search_start_index);
202  if (search_start_index == std::string::npos) {
203  break;
204  }
205  replaced_str.replace(search_start_index, pattern_str_len, replacement_str);
206  search_start_index += replacement_str_len;
207  }
208  return replaced_str;
209 }
210 
211 std::vector<StatsRequest> parse_stats_requests_json(
212  const std::string& stats_requests_json_str,
213  const int64_t num_attrs) {
214  std::vector<StatsRequest> stats_requests;
215  rapidjson::Document doc;
216 
217  // remove double double quotes our parser introduces
218  const auto fixed_stats_requests_json_str =
219  replace_substrings(stats_requests_json_str, "\"\"", "\"");
220 
221  if (doc.Parse(fixed_stats_requests_json_str.c_str()).HasParseError()) {
222  // Not valid JSON
223  std::cout << "DEBUG: Failed JSON: " << fixed_stats_requests_json_str << std::endl;
224  throw std::runtime_error("Could not parse Stats Requests JSON.");
225  }
226  // Todo (todd): Enforce Schema
227  if (!doc.IsArray()) {
228  throw std::runtime_error("Stats Request JSON did not contain valid root Array.");
229  }
230  const std::vector<std::string> required_keys = {
231  "name", "attr_id", "agg_type", "filter_type"};
232 
233  for (const auto& stat_request_obj : doc.GetArray()) {
234  for (const auto& required_key : required_keys) {
235  if (!stat_request_obj.HasMember(required_key)) {
236  throw std::runtime_error("Stats Request JSON missing key " + required_key + ".");
237  }
238  if (required_key == "attr_id") {
239  if (!stat_request_obj[required_key].IsUint()) {
240  throw std::runtime_error(required_key + " must be int type");
241  }
242  } else {
243  if (!stat_request_obj[required_key].IsString()) {
244  throw std::runtime_error(required_key + " must be string type");
245  }
246  }
247  }
248  StatsRequest stats_request;
249  stats_request.name = stat_request_obj["name"].GetString();
250  stats_request.attr_id = stat_request_obj["attr_id"].GetInt() - 1;
251  if (stats_request.attr_id < 0 || stats_request.attr_id >= num_attrs) {
252  throw std::runtime_error("Invalid attr_id: " +
253  std::to_string(stats_request.attr_id));
254  }
255 
256  std::string agg_type_str = stat_request_obj["agg_type"].GetString();
258  agg_type_str.begin(), agg_type_str.end(), agg_type_str.begin(), ::toupper);
259  stats_request.agg_type = convert_string_to_stats_request_agg_type(agg_type_str);
260 
261  std::string filter_type_str = stat_request_obj["filter_type"].GetString();
262  std::transform(filter_type_str.begin(),
263  filter_type_str.end(),
264  filter_type_str.begin(),
265  ::toupper);
266  stats_request.filter_type =
268  if (stats_request.filter_type != StatsRequestPredicateOp::NONE) {
269  if (!stat_request_obj.HasMember("filter_val")) {
270  throw std::runtime_error("Stats Request JSON missing expected filter_val");
271  }
272  if (!stat_request_obj["filter_val"].IsNumber()) {
273  throw std::runtime_error("Stats Request JSON filter_val should be numeric.");
274  }
275  stats_request.filter_val = stat_request_obj["filter_val"].GetDouble();
276  }
277  stats_requests.emplace_back(stats_request);
278  }
279  return stats_requests;
280 }
281 
282 std::vector<std::pair<const char*, double>> get_stats_key_value_pairs(
283  const std::vector<StatsRequest>& stats_requests) {
284  std::vector<std::pair<const char*, double>> stats_key_value_pairs;
285  for (const auto& stats_request : stats_requests) {
286  stats_key_value_pairs.emplace_back(
287  std::make_pair(stats_request.name.c_str(), stats_request.result));
288  }
289  return stats_key_value_pairs;
290 }
291 
292 #endif // __CUDACC__
std::vector< StatsRequest > parse_stats_requests_json(const std::string &stats_requests_json_str, const int64_t num_attrs)
std::vector< std::pair< const char *, double > > get_stats_key_value_pairs(const std::vector< StatsRequest > &stats_requests)
DEVICE int64_t size() const
NEVER_INLINE HOST ColumnStats< T > get_column_stats(const T *data, const int64_t num_rows, const StatsRequestPredicate &predicate)
StatsRequestPredicateOp convert_string_to_stats_request_predicate_op(const std::string &str)
StatsRequestPredicateOp
DEVICE T * getPtr() const
std::string to_string(char const *&&v)
std::string replace_substrings(const std::string &str, const std::string &pattern_str, const std::string &replacement_str)
#define HOST
const size_t max_inputs_per_thread
int64_t non_null_or_filtered_count
OUTPUT transform(INPUT const &input, FUNC const &func)
Definition: misc.h:329
StatsRequestAggType convert_string_to_stats_request_agg_type(const std::string &str)
StatsRequestPredicateOp filter_type
void parallel_for(const blocked_range< Int > &range, const Body &body, const Partitioner &p=Partitioner())
#define NEVER_INLINE
StatsRequestAggType
StatsRequestAggType agg_type