24 const int64_t num_rows,
27 const size_t max_thread_count = std::thread::hardware_concurrency();
29 const size_t num_threads = std::min(
30 max_thread_count, ((num_rows + max_inputs_per_thread - 1) / max_inputs_per_thread));
32 std::vector<T> local_col_mins(num_threads, std::numeric_limits<T>::max());
33 std::vector<T> local_col_maxes(num_threads, std::numeric_limits<T>::lowest());
34 std::vector<double> local_col_sums(num_threads, 0.);
35 std::vector<int64_t> local_col_non_null_or_filtered_counts(num_threads, 0L);
36 tbb::task_arena limited_arena(num_threads);
37 limited_arena.execute([&] {
39 [&](
const tbb::blocked_range<int64_t>& r) {
40 const int64_t start_idx = r.begin();
41 const int64_t end_idx = r.end();
42 T local_col_min = std::numeric_limits<T>::max();
43 T local_col_max = std::numeric_limits<T>::lowest();
44 double local_col_sum = 0.;
45 int64_t local_col_non_null_or_filtered_count = 0;
46 for (int64_t r = start_idx; r < end_idx; ++r) {
47 const T val = data[r];
48 if (val == inline_null_value<T>()) {
51 if (!predicate(val)) {
54 if (val < local_col_min) {
57 if (val > local_col_max) {
60 local_col_sum += data[r];
61 local_col_non_null_or_filtered_count++;
63 size_t thread_idx = tbb::this_task_arena::current_thread_index();
64 if (local_col_min < local_col_mins[thread_idx]) {
65 local_col_mins[thread_idx] = local_col_min;
67 if (local_col_max > local_col_maxes[thread_idx]) {
68 local_col_maxes[thread_idx] = local_col_max;
70 local_col_sums[thread_idx] += local_col_sum;
71 local_col_non_null_or_filtered_counts[thread_idx] +=
72 local_col_non_null_or_filtered_count;
82 for (
size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) {
83 if (local_col_mins[thread_idx] < column_stats.
min) {
84 column_stats.
min = local_col_mins[thread_idx];
86 if (local_col_maxes[thread_idx] > column_stats.
max) {
87 column_stats.
max = local_col_maxes[thread_idx];
89 col_sum += local_col_sums[thread_idx];
91 local_col_non_null_or_filtered_counts[thread_idx];
95 column_stats.
sum = col_sum;
103 const int64_t num_rows,
107 const int64_t num_rows,
111 const int64_t num_rows,
115 const int64_t num_rows,
119 const int64_t num_rows,
123 const int64_t num_rows,
126 template <
typename T>
153 if (str ==
"COUNT") {
168 throw std::runtime_error(
"Invalid StatsRequestAggType: " + str);
172 const std::string& str) {
176 if (str ==
"LT" || str ==
"<") {
179 if (str ==
"GT" || str ==
">") {
182 throw std::runtime_error(
"Invalid StatsRequestPredicateOp: " + str);
186 const std::string& pattern_str,
187 const std::string& replacement_str) {
188 std::string replaced_str(str);
190 size_t search_start_index = 0;
191 const auto pattern_str_len = pattern_str.size();
192 const auto replacement_str_len = replacement_str.size();
195 search_start_index = replaced_str.find(pattern_str, search_start_index);
196 if (search_start_index == std::string::npos) {
199 replaced_str.replace(search_start_index, pattern_str_len, replacement_str);
200 search_start_index += replacement_str_len;
206 const std::string& stats_requests_json_str,
207 const int64_t num_attrs) {
208 std::vector<StatsRequest> stats_requests;
209 rapidjson::Document doc;
212 const auto fixed_stats_requests_json_str =
215 if (doc.Parse(fixed_stats_requests_json_str.c_str()).HasParseError()) {
217 std::cout <<
"DEBUG: Failed JSON: " << fixed_stats_requests_json_str << std::endl;
218 throw std::runtime_error(
"Could not parse Stats Requests JSON.");
221 if (!doc.IsArray()) {
222 throw std::runtime_error(
"Stats Request JSON did not contain valid root Array.");
224 const std::vector<std::string> required_keys = {
225 "name",
"attr_id",
"agg_type",
"filter_type"};
227 for (
const auto& stat_request_obj : doc.GetArray()) {
228 for (
const auto& required_key : required_keys) {
229 if (!stat_request_obj.HasMember(required_key)) {
230 throw std::runtime_error(
"Stats Request JSON missing key " + required_key +
".");
232 if (required_key ==
"attr_id") {
233 if (!stat_request_obj[required_key].IsUint()) {
234 throw std::runtime_error(required_key +
" must be int type");
237 if (!stat_request_obj[required_key].IsString()) {
238 throw std::runtime_error(required_key +
" must be string type");
243 stats_request.
name = stat_request_obj[
"name"].GetString();
244 stats_request.
attr_id = stat_request_obj[
"attr_id"].GetInt() - 1;
245 if (stats_request.
attr_id < 0 || stats_request.
attr_id >= num_attrs) {
246 throw std::runtime_error(
"Invalid attr_id: " +
250 std::string agg_type_str = stat_request_obj[
"agg_type"].GetString();
252 agg_type_str.begin(), agg_type_str.end(), agg_type_str.begin(), ::toupper);
255 std::string filter_type_str = stat_request_obj[
"filter_type"].GetString();
257 filter_type_str.end(),
258 filter_type_str.begin(),
263 if (!stat_request_obj.HasMember(
"filter_val")) {
264 throw std::runtime_error(
"Stats Request JSON missing expected filter_val");
266 if (!stat_request_obj[
"filter_val"].IsNumber()) {
267 throw std::runtime_error(
"Stats Request JSON filter_val should be numeric.");
269 stats_request.
filter_val = stat_request_obj[
"filter_val"].GetDouble();
271 stats_requests.emplace_back(stats_request);
273 return stats_requests;
277 const std::vector<StatsRequest>& stats_requests) {
278 std::vector<std::pair<const char*, double>> stats_key_value_pairs;
279 for (
const auto& stats_request : stats_requests) {
280 stats_key_value_pairs.emplace_back(
281 std::make_pair(stats_request.name.c_str(), stats_request.result));
283 return stats_key_value_pairs;
std::vector< StatsRequest > parse_stats_requests_json(const std::string &stats_requests_json_str, const int64_t num_attrs)
std::vector< std::pair< const char *, double > > get_stats_key_value_pairs(const std::vector< StatsRequest > &stats_requests)
DEVICE int64_t size() const
NEVER_INLINE HOST ColumnStats< T > get_column_stats(const T *data, const int64_t num_rows, const StatsRequestPredicate &predicate)
StatsRequestPredicateOp convert_string_to_stats_request_predicate_op(const std::string &str)
DEVICE T * getPtr() const
const size_t max_inputs_per_thread
std::string replace_substrings(const std::string &str, const std::string &pattern_str, const std::string &replacement_str)
int64_t non_null_or_filtered_count
OUTPUT transform(INPUT const &input, FUNC const &func)
StatsRequestAggType convert_string_to_stats_request_agg_type(const std::string &str)
StatsRequestPredicateOp filter_type
void parallel_for(const blocked_range< Int > &range, const Body &body, const Partitioner &p=Partitioner())
StatsRequestAggType agg_type