OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
SessionsStore.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "SessionsStore.h"
18 #include "Catalog.h"
19 #include "Shared/StringTransform.h"
20 
21 #include <boost/algorithm/string.hpp>
22 #include <memory>
23 #include <thread>
24 #include <unordered_map>
25 
26 using namespace Catalog_Namespace;
27 
28 SessionInfo SessionsStore::getSessionCopy(const std::string& session_id) {
29  auto origin = get(session_id);
30  if (origin) {
31  heavyai::shared_lock<heavyai::shared_mutex> lock(origin->getLock());
32  return *origin;
33  }
34  throw std::runtime_error("No session with id " + session_id);
35 }
36 
37 void SessionsStore::erase(const std::string& session_id) {
39  eraseUnlocked(session_id);
40 }
41 
42 void SessionsStore::eraseByUser(const std::string& user_name) {
43  eraseIf([&user_name](const SessionInfoPtr& session_ptr) {
44  return boost::iequals(user_name, session_ptr->get_currentUser().userName);
45  });
46 }
47 
48 void SessionsStore::eraseByDB(const std::string& db_name) {
49  eraseIf([&db_name](const SessionInfoPtr& session_ptr) {
50  return boost::iequals(db_name, session_ptr->getCatalog().getCurrentDB().dbName);
51  });
52 }
53 
54 void SessionsStore::disconnect(const std::string session_id) {
56  auto session_ptr = getUnlocked(session_id);
57  if (session_ptr) {
58  const auto dbname = session_ptr->getCatalog().getCurrentDB().dbName;
59  LOG(INFO) << "User " << session_ptr->get_currentUser().userLoggable()
60  << " disconnected from database " << dbname
61  << " with public_session_id: " << session_ptr->get_public_session_id();
62  getDisconnectCallback()(session_ptr);
63  eraseUnlocked(session_ptr->get_session_id());
64  }
65 }
66 
68  int idle_session_duration,
69  int max_session_duration) {
70  if (isSessionInUse(session_ptr)) {
71  return false;
72  }
73  time_t last_used_time = session_ptr->get_last_used_time();
74  time_t start_time = session_ptr->get_start_time();
75  const auto current_session_duration = time(0) - last_used_time;
76  if (current_session_duration > idle_session_duration) {
77  LOG(INFO) << "Session " << session_ptr->get_public_session_id() << " idle duration "
78  << current_session_duration << " seconds exceeds maximum idle duration "
79  << idle_session_duration << " seconds. Invalidating session.";
80  return true;
81  }
82  const auto total_session_duration = time(0) - start_time;
83  if (total_session_duration > max_session_duration) {
84  LOG(INFO) << "Session " << session_ptr->get_public_session_id() << " total duration "
85  << total_session_duration
86  << " seconds exceeds maximum total session duration "
87  << max_session_duration << " seconds. Invalidating session.";
88  return true;
89  }
90  return false;
91 }
92 
93 std::vector<SessionInfoPtr> SessionsStore::getAllSessions() {
94  return getIf([](const SessionInfoPtr&) { return true; });
95 }
96 
97 std::vector<SessionInfoPtr> SessionsStore::getUserSessions(const std::string& user_name) {
98  return getIf([&user_name](const SessionInfoPtr& session_ptr) {
99  return session_ptr->get_currentUser().userName == user_name;
100  });
101 }
102 
103 SessionInfoPtr SessionsStore::getByPublicID(const std::string& public_id) {
104  auto sessions = getIf([&public_id](const SessionInfoPtr& session_ptr) {
105  return session_ptr->get_public_session_id() == public_id;
106  });
107  if (sessions.empty()) {
108  return nullptr;
109  }
110  CHECK_EQ(sessions.size(), 1ul);
111  return sessions[0];
112 }
113 
115  public:
116  CachedSessionStore(int idle_session_duration,
117  int max_session_duration,
118  int capacity,
119  DisconnectCallback disconnect_callback)
120  : idle_session_duration_(idle_session_duration)
121  , max_session_duration_(max_session_duration)
122  , capacity_(capacity > 0 ? capacity : INT_MAX)
123  , disconnect_callback_(disconnect_callback) {}
124 
126  std::shared_ptr<Catalog> cat,
127  ExecutorDeviceType device) override {
129  if (int(sessions_.size()) >= capacity_) {
130  std::vector<SessionInfoPtr> expired_sessions;
131  for (auto it = sessions_.begin(); it != sessions_.end(); it++) {
132  if (isSessionExpired(it->second, idle_session_duration_, max_session_duration_)) {
133  expired_sessions.push_back(it->second);
134  }
135  }
136  for (auto& session_ptr : expired_sessions) {
137  try {
138  disconnect_callback_(session_ptr);
139  eraseUnlocked(session_ptr->get_session_id());
140  } catch (const std::exception& e) {
141  eraseUnlocked(session_ptr->get_session_id());
142  throw e;
143  }
144  }
145  }
146  if (int(sessions_.size()) < capacity_) {
147  do {
149  if (sessions_.count(session_id) != 0) {
150  continue;
151  }
152  auto session_ptr = std::make_shared<Catalog_Namespace::SessionInfo>(
153  cat, user_meta, device, session_id);
154  sessions_[session_id] = session_ptr;
155  return session_ptr;
156  } while (true);
157  UNREACHABLE();
158  }
159  throw std::runtime_error("Too many active sessions");
160  }
161 
162  SessionInfoPtr get(const std::string& session_id) override {
164  auto session_ptr = getUnlocked(session_id);
165  if (session_ptr) {
167  session_ptr, idle_session_duration_, max_session_duration_)) {
168  try {
169  disconnect_callback_(session_ptr);
170  eraseUnlocked(session_ptr->get_session_id());
171  } catch (const std::exception& e) {
172  eraseUnlocked(session_ptr->get_session_id());
173  throw e;
174  }
175  return nullptr;
176  }
177  session_ptr->update_last_used_time();
178  return session_ptr;
179  }
180  return nullptr;
181  }
182 
183  heavyai::shared_mutex& getLock() override { return mtx_; }
184 
185  void eraseIf(std::function<bool(const SessionInfoPtr&)> predicate) override {
187  for (auto it = sessions_.begin(); it != sessions_.end();) {
188  if (predicate(it->second)) {
189  it = sessions_.erase(it);
190  } else {
191  it++;
192  }
193  }
194  }
195 
196  ~CachedSessionStore() override {
197  std::lock_guard lg(mtx_);
198  sessions_.clear();
199  }
200 
201  protected:
202  void eraseUnlocked(const std::string& session_id) override {
203  sessions_.erase(session_id);
204  }
205 
206  bool isSessionInUse(const SessionInfoPtr& session_ptr) override {
207  return session_ptr.use_count() > 2;
208  }
209 
210  SessionInfoPtr getUnlocked(const std::string& session_id) override {
211  if (auto session_it = sessions_.find(session_id); session_it != sessions_.end()) {
212  return session_it->second;
213  }
214  return nullptr;
215  }
216 
217  DisconnectCallback getDisconnectCallback() override { return disconnect_callback_; }
218 
219  std::vector<SessionInfoPtr> getIf(
220  std::function<bool(const SessionInfoPtr&)> predicate) override {
221  std::vector<SessionInfoPtr> out;
222  heavyai::shared_lock<heavyai::shared_mutex> sessions_lock(getLock());
223  for (auto& [_, session] : sessions_) {
224  heavyai::shared_lock<heavyai::shared_mutex> session_lock(session->getLock());
225  if (predicate(session)) {
226  out.push_back(session);
227  }
228  }
229  return out;
230  }
231 
232  private:
233  std::unordered_map<std::string, SessionInfoPtr> sessions_;
237  const int capacity_;
239 };
240 
241 std::unique_ptr<SessionsStore> SessionsStore::create(
242  const std::string& base_path,
243  size_t n_workers,
244  int idle_session_duration,
245  int max_session_duration,
246  int capacity,
247  DisconnectCallback disconnect_callback) {
248  return std::make_unique<CachedSessionStore>(
249  idle_session_duration, max_session_duration, capacity, disconnect_callback);
250 }
std::lock_guard< T > lock_guard
void erase(const std::string &session_id)
#define CHECK_EQ(x, y)
Definition: Logger.h:301
virtual SessionInfoPtr getUnlocked(const std::string &session_id)=0
std::string cat(Ts &&...args)
std::function< void(SessionInfoPtr &session)> DisconnectCallback
Definition: SessionsStore.h:29
const int max_session_duration_
#define LOG(tag)
Definition: Logger.h:285
std::vector< SessionInfoPtr > getIf(std::function< bool(const SessionInfoPtr &)> predicate) override
#define UNREACHABLE()
Definition: Logger.h:338
std::unordered_map< std::string, SessionInfoPtr > sessions_
virtual void eraseUnlocked(const std::string &session_id)=0
const int idle_session_duration_
virtual DisconnectCallback getDisconnectCallback()=0
void eraseUnlocked(const std::string &session_id) override
ExecutorDeviceType
~CachedSessionStore() override
std::shared_lock< T > shared_lock
This file contains the class specification and related data structures for Catalog.
SessionInfo getSessionCopy(const std::string &session_id)
virtual heavyai::shared_mutex & getLock()=0
std::string generate_random_string(const size_t len)
bool isSessionExpired(const SessionInfoPtr &session_ptr, int idle_session_duration, int max_session_duration)
virtual bool isSessionInUse(const SessionInfoPtr &session_ptr)=0
virtual void eraseIf(std::function< bool(const SessionInfoPtr &)> predicate)=0
static std::unique_ptr< SessionsStore > create(const std::string &base_path, size_t n_workers, int idle_session_duration, int max_session_duration, int capacity, DisconnectCallback disconnect_callback)
DisconnectCallback getDisconnectCallback() override
const size_t SESSION_ID_LENGTH
Definition: SessionInfo.h:127
virtual std::vector< SessionInfoPtr > getIf(std::function< bool(const SessionInfoPtr &)> predicate)=0
void eraseIf(std::function< bool(const SessionInfoPtr &)> predicate) override
CachedSessionStore(int idle_session_duration, int max_session_duration, int capacity, DisconnectCallback disconnect_callback)
void eraseByUser(const std::string &user_name)
DisconnectCallback disconnect_callback_
SessionInfoPtr add(const Catalog_Namespace::UserMetadata &user_meta, std::shared_ptr< Catalog > cat, ExecutorDeviceType device) override
SessionInfoPtr getUnlocked(const std::string &session_id) override
void disconnect(const std::string session_id)
SessionInfoPtr getByPublicID(const std::string &public_id)
std::shared_timed_mutex shared_mutex
heavyai::shared_mutex & getLock() override
std::vector< SessionInfoPtr > getAllSessions()
bool isSessionInUse(const SessionInfoPtr &session_ptr) override
heavyai::shared_mutex mtx_
void eraseByDB(const std::string &db_name)
std::vector< SessionInfoPtr > getUserSessions(const std::string &user_name)
std::shared_ptr< SessionInfo > SessionInfoPtr
Definition: SessionsStore.h:27