40 #include <type_traits> 46 template <
typename RealType,
typename IndexType>
52 template <
typename RealType,
typename IndexType>
59 static constexpr RealType infinity = std::numeric_limits<RealType>::infinity();
60 static constexpr RealType nan = std::numeric_limits<RealType>::quiet_NaN();
61 RealType max_{-infinity};
62 RealType min_{infinity};
67 : sums_(sums, size, size), counts_(counts, size, size) {}
70 : sums_(sums), counts_(counts) {}
91 DEVICE RealType
mean(IndexType
const i)
const {
return sums_[i] / counts_[i]; }
94 DEVICE bool mergeIfFits(
Centroids& centroid, IndexType
const max_count);
96 DEVICE void moveNextToCurrent();
106 RealType
const lhs = nextSum() * b.
nextCount();
107 RealType
const rhs = b.
nextSum() * nextCount();
108 return lhs < rhs || (lhs == rhs && b.
nextCount() < nextCount());
116 DEVICE void resetIndices(
bool const forward);
125 template <
typename RealType2,
typename IndexType2>
126 friend std::ostream& operator<<(std::ostream&, Centroids<RealType2, IndexType2>
const&);
130 template <
typename RealType,
typename IndexType>
135 IndexType prefix_sum_{0};
139 DEVICE void mergeMinMax();
140 DEVICE void setCurrCentroid();
153 DEVICE void merge(IndexType
const max_count);
164 template <
typename RealType,
typename IndexType>
172 size_t nbytes()
const {
return sums_.size() * (
sizeof(RealType) +
sizeof(IndexType)); }
173 size_t size()
const {
return sums_.size(); }
176 return {counts_.
data(), counts_.size(), counts_.size()};
180 template <
typename RealType,
typename IndexType =
size_t>
189 DEVICE IndexType maxCardinality(IndexType
const sum, IndexType
const total_weight);
194 DEVICE RealType firstCentroid(RealType
const x);
195 DEVICE RealType interiorCentroid(RealType
const x,
196 IndexType
const idx1,
197 IndexType
const prefix_sum);
198 DEVICE RealType lastCentroid(RealType
const x, IndexType
const N);
199 DEVICE RealType oneCentroid(RealType
const x);
201 DEVICE RealType slope(IndexType
const idx1, IndexType
const idx2);
209 : centroids_(mem.sums().data(), mem.counts().data(), mem.size()) {
216 DEVICE void add(RealType value);
218 DEVICE void mergeBuffer();
221 DEVICE void mergeSorted(RealType* sums, IndexType* counts, IndexType size);
269 template <
typename RealType,
typename IndexType>
276 return lhs < rhs || (lhs == rhs && b.
value1() < a.
value1());
282 template <
typename RealType,
typename IndexType>
284 if (inc_ == -1 && curr_idx_ != 0) {
294 IndexType
const offset = inc_ == 1 ? 0 : buff.
curr_idx_;
295 IndexType
const buff_size =
300 IndexType
const curr_size = inc_ == 1 ? curr_idx_ + 1 : size() - curr_idx_;
301 IndexType
const total_size = curr_size + buff_sums.
size();
302 assert(total_size <= sums_.capacity());
303 sums_.resize(total_size);
304 gpu_enabled::copy(buff_sums.begin(), buff_sums.end(), sums_.begin() + curr_size);
305 assert(total_size <= counts_.capacity());
306 counts_.resize(total_size);
307 gpu_enabled::copy(buff_counts.begin(), buff_counts.end(), counts_.begin() + curr_size);
316 template <
typename RealType,
typename IndexType>
318 IndexType
const max_count) {
319 if (counts_[curr_idx_] + centroid.
nextCount() <= max_count) {
320 sums_[curr_idx_] += centroid.
nextSum();
321 counts_[curr_idx_] += centroid.
nextCount();
328 template <
typename RealType,
typename IndexType>
331 if (curr_idx_ != next_idx_) {
332 sums_[curr_idx_] = sums_[next_idx_];
333 counts_[curr_idx_] = counts_[next_idx_];
338 template <
typename RealType,
typename IndexType>
342 curr_idx_ = ~IndexType(0);
347 static_assert(std::is_unsigned<IndexType>::value,
348 "IndexType must be an unsigned type.");
349 next_idx_ = curr_idx_ + inc_;
353 template <
typename RealType,
typename IndexType>
356 out <<
"Centroids<" <<
typeid(RealType).
name() <<
',' <<
typeid(IndexType).
name()
357 <<
">(size(" << centroids.
size() <<
") curr_idx_(" << centroids.
curr_idx_ 358 <<
") next_idx_(" << centroids.
next_idx_ <<
") sums_(";
359 for (IndexType i = 0; i < centroids.
sums_.
size(); ++i) {
360 out << (i ?
" " :
"") << std::setprecision(20) << centroids.
sums_[i];
363 for (IndexType i = 0; i < centroids.
counts_.
size(); ++i) {
364 out << (i ?
" " :
"") << centroids.
counts_[i];
372 template <
typename RealType,
typename IndexType>
378 , centroids_(centroids)
379 , total_weight_(centroids->totalWeight() + buf->totalWeight())
380 , forward_(forward) {
387 template <
typename RealType,
typename IndexType>
392 }
else if (
buf_->hasNext()) {
404 template <
typename RealType,
typename IndexType>
409 IndexType count_merged_{};
410 IndexType count_skipped_{};
427 template <
typename T>
430 IndexType
const merged,
433 T* src = begin + inc * (skipped - 1);
434 T* dst = src + inc * merged;
435 for (; skipped; --skipped, src -= inc, dst -= inc) {
440 std::copy_backward(begin, begin + skipped, begin + skipped + merged);
442 std::copy(begin + 1 - skipped, begin + 1, begin + 1 - skipped - merged);
449 return data_[0].centroid_ != centroid;
456 IndexType
const idx = index(next_centroid);
457 if (idx == 1 && data_[1].centroid_ ==
nullptr) {
458 data_[1] = {next_centroid, next_centroid->
next_idx_ + next_centroid->
inc_, 0, 0};
459 }
else if (data_[idx].count_skipped_) {
460 ++data_[idx].count_merged_;
463 data_[1].start_ += next_centroid->
inc_;
466 DEVICE operator bool()
const {
return data_[0].centroid_; }
469 shiftCentroids(data_[0]);
470 data_[0].centroid_->next_idx_ = data_[0].start_;
471 if (data_[1].centroid_) {
472 shiftCentroids(data_[1]);
473 data_[1].centroid_->next_idx_ = data_[1].start_;
479 data_[0] = {next_centroid, next_centroid->
next_idx_, 0, 1};
482 IndexType
const idx = index(next_centroid);
483 if (idx == 1 && data_[1].centroid_ ==
nullptr) {
484 data_[1] = {next_centroid, next_centroid->
next_idx_, 0, 1};
486 if (data_[idx].count_merged_) {
487 shiftCentroids(data_[idx]);
488 data_[idx].count_merged_ = 0;
490 ++data_[idx].count_skipped_;
499 template <
typename RealType,
typename IndexType>
501 Skipped<RealType, IndexType> skipped;
504 if (skipped.isDifferentMean(next_centroid)) {
506 }
else if (
curr_centroid_->mergeIfFits(*next_centroid, max_count)) {
507 skipped.merged(next_centroid);
509 skipped.skipSubsequent(next_centroid);
511 }
else if (!
curr_centroid_->mergeIfFits(*next_centroid, max_count)) {
512 skipped.skipFirst(next_centroid);
516 skipped.shiftCentroidsAndSetNext();
528 template <
typename RealType,
typename IndexType>
539 template <
typename RealType,
typename IndexType>
545 template <
typename RealType,
typename IndexType>
554 template <
typename RealType,
typename IndexType>
556 if (
buf_.sums_.full()) {
559 buf_.sums_.push_back(value);
560 buf_.counts_.push_back(1);
563 template <
typename RealType,
typename IndexType>
566 IndexType
const total_weight) {
567 IndexType
const max_bins =
centroids_.capacity();
568 return max_bins < total_weight ? 2 * total_weight / max_bins : 0;
572 template <
typename RealType,
typename IndexType>
578 mergeCentroids(
buf_);
582 template <
typename RealType,
typename IndexType>
587 if (
buf_.capacity() == 0) {
590 buf_.sums_.set(sums, size);
591 buf_.counts_.set(counts, size);
596 mergeCentroids(
buf_);
605 template <
typename RealType,
typename IndexType>
615 IndexType
const max_cardinality = maxCardinality(cm.prefixSum(), cm.totalWeight());
616 cm.
merge(max_cardinality);
625 template <
typename CountsIterator>
632 template <
typename RealType,
typename IndexType>
637 return oneCentroid(x);
639 RealType
const sum =
centroids_.sums_.front();
640 return x == 1 ? 0.5 * sum : sum - min();
642 RealType
const count =
centroids_.counts_.front();
643 RealType
const dx = x - RealType(0.5) * (1 + count);
644 RealType
const mean = (
centroids_.sums_.front() - min()) / (count - 1);
645 return mean + slope(0, 0 < dx) * dx;
650 template <
typename RealType,
typename IndexType>
653 IndexType
const idx1,
654 IndexType
const prefix_sum) {
657 if (x == prefix_sum -
centroids_.counts_[idx1]) {
659 return 0.5 * (
centroids_.sums_[idx1 - 1] + sum1);
660 }
else if (idx1 == 1 &&
centroids_.counts_[0] == 2) {
661 return 0.5 * (
centroids_.sums_[idx1 - 1] - min() + sum1);
666 RealType
const dx = x + RealType(0.5) *
centroids_.counts_[idx1] - prefix_sum;
667 IndexType
const idx2 = idx1 + 2 * (0 < dx) - 1;
668 return centroids_.mean(idx1) + slope(idx1, idx2) * dx;
673 template <
typename RealType,
typename IndexType>
681 IndexType
const count1 =
centroids_.counts_[idx1];
684 return 0.5 * (
centroids_.sums_[idx1 - 1] + sum1);
685 }
else if (idx1 == 1 &&
centroids_.counts_[0] == 2) {
686 return 0.5 * (
centroids_.sums_[idx1 - 1] - min() + sum1);
690 }
else if (count1 == 2) {
693 }
else if (x == N - 2) {
694 RealType
const sum2 =
centroids_.sums_[idx1 - 1];
696 return 0.5 * (sum2 + sum1 - max());
697 }
else if (idx1 == 1 &&
centroids_.counts_[0] == 2) {
698 return 0.5 * (sum2 - min() + sum1 - max());
703 RealType
const dx = x + RealType(0.5) * (count1 + 1) - N;
704 RealType
const mean = (sum1 - max()) / (count1 - 1);
705 return mean + slope(idx1, idx1 - (dx < 0)) * dx;
710 template <
typename RealType,
typename IndexType>
712 IndexType
const N =
centroids_.counts_.front();
719 return 0.5 * (
centroids_.sums_.front() - min());
721 RealType
const s =
centroids_.sums_.front() - max();
722 return x == 1 ? 0.5 * s : s - min();
725 RealType
const dx = x - RealType(0.5) * N;
726 RealType
const mean = (
centroids_.sums_.front() - (min() + max())) / (N - 2);
727 RealType
const slope = 2 * (0 < dx ? max() - mean : mean - min()) / (N - 2);
728 return mean + slope * dx;
733 template <
typename RealType,
typename IndexType>
740 RealType
const x = q * N;
743 return firstCentroid(x);
747 return lastCentroid(x, N);
749 return interiorCentroid(x, it1 -
partial_sum.begin(), *it1);
760 template <
typename RealType,
typename IndexType>
764 RealType
const n =
static_cast<RealType
>(
centroids_.counts_[idx1]);
766 return idx1 == 0 ? 2 * (s - n * min()) / ((n - 1) * (n - 1))
767 : 2 * (n * max() - s) / ((n - 1) * (n - 1));
769 bool const min1 = idx1 == 0;
770 bool const max1 = idx1 == M - 1;
771 bool const min2 = idx2 == 0;
772 bool const max2 = idx2 == M - 1;
773 RealType
const n1 =
static_cast<RealType
>(
centroids_.counts_[idx1] - min1 - max1);
774 RealType
const s1 =
centroids_.sums_[idx1] - (min1 ? min() : max1 ? max() : 0);
775 RealType
const s2 =
centroids_.sums_[idx2] - (min2 ? min() : max2 ? max() : 0);
777 return (idx1 < idx2 ? 2 : -2) * (n1 * s2 - s1) / (n1 * n1);
779 RealType
const n2 =
static_cast<RealType
>(
centroids_.counts_[idx2] - min2 - max2);
780 return (idx1 < idx2 ? 2 : -2) * (n1 * s2 - n2 * s1) / (n1 * n2 * (n1 + n2));
DEVICE void setCurrCentroid()
DEVICE void push_back(RealType const value, RealType const count)
DEVICE void push_back(T const &value)
std::vector< RealType > sums_
Centroids< RealType, IndexType > * buf_
DEVICE void skipSubsequent(Centroids< RealType, IndexType > *next_centroid)
std::ostream & operator<<(std::ostream &out, Centroids< RealType, IndexType > const ¢roids)
DEVICE void sort(ARGS &&... args)
DEVICE RealType quantile(RealType const q)
DEVICE IndexType capacity() const
DEVICE void fill(ARGS &&... args)
CentroidsMemory(size_t const size)
DEVICE void resetIndices(bool const forward)
DEVICE Centroids< RealType, IndexType > & centroids()
DEVICE IndexType totalWeight() const
DEVICE void moveNextToCurrent()
DEVICE void add(RealType value)
DEVICE RealType nextSum() const
Centroids< RealType, IndexType > * centroid_
IndexType const total_weight_
DEVICE void merged(Centroids< RealType, IndexType > *next_centroid)
DEVICE bool mergeIfFits(Centroids ¢roid, IndexType const max_count)
DEVICE TDigest(Memory &mem)
DEVICE void mergeBuffer()
DEVICE RealType max() const
DEVICE void partial_sum(ARGS &&... args)
VectorView< RealType > sums()
VectorView< RealType > sums_
VectorView< IndexType > counts_
DEVICE IndexType maxCardinality(IndexType const sum, IndexType const total_weight)
std::vector< IndexType > counts_
DEVICE RealType lastCentroid(RealType const x, IndexType const N)
static DEVICE void shiftCentroids(Data &data)
DEVICE IndexType currCount() const
DEVICE auto copy(ARGS &&... args)
DEVICE void setCentroids(Memory &mem)
DEVICE Centroids< RealType, IndexType > * getNextCentroid() const
DEVICE RealType oneCentroid(RealType const x)
DEVICE IndexType prefixSum() const
Centroids< RealType, IndexType > * centroids_
DEVICE size_type size() const
DEVICE RealType firstCentroid(RealType const x)
DEVICE auto upper_bound(ARGS &&... args)
DEVICE Centroids(VectorView< RealType > sums, VectorView< IndexType > counts)
DEVICE CentroidsMerger(Centroids< RealType, IndexType > *buf, Centroids< RealType, IndexType > *centroids, bool const forward)
DEVICE void mergeTDigest(TDigest &t_digest)
DEVICE void setCentroids(VectorView< RealType > const sums, VectorView< IndexType > const counts)
DEVICE RealType currMean() const
DEVICE RealType interiorCentroid(RealType const x, IndexType const idx1, IndexType const prefix_sum)
Centroids< RealType, IndexType > buf_
DEVICE void skipFirst(Centroids< RealType, IndexType > *next_centroid)
static DEVICE void shiftRange(T *const begin, IndexType skipped, IndexType const merged, int const inc)
DEVICE RealType mean(IndexType const i) const
DEVICE void setBuffer(Memory &mem)
DEVICE IndexType totalWeight() const
DEVICE void reverse(ARGS &&... args)
DEVICE Centroids(RealType *sums, IndexType *counts, IndexType const size)
DEVICE IndexType totalWeight() const
DEVICE void merge(IndexType const max_count)
DEVICE bool hasNext() const
DEVICE bool isDifferentMean(Centroids< RealType, IndexType > *next_centroid) const
DEVICE IndexType nextCount() const
DEVICE void shiftCentroidsAndSetNext()
DEVICE bool index(Centroids< RealType, IndexType > *centroid) const
Centroids< RealType, IndexType > * curr_centroid_
DEVICE RealType slope(IndexType const idx1, IndexType const idx2)
detail::TDigest< double, size_t > TDigest
DEVICE bool hasNext() const
DEVICE size_type capacity() const
DEVICE void mergeMinMax()
DEVICE bool hasCurr() const
DEVICE size_t size() const
DEVICE void mergeCentroids(Centroids< RealType, IndexType > &)
DEVICE bool operator()(Value const &a, Value const &b) const
DEVICE bool operator<(Centroids const &b) const
DEVICE RealType min() const
DEVICE auto accumulate(ARGS &&... args)
DEVICE bool isSingleton(CountsIterator itr)
Centroid< RealType, IndexType > mean_
DEVICE void appendAndSortCurrent(Centroids &buff)
DEVICE void mergeSorted(RealType *sums, IndexType *counts, IndexType size)
VectorView< IndexType > counts()
Centroids< RealType, IndexType > centroids_