24 #include "../Shared/funcannotations.h" 30 template <
typename KeyT =
int64_t,
typename IndexT =
int32_t>
33 const size_t key_stride,
35 : buffer(key_buff), stride(key_stride), index(key_idx) {}
37 auto keys_ptr =
reinterpret_cast<const KeyT*
>(buffer + stride * rowid);
38 return keys_ptr[index];
46 template <
typename KeyT =
int64_t>
52 : heap_ordering(hp_order)
55 , nulls_ordering(null_order) {}
59 if (rhs == null_key) {
62 if (lhs == null_key) {
66 if (lhs == null_key) {
69 if (rhs == null_key) {
82 template <
typename KeyT =
int64_t,
typename NodeT =
int64_t>
84 const size_t heap_size,
88 for (NodeT i = curr_idx, last = static_cast<NodeT>(heap_size); i < last;) {
90 const auto left_child = min(2 * i + 1, last);
91 const auto right_child = min(2 * i + 2, last);
93 const auto left_child = std::min(2 * i + 1, last);
94 const auto right_child = std::min(2 * i + 2, last);
96 auto candidate_idx = last;
97 if (left_child < last) {
98 if (right_child < last) {
99 const auto left_key = accessor.
get(heap[left_child]);
100 const auto right_key = accessor.
get(heap[right_child]);
101 candidate_idx = compare(left_key, right_key) ? left_child : right_child;
103 candidate_idx = left_child;
106 candidate_idx = right_child;
108 if (candidate_idx >= last) {
111 const auto curr_key = accessor.
get(heap[i]);
112 const auto candidate_key = accessor.
get(heap[candidate_idx]);
113 if (compare(curr_key, candidate_key)) {
116 auto temp_id = heap[i];
117 heap[i] = heap[candidate_idx];
118 heap[candidate_idx] = temp_id;
123 template <
typename KeyT =
int64_t,
typename NodeT =
int64_t>
125 const NodeT curr_idx,
128 for (NodeT i = curr_idx; i > 0 && (i - 1) < i;) {
129 const auto parent = (i - 1) / 2;
130 const auto curr_key = accessor.
get(heap[i]);
131 const auto parent_key = accessor.
get(heap[parent]);
132 if (compare(parent_key, curr_key)) {
135 auto temp_id = heap[i];
136 heap[i] = heap[parent];
137 heap[parent] = temp_id;
142 template <
typename KeyT =
int64_t,
typename NodeT =
int64_t>
146 const uint32_t row_size_quad,
147 const uint32_t key_offset,
150 const KeyT curr_key) {
151 const NodeT bin_index = node_count++;
152 heap_ptr[bin_index] = bin_index;
153 int8_t* row_ptr =
reinterpret_cast<int8_t*
>(rows_ptr + bin_index * row_size_quad);
154 auto key_ptr =
reinterpret_cast<KeyT*
>(row_ptr + key_offset);
157 sift_up<KeyT, NodeT>(heap_ptr, bin_index, comparator, accessor);
160 template <
typename KeyT =
int64_t,
typename NodeT =
int64_t>
163 const NodeT node_count,
164 const uint32_t row_size_quad,
165 const uint32_t key_offset,
168 const KeyT curr_key) {
169 const NodeT top_bin_idx =
static_cast<NodeT
>(heap_ptr[0]);
170 int8_t* top_row_ptr =
reinterpret_cast<int8_t*
>(rows_ptr + top_bin_idx * row_size_quad);
171 auto top_key =
reinterpret_cast<KeyT*
>(top_row_ptr + key_offset);
172 if (compare(curr_key, *top_key)) {
178 sift_down<KeyT, NodeT>(heap_ptr, node_count, 0, compare, accessor);
183 template <
typename KeyT =
int64_t>
186 const uint32_t row_size_quad,
187 const uint32_t key_offset,
190 const bool nulls_first,
192 const KeyT curr_key) {
195 int64_t& node_count = heaps[thread_global_index];
196 int64_t* heap_ptr = heaps + thread_count + thread_global_index * k;
198 heaps + thread_count + thread_count * k + thread_global_index * row_size_quad * k;
204 row_size_quad *
sizeof(int64_t),
205 key_offset /
sizeof(KeyT));
206 if (node_count < static_cast<int64_t>(k)) {
215 const auto last_bin_index = node_count - 1;
216 auto row_ptr = rows_ptr + last_bin_index * row_size_quad;
217 row_ptr[0] = last_bin_index;
220 const int64_t top_bin_idx = heap_ptr[0];
232 auto row_ptr = rows_ptr + top_bin_idx * row_size_quad;
233 row_ptr[0] = top_bin_idx;
238 #define DEF_GET_BIN_FROM_K_HEAP(key_type) \ 239 extern "C" NEVER_INLINE DEVICE int64_t* get_bin_from_k_heap_##key_type( \ 242 const uint32_t row_size_quad, \ 243 const uint32_t key_offset, \ 244 const bool min_heap, \ 245 const bool has_null, \ 246 const bool nulls_first, \ 247 const key_type null_key, \ 248 const key_type curr_key) { \ 249 return get_bin_from_k_heap_impl(heaps, \ ALWAYS_INLINE DEVICE void sift_down(NodeT *heap, const size_t heap_size, const NodeT curr_idx, const KeyComparator< KeyT > &compare, const KeyAccessor< KeyT, NodeT > &accessor)
ALWAYS_INLINE DEVICE bool operator()(const KeyT lhs, const KeyT rhs) const
const NullsOrdering nulls_ordering
ALWAYS_INLINE DEVICE void sift_up(NodeT *heap, const NodeT curr_idx, const KeyComparator< KeyT > &compare, const KeyAccessor< KeyT, NodeT > &accessor)
ALWAYS_INLINE DEVICE void push_heap(int64_t *heap_ptr, int64_t *rows_ptr, NodeT &node_count, const uint32_t row_size_quad, const uint32_t key_offset, const KeyComparator< KeyT > &comparator, const KeyAccessor< KeyT, NodeT > &accessor, const KeyT curr_key)
NEVER_INLINE int32_t pos_step_impl()
const HeapOrdering heap_ordering
#define DEF_GET_BIN_FROM_K_HEAP(key_type)
ALWAYS_INLINE DEVICE bool pop_and_push_heap(int64_t *heap_ptr, int64_t *rows_ptr, const NodeT node_count, const uint32_t row_size_quad, const uint32_t key_offset, const KeyComparator< KeyT > &compare, const KeyAccessor< KeyT, NodeT > &accessor, const KeyT curr_key)
ALWAYS_INLINE DEVICE int64_t * get_bin_from_k_heap_impl(int64_t *heaps, const uint32_t k, const uint32_t row_size_quad, const uint32_t key_offset, const bool min_heap, const bool has_null, const bool nulls_first, const KeyT null_key, const KeyT curr_key)
ALWAYS_INLINE DEVICE KeyT get(const IndexT rowid) const
DEVICE KeyComparator(const HeapOrdering hp_order, const bool nullable, const KeyT null_val, const NullsOrdering null_order)
NEVER_INLINE int32_t pos_start_impl(int32_t *error_code)
DEVICE KeyAccessor(const int8_t *key_buff, const size_t key_stride, const size_t key_idx)