41 #include <type_traits>
47 template <
typename RealType,
typename IndexType>
53 template <
typename RealType,
typename IndexType>
60 static constexpr RealType
infinity = std::numeric_limits<RealType>::infinity();
61 static constexpr RealType
nan = std::numeric_limits<RealType>::quiet_NaN();
68 :
sums_(sums, size, size),
counts_(counts, size, size) {}
126 template <
typename RealType2,
typename IndexType2>
127 friend std::ostream& operator<<(std::ostream&, Centroids<RealType2, IndexType2>
const&);
131 template <
typename RealType,
typename IndexType>
165 template <
typename RealType,
typename IndexType>
173 size_t nbytes()
const {
return sums_.size() * (
sizeof(RealType) +
sizeof(IndexType)); }
181 template <
typename RealType,
typename IndexType =
size_t>
202 IndexType
const idx1,
203 IndexType
const prefix_sum);
207 DEVICE RealType
slope(IndexType
const idx1, IndexType
const idx2);
215 :
centroids_(mem.sums().data(), mem.counts().data(), mem.size()) {
220 IndexType buf_allocate,
221 IndexType centroids_allocate)
285 template <
typename RealType,
typename IndexType>
292 return lhs < rhs || (lhs == rhs && b.
value1() < a.
value1());
298 template <
typename RealType,
typename IndexType>
300 if (inc_ == -1 && curr_idx_ != 0) {
310 IndexType
const offset = inc_ == 1 ? 0 : buff.
curr_idx_;
311 IndexType
const buff_size =
316 IndexType
const curr_size = inc_ == 1 ? curr_idx_ + 1 : size() - curr_idx_;
317 IndexType
const total_size = curr_size + buff_sums.
size();
318 assert(total_size <= sums_.capacity());
319 sums_.resize(total_size);
320 gpu_enabled::copy(buff_sums.begin(), buff_sums.end(), sums_.begin() + curr_size);
321 assert(total_size <= counts_.capacity());
322 counts_.resize(total_size);
323 gpu_enabled::copy(buff_counts.begin(), buff_counts.end(), counts_.begin() + curr_size);
332 template <
typename RealType,
typename IndexType>
334 IndexType
const max_count) {
335 if (counts_[curr_idx_] + centroid.
nextCount() <= max_count) {
336 sums_[curr_idx_] += centroid.
nextSum();
337 counts_[curr_idx_] += centroid.
nextCount();
344 template <
typename RealType,
typename IndexType>
347 if (curr_idx_ != next_idx_) {
348 sums_[curr_idx_] = sums_[next_idx_];
349 counts_[curr_idx_] = counts_[next_idx_];
354 template <
typename RealType,
typename IndexType>
358 curr_idx_ = ~IndexType(0);
363 static_assert(std::is_unsigned<IndexType>::value,
364 "IndexType must be an unsigned type.");
365 next_idx_ = curr_idx_ + inc_;
369 template <
typename RealType,
typename IndexType>
372 out <<
"Centroids<" <<
typeid(RealType).
name() <<
',' <<
typeid(IndexType).
name()
373 <<
">(size(" << centroids.
size() <<
") curr_idx_(" << centroids.
curr_idx_
374 <<
") next_idx_(" << centroids.
next_idx_ <<
") sums_(";
375 for (IndexType
i = 0;
i < centroids.
sums_.
size(); ++
i) {
376 out << (
i ?
" " :
"") << std::setprecision(20) << centroids.
sums_[
i];
380 out << (
i ?
" " :
"") << centroids.
counts_[
i];
388 template <
typename RealType,
typename IndexType>
394 , centroids_(centroids)
395 , total_weight_(centroids->totalWeight() + buf->totalWeight())
396 , forward_(forward) {
403 template <
typename RealType,
typename IndexType>
406 if (buf_->hasNext()) {
407 if (centroids_->hasNext()) {
408 return (*buf_ < *centroids_) == forward_ ? buf_ : centroids_;
411 }
else if (centroids_->hasNext()) {
422 template <
typename RealType,
typename IndexType>
427 IndexType count_merged_{0};
428 IndexType count_skipped_{0};
445 template <
typename T>
448 IndexType
const merged,
451 T* src = begin + inc * (skipped - 1);
452 T* dst = src + inc * merged;
453 for (; skipped; --skipped, src -= inc, dst -= inc) {
458 std::copy_backward(begin, begin + skipped, begin + skipped + merged);
460 std::copy(begin + 1 - skipped, begin + 1, begin + 1 - skipped - merged);
467 return data_[0].centroid_ != centroid;
470 return mean_.sum_ * next_centroid->
nextCount() !=
471 next_centroid->
nextSum() * mean_.count_;
474 IndexType
const idx = index(next_centroid);
475 if (data_[idx].count_skipped_) {
476 ++data_[idx].count_merged_;
479 DEVICE operator bool()
const {
return data_[0].centroid_; }
482 shiftCentroids(data_[0]);
483 data_[0].centroid_->next_idx_ = data_[0].start_;
484 if (data_[1].centroid_) {
485 shiftCentroids(data_[1]);
486 data_[1].centroid_->next_idx_ = data_[1].start_;
492 data_[0] = {next_centroid, next_centroid->
next_idx_, 0, 1};
496 IndexType
const idx = index(next_centroid);
497 if (idx == 1 && data_[1].centroid_ ==
nullptr) {
498 data_[1] = {next_centroid, next_centroid->
next_idx_, 0, 1};
500 if (data_[idx].count_merged_) {
501 shiftCentroids(data_[idx]);
502 data_[idx].count_merged_ = 0;
504 ++data_[idx].count_skipped_;
513 template <
typename RealType,
typename IndexType>
515 Skipped<RealType, IndexType> skipped;
516 while (
auto* next_centroid = getNextCentroid()) {
518 if (skipped.isDifferentMean(next_centroid)) {
520 }
else if (curr_centroid_->mergeIfFits(*next_centroid, max_count)) {
521 skipped.merged(next_centroid);
523 skipped.skipSubsequent(next_centroid);
525 }
else if (!curr_centroid_->mergeIfFits(*next_centroid, max_count)) {
526 skipped.skipFirst(next_centroid);
530 skipped.shiftCentroidsAndSetNext();
542 template <
typename RealType,
typename IndexType>
544 if (centroids_->max_ < buf_->max_) {
545 centroids_->max_ = buf_->max_;
547 if (buf_->min_ < centroids_->min_) {
548 centroids_->min_ = buf_->min_;
553 template <
typename RealType,
typename IndexType>
555 prefix_sum_ += curr_centroid_->currCount();
559 template <
typename RealType,
typename IndexType>
561 if ((curr_centroid_ = getNextCentroid())) {
562 curr_centroid_->moveNextToCurrent();
568 template <
typename RealType,
typename IndexType>
570 if (buf_.sums_.full()) {
573 buf_.sums_.push_back(value);
574 buf_.counts_.push_back(1);
578 template <
typename RealType,
typename IndexType>
580 if (buf_.capacity() == 0) {
581 auto* p0 = simple_allocator_->allocate(buf_allocate_ *
sizeof(RealType));
582 auto* p1 = simple_allocator_->allocate(buf_allocate_ *
sizeof(IndexType));
586 p0 = simple_allocator_->allocate(centroids_allocate_ *
sizeof(RealType));
587 p1 = simple_allocator_->allocate(centroids_allocate_ *
sizeof(IndexType));
594 template <
typename RealType,
typename IndexType>
597 IndexType
const total_weight) {
598 IndexType
const max_bins = centroids_.capacity();
599 return max_bins < total_weight ? 2 * total_weight / max_bins : 0;
603 template <
typename RealType,
typename IndexType>
607 buf_.min_ = buf_.sums_.front();
608 buf_.max_ = buf_.sums_.back();
609 mergeCentroids(buf_);
613 template <
typename RealType,
typename IndexType>
618 if (buf_.capacity() == 0) {
622 buf_.counts_.set(counts, size);
625 buf_.min_ = buf_.sums_.front();
626 buf_.max_ = buf_.sums_.back();
627 mergeCentroids(buf_);
636 template <
typename RealType,
typename IndexType>
643 for (CM cm(&buf, ¢roids_, forward_); cm.hasNext(); cm.next()) {
646 IndexType
const max_cardinality = maxCardinality(cm.prefixSum(), cm.totalWeight());
647 cm.
merge(max_cardinality);
650 centroids_.appendAndSortCurrent(buf);
656 template <
typename CountsIterator>
663 template <
typename RealType,
typename IndexType>
667 }
else if (centroids_.size() == 1) {
668 return oneCentroid(x);
669 }
else if (centroids_.counts_.front() == 2) {
670 RealType
const sum = centroids_.sums_.front();
671 return x == 1 ? 0.5 * sum : sum - min();
673 RealType
const count = centroids_.counts_.front();
674 RealType
const dx = x - RealType(0.5) * (1 +
count);
675 RealType
const mean = (centroids_.sums_.front() - min()) / (count - 1);
676 return mean + slope(0, 0 < dx) * dx;
681 template <
typename RealType,
typename IndexType>
684 IndexType
const idx1,
685 IndexType
const prefix_sum) {
686 if (
isSingleton(centroids_.counts_.begin() + idx1)) {
687 RealType
const sum1 = centroids_.sums_[idx1];
688 if (x == prefix_sum - centroids_.counts_[idx1]) {
689 if (
isSingleton(centroids_.counts_.begin() + idx1 - 1)) {
690 return 0.5 * (centroids_.sums_[idx1 - 1] + sum1);
691 }
else if (idx1 == 1 && centroids_.counts_[0] == 2) {
692 return 0.5 * (centroids_.sums_[idx1 - 1] - min() + sum1);
697 RealType
const dx = x + RealType(0.5) * centroids_.counts_[idx1] - prefix_sum;
698 IndexType
const idx2 = idx1 + 2 * (0 < dx) - 1;
699 return centroids_.mean(idx1) + slope(idx1, idx2) * dx;
704 template <
typename RealType,
typename IndexType>
710 IndexType
const idx1 = centroids_.size() - 1;
711 RealType
const sum1 = centroids_.sums_[idx1];
712 IndexType
const count1 = centroids_.counts_[idx1];
714 if (
isSingleton(centroids_.counts_.begin() + (idx1 - 1))) {
715 return 0.5 * (centroids_.sums_[idx1 - 1] + sum1);
716 }
else if (idx1 == 1 && centroids_.counts_[0] == 2) {
717 return 0.5 * (centroids_.sums_[idx1 - 1] - min() + sum1);
721 }
else if (count1 == 2) {
724 }
else if (x == N - 2) {
725 RealType
const sum2 = centroids_.sums_[idx1 - 1];
726 if (
isSingleton(centroids_.counts_.begin() + (idx1 - 1))) {
727 return 0.5 * (sum2 + sum1 - max());
728 }
else if (idx1 == 1 && centroids_.counts_[0] == 2) {
729 return 0.5 * (sum2 - min() + sum1 - max());
734 RealType
const dx = x + RealType(0.5) * (count1 + 1) - N;
735 RealType
const mean = (sum1 - max()) / (count1 - 1);
736 return mean + slope(idx1, idx1 - (dx < 0)) * dx;
741 template <
typename RealType,
typename IndexType>
743 IndexType
const N = centroids_.counts_.front();
747 return 0.5 * centroids_.sums_.front();
750 return 0.5 * (centroids_.sums_.front() - min());
752 RealType
const s = centroids_.sums_.front() - max();
753 return x == 1 ? 0.5 * s : s - min();
756 RealType
const dx = x - RealType(0.5) * N;
757 RealType
const mean = (centroids_.sums_.front() - (min() + max())) / (N - 2);
758 RealType
const slope = 2 * (0 < dx ? max() - mean : mean - min()) / (N - 2);
759 return mean + slope * dx;
764 template <
typename RealType,
typename IndexType>
766 if (centroids_.size()) {
769 centroids_.counts_.begin(), centroids_.counts_.end(),
partial_sum.begin());
771 RealType
const x = q * N;
774 return firstCentroid(x);
778 return lastCentroid(x, N);
780 return interiorCentroid(x, it1 -
partial_sum.begin(), *it1);
783 return centroids_.nan;
791 template <
typename RealType,
typename IndexType>
793 IndexType
const M = centroids_.size();
795 RealType
const n =
static_cast<RealType
>(centroids_.counts_[idx1]);
796 RealType
const s = centroids_.sums_[idx1];
797 return idx1 == 0 ? 2 * (s - n * min()) / ((n - 1) * (n - 1))
798 : 2 * (n * max() - s) / ((n - 1) * (n - 1));
800 bool const min1 = idx1 == 0;
801 bool const max1 = idx1 == M - 1;
802 bool const min2 = idx2 == 0;
803 bool const max2 = idx2 == M - 1;
804 RealType
const n1 =
static_cast<RealType
>(centroids_.counts_[idx1] - min1 - max1);
805 RealType
const s1 = centroids_.sums_[idx1] - (min1 ? min() : max1 ? max() : 0);
806 RealType
const s2 = centroids_.sums_[idx2] - (min2 ? min() : max2 ? max() : 0);
807 if (
isSingleton(centroids_.counts_.begin() + idx2)) {
808 return (idx1 < idx2 ? 2 : -2) * (n1 * s2 - s1) / (n1 * n1);
810 RealType
const n2 =
static_cast<RealType
>(centroids_.counts_[idx2] - min2 - max2);
811 return (idx1 < idx2 ? 2 : -2) * (n1 * s2 - n2 * s1) / (n1 * n2 * (n1 + n2));
DEVICE auto upper_bound(ARGS &&...args)
DEVICE void setCurrCentroid()
static constexpr RealType infinity
DEVICE void push_back(RealType const value, RealType const count)
DEVICE void push_back(T const &value)
std::vector< RealType > sums_
DEVICE size_type capacity() const
Centroids< RealType, IndexType > * buf_
DEVICE void skipSubsequent(Centroids< RealType, IndexType > *next_centroid)
std::ostream & operator<<(std::ostream &out, Centroids< RealType, IndexType > const ¢roids)
DEVICE RealType quantile(RealType const q)
CentroidsMemory(size_t const size)
DEVICE RealType mean(IndexType const i) const
DEVICE bool index(Centroids< RealType, IndexType > *centroid) const
DEVICE void resetIndices(bool const forward)
DEVICE Centroids< RealType, IndexType > & centroids()
DEVICE void moveNextToCurrent()
DEVICE void add(RealType value)
Centroids< RealType, IndexType > * centroid_
IndexType const total_weight_
DEVICE void sort(ARGS &&...args)
DEVICE void merged(Centroids< RealType, IndexType > *next_centroid)
DEVICE bool mergeIfFits(Centroids ¢roid, IndexType const max_count)
DEVICE bool hasNext() const
DEVICE TDigest(Memory &mem)
DEVICE void mergeBuffer()
DEVICE IndexType totalWeight() const
VectorView< RealType > sums()
VectorView< RealType > sums_
VectorView< IndexType > counts_
DEVICE IndexType maxCardinality(IndexType const sum, IndexType const total_weight)
std::vector< IndexType > counts_
DEVICE RealType lastCentroid(RealType const x, IndexType const N)
static DEVICE void shiftCentroids(Data &data)
DEVICE RealType currMean() const
DEVICE void setCentroids(Memory &mem)
DEVICE RealType oneCentroid(RealType const x)
DEVICE size_type size() const
Centroids< RealType, IndexType > * centroids_
DEVICE RealType firstCentroid(RealType const x)
DEVICE RealType max() const
DEVICE void fill(ARGS &&...args)
DEVICE void set(T *data, size_type const size)
DEVICE auto copy(ARGS &&...args)
DEVICE Centroids(VectorView< RealType > sums, VectorView< IndexType > counts)
DEVICE CentroidsMerger(Centroids< RealType, IndexType > *buf, Centroids< RealType, IndexType > *centroids, bool const forward)
DEVICE void mergeTDigest(TDigest &t_digest)
DEVICE IndexType nextCount() const
DEVICE void partial_sum(ARGS &&...args)
DEVICE void setCentroids(VectorView< RealType > const sums, VectorView< IndexType > const counts)
DEVICE RealType interiorCentroid(RealType const x, IndexType const idx1, IndexType const prefix_sum)
DEVICE auto accumulate(ARGS &&...args)
Centroids< RealType, IndexType > buf_
IndexType const buf_allocate_
DEVICE RealType nextSum() const
DEVICE void skipFirst(Centroids< RealType, IndexType > *next_centroid)
static DEVICE void shiftRange(T *const begin, IndexType skipped, IndexType const merged, int const inc)
DEVICE void setBuffer(Memory &mem)
DEVICE bool operator()(Value const &a, Value const &b) const
static constexpr RealType nan
DEVICE Centroids< RealType, IndexType > * getNextCentroid() const
DEVICE IndexType prefixSum() const
DEVICE Centroids(RealType *sums, IndexType *counts, IndexType const size)
SimpleAllocator *const simple_allocator_
DEVICE void merge(IndexType const max_count)
DEVICE bool operator<(Centroids const &b) const
DEVICE bool hasNext() const
DEVICE TDigest(SimpleAllocator *simple_allocator, IndexType buf_allocate, IndexType centroids_allocate)
IndexType const centroids_allocate_
DEVICE void shiftCentroidsAndSetNext()
Centroids< RealType, IndexType > * curr_centroid_
DEVICE size_t size() const
DEVICE IndexType currCount() const
DEVICE IndexType totalWeight() const
DEVICE bool isDifferentMean(Centroids< RealType, IndexType > *next_centroid) const
DEVICE RealType slope(IndexType const idx1, IndexType const idx2)
DEVICE IndexType capacity() const
DEVICE void mergeMinMax()
DEVICE IndexType totalWeight() const
DEVICE void mergeCentroids(Centroids< RealType, IndexType > &)
DEVICE void reverse(ARGS &&...args)
DEVICE RealType min() const
DEVICE bool isSingleton(CountsIterator itr)
Centroid< RealType, IndexType > mean_
DEVICE void appendAndSortCurrent(Centroids &buff)
DEVICE bool hasCurr() const
DEVICE void mergeSorted(RealType *sums, IndexType *counts, IndexType size)
VectorView< IndexType > counts()
Centroids< RealType, IndexType > centroids_