OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TableFunctionsStats.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef __CUDACC__
18 
19 #include "TableFunctionsStats.hpp"
20 
21 template <typename T>
23  const T* data,
24  const int64_t num_rows,
25  const StatsRequestPredicate& predicate) {
26  // const int64_t num_rows = col.size();
27  const size_t max_thread_count = std::thread::hardware_concurrency();
28  const size_t max_inputs_per_thread = 20000;
29  const size_t num_threads = std::min(
30  max_thread_count, ((num_rows + max_inputs_per_thread - 1) / max_inputs_per_thread));
31 
32  std::vector<T> local_col_mins(num_threads, std::numeric_limits<T>::max());
33  std::vector<T> local_col_maxes(num_threads, std::numeric_limits<T>::lowest());
34  std::vector<double> local_col_sums(num_threads, 0.);
35  std::vector<int64_t> local_col_non_null_or_filtered_counts(num_threads, 0L);
36  tbb::task_arena limited_arena(num_threads);
37  limited_arena.execute([&] {
38  tbb::parallel_for(tbb::blocked_range<int64_t>(0, num_rows),
39  [&](const tbb::blocked_range<int64_t>& r) {
40  const int64_t start_idx = r.begin();
41  const int64_t end_idx = r.end();
42  T local_col_min = std::numeric_limits<T>::max();
43  T local_col_max = std::numeric_limits<T>::lowest();
44  double local_col_sum = 0.;
45  int64_t local_col_non_null_or_filtered_count = 0;
46  for (int64_t r = start_idx; r < end_idx; ++r) {
47  const T val = data[r];
48  if (val == inline_null_value<T>()) {
49  continue;
50  }
51  if (!predicate(val)) {
52  continue;
53  }
54  if (val < local_col_min) {
55  local_col_min = val;
56  }
57  if (val > local_col_max) {
58  local_col_max = val;
59  }
60  local_col_sum += data[r];
61  local_col_non_null_or_filtered_count++;
62  }
63  size_t thread_idx = tbb::this_task_arena::current_thread_index();
64  if (local_col_min < local_col_mins[thread_idx]) {
65  local_col_mins[thread_idx] = local_col_min;
66  }
67  if (local_col_max > local_col_maxes[thread_idx]) {
68  local_col_maxes[thread_idx] = local_col_max;
69  }
70  local_col_sums[thread_idx] += local_col_sum;
71  local_col_non_null_or_filtered_counts[thread_idx] +=
72  local_col_non_null_or_filtered_count;
73  });
74  });
75 
76  ColumnStats<T> column_stats;
77  // Use separate double col_sum instead of column_stats.sum to avoid fp imprecision if T
78  // is float
79  double col_sum = 0.0;
80  column_stats.total_count = num_rows;
81 
82  for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) {
83  if (local_col_mins[thread_idx] < column_stats.min) {
84  column_stats.min = local_col_mins[thread_idx];
85  }
86  if (local_col_maxes[thread_idx] > column_stats.max) {
87  column_stats.max = local_col_maxes[thread_idx];
88  }
89  col_sum += local_col_sums[thread_idx];
90  column_stats.non_null_or_filtered_count +=
91  local_col_non_null_or_filtered_counts[thread_idx];
92  }
93 
94  if (column_stats.non_null_or_filtered_count > 0) {
95  column_stats.sum = col_sum;
96  column_stats.mean = col_sum / column_stats.non_null_or_filtered_count;
97  }
98  return column_stats;
99 }
100 
102  const int8_t* data,
103  const int64_t num_rows,
104  const StatsRequestPredicate& predicate);
106  const int16_t* data,
107  const int64_t num_rows,
108  const StatsRequestPredicate& predicate);
110  const int32_t* data,
111  const int64_t num_rows,
112  const StatsRequestPredicate& predicate);
114  const int64_t* data,
115  const int64_t num_rows,
116  const StatsRequestPredicate& predicate);
118  const float* data,
119  const int64_t num_rows,
120  const StatsRequestPredicate& predicate);
122  const double* data,
123  const int64_t num_rows,
124  const StatsRequestPredicate& predicate);
125 
126 template <typename T>
128  const Column<T>& col,
129  const StatsRequestPredicate& predicate) {
130  return get_column_stats(col.getPtr(), col.size(), predicate);
131 }
132 
134  const Column<int8_t>& col,
135  const StatsRequestPredicate& predicate);
137  const Column<int16_t>& col,
138  const StatsRequestPredicate& predicate);
140  const Column<int32_t>& col,
141  const StatsRequestPredicate& predicate);
143  const Column<int64_t>& col,
144  const StatsRequestPredicate& predicate);
146  const Column<float>& col,
147  const StatsRequestPredicate& predicate);
149  const Column<double>& col,
150  const StatsRequestPredicate& predicate);
151 
153  if (str == "COUNT") {
155  }
156  if (str == "MIN") {
158  }
159  if (str == "MAX") {
161  }
162  if (str == "SUM") {
164  }
165  if (str == "AVG") {
167  }
168  throw std::runtime_error("Invalid StatsRequestAggType: " + str);
169 }
170 
172  const std::string& str) {
173  if (str == "NONE") {
175  }
176  if (str == "LT" || str == "<") {
178  }
179  if (str == "GT" || str == ">") {
181  }
182  throw std::runtime_error("Invalid StatsRequestPredicateOp: " + str);
183 }
184 
185 std::string replace_substrings(const std::string& str,
186  const std::string& pattern_str,
187  const std::string& replacement_str) {
188  std::string replaced_str(str);
189 
190  size_t search_start_index = 0;
191  const auto pattern_str_len = pattern_str.size();
192  const auto replacement_str_len = replacement_str.size();
193 
194  while (true) {
195  search_start_index = replaced_str.find(pattern_str, search_start_index);
196  if (search_start_index == std::string::npos) {
197  break;
198  }
199  replaced_str.replace(search_start_index, pattern_str_len, replacement_str);
200  search_start_index += replacement_str_len;
201  }
202  return replaced_str;
203 }
204 
205 std::vector<StatsRequest> parse_stats_requests_json(
206  const std::string& stats_requests_json_str,
207  const int64_t num_attrs) {
208  std::vector<StatsRequest> stats_requests;
209  rapidjson::Document doc;
210 
211  // remove double double quotes our parser introduces
212  const auto fixed_stats_requests_json_str =
213  replace_substrings(stats_requests_json_str, "\"\"", "\"");
214 
215  if (doc.Parse(fixed_stats_requests_json_str.c_str()).HasParseError()) {
216  // Not valid JSON
217  std::cout << "DEBUG: Failed JSON: " << fixed_stats_requests_json_str << std::endl;
218  throw std::runtime_error("Could not parse Stats Requests JSON.");
219  }
220  // Todo (todd): Enforce Schema
221  if (!doc.IsArray()) {
222  throw std::runtime_error("Stats Request JSON did not contain valid root Array.");
223  }
224  const std::vector<std::string> required_keys = {
225  "name", "attr_id", "agg_type", "filter_type"};
226 
227  for (const auto& stat_request_obj : doc.GetArray()) {
228  for (const auto& required_key : required_keys) {
229  if (!stat_request_obj.HasMember(required_key)) {
230  throw std::runtime_error("Stats Request JSON missing key " + required_key + ".");
231  }
232  if (required_key == "attr_id") {
233  if (!stat_request_obj[required_key].IsUint()) {
234  throw std::runtime_error(required_key + " must be int type");
235  }
236  } else {
237  if (!stat_request_obj[required_key].IsString()) {
238  throw std::runtime_error(required_key + " must be string type");
239  }
240  }
241  }
242  StatsRequest stats_request;
243  stats_request.name = stat_request_obj["name"].GetString();
244  stats_request.attr_id = stat_request_obj["attr_id"].GetInt() - 1;
245  if (stats_request.attr_id < 0 || stats_request.attr_id >= num_attrs) {
246  throw std::runtime_error("Invalid attr_id: " +
247  std::to_string(stats_request.attr_id));
248  }
249 
250  std::string agg_type_str = stat_request_obj["agg_type"].GetString();
252  agg_type_str.begin(), agg_type_str.end(), agg_type_str.begin(), ::toupper);
253  stats_request.agg_type = convert_string_to_stats_request_agg_type(agg_type_str);
254 
255  std::string filter_type_str = stat_request_obj["filter_type"].GetString();
256  std::transform(filter_type_str.begin(),
257  filter_type_str.end(),
258  filter_type_str.begin(),
259  ::toupper);
260  stats_request.filter_type =
262  if (stats_request.filter_type != StatsRequestPredicateOp::NONE) {
263  if (!stat_request_obj.HasMember("filter_val")) {
264  throw std::runtime_error("Stats Request JSON missing expected filter_val");
265  }
266  if (!stat_request_obj["filter_val"].IsNumber()) {
267  throw std::runtime_error("Stats Request JSON filter_val should be numeric.");
268  }
269  stats_request.filter_val = stat_request_obj["filter_val"].GetDouble();
270  }
271  stats_requests.emplace_back(stats_request);
272  }
273  return stats_requests;
274 }
275 
276 std::vector<std::pair<const char*, double>> get_stats_key_value_pairs(
277  const std::vector<StatsRequest>& stats_requests) {
278  std::vector<std::pair<const char*, double>> stats_key_value_pairs;
279  for (const auto& stats_request : stats_requests) {
280  stats_key_value_pairs.emplace_back(
281  std::make_pair(stats_request.name.c_str(), stats_request.result));
282  }
283  return stats_key_value_pairs;
284 }
285 
286 #endif // __CUDACC__
std::vector< StatsRequest > parse_stats_requests_json(const std::string &stats_requests_json_str, const int64_t num_attrs)
std::vector< std::pair< const char *, double > > get_stats_key_value_pairs(const std::vector< StatsRequest > &stats_requests)
DEVICE int64_t size() const
NEVER_INLINE HOST ColumnStats< T > get_column_stats(const T *data, const int64_t num_rows, const StatsRequestPredicate &predicate)
StatsRequestPredicateOp convert_string_to_stats_request_predicate_op(const std::string &str)
StatsRequestPredicateOp
DEVICE T * getPtr() const
std::string to_string(char const *&&v)
std::string replace_substrings(const std::string &str, const std::string &pattern_str, const std::string &replacement_str)
#define HOST
const size_t max_inputs_per_thread
int64_t non_null_or_filtered_count
OUTPUT transform(INPUT const &input, FUNC const &func)
Definition: misc.h:320
StatsRequestAggType convert_string_to_stats_request_agg_type(const std::string &str)
StatsRequestPredicateOp filter_type
void parallel_for(const blocked_range< Int > &range, const Body &body, const Partitioner &p=Partitioner())
#define NEVER_INLINE
StatsRequestAggType
StatsRequestAggType agg_type