OmniSciDB  04ee39c94c
TopKRuntime.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /*
18  * @file TopKRuntime.cpp
19  * @author Minggang Yu <miyu@mapd.com>
20  * @brief Structures and runtime functions of streaming top-k heap
21  *
22  * Copyright (c) 2017 MapD Technologies, Inc. All rights reserved.
23  */
24 #include "../Shared/funcannotations.h"
25 
26 enum class HeapOrdering { MIN, MAX };
27 
28 enum class NullsOrdering { FIRST, LAST };
29 
30 template <typename KeyT = int64_t, typename IndexT = int32_t>
31 struct KeyAccessor {
32  DEVICE KeyAccessor(const int8_t* key_buff,
33  const size_t key_stride,
34  const size_t key_idx)
35  : buffer(key_buff), stride(key_stride), index(key_idx) {}
36  ALWAYS_INLINE DEVICE KeyT get(const IndexT rowid) const {
37  auto keys_ptr = reinterpret_cast<const KeyT*>(buffer + stride * rowid);
38  return keys_ptr[index];
39  }
40 
41  const int8_t* buffer;
42  const size_t stride;
43  const size_t index;
44 };
45 
46 template <typename KeyT = int64_t>
47 struct KeyComparator {
49  const bool nullable,
50  const KeyT null_val,
51  const NullsOrdering null_order)
52  : heap_ordering(hp_order)
53  , has_nulls(nullable)
54  , null_key(null_val)
55  , nulls_ordering(null_order) {}
56  ALWAYS_INLINE DEVICE bool operator()(const KeyT lhs, const KeyT rhs) const {
57  if (has_nulls) {
58  if (nulls_ordering == NullsOrdering::FIRST) {
59  if (rhs == null_key) {
60  return true;
61  }
62  if (lhs == null_key) {
63  return false;
64  }
65  } else {
66  if (lhs == null_key) {
67  return true;
68  }
69  if (rhs == null_key) {
70  return false;
71  }
72  }
73  }
74  return heap_ordering == HeapOrdering::MIN ? (lhs < rhs) : (lhs > rhs);
75  }
77  const bool has_nulls;
78  const KeyT null_key;
80 };
81 
82 template <typename KeyT = int64_t, typename NodeT = int64_t>
83 ALWAYS_INLINE DEVICE void sift_down(NodeT* heap,
84  const size_t heap_size,
85  const NodeT curr_idx,
86  const KeyComparator<KeyT>& compare,
87  const KeyAccessor<KeyT, NodeT>& accessor) {
88  for (NodeT i = curr_idx, last = static_cast<NodeT>(heap_size); i < last;) {
89 #ifdef __CUDACC__
90  const auto left_child = min(2 * i + 1, last);
91  const auto right_child = min(2 * i + 2, last);
92 #else
93  const auto left_child = std::min(2 * i + 1, last);
94  const auto right_child = std::min(2 * i + 2, last);
95 #endif
96  auto candidate_idx = last;
97  if (left_child < last) {
98  if (right_child < last) {
99  const auto left_key = accessor.get(heap[left_child]);
100  const auto right_key = accessor.get(heap[right_child]);
101  candidate_idx = compare(left_key, right_key) ? left_child : right_child;
102  } else {
103  candidate_idx = left_child;
104  }
105  } else {
106  candidate_idx = right_child;
107  }
108  if (candidate_idx >= last) {
109  break;
110  }
111  const auto curr_key = accessor.get(heap[i]);
112  const auto candidate_key = accessor.get(heap[candidate_idx]);
113  if (compare(curr_key, candidate_key)) {
114  break;
115  }
116  auto temp_id = heap[i];
117  heap[i] = heap[candidate_idx];
118  heap[candidate_idx] = temp_id;
119  i = candidate_idx;
120  }
121 }
122 
123 template <typename KeyT = int64_t, typename NodeT = int64_t>
124 ALWAYS_INLINE DEVICE void sift_up(NodeT* heap,
125  const NodeT curr_idx,
126  const KeyComparator<KeyT>& compare,
127  const KeyAccessor<KeyT, NodeT>& accessor) {
128  for (NodeT i = curr_idx; i > 0 && (i - 1) < i;) {
129  const auto parent = (i - 1) / 2;
130  const auto curr_key = accessor.get(heap[i]);
131  const auto parent_key = accessor.get(heap[parent]);
132  if (compare(parent_key, curr_key)) {
133  break;
134  }
135  auto temp_id = heap[i];
136  heap[i] = heap[parent];
137  heap[parent] = temp_id;
138  i = parent;
139  }
140 }
141 
142 template <typename KeyT = int64_t, typename NodeT = int64_t>
143 ALWAYS_INLINE DEVICE void push_heap(int64_t* heap_ptr,
144  int64_t* rows_ptr,
145  NodeT& node_count,
146  const uint32_t row_size_quad,
147  const uint32_t key_offset,
148  const KeyComparator<KeyT>& comparator,
149  const KeyAccessor<KeyT, NodeT>& accessor,
150  const KeyT curr_key) {
151  const NodeT bin_index = node_count++;
152  heap_ptr[bin_index] = bin_index;
153  int8_t* row_ptr = reinterpret_cast<int8_t*>(rows_ptr + bin_index * row_size_quad);
154  auto key_ptr = reinterpret_cast<KeyT*>(row_ptr + key_offset);
155  *key_ptr = curr_key;
156  // sift up
157  sift_up<KeyT, NodeT>(heap_ptr, bin_index, comparator, accessor);
158 }
159 
160 template <typename KeyT = int64_t, typename NodeT = int64_t>
161 ALWAYS_INLINE DEVICE bool pop_and_push_heap(int64_t* heap_ptr,
162  int64_t* rows_ptr,
163  const NodeT node_count,
164  const uint32_t row_size_quad,
165  const uint32_t key_offset,
166  const KeyComparator<KeyT>& compare,
167  const KeyAccessor<KeyT, NodeT>& accessor,
168  const KeyT curr_key) {
169  const NodeT top_bin_idx = static_cast<NodeT>(heap_ptr[0]);
170  int8_t* top_row_ptr = reinterpret_cast<int8_t*>(rows_ptr + top_bin_idx * row_size_quad);
171  auto top_key = reinterpret_cast<KeyT*>(top_row_ptr + key_offset);
172  if (compare(curr_key, *top_key)) {
173  return false;
174  }
175  // kick out
176  *top_key = curr_key;
177  // sift down
178  sift_down<KeyT, NodeT>(heap_ptr, node_count, 0, compare, accessor);
179  return true;
180 }
181 
182 // This function only works on rowwise layout.
183 template <typename KeyT = int64_t>
185  const uint32_t k,
186  const uint32_t row_size_quad,
187  const uint32_t key_offset,
188  const bool min_heap,
189  const bool has_null,
190  const bool nulls_first,
191  const KeyT null_key,
192  const KeyT curr_key) {
193  const int32_t thread_global_index = pos_start_impl(nullptr);
194  const int32_t thread_count = pos_step_impl();
195  int64_t& node_count = heaps[thread_global_index];
196  int64_t* heap_ptr = heaps + thread_count + thread_global_index * k;
197  int64_t* rows_ptr =
198  heaps + thread_count + thread_count * k + thread_global_index * row_size_quad * k;
200  has_null,
201  null_key,
202  nulls_first ? NullsOrdering::FIRST : NullsOrdering::LAST);
203  KeyAccessor<KeyT, int64_t> accessor(reinterpret_cast<int8_t*>(rows_ptr),
204  row_size_quad * sizeof(int64_t),
205  key_offset / sizeof(KeyT));
206  if (node_count < static_cast<int64_t>(k)) {
207  push_heap(heap_ptr,
208  rows_ptr,
209  node_count,
210  row_size_quad,
211  key_offset,
212  compare,
213  accessor,
214  curr_key);
215  const auto last_bin_index = node_count - 1;
216  auto row_ptr = rows_ptr + last_bin_index * row_size_quad;
217  row_ptr[0] = last_bin_index;
218  return row_ptr + 1;
219  } else {
220  const int64_t top_bin_idx = heap_ptr[0];
221  const bool rejected = !pop_and_push_heap(heap_ptr,
222  rows_ptr,
223  node_count,
224  row_size_quad,
225  key_offset,
226  compare,
227  accessor,
228  curr_key);
229  if (rejected) {
230  return nullptr;
231  }
232  auto row_ptr = rows_ptr + top_bin_idx * row_size_quad;
233  row_ptr[0] = top_bin_idx;
234  return row_ptr + 1;
235  }
236 }
237 
238 #define DEF_GET_BIN_FROM_K_HEAP(key_type) \
239  extern "C" NEVER_INLINE DEVICE int64_t* get_bin_from_k_heap_##key_type( \
240  int64_t* heaps, \
241  const uint32_t k, \
242  const uint32_t row_size_quad, \
243  const uint32_t key_offset, \
244  const bool min_heap, \
245  const bool has_null, \
246  const bool nulls_first, \
247  const key_type null_key, \
248  const key_type curr_key) { \
249  return get_bin_from_k_heap_impl(heaps, \
250  k, \
251  row_size_quad, \
252  key_offset, \
253  min_heap, \
254  has_null, \
255  nulls_first, \
256  null_key, \
257  curr_key); \
258  }
259 
ALWAYS_INLINE DEVICE void sift_down(NodeT *heap, const size_t heap_size, const NodeT curr_idx, const KeyComparator< KeyT > &compare, const KeyAccessor< KeyT, NodeT > &accessor)
Definition: TopKRuntime.cpp:83
ALWAYS_INLINE DEVICE bool operator()(const KeyT lhs, const KeyT rhs) const
Definition: TopKRuntime.cpp:56
const NullsOrdering nulls_ordering
Definition: TopKRuntime.cpp:79
ALWAYS_INLINE DEVICE void sift_up(NodeT *heap, const NodeT curr_idx, const KeyComparator< KeyT > &compare, const KeyAccessor< KeyT, NodeT > &accessor)
ALWAYS_INLINE DEVICE void push_heap(int64_t *heap_ptr, int64_t *rows_ptr, NodeT &node_count, const uint32_t row_size_quad, const uint32_t key_offset, const KeyComparator< KeyT > &comparator, const KeyAccessor< KeyT, NodeT > &accessor, const KeyT curr_key)
const bool has_nulls
Definition: TopKRuntime.cpp:77
HeapOrdering
Definition: TopKRuntime.cpp:26
const KeyT null_key
Definition: TopKRuntime.cpp:78
#define DEVICE
const size_t index
Definition: TopKRuntime.cpp:43
const HeapOrdering heap_ordering
Definition: TopKRuntime.cpp:76
const size_t stride
Definition: TopKRuntime.cpp:42
#define DEF_GET_BIN_FROM_K_HEAP(key_type)
NullsOrdering
Definition: TopKRuntime.cpp:28
ALWAYS_INLINE DEVICE bool pop_and_push_heap(int64_t *heap_ptr, int64_t *rows_ptr, const NodeT node_count, const uint32_t row_size_quad, const uint32_t key_offset, const KeyComparator< KeyT > &compare, const KeyAccessor< KeyT, NodeT > &accessor, const KeyT curr_key)
ALWAYS_INLINE DEVICE int64_t * get_bin_from_k_heap_impl(int64_t *heaps, const uint32_t k, const uint32_t row_size_quad, const uint32_t key_offset, const bool min_heap, const bool has_null, const bool nulls_first, const KeyT null_key, const KeyT curr_key)
ALWAYS_INLINE DEVICE KeyT get(const IndexT rowid) const
Definition: TopKRuntime.cpp:36
DEVICE KeyComparator(const HeapOrdering hp_order, const bool nullable, const KeyT null_val, const NullsOrdering null_order)
Definition: TopKRuntime.cpp:48
#define ALWAYS_INLINE
const int8_t * buffer
Definition: TopKRuntime.cpp:41
DEVICE KeyAccessor(const int8_t *key_buff, const size_t key_stride, const size_t key_idx)
Definition: TopKRuntime.cpp:32