OmniSciDB  b28c0d5765
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TableFunctionsDataCache.h
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 #ifndef __CUDACC__
19 
20 #include <cstring> // std::memcpy
21 #include <iostream>
22 #include <memory>
23 #include <shared_mutex>
24 #include <string>
25 #include <unordered_map>
26 
27 #include <tbb/parallel_for.h>
28 #include <tbb/task_arena.h>
29 
30 struct CacheDataTf {
31  int8_t* data_buffer;
32  size_t num_bytes;
33 
34  CacheDataTf(const size_t num_bytes) : num_bytes(num_bytes) {
35  data_buffer = new int8_t[num_bytes];
36  }
37 
38  ~CacheDataTf() { delete[] data_buffer; }
39 };
40 
42  public:
43  bool isKeyCached(const std::string& key) const {
44  std::shared_lock<std::shared_mutex> read_lock(cache_mutex_);
45  return data_cache_.count(key) > 0;
46  }
47 
48  bool isKeyCachedAndSameLength(const std::string& key, const size_t num_bytes) const {
49  std::shared_lock<std::shared_mutex> read_lock(cache_mutex_);
50  const auto& cached_data_itr = data_cache_.find(key);
51  if (cached_data_itr == data_cache_.end()) {
52  return false;
53  }
54  return num_bytes == cached_data_itr->second->num_bytes;
55  }
56 
57  template <typename T>
58  void getDataForKey(const std::string& key, T* dest_buffer) const {
59  auto timer = DEBUG_TIMER(__func__);
60  std::shared_lock<std::shared_mutex> read_lock(cache_mutex_);
61  const auto& cached_data_itr = data_cache_.find(key);
62  if (cached_data_itr == data_cache_.end()) {
63  const std::string error_msg = "Data for key " + key + " not found in cache.";
64  throw std::runtime_error(error_msg);
65  }
66  copyData(reinterpret_cast<int8_t*>(dest_buffer),
67  cached_data_itr->second->data_buffer,
68  cached_data_itr->second->num_bytes);
69  }
70 
71  template <typename T>
72  const T& getDataRefForKey(const std::string& key) const {
73  std::shared_lock<std::shared_mutex> read_lock(cache_mutex_);
74  const auto& cached_data_itr = data_cache_.find(key);
75  if (cached_data_itr == data_cache_.end()) {
76  const std::string error_msg{"Data for key " + key + " not found in cache."};
77  throw std::runtime_error(error_msg);
78  }
79  return *reinterpret_cast<const T*>(cached_data_itr->second->data_buffer);
80  }
81 
82  template <typename T>
83  const T* getDataPtrForKey(const std::string& key) const {
84  std::shared_lock<std::shared_mutex> read_lock(cache_mutex_);
85  const auto& cached_data_itr = data_cache_.find(key);
86  if (cached_data_itr == data_cache_.end()) {
87  return nullptr;
88  }
89  return reinterpret_cast<const T* const>(cached_data_itr->second->data_buffer);
90  }
91 
92  template <typename T>
93  void putDataForKey(const std::string& key,
94  T* const data_buffer,
95  const size_t num_elements) {
96  auto timer = DEBUG_TIMER(__func__);
97  const size_t num_bytes(num_elements * sizeof(T));
98  auto cache_data = std::make_shared<CacheDataTf>(num_bytes);
99  copyData(cache_data->data_buffer, reinterpret_cast<int8_t*>(data_buffer), num_bytes);
100  std::unique_lock<std::shared_mutex> write_lock(cache_mutex_);
101  const auto& cached_data_itr = data_cache_.find(key);
102  if (data_cache_.find(key) != data_cache_.end()) {
103  if constexpr (debug_print_) {
104  const std::string warning_msg =
105  "Data for key " + key + " already exists in cache. Replacing.";
106  std::cout << warning_msg << std::endl;
107  }
108  cached_data_itr->second.reset();
109  cached_data_itr->second = cache_data;
110  return;
111  }
112  data_cache_.insert(std::make_pair(key, cache_data));
113  }
114 
115  private:
116  const size_t parallel_copy_min_bytes{1 << 20};
117 
118  void copyData(int8_t* dest, const int8_t* source, const size_t num_bytes) const {
119  if (num_bytes < parallel_copy_min_bytes) {
120  std::memcpy(dest, source, num_bytes);
121  return;
122  }
123  const size_t max_bytes_per_thread = parallel_copy_min_bytes;
124  const size_t num_threads =
125  (num_bytes + max_bytes_per_thread - 1) / max_bytes_per_thread;
127  tbb::blocked_range<size_t>(0, num_threads, 1),
128  [&](const tbb::blocked_range<size_t>& r) {
129  const size_t end_chunk_idx = r.end();
130  for (size_t chunk_idx = r.begin(); chunk_idx != end_chunk_idx; ++chunk_idx) {
131  const size_t start_byte = chunk_idx * max_bytes_per_thread;
132  const size_t length =
133  std::min(start_byte + max_bytes_per_thread, num_bytes) - start_byte;
134  std::memcpy(dest + start_byte, source + start_byte, length);
135  }
136  });
137  }
138 
139  std::unordered_map<std::string, std::shared_ptr<CacheDataTf>> data_cache_;
141  static constexpr bool debug_print_{false};
142 };
143 
144 template <typename T>
145 class DataCache {
146  public:
147  bool isKeyCached(const std::string& key) const {
148  std::shared_lock<std::shared_mutex> read_lock(cache_mutex_);
149  return data_cache_.count(key) > 0;
150  }
151 
152  std::shared_ptr<T> getDataForKey(const std::string& key) const {
153  std::shared_lock<std::shared_mutex> read_lock(cache_mutex_);
154  const auto& cached_data_itr = data_cache_.find(key);
155  if (cached_data_itr == data_cache_.end()) {
156  const std::string error_msg{"Data for key " + key + " not found in cache."};
157  throw std::runtime_error(error_msg);
158  }
159  return cached_data_itr->second;
160  }
161 
162  void putDataForKey(const std::string& key, std::shared_ptr<T> const data) {
163  std::unique_lock<std::shared_mutex> write_lock(cache_mutex_);
164  const auto& cached_data_itr = data_cache_.find(key);
165  if (cached_data_itr != data_cache_.end()) {
166  if constexpr (debug_print_) {
167  const std::string warning_msg =
168  "Data for key " + key + " already exists in cache. Replacing.";
169  std::cout << warning_msg << std::endl;
170  }
171  cached_data_itr->second.reset();
172  cached_data_itr->second = data;
173  }
174  data_cache_.insert(std::make_pair(key, data));
175  }
176 
177  private:
178  std::unordered_map<std::string, std::shared_ptr<T>> data_cache_;
180  static constexpr bool debug_print_{false};
181 };
182 
183 #endif
bool isKeyCachedAndSameLength(const std::string &key, const size_t num_bytes) const
heavyai::shared_lock< heavyai::shared_mutex > read_lock
void copyData(int8_t *dest, const int8_t *source, const size_t num_bytes) const
bool isKeyCached(const std::string &key) const
std::unordered_map< std::string, std::shared_ptr< T > > data_cache_
heavyai::unique_lock< heavyai::shared_mutex > write_lock
void putDataForKey(const std::string &key, T *const data_buffer, const size_t num_elements)
static constexpr bool debug_print_
bool isKeyCached(const std::string &key) const
void putDataForKey(const std::string &key, std::shared_ptr< T > const data)
void getDataForKey(const std::string &key, T *dest_buffer) const
std::shared_ptr< T > getDataForKey(const std::string &key) const
const T & getDataRefForKey(const std::string &key) const
const T * getDataPtrForKey(const std::string &key) const
void parallel_for(const blocked_range< Int > &range, const Body &body, const Partitioner &p=Partitioner())
const size_t parallel_copy_min_bytes
#define DEBUG_TIMER(name)
Definition: Logger.h:374
std::shared_mutex cache_mutex_
CacheDataTf(const size_t num_bytes)
static constexpr bool debug_print_
std::shared_timed_mutex shared_mutex
std::shared_mutex cache_mutex_
std::unordered_map< std::string, std::shared_ptr< CacheDataTf > > data_cache_