OmniSciDB  cde582ebc3
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TopKRuntime.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
23 #include "../Shared/funcannotations.h"
24 
25 enum class HeapOrdering { MIN, MAX };
26 
27 enum class NullsOrdering { FIRST, LAST };
28 
29 template <typename KeyT = int64_t, typename IndexT = int32_t>
30 struct KeyAccessor {
31  DEVICE KeyAccessor(const int8_t* key_buff,
32  const size_t key_stride,
33  const size_t key_idx)
34  : buffer(key_buff), stride(key_stride), index(key_idx) {}
35  ALWAYS_INLINE DEVICE KeyT get(const IndexT rowid) const {
36  auto keys_ptr = reinterpret_cast<const KeyT*>(buffer + stride * rowid);
37  return keys_ptr[index];
38  }
39 
40  const int8_t* buffer;
41  const size_t stride;
42  const size_t index;
43 };
44 
45 template <typename KeyT = int64_t>
46 struct KeyComparator {
48  const bool nullable,
49  const KeyT null_val,
50  const NullsOrdering null_order)
51  : heap_ordering(hp_order)
52  , has_nulls(nullable)
53  , null_key(null_val)
54  , nulls_ordering(null_order) {}
55  ALWAYS_INLINE DEVICE bool operator()(const KeyT lhs, const KeyT rhs) const {
56  if (has_nulls) {
58  if (rhs == null_key) {
59  return true;
60  }
61  if (lhs == null_key) {
62  return false;
63  }
64  } else {
65  if (lhs == null_key) {
66  return true;
67  }
68  if (rhs == null_key) {
69  return false;
70  }
71  }
72  }
73  return heap_ordering == HeapOrdering::MIN ? (lhs < rhs) : (lhs > rhs);
74  }
76  const bool has_nulls;
77  const KeyT null_key;
79 };
80 
81 template <typename KeyT = int64_t, typename NodeT = int64_t>
82 ALWAYS_INLINE DEVICE void sift_down(NodeT* heap,
83  const size_t heap_size,
84  const NodeT curr_idx,
85  const KeyComparator<KeyT>& compare,
86  const KeyAccessor<KeyT, NodeT>& accessor) {
87  for (NodeT i = curr_idx, last = static_cast<NodeT>(heap_size); i < last;) {
88 #ifdef __CUDACC__
89  const auto left_child = min(2 * i + 1, last);
90  const auto right_child = min(2 * i + 2, last);
91 #else
92  const auto left_child = std::min(2 * i + 1, last);
93  const auto right_child = std::min(2 * i + 2, last);
94 #endif
95  auto candidate_idx = last;
96  if (left_child < last) {
97  if (right_child < last) {
98  const auto left_key = accessor.get(heap[left_child]);
99  const auto right_key = accessor.get(heap[right_child]);
100  candidate_idx = compare(left_key, right_key) ? left_child : right_child;
101  } else {
102  candidate_idx = left_child;
103  }
104  } else {
105  candidate_idx = right_child;
106  }
107  if (candidate_idx >= last) {
108  break;
109  }
110  const auto curr_key = accessor.get(heap[i]);
111  const auto candidate_key = accessor.get(heap[candidate_idx]);
112  if (compare(curr_key, candidate_key)) {
113  break;
114  }
115  auto temp_id = heap[i];
116  heap[i] = heap[candidate_idx];
117  heap[candidate_idx] = temp_id;
118  i = candidate_idx;
119  }
120 }
121 
122 template <typename KeyT = int64_t, typename NodeT = int64_t>
123 ALWAYS_INLINE DEVICE void sift_up(NodeT* heap,
124  const NodeT curr_idx,
125  const KeyComparator<KeyT>& compare,
126  const KeyAccessor<KeyT, NodeT>& accessor) {
127  for (NodeT i = curr_idx; i > 0 && (i - 1) < i;) {
128  const auto parent = (i - 1) / 2;
129  const auto curr_key = accessor.get(heap[i]);
130  const auto parent_key = accessor.get(heap[parent]);
131  if (compare(parent_key, curr_key)) {
132  break;
133  }
134  auto temp_id = heap[i];
135  heap[i] = heap[parent];
136  heap[parent] = temp_id;
137  i = parent;
138  }
139 }
140 
141 template <typename KeyT = int64_t, typename NodeT = int64_t>
142 ALWAYS_INLINE DEVICE void push_heap(int64_t* heap_ptr,
143  int64_t* rows_ptr,
144  NodeT& node_count,
145  const uint32_t row_size_quad,
146  const uint32_t key_offset,
147  const KeyComparator<KeyT>& comparator,
148  const KeyAccessor<KeyT, NodeT>& accessor,
149  const KeyT curr_key) {
150  const NodeT bin_index = node_count++;
151  heap_ptr[bin_index] = bin_index;
152  int8_t* row_ptr = reinterpret_cast<int8_t*>(rows_ptr + bin_index * row_size_quad);
153  auto key_ptr = reinterpret_cast<KeyT*>(row_ptr + key_offset);
154  *key_ptr = curr_key;
155  // sift up
156  sift_up<KeyT, NodeT>(heap_ptr, bin_index, comparator, accessor);
157 }
158 
159 template <typename KeyT = int64_t, typename NodeT = int64_t>
160 ALWAYS_INLINE DEVICE bool pop_and_push_heap(int64_t* heap_ptr,
161  int64_t* rows_ptr,
162  const NodeT node_count,
163  const uint32_t row_size_quad,
164  const uint32_t key_offset,
165  const KeyComparator<KeyT>& compare,
166  const KeyAccessor<KeyT, NodeT>& accessor,
167  const KeyT curr_key) {
168  const NodeT top_bin_idx = static_cast<NodeT>(heap_ptr[0]);
169  int8_t* top_row_ptr = reinterpret_cast<int8_t*>(rows_ptr + top_bin_idx * row_size_quad);
170  auto top_key = reinterpret_cast<KeyT*>(top_row_ptr + key_offset);
171  if (compare(curr_key, *top_key)) {
172  return false;
173  }
174  // kick out
175  *top_key = curr_key;
176  // sift down
177  sift_down<KeyT, NodeT>(heap_ptr, node_count, 0, compare, accessor);
178  return true;
179 }
180 
181 // This function only works on rowwise layout.
182 template <typename KeyT = int64_t>
184  const uint32_t k,
185  const uint32_t row_size_quad,
186  const uint32_t key_offset,
187  const bool min_heap,
188  const bool has_null,
189  const bool nulls_first,
190  const KeyT null_key,
191  const KeyT curr_key) {
192  const int32_t thread_global_index = pos_start_impl(nullptr);
193  const int32_t thread_count = pos_step_impl();
194  int64_t& node_count = heaps[thread_global_index];
195  int64_t* heap_ptr = heaps + thread_count + thread_global_index * k;
196  int64_t* rows_ptr =
197  heaps + thread_count + thread_count * k + thread_global_index * row_size_quad * k;
199  has_null,
200  null_key,
201  nulls_first ? NullsOrdering::FIRST : NullsOrdering::LAST);
202  KeyAccessor<KeyT, int64_t> accessor(reinterpret_cast<int8_t*>(rows_ptr),
203  row_size_quad * sizeof(int64_t),
204  key_offset / sizeof(KeyT));
205  if (node_count < static_cast<int64_t>(k)) {
206  push_heap(heap_ptr,
207  rows_ptr,
208  node_count,
209  row_size_quad,
210  key_offset,
211  compare,
212  accessor,
213  curr_key);
214  const auto last_bin_index = node_count - 1;
215  auto row_ptr = rows_ptr + last_bin_index * row_size_quad;
216  row_ptr[0] = last_bin_index;
217  return row_ptr + 1;
218  } else {
219  const int64_t top_bin_idx = heap_ptr[0];
220  const bool rejected = !pop_and_push_heap(heap_ptr,
221  rows_ptr,
222  node_count,
223  row_size_quad,
224  key_offset,
225  compare,
226  accessor,
227  curr_key);
228  if (rejected) {
229  return nullptr;
230  }
231  auto row_ptr = rows_ptr + top_bin_idx * row_size_quad;
232  row_ptr[0] = top_bin_idx;
233  return row_ptr + 1;
234  }
235 }
236 
237 #define DEF_GET_BIN_FROM_K_HEAP(key_type) \
238  extern "C" RUNTIME_EXPORT NEVER_INLINE DEVICE int64_t* get_bin_from_k_heap_##key_type( \
239  int64_t* heaps, \
240  const uint32_t k, \
241  const uint32_t row_size_quad, \
242  const uint32_t key_offset, \
243  const bool min_heap, \
244  const bool has_null, \
245  const bool nulls_first, \
246  const key_type null_key, \
247  const key_type curr_key) { \
248  return get_bin_from_k_heap_impl(heaps, \
249  k, \
250  row_size_quad, \
251  key_offset, \
252  min_heap, \
253  has_null, \
254  nulls_first, \
255  null_key, \
256  curr_key); \
257  }
258 
ALWAYS_INLINE DEVICE void sift_down(NodeT *heap, const size_t heap_size, const NodeT curr_idx, const KeyComparator< KeyT > &compare, const KeyAccessor< KeyT, NodeT > &accessor)
Definition: TopKRuntime.cpp:82
const NullsOrdering nulls_ordering
Definition: TopKRuntime.cpp:78
ALWAYS_INLINE DEVICE void sift_up(NodeT *heap, const NodeT curr_idx, const KeyComparator< KeyT > &compare, const KeyAccessor< KeyT, NodeT > &accessor)
ALWAYS_INLINE DEVICE void push_heap(int64_t *heap_ptr, int64_t *rows_ptr, NodeT &node_count, const uint32_t row_size_quad, const uint32_t key_offset, const KeyComparator< KeyT > &comparator, const KeyAccessor< KeyT, NodeT > &accessor, const KeyT curr_key)
__device__ int32_t pos_step_impl()
Definition: cuda_mapd_rt.cu:35
const bool has_nulls
Definition: TopKRuntime.cpp:76
HeapOrdering
Definition: TopKRuntime.cpp:25
const KeyT null_key
Definition: TopKRuntime.cpp:77
#define DEVICE
const size_t index
Definition: TopKRuntime.cpp:42
const HeapOrdering heap_ordering
Definition: TopKRuntime.cpp:75
__device__ int32_t pos_start_impl(const int32_t *row_index_resume)
Definition: cuda_mapd_rt.cu:27
const size_t stride
Definition: TopKRuntime.cpp:41
#define DEF_GET_BIN_FROM_K_HEAP(key_type)
NullsOrdering
Definition: TopKRuntime.cpp:27
ALWAYS_INLINE DEVICE KeyT get(const IndexT rowid) const
Definition: TopKRuntime.cpp:35
ALWAYS_INLINE DEVICE bool operator()(const KeyT lhs, const KeyT rhs) const
Definition: TopKRuntime.cpp:55
ALWAYS_INLINE DEVICE bool pop_and_push_heap(int64_t *heap_ptr, int64_t *rows_ptr, const NodeT node_count, const uint32_t row_size_quad, const uint32_t key_offset, const KeyComparator< KeyT > &compare, const KeyAccessor< KeyT, NodeT > &accessor, const KeyT curr_key)
ALWAYS_INLINE DEVICE int64_t * get_bin_from_k_heap_impl(int64_t *heaps, const uint32_t k, const uint32_t row_size_quad, const uint32_t key_offset, const bool min_heap, const bool has_null, const bool nulls_first, const KeyT null_key, const KeyT curr_key)
DEVICE KeyComparator(const HeapOrdering hp_order, const bool nullable, const KeyT null_val, const NullsOrdering null_order)
Definition: TopKRuntime.cpp:47
#define ALWAYS_INLINE
const int8_t * buffer
Definition: TopKRuntime.cpp:40
DEVICE KeyAccessor(const int8_t *key_buff, const size_t key_stride, const size_t key_idx)
Definition: TopKRuntime.cpp:31