OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TypedThrustAllocator.h
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <thrust/device_vector.h>
20 #include <thrust/mr/allocator.h>
21 #include <thrust/mr/memory_resource.h>
22 
24 
25 namespace Data_Namespace {
26 
27 namespace detail {
34 template <typename Pointer>
35 class DataMgrMemoryResource final : public thrust::mr::memory_resource<Pointer> {
36  using base = thrust::mr::memory_resource<Pointer>;
37 
38  public:
40  : base(), thrust_allocator_(&thrust_allocator) {}
42  : base(other), thrust_allocator_(other.thrust_allocator_) {}
44 
49  Pointer do_allocate(std::size_t bytes,
50  std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT) final {
51  (void)alignment; // dummy cast to avoid unused warnings
52  return Pointer(
53  reinterpret_cast<typename thrust::detail::pointer_traits<Pointer>::element_type*>(
54  thrust_allocator_->allocate(bytes)));
55  }
56 
61  void do_deallocate(Pointer p, std::size_t bytes, std::size_t alignment) final {
62  (void)alignment; // dummy cast to avoid unused warnings
64  reinterpret_cast<int8_t*>(thrust::detail::pointer_traits<Pointer>::get(p)),
65  bytes);
66  }
67 
68  __host__ __device__ const ThrustAllocator* getThrustAllocator() const {
69  return thrust_allocator_;
70  }
71 
72  private:
74 };
75 
84  public:
85  using Pointer = thrust::
86  pointer<void, thrust::device_system_tag, thrust::use_default, thrust::use_default>;
87 
88  // Need to define a device_ptr_memory_resource here so any implied execution
89  // policies can be defined as device execution policies
90  using DeviceResource =
91  thrust::device_ptr_memory_resource<DataMgrMemoryResource<Pointer>>;
92 
94  : data_mgr_mem_rsrc_(thrust_allocator), device_rsrc_(&data_mgr_mem_rsrc_) {}
95 
96  // Need to override the default copy constructor/operator and move constructor to ensure
97  // that the device_rsrc_ is constructed with a pointer to the local data_mgr_mem_rsrc_
100 
102  : data_mgr_mem_rsrc_(std::move(other.data_mgr_mem_rsrc_))
104 
105  __host__ __device__ void operator=(const TypedThrustAllocatorState& other) {
106  assert(data_mgr_mem_rsrc_.getThrustAllocator() ==
107  other.data_mgr_mem_rsrc_.getThrustAllocator());
108  // NOTE: only copying the data_mgr_mem_rsrc_
109  // The device_rsrc_ should have already been constructed with a poitner to the local
110  // data_mgr_mem_rsrc_ and is therefore up-to-date.
112  }
113 
114  // TODO(croot): handle rvalue operator=?
115 
116  protected:
119 };
120 } // namespace detail
121 
135 template <typename T>
138  public thrust::mr::allocator<T, detail::TypedThrustAllocatorState::DeviceResource> {
139  using Base =
140  thrust::mr::allocator<T, detail::TypedThrustAllocatorState::DeviceResource>;
141 
142  public:
144  : detail::TypedThrustAllocatorState(thrust_allocator), Base(&device_rsrc_) {}
145 
146  // Need to override the default copy constructor/operator and move constructor to ensure
147  // that our Base(thrust::mr::allocator) is constructed with a pointer to our
148  // device_rsrc_ state
150  : detail::TypedThrustAllocatorState(other), Base(&device_rsrc_) {}
151 
153  : detail::TypedThrustAllocatorState(std::move(other)), Base(&device_rsrc_) {}
154 
155  __host__ __device__ void operator=(const TypedThrustAllocator<T>& other) {
156  // NOTE: only applying the copy operator to TypedThrustAllocatorState
157  // The thrust::mr::allocator should have already been constructed with a poitner to
158  // the local state and is therefore up-to-date
160  }
161 }; // namespace Data_Namespace
162 
163 template <typename T>
164 using ThrustAllocatorDeviceVector = thrust::device_vector<T, TypedThrustAllocator<T>>;
165 
166 } // namespace Data_Namespace
DataMgrMemoryResource(const DataMgrMemoryResource &other)
void do_deallocate(Pointer p, std::size_t bytes, std::size_t alignment) final
Overrides a pure virtual function defined in thrust::mr::memory_resource to deallocate memory from a ...
int8_t * allocate(std::ptrdiff_t num_bytes)
__host__ __device__ void operator=(const TypedThrustAllocator< T > &other)
void deallocate(int8_t *ptr, size_t num_bytes)
Pointer do_allocate(std::size_t bytes, std::size_t alignment=THRUST_MR_DEFAULT_ALIGNMENT) final
Overrides a pure virtual function defined in thrust::mr::memory_resource to allocate from a ThrustAll...
Manages the underlying state of a TypedThrustAllocator. The state consists of: DataMgrMemoryResource:...
__host__ __device__ const ThrustAllocator * getThrustAllocator() const
A thrust memory resource wrapped around a Data_Namespace::ThrustAllocator that allocates memory via D...
__host__ __device__ void operator=(const TypedThrustAllocatorState &other)
thrust::pointer< void, thrust::device_system_tag, thrust::use_default, thrust::use_default > Pointer
thrust::mr::allocator< T, detail::TypedThrustAllocatorState::DeviceResource > Base
TypedThrustAllocatorState(ThrustAllocator &thrust_allocator)
DataMgrMemoryResource(ThrustAllocator &thrust_allocator)
a Templated version of Data_Namespace::ThrustAllocator that can be used as a custom allocator in thru...
TypedThrustAllocatorState(TypedThrustAllocatorState &&other)
TypedThrustAllocator(ThrustAllocator &thrust_allocator)
TypedThrustAllocatorState(const TypedThrustAllocatorState &other)
TypedThrustAllocator(TypedThrustAllocator &&other)
thrust::mr::memory_resource< Pointer > base
thrust::device_vector< T, TypedThrustAllocator< T >> ThrustAllocatorDeviceVector
TypedThrustAllocator(const TypedThrustAllocator &other)
thrust::device_ptr_memory_resource< DataMgrMemoryResource< Pointer >> DeviceResource