43 #include <type_traits>
49 template <
typename RealType,
typename IndexType>
55 template <
typename RealType,
typename IndexType>
62 static constexpr RealType
infinity = std::numeric_limits<RealType>::infinity();
63 static constexpr RealType
nan = std::numeric_limits<RealType>::quiet_NaN();
70 :
sums_(sums, size, size),
counts_(counts, size, size) {}
128 template <
typename RealType2,
typename IndexType2>
129 friend std::ostream& operator<<(std::ostream&, Centroids<RealType2, IndexType2>
const&);
133 template <
typename RealType,
typename IndexType>
167 template <
typename RealType,
typename IndexType>
175 size_t nbytes()
const {
return sums_.size() * (
sizeof(RealType) +
sizeof(IndexType)); }
183 template <
typename RealType,
typename IndexType =
size_t>
196 std::optional<RealType>
const q_{std::nullopt};
210 IndexType
const total_weight,
218 IndexType
const idx1,
219 IndexType
const prefix_sum)
const;
223 DEVICE RealType
slope(IndexType
const idx1, IndexType
const idx2)
const;
231 :
centroids_(mem.sums().data(), mem.counts().data(), mem.size()) {
237 IndexType buf_allocate,
238 IndexType centroids_allocate)
273 RealType
const q)
const;
317 template <
typename RealType,
typename IndexType>
324 return lhs < rhs || (lhs == rhs && b.
value1() < a.
value1());
330 template <
typename RealType,
typename IndexType>
332 if (inc_ == -1 && curr_idx_ != 0) {
342 IndexType
const offset = inc_ == 1 ? 0 : buff.
curr_idx_;
343 IndexType
const buff_size =
348 IndexType
const curr_size = inc_ == 1 ? curr_idx_ + 1 : size() - curr_idx_;
349 IndexType
const total_size = curr_size + buff_sums.
size();
350 assert(total_size <= sums_.capacity());
351 sums_.resize(total_size);
352 gpu_enabled::copy(buff_sums.begin(), buff_sums.end(), sums_.begin() + curr_size);
353 assert(total_size <= counts_.capacity());
354 counts_.resize(total_size);
355 gpu_enabled::copy(buff_counts.begin(), buff_counts.end(), counts_.begin() + curr_size);
364 template <
typename RealType,
typename IndexType>
366 IndexType
const max_count) {
367 if (counts_[curr_idx_] + centroid.
nextCount() <= max_count) {
368 sums_[curr_idx_] += centroid.
nextSum();
369 counts_[curr_idx_] += centroid.
nextCount();
376 template <
typename RealType,
typename IndexType>
379 if (curr_idx_ != next_idx_) {
380 sums_[curr_idx_] = sums_[next_idx_];
381 counts_[curr_idx_] = counts_[next_idx_];
386 template <
typename RealType,
typename IndexType>
390 curr_idx_ = ~IndexType(0);
395 static_assert(std::is_unsigned<IndexType>::value,
396 "IndexType must be an unsigned type.");
397 next_idx_ = curr_idx_ + inc_;
401 template <
typename RealType,
typename IndexType>
404 out <<
"Centroids<" <<
typeid(RealType).
name() <<
',' <<
typeid(IndexType).
name()
405 <<
">(size(" << centroids.
size() <<
") curr_idx_(" << centroids.
curr_idx_
406 <<
") next_idx_(" << centroids.
next_idx_ <<
") sums_(";
407 for (IndexType i = 0; i < centroids.
sums_.
size(); ++i) {
408 out << (i ?
" " :
"") << std::setprecision(20) << centroids.
sums_[i];
411 for (IndexType i = 0; i < centroids.
counts_.
size(); ++i) {
412 out << (i ?
" " :
"") << centroids.
counts_[i];
420 template <
typename RealType,
typename IndexType>
426 , centroids_(centroids)
427 , total_weight_(centroids->totalWeight() + buf->totalWeight())
428 , forward_(forward) {
435 template <
typename RealType,
typename IndexType>
438 if (buf_->hasNext()) {
439 if (centroids_->hasNext()) {
440 return (*buf_ < *centroids_) == forward_ ? buf_ : centroids_;
443 }
else if (centroids_->hasNext()) {
454 template <
typename RealType,
typename IndexType>
459 IndexType count_merged_{0};
460 IndexType count_skipped_{0};
477 template <
typename T>
480 IndexType
const merged,
483 T* src = begin + inc * (skipped - 1);
484 T* dst = src + inc * merged;
485 for (; skipped; --skipped, src -= inc, dst -= inc) {
490 std::copy_backward(begin, begin + skipped, begin + skipped + merged);
492 std::copy(begin + 1 - skipped, begin + 1, begin + 1 - skipped - merged);
499 return data_[0].centroid_ != centroid;
502 return mean_.sum_ * next_centroid->
nextCount() !=
503 next_centroid->
nextSum() * mean_.count_;
506 IndexType
const idx = index(next_centroid);
507 if (data_[idx].count_skipped_) {
508 ++data_[idx].count_merged_;
512 return data_[0].centroid_;
516 shiftCentroids(data_[0]);
517 data_[0].centroid_->next_idx_ = data_[0].start_;
518 if (data_[1].centroid_) {
519 shiftCentroids(data_[1]);
520 data_[1].centroid_->next_idx_ = data_[1].start_;
526 data_[0] = {next_centroid, next_centroid->
next_idx_, 0, 1};
530 IndexType
const idx = index(next_centroid);
531 if (idx == 1 && data_[1].centroid_ ==
nullptr) {
532 data_[1] = {next_centroid, next_centroid->
next_idx_, 0, 1};
534 if (data_[idx].count_merged_) {
535 shiftCentroids(data_[idx]);
536 data_[idx].count_merged_ = 0;
538 ++data_[idx].count_skipped_;
547 template <
typename RealType,
typename IndexType>
549 Skipped<RealType, IndexType> skipped;
550 while (
auto* next_centroid = getNextCentroid()) {
552 if (skipped.isDifferentMean(next_centroid)) {
554 }
else if (curr_centroid_->mergeIfFits(*next_centroid, max_count)) {
555 skipped.merged(next_centroid);
557 skipped.skipSubsequent(next_centroid);
559 }
else if (!curr_centroid_->mergeIfFits(*next_centroid, max_count)) {
560 skipped.skipFirst(next_centroid);
564 skipped.shiftCentroidsAndSetNext();
576 template <
typename RealType,
typename IndexType>
578 if (centroids_->max_ < buf_->max_) {
579 centroids_->max_ = buf_->max_;
581 if (buf_->min_ < centroids_->min_) {
582 centroids_->min_ = buf_->min_;
587 template <
typename RealType,
typename IndexType>
589 prefix_sum_ += curr_centroid_->currCount();
593 template <
typename RealType,
typename IndexType>
595 if ((curr_centroid_ = getNextCentroid())) {
596 curr_centroid_->moveNextToCurrent();
602 template <
typename RealType,
typename IndexType>
604 if (buf_.sums_.full()) {
607 buf_.sums_.push_back(value);
608 buf_.counts_.push_back(1);
612 template <
typename RealType,
typename IndexType>
614 if (buf_.capacity() == 0) {
615 auto* p0 = simple_allocator_->allocate(buf_allocate_ *
sizeof(RealType));
616 auto* p1 = simple_allocator_->allocate(buf_allocate_ *
sizeof(IndexType));
620 p0 = simple_allocator_->allocate(centroids_allocate_ *
sizeof(RealType));
621 p1 = simple_allocator_->allocate(centroids_allocate_ *
sizeof(IndexType));
629 template <
typename RealType,
typename IndexType>
632 IndexType
const total_weight,
634 IndexType
const max_bins = centroids_.capacity();
635 if (total_weight <= max_bins) {
637 }
else if (use_linear_scaling_function_) {
639 return 2 * total_weight / max_bins;
642 RealType
const x = 2.0 * sum / total_weight - 1;
643 RealType
const f_inv = 0.5 + 0.5 * std::sin(c + std::asin(x));
644 constexpr RealType eps = 1e-5;
645 IndexType
const dsum =
static_cast<IndexType
>(total_weight * f_inv + eps);
646 return dsum < sum ? 0 : dsum - sum;
651 template <
typename RealType,
typename IndexType>
655 buf_.min_ = buf_.sums_.front();
656 buf_.max_ = buf_.sums_.back();
657 mergeCentroids(buf_);
662 template <
typename RealType,
typename IndexType>
664 auto const call_once = [
this] {
666 assert(centroids_.size() <= buf_.capacity());
667 partialSumOfCounts(buf_.counts_.data());
672 std::call_once(merge_buffer_final_once_, call_once);
676 template <
typename RealType,
typename IndexType>
681 if (buf_.capacity() == 0) {
685 buf_.counts_.set(counts, size);
688 buf_.min_ = buf_.sums_.front();
689 buf_.max_ = buf_.sums_.back();
690 mergeCentroids(buf_);
699 template <
typename RealType,
typename IndexType>
702 constexpr RealType two_pi = 6.283185307179586476925286766559005768e+00;
704 RealType
const c = two_pi / centroids_.capacity();
709 for (CM cm(&buf, ¢roids_, forward_); cm.hasNext(); cm.next()) {
712 IndexType
const max_cardinality = maxCardinality(cm.prefixSum(), cm.totalWeight(), c);
713 cm.
merge(max_cardinality);
716 centroids_.appendAndSortCurrent(buf);
722 template <
typename CountsIterator>
729 template <
typename RealType,
typename IndexType>
733 }
else if (centroids_.size() == 1) {
734 return oneCentroid(x);
735 }
else if (centroids_.counts_.front() == 2) {
736 RealType
const sum = centroids_.sums_.front();
737 return x == 1 ? 0.5 * sum : sum - min();
739 RealType
const count = centroids_.counts_.front();
740 RealType
const dx = x - RealType(0.5) * (1 + count);
741 RealType
const mean = (centroids_.sums_.front() - min()) / (count - 1);
742 return mean + slope(0, 0 < dx) * dx;
747 template <
typename RealType,
typename IndexType>
750 IndexType
const idx1,
751 IndexType
const prefix_sum)
const {
752 if (
isSingleton(centroids_.counts_.begin() + idx1)) {
753 RealType
const sum1 = centroids_.sums_[idx1];
754 if (x == prefix_sum - centroids_.counts_[idx1]) {
755 if (
isSingleton(centroids_.counts_.begin() + idx1 - 1)) {
756 return 0.5 * (centroids_.sums_[idx1 - 1] + sum1);
757 }
else if (idx1 == 1 && centroids_.counts_[0] == 2) {
758 return 0.5 * (centroids_.sums_[idx1 - 1] - min() + sum1);
763 RealType
const dx = x + RealType(0.5) * centroids_.counts_[idx1] - prefix_sum;
764 IndexType
const idx2 = idx1 + 2 * (0 < dx) - 1;
765 return centroids_.mean(idx1) + slope(idx1, idx2) * dx;
770 template <
typename RealType,
typename IndexType>
772 IndexType
const N)
const {
776 IndexType
const idx1 = centroids_.size() - 1;
777 RealType
const sum1 = centroids_.sums_[idx1];
778 IndexType
const count1 = centroids_.counts_[idx1];
780 if (
isSingleton(centroids_.counts_.begin() + (idx1 - 1))) {
781 return 0.5 * (centroids_.sums_[idx1 - 1] + sum1);
782 }
else if (idx1 == 1 && centroids_.counts_[0] == 2) {
783 return 0.5 * (centroids_.sums_[idx1 - 1] - min() + sum1);
787 }
else if (count1 == 2) {
790 }
else if (x == N - 2) {
791 RealType
const sum2 = centroids_.sums_[idx1 - 1];
792 if (
isSingleton(centroids_.counts_.begin() + (idx1 - 1))) {
793 return 0.5 * (sum2 + sum1 - max());
794 }
else if (idx1 == 1 && centroids_.counts_[0] == 2) {
795 return 0.5 * (sum2 - min() + sum1 - max());
800 RealType
const dx = x + RealType(0.5) * (count1 + 1) - N;
801 RealType
const mean = (sum1 - max()) / (count1 - 1);
802 return mean + slope(idx1, idx1 - (dx < 0)) * dx;
807 template <
typename RealType,
typename IndexType>
809 IndexType
const N = centroids_.counts_.front();
813 return 0.5 * centroids_.sums_.front();
816 return 0.5 * (centroids_.sums_.front() - min());
818 RealType
const s = centroids_.sums_.front() - max();
819 return x == 1 ? 0.5 * s : s - min();
822 RealType
const dx = x - RealType(0.5) *
N;
823 RealType
const mean = (centroids_.sums_.front() - (min() + max())) / (N - 2);
824 RealType
const slope = 2 * (0 < dx ? max() - mean : mean - min()) / (N - 2);
825 return mean + slope * dx;
830 template <
typename RealType,
typename IndexType>
832 IndexType*
const buf)
const {
834 return {buf, centroids_.size()};
837 template <
typename RealType,
typename IndexType>
840 RealType
const q)
const {
841 if (centroids_.size()) {
842 IndexType
const N = partial_sum.
back();
843 RealType
const x = q *
N;
845 if (it1 == partial_sum.
begin()) {
846 return firstCentroid(x);
847 }
else if (it1 == partial_sum.
end()) {
849 }
else if (it1 + 1 == partial_sum.
end()) {
850 return lastCentroid(x, N);
852 return interiorCentroid(x, it1 - partial_sum.
begin(), *it1);
855 return centroids_.nan;
863 template <
typename RealType,
typename IndexType>
865 IndexType idx2)
const {
866 IndexType
const M = centroids_.size();
868 RealType
const n =
static_cast<RealType
>(centroids_.counts_[idx1]);
869 RealType
const s = centroids_.sums_[idx1];
870 return idx1 == 0 ? 2 * (s - n * min()) / ((n - 1) * (n - 1))
871 : 2 * (n * max() - s) / ((n - 1) * (n - 1));
873 bool const min1 = idx1 == 0;
874 bool const max1 = idx1 == M - 1;
875 bool const min2 = idx2 == 0;
876 bool const max2 = idx2 == M - 1;
877 RealType
const n1 =
static_cast<RealType
>(centroids_.counts_[idx1] - min1 - max1);
878 RealType
const s1 = centroids_.sums_[idx1] - (min1 ? min() : max1 ? max() : 0);
879 RealType
const s2 = centroids_.sums_[idx2] - (min2 ? min() : max2 ? max() : 0);
880 if (
isSingleton(centroids_.counts_.begin() + idx2)) {
881 return (idx1 < idx2 ? 2 : -2) * (n1 * s2 - s1) / (n1 * n1);
883 RealType
const n2 =
static_cast<RealType
>(centroids_.counts_[idx2] - min2 - max2);
884 return (idx1 < idx2 ? 2 : -2) * (n1 * s2 - n2 * s1) / (n1 * n2 * (n1 + n2));
DEVICE auto upper_bound(ARGS &&...args)
DEVICE void setCurrCentroid()
static constexpr RealType infinity
DEVICE void push_back(RealType const value, RealType const count)
DEVICE void push_back(T const &value)
std::vector< RealType > sums_
DEVICE size_type capacity() const
Centroids< RealType, IndexType > * buf_
DEVICE void skipSubsequent(Centroids< RealType, IndexType > *next_centroid)
std::ostream & operator<<(std::ostream &out, Centroids< RealType, IndexType > const ¢roids)
bool const use_linear_scaling_function_
CentroidsMemory(size_t const size)
DEVICE RealType mean(IndexType const i) const
DEVICE bool index(Centroids< RealType, IndexType > *centroid) const
DEVICE void resetIndices(bool const forward)
DEVICE Centroids< RealType, IndexType > & centroids()
DEVICE void moveNextToCurrent()
DEVICE void add(RealType value)
DEVICE VectorView< IndexType const > partialSumOfCounts(IndexType *const buf) const
Centroids< RealType, IndexType > * centroid_
DEVICE RealType quantile(RealType const q) const
std::once_flag merge_buffer_final_once_
IndexType const total_weight_
DEVICE void sort(ARGS &&...args)
DEVICE void merged(Centroids< RealType, IndexType > *next_centroid)
DEVICE bool mergeIfFits(Centroids ¢roid, IndexType const max_count)
DEVICE bool hasNext() const
DEVICE TDigest(Memory &mem)
DEVICE void mergeBuffer()
DEVICE IndexType totalWeight() const
VectorView< RealType > sums()
VectorView< RealType > sums_
VectorView< IndexType > counts_
std::vector< IndexType > counts_
static DEVICE void shiftCentroids(Data &data)
DEVICE void mergeBufferFinal()
DEVICE RealType currMean() const
DEVICE RealType quantile()
DEVICE void setCentroids(Memory &mem)
DEVICE size_type size() const
DEVICE TDigest(RealType q, SimpleAllocator *simple_allocator, IndexType buf_allocate, IndexType centroids_allocate)
Centroids< RealType, IndexType > * centroids_
DEVICE RealType max() const
DEVICE void fill(ARGS &&...args)
DEVICE void set(T *data, size_type const size)
DEVICE auto copy(ARGS &&...args)
DEVICE Centroids(VectorView< RealType > sums, VectorView< IndexType > counts)
DEVICE CentroidsMerger(Centroids< RealType, IndexType > *buf, Centroids< RealType, IndexType > *centroids, bool const forward)
DEVICE void mergeTDigest(TDigest &t_digest)
DEVICE IndexType nextCount() const
DEVICE void partial_sum(ARGS &&...args)
DEVICE void setCentroids(VectorView< RealType > const sums, VectorView< IndexType > const counts)
DEVICE RealType slope(IndexType const idx1, IndexType const idx2) const
DEVICE auto accumulate(ARGS &&...args)
Centroids< RealType, IndexType > buf_
IndexType const buf_allocate_
DEVICE RealType nextSum() const
DEVICE RealType lastCentroid(RealType const x, IndexType const N) const
std::optional< RealType > const q_
DEVICE void skipFirst(Centroids< RealType, IndexType > *next_centroid)
static DEVICE void shiftRange(T *const begin, IndexType skipped, IndexType const merged, int const inc)
DEVICE void setBuffer(Memory &mem)
DEVICE bool operator()(Value const &a, Value const &b) const
static constexpr RealType nan
DEVICE Centroids< RealType, IndexType > * getNextCentroid() const
DEVICE IndexType prefixSum() const
DEVICE Centroids(RealType *sums, IndexType *counts, IndexType const size)
SimpleAllocator *const simple_allocator_
DEVICE void merge(IndexType const max_count)
DEVICE bool operator<(Centroids const &b) const
DEVICE bool hasNext() const
IndexType const centroids_allocate_
DEVICE void shiftCentroidsAndSetNext()
Centroids< RealType, IndexType > * curr_centroid_
DEVICE IndexType maxCardinality(IndexType const sum, IndexType const total_weight, RealType const c)
DEVICE size_t size() const
DEVICE IndexType currCount() const
DEVICE IndexType totalWeight() const
DEVICE bool isDifferentMean(Centroids< RealType, IndexType > *next_centroid) const
DEVICE IndexType capacity() const
DEVICE void mergeMinMax()
DEVICE IndexType totalWeight() const
DEVICE RealType interiorCentroid(RealType const x, IndexType const idx1, IndexType const prefix_sum) const
DEVICE RealType oneCentroid(RealType const x) const
DEVICE void mergeCentroids(Centroids< RealType, IndexType > &)
DEVICE void reverse(ARGS &&...args)
DEVICE RealType min() const
DEVICE RealType firstCentroid(RealType const x) const
DEVICE bool isSingleton(CountsIterator itr)
Centroid< RealType, IndexType > mean_
DEVICE void appendAndSortCurrent(Centroids &buff)
DEVICE bool hasCurr() const
DEVICE void mergeSorted(RealType *sums, IndexType *counts, IndexType size)
VectorView< IndexType > counts()
Centroids< RealType, IndexType > centroids_