43 #include <type_traits>
49 template <
typename RealType,
typename IndexType>
55 template <
typename RealType,
typename IndexType>
62 static constexpr RealType
infinity = std::numeric_limits<RealType>::infinity();
63 static constexpr RealType
nan = std::numeric_limits<RealType>::quiet_NaN();
70 :
sums_(sums, size, size),
counts_(counts, size, size) {}
128 template <
typename RealType2,
typename IndexType2>
129 friend std::ostream& operator<<(std::ostream&, Centroids<RealType2, IndexType2>
const&);
133 template <
typename RealType,
typename IndexType>
167 template <
typename RealType,
typename IndexType>
175 size_t nbytes()
const {
return sums_.size() * (
sizeof(RealType) +
sizeof(IndexType)); }
183 template <
typename RealType,
typename IndexType =
size_t>
196 std::optional<RealType>
const q_{std::nullopt};
206 IndexType
const total_weight,
214 IndexType
const idx1,
215 IndexType
const prefix_sum)
const;
219 DEVICE RealType
slope(IndexType
const idx1, IndexType
const idx2)
const;
227 :
centroids_(mem.sums().data(), mem.counts().data(), mem.size()) {
233 IndexType buf_allocate,
234 IndexType centroids_allocate)
267 RealType
const q)
const;
307 template <
typename RealType,
typename IndexType>
314 return lhs < rhs || (lhs == rhs && b.
value1() < a.
value1());
320 template <
typename RealType,
typename IndexType>
322 if (inc_ == -1 && curr_idx_ != 0) {
332 IndexType
const offset = inc_ == 1 ? 0 : buff.
curr_idx_;
333 IndexType
const buff_size =
338 IndexType
const curr_size = inc_ == 1 ? curr_idx_ + 1 : size() - curr_idx_;
339 IndexType
const total_size = curr_size + buff_sums.
size();
340 assert(total_size <= sums_.capacity());
341 sums_.resize(total_size);
342 gpu_enabled::copy(buff_sums.begin(), buff_sums.end(), sums_.begin() + curr_size);
343 assert(total_size <= counts_.capacity());
344 counts_.resize(total_size);
345 gpu_enabled::copy(buff_counts.begin(), buff_counts.end(), counts_.begin() + curr_size);
354 template <
typename RealType,
typename IndexType>
356 IndexType
const max_count) {
357 if (counts_[curr_idx_] + centroid.
nextCount() <= max_count) {
358 sums_[curr_idx_] += centroid.
nextSum();
359 counts_[curr_idx_] += centroid.
nextCount();
366 template <
typename RealType,
typename IndexType>
369 if (curr_idx_ != next_idx_) {
370 sums_[curr_idx_] = sums_[next_idx_];
371 counts_[curr_idx_] = counts_[next_idx_];
376 template <
typename RealType,
typename IndexType>
380 curr_idx_ = ~IndexType(0);
385 static_assert(std::is_unsigned<IndexType>::value,
386 "IndexType must be an unsigned type.");
387 next_idx_ = curr_idx_ + inc_;
391 template <
typename RealType,
typename IndexType>
394 out <<
"Centroids<" <<
typeid(RealType).
name() <<
',' <<
typeid(IndexType).
name()
395 <<
">(size(" << centroids.
size() <<
") curr_idx_(" << centroids.
curr_idx_
396 <<
") next_idx_(" << centroids.
next_idx_ <<
") sums_(";
397 for (IndexType i = 0; i < centroids.
sums_.
size(); ++i) {
398 out << (i ?
" " :
"") << std::setprecision(20) << centroids.
sums_[i];
401 for (IndexType i = 0; i < centroids.
counts_.
size(); ++i) {
402 out << (i ?
" " :
"") << centroids.
counts_[i];
410 template <
typename RealType,
typename IndexType>
416 , centroids_(centroids)
417 , total_weight_(centroids->totalWeight() + buf->totalWeight())
418 , forward_(forward) {
425 template <
typename RealType,
typename IndexType>
428 if (buf_->hasNext()) {
429 if (centroids_->hasNext()) {
430 return (*buf_ < *centroids_) == forward_ ? buf_ : centroids_;
433 }
else if (centroids_->hasNext()) {
444 template <
typename RealType,
typename IndexType>
449 IndexType count_merged_{0};
450 IndexType count_skipped_{0};
467 template <
typename T>
470 IndexType
const merged,
473 T* src = begin + inc * (skipped - 1);
474 T* dst = src + inc * merged;
475 for (; skipped; --skipped, src -= inc, dst -= inc) {
480 std::copy_backward(begin, begin + skipped, begin + skipped + merged);
482 std::copy(begin + 1 - skipped, begin + 1, begin + 1 - skipped - merged);
489 return data_[0].centroid_ != centroid;
492 return mean_.sum_ * next_centroid->
nextCount() !=
493 next_centroid->
nextSum() * mean_.count_;
496 IndexType
const idx = index(next_centroid);
497 if (data_[idx].count_skipped_) {
498 ++data_[idx].count_merged_;
501 DEVICE operator bool()
const {
return data_[0].centroid_; }
504 shiftCentroids(data_[0]);
505 data_[0].centroid_->next_idx_ = data_[0].start_;
506 if (data_[1].centroid_) {
507 shiftCentroids(data_[1]);
508 data_[1].centroid_->next_idx_ = data_[1].start_;
514 data_[0] = {next_centroid, next_centroid->
next_idx_, 0, 1};
518 IndexType
const idx = index(next_centroid);
519 if (idx == 1 && data_[1].centroid_ ==
nullptr) {
520 data_[1] = {next_centroid, next_centroid->
next_idx_, 0, 1};
522 if (data_[idx].count_merged_) {
523 shiftCentroids(data_[idx]);
524 data_[idx].count_merged_ = 0;
526 ++data_[idx].count_skipped_;
535 template <
typename RealType,
typename IndexType>
537 Skipped<RealType, IndexType> skipped;
538 while (
auto* next_centroid = getNextCentroid()) {
540 if (skipped.isDifferentMean(next_centroid)) {
542 }
else if (curr_centroid_->mergeIfFits(*next_centroid, max_count)) {
543 skipped.merged(next_centroid);
545 skipped.skipSubsequent(next_centroid);
547 }
else if (!curr_centroid_->mergeIfFits(*next_centroid, max_count)) {
548 skipped.skipFirst(next_centroid);
552 skipped.shiftCentroidsAndSetNext();
564 template <
typename RealType,
typename IndexType>
566 if (centroids_->max_ < buf_->max_) {
567 centroids_->max_ = buf_->max_;
569 if (buf_->min_ < centroids_->min_) {
570 centroids_->min_ = buf_->min_;
575 template <
typename RealType,
typename IndexType>
577 prefix_sum_ += curr_centroid_->currCount();
581 template <
typename RealType,
typename IndexType>
583 if ((curr_centroid_ = getNextCentroid())) {
584 curr_centroid_->moveNextToCurrent();
590 template <
typename RealType,
typename IndexType>
592 if (buf_.sums_.full()) {
595 buf_.sums_.push_back(value);
596 buf_.counts_.push_back(1);
600 template <
typename RealType,
typename IndexType>
602 if (buf_.capacity() == 0) {
603 auto* p0 = simple_allocator_->allocate(buf_allocate_ *
sizeof(RealType));
604 auto* p1 = simple_allocator_->allocate(buf_allocate_ *
sizeof(IndexType));
608 p0 = simple_allocator_->allocate(centroids_allocate_ *
sizeof(RealType));
609 p1 = simple_allocator_->allocate(centroids_allocate_ *
sizeof(IndexType));
617 template <
typename RealType,
typename IndexType>
620 IndexType
const total_weight,
622 IndexType
const max_bins = centroids_.capacity();
623 if (total_weight <= max_bins) {
625 }
else if (use_linear_scaling_function_) {
627 return 2 * total_weight / max_bins;
630 RealType
const x = 2.0 * sum / total_weight - 1;
631 RealType
const f_inv = 0.5 + 0.5 * std::sin(c + std::asin(x));
632 constexpr RealType eps = 1e-5;
633 IndexType
const dsum =
static_cast<IndexType
>(total_weight * f_inv + eps);
634 return dsum < sum ? 0 : dsum - sum;
639 template <
typename RealType,
typename IndexType>
643 buf_.min_ = buf_.sums_.front();
644 buf_.max_ = buf_.sums_.back();
645 mergeCentroids(buf_);
650 template <
typename RealType,
typename IndexType>
652 auto const call_once = [
this] {
654 assert(centroids_.size() <= buf_.capacity());
655 partialSumOfCounts(buf_.counts_.data());
660 std::call_once(merge_buffer_final_once_, call_once);
664 template <
typename RealType,
typename IndexType>
669 if (buf_.capacity() == 0) {
673 buf_.counts_.set(counts, size);
676 buf_.min_ = buf_.sums_.front();
677 buf_.max_ = buf_.sums_.back();
678 mergeCentroids(buf_);
687 template <
typename RealType,
typename IndexType>
690 constexpr RealType two_pi = 6.283185307179586476925286766559005768e+00;
692 RealType
const c = two_pi / centroids_.capacity();
697 for (CM cm(&buf, ¢roids_, forward_); cm.hasNext(); cm.next()) {
700 IndexType
const max_cardinality = maxCardinality(cm.prefixSum(), cm.totalWeight(), c);
701 cm.
merge(max_cardinality);
704 centroids_.appendAndSortCurrent(buf);
710 template <
typename CountsIterator>
717 template <
typename RealType,
typename IndexType>
721 }
else if (centroids_.size() == 1) {
722 return oneCentroid(x);
723 }
else if (centroids_.counts_.front() == 2) {
724 RealType
const sum = centroids_.sums_.front();
725 return x == 1 ? 0.5 * sum : sum - min();
727 RealType
const count = centroids_.counts_.front();
728 RealType
const dx = x - RealType(0.5) * (1 + count);
729 RealType
const mean = (centroids_.sums_.front() - min()) / (count - 1);
730 return mean + slope(0, 0 < dx) * dx;
735 template <
typename RealType,
typename IndexType>
738 IndexType
const idx1,
739 IndexType
const prefix_sum)
const {
740 if (
isSingleton(centroids_.counts_.begin() + idx1)) {
741 RealType
const sum1 = centroids_.sums_[idx1];
742 if (x == prefix_sum - centroids_.counts_[idx1]) {
743 if (
isSingleton(centroids_.counts_.begin() + idx1 - 1)) {
744 return 0.5 * (centroids_.sums_[idx1 - 1] + sum1);
745 }
else if (idx1 == 1 && centroids_.counts_[0] == 2) {
746 return 0.5 * (centroids_.sums_[idx1 - 1] - min() + sum1);
751 RealType
const dx = x + RealType(0.5) * centroids_.counts_[idx1] - prefix_sum;
752 IndexType
const idx2 = idx1 + 2 * (0 < dx) - 1;
753 return centroids_.mean(idx1) + slope(idx1, idx2) * dx;
758 template <
typename RealType,
typename IndexType>
760 IndexType
const N)
const {
764 IndexType
const idx1 = centroids_.size() - 1;
765 RealType
const sum1 = centroids_.sums_[idx1];
766 IndexType
const count1 = centroids_.counts_[idx1];
768 if (
isSingleton(centroids_.counts_.begin() + (idx1 - 1))) {
769 return 0.5 * (centroids_.sums_[idx1 - 1] + sum1);
770 }
else if (idx1 == 1 && centroids_.counts_[0] == 2) {
771 return 0.5 * (centroids_.sums_[idx1 - 1] - min() + sum1);
775 }
else if (count1 == 2) {
778 }
else if (x == N - 2) {
779 RealType
const sum2 = centroids_.sums_[idx1 - 1];
780 if (
isSingleton(centroids_.counts_.begin() + (idx1 - 1))) {
781 return 0.5 * (sum2 + sum1 - max());
782 }
else if (idx1 == 1 && centroids_.counts_[0] == 2) {
783 return 0.5 * (sum2 - min() + sum1 - max());
788 RealType
const dx = x + RealType(0.5) * (count1 + 1) - N;
789 RealType
const mean = (sum1 - max()) / (count1 - 1);
790 return mean + slope(idx1, idx1 - (dx < 0)) * dx;
795 template <
typename RealType,
typename IndexType>
797 IndexType
const N = centroids_.counts_.front();
801 return 0.5 * centroids_.sums_.front();
804 return 0.5 * (centroids_.sums_.front() - min());
806 RealType
const s = centroids_.sums_.front() - max();
807 return x == 1 ? 0.5 * s : s - min();
810 RealType
const dx = x - RealType(0.5) *
N;
811 RealType
const mean = (centroids_.sums_.front() - (min() + max())) / (N - 2);
812 RealType
const slope = 2 * (0 < dx ? max() - mean : mean - min()) / (N - 2);
813 return mean + slope * dx;
818 template <
typename RealType,
typename IndexType>
820 IndexType*
const buf)
const {
822 return {buf, centroids_.size()};
825 template <
typename RealType,
typename IndexType>
828 RealType
const q)
const {
829 if (centroids_.size()) {
830 IndexType
const N = partial_sum.
back();
831 RealType
const x = q *
N;
833 if (it1 == partial_sum.
begin()) {
834 return firstCentroid(x);
835 }
else if (it1 == partial_sum.
end()) {
837 }
else if (it1 + 1 == partial_sum.
end()) {
838 return lastCentroid(x, N);
840 return interiorCentroid(x, it1 - partial_sum.
begin(), *it1);
843 return centroids_.nan;
851 template <
typename RealType,
typename IndexType>
853 IndexType idx2)
const {
854 IndexType
const M = centroids_.size();
856 RealType
const n =
static_cast<RealType
>(centroids_.counts_[idx1]);
857 RealType
const s = centroids_.sums_[idx1];
858 return idx1 == 0 ? 2 * (s - n * min()) / ((n - 1) * (n - 1))
859 : 2 * (n * max() - s) / ((n - 1) * (n - 1));
861 bool const min1 = idx1 == 0;
862 bool const max1 = idx1 == M - 1;
863 bool const min2 = idx2 == 0;
864 bool const max2 = idx2 == M - 1;
865 RealType
const n1 =
static_cast<RealType
>(centroids_.counts_[idx1] - min1 - max1);
866 RealType
const s1 = centroids_.sums_[idx1] - (min1 ? min() : max1 ? max() : 0);
867 RealType
const s2 = centroids_.sums_[idx2] - (min2 ? min() : max2 ? max() : 0);
868 if (
isSingleton(centroids_.counts_.begin() + idx2)) {
869 return (idx1 < idx2 ? 2 : -2) * (n1 * s2 - s1) / (n1 * n1);
871 RealType
const n2 =
static_cast<RealType
>(centroids_.counts_[idx2] - min2 - max2);
872 return (idx1 < idx2 ? 2 : -2) * (n1 * s2 - n2 * s1) / (n1 * n2 * (n1 + n2));
DEVICE auto upper_bound(ARGS &&...args)
DEVICE void setCurrCentroid()
static constexpr RealType infinity
DEVICE void push_back(RealType const value, RealType const count)
DEVICE void push_back(T const &value)
std::vector< RealType > sums_
DEVICE size_type capacity() const
Centroids< RealType, IndexType > * buf_
DEVICE void skipSubsequent(Centroids< RealType, IndexType > *next_centroid)
std::ostream & operator<<(std::ostream &out, Centroids< RealType, IndexType > const ¢roids)
bool const use_linear_scaling_function_
CentroidsMemory(size_t const size)
DEVICE RealType mean(IndexType const i) const
DEVICE bool index(Centroids< RealType, IndexType > *centroid) const
DEVICE void resetIndices(bool const forward)
DEVICE Centroids< RealType, IndexType > & centroids()
DEVICE void moveNextToCurrent()
DEVICE void add(RealType value)
DEVICE VectorView< IndexType const > partialSumOfCounts(IndexType *const buf) const
Centroids< RealType, IndexType > * centroid_
DEVICE RealType quantile(RealType const q) const
std::once_flag merge_buffer_final_once_
IndexType const total_weight_
DEVICE void sort(ARGS &&...args)
DEVICE void merged(Centroids< RealType, IndexType > *next_centroid)
DEVICE bool mergeIfFits(Centroids ¢roid, IndexType const max_count)
DEVICE bool hasNext() const
DEVICE TDigest(Memory &mem)
DEVICE void mergeBuffer()
DEVICE IndexType totalWeight() const
VectorView< RealType > sums()
VectorView< RealType > sums_
VectorView< IndexType > counts_
std::vector< IndexType > counts_
static DEVICE void shiftCentroids(Data &data)
DEVICE void mergeBufferFinal()
DEVICE RealType currMean() const
DEVICE RealType quantile()
DEVICE void setCentroids(Memory &mem)
DEVICE size_type size() const
DEVICE TDigest(RealType q, SimpleAllocator *simple_allocator, IndexType buf_allocate, IndexType centroids_allocate)
Centroids< RealType, IndexType > * centroids_
DEVICE RealType max() const
DEVICE void fill(ARGS &&...args)
DEVICE void set(T *data, size_type const size)
DEVICE auto copy(ARGS &&...args)
DEVICE Centroids(VectorView< RealType > sums, VectorView< IndexType > counts)
DEVICE CentroidsMerger(Centroids< RealType, IndexType > *buf, Centroids< RealType, IndexType > *centroids, bool const forward)
DEVICE void mergeTDigest(TDigest &t_digest)
DEVICE IndexType nextCount() const
DEVICE void partial_sum(ARGS &&...args)
DEVICE void setCentroids(VectorView< RealType > const sums, VectorView< IndexType > const counts)
DEVICE RealType slope(IndexType const idx1, IndexType const idx2) const
DEVICE auto accumulate(ARGS &&...args)
Centroids< RealType, IndexType > buf_
IndexType const buf_allocate_
DEVICE RealType nextSum() const
DEVICE RealType lastCentroid(RealType const x, IndexType const N) const
std::optional< RealType > const q_
DEVICE void skipFirst(Centroids< RealType, IndexType > *next_centroid)
static DEVICE void shiftRange(T *const begin, IndexType skipped, IndexType const merged, int const inc)
DEVICE void setBuffer(Memory &mem)
DEVICE bool operator()(Value const &a, Value const &b) const
static constexpr RealType nan
DEVICE Centroids< RealType, IndexType > * getNextCentroid() const
DEVICE IndexType prefixSum() const
DEVICE Centroids(RealType *sums, IndexType *counts, IndexType const size)
SimpleAllocator *const simple_allocator_
DEVICE void merge(IndexType const max_count)
DEVICE bool operator<(Centroids const &b) const
DEVICE bool hasNext() const
IndexType const centroids_allocate_
DEVICE void shiftCentroidsAndSetNext()
Centroids< RealType, IndexType > * curr_centroid_
DEVICE IndexType maxCardinality(IndexType const sum, IndexType const total_weight, RealType const c)
DEVICE size_t size() const
DEVICE IndexType currCount() const
DEVICE IndexType totalWeight() const
DEVICE bool isDifferentMean(Centroids< RealType, IndexType > *next_centroid) const
DEVICE IndexType capacity() const
DEVICE void mergeMinMax()
DEVICE IndexType totalWeight() const
DEVICE RealType interiorCentroid(RealType const x, IndexType const idx1, IndexType const prefix_sum) const
DEVICE RealType oneCentroid(RealType const x) const
DEVICE void mergeCentroids(Centroids< RealType, IndexType > &)
DEVICE void reverse(ARGS &&...args)
DEVICE RealType min() const
DEVICE RealType firstCentroid(RealType const x) const
DEVICE bool isSingleton(CountsIterator itr)
Centroid< RealType, IndexType > mean_
DEVICE void appendAndSortCurrent(Centroids &buff)
DEVICE bool hasCurr() const
DEVICE void mergeSorted(RealType *sums, IndexType *counts, IndexType size)
VectorView< IndexType > counts()
Centroids< RealType, IndexType > centroids_