Skip to content

Commit

Permalink
Merge pull request #88 from nightly-labs/refactor
Browse files Browse the repository at this point in the history
Add additional lock on the session itself to spread the access
  • Loading branch information
Giems authored Feb 7, 2024
2 parents 2834f7e + 5b4a832 commit eee8b4b
Show file tree
Hide file tree
Showing 15 changed files with 135 additions and 103 deletions.
8 changes: 4 additions & 4 deletions server/src/http/connect_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ pub async fn connect_session(
State(client_to_sessions): State<ClientToSessions>,
Json(request): Json<HttpConnectSessionRequest>,
) -> Result<Json<HttpConnectSessionResponse>, (StatusCode, String)> {
let mut sessions = sessions.write().await;
let session = match sessions.get_mut(&request.session_id) {
Some(session) => session,
let sessions_read = sessions.read().await;
let mut session_write = match sessions_read.get(&request.session_id) {
Some(session) => session.write().await,
None => {
return Err((
StatusCode::BAD_REQUEST,
Expand All @@ -43,7 +43,7 @@ pub async fn connect_session(
};

// Insert user socket
session
session_write
.connect_user(
&request.device,
&request.public_keys,
Expand Down
10 changes: 5 additions & 5 deletions server/src/http/get_pending_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ pub async fn get_pending_request(
State(sessions): State<Sessions>,
Json(request): Json<HttpGetPendingRequestRequest>,
) -> Result<Json<HttpGetPendingRequestResponse>, (StatusCode, String)> {
let sessions = sessions.read().await;
let session = match sessions.get(&request.session_id) {
Some(session) => session,
let sessions_read = sessions.read().await;
let session_read = match sessions_read.get(&request.session_id) {
Some(session) => session.read().await,
None => {
return Err((
StatusCode::BAD_REQUEST,
Expand All @@ -38,14 +38,14 @@ pub async fn get_pending_request(
}
};

if session.client_state.client_id != Some(request.client_id.clone()) {
if session_read.client_state.client_id != Some(request.client_id.clone()) {
return Err((
StatusCode::BAD_REQUEST,
NightlyError::UserNotConnected.to_string(),
));
}

match session.pending_requests.get(&request.request_id) {
match session_read.pending_requests.get(&request.request_id) {
Some(pending_request) => {
return Ok(Json(HttpGetPendingRequestResponse {
request: pending_request.clone(),
Expand Down
10 changes: 5 additions & 5 deletions server/src/http/get_pending_requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ pub async fn get_pending_requests(
State(sessions): State<Sessions>,
Json(request): Json<HttpGetPendingRequestsRequest>,
) -> Result<Json<HttpGetPendingRequestsResponse>, (StatusCode, String)> {
let sessions = sessions.read().await;
let session = match sessions.get(&request.session_id) {
Some(session) => session,
let sessions_read = sessions.read().await;
let session_read = match sessions_read.get(&request.session_id) {
Some(session) => session.read().await,
None => {
return Err((
StatusCode::BAD_REQUEST,
Expand All @@ -38,14 +38,14 @@ pub async fn get_pending_requests(
}
};

if session.client_state.client_id != Some(request.client_id.clone()) {
if session_read.client_state.client_id != Some(request.client_id.clone()) {
return Err((
StatusCode::BAD_REQUEST,
NightlyError::UserNotConnected.to_string(),
));
}

let pending_requests = session
let pending_requests = session_read
.pending_requests
.values()
.cloned()
Expand Down
16 changes: 8 additions & 8 deletions server/src/http/get_session_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ pub async fn get_session_info(
State(sessions): State<Sessions>,
Json(request): Json<HttpGetSessionInfoRequest>,
) -> Result<Json<HttpGetSessionInfoResponse>, (StatusCode, String)> {
let sessions = sessions.read().await;
let session = match sessions.get(&request.session_id) {
Some(session) => session,
let sessions_read = sessions.read().await;
let session_read = match sessions_read.get(&request.session_id) {
Some(session) => session.read().await,
None => {
return Err((
StatusCode::BAD_REQUEST,
Expand All @@ -40,11 +40,11 @@ pub async fn get_session_info(
};

let response = HttpGetSessionInfoResponse {
status: session.status.clone(),
persistent: session.persistent,
version: session.version.clone(),
network: session.network.clone(),
app_metadata: session.app_state.metadata.clone(),
status: session_read.status.clone(),
persistent: session_read.persistent,
version: session_read.version.clone(),
network: session_read.network.clone(),
app_metadata: session_read.app_state.metadata.clone(),
};
return Ok(Json(response));
}
12 changes: 6 additions & 6 deletions server/src/http/resolve_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ pub async fn resolve_request(
Json(request): Json<HttpResolveRequestRequest>,
) -> Result<Json<HttpResolveRequestResponse>, (StatusCode, String)> {
// Get session
let mut sessions = sessions.write().await;
let session = match sessions.get_mut(&request.session_id) {
Some(session) => session,
let sessions_read = sessions.read().await;
let mut session_write = match sessions_read.get(&request.session_id) {
Some(session) => session.write().await,
None => {
return Err((
StatusCode::BAD_REQUEST,
Expand All @@ -39,14 +39,14 @@ pub async fn resolve_request(
};

// Check if client_id matches
if session.client_state.client_id != Some(request.client_id.clone()) {
if session_write.client_state.client_id != Some(request.client_id.clone()) {
return Err((
StatusCode::BAD_REQUEST,
NightlyError::UserNotConnected.to_string(),
));
}
// Remove request from pending requests
if let None = session.pending_requests.remove(&request.request_id) {
if let None = session_write.pending_requests.remove(&request.request_id) {
return Err((
StatusCode::BAD_REQUEST,
NightlyError::RequestDoesNotExist.to_string(),
Expand All @@ -58,7 +58,7 @@ pub async fn resolve_request(
response_id: request.request_id.clone(),
content: request.content.clone(),
});
if let Err(_) = session.send_to_app(app_msg).await {
if let Err(_) = session_write.send_to_app(app_msg).await {
return Err((
StatusCode::BAD_REQUEST,
NightlyError::AppDisconnected.to_string(),
Expand Down
58 changes: 38 additions & 20 deletions server/src/sesssion_cleaner.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::{time::Duration, vec};

use futures::SinkExt;

use crate::{
state::{ClientToSessions, ModifySession, Sessions},
utils::get_timestamp_in_milliseconds,
};
use futures::SinkExt;
use log::info;
use std::{time::Duration, vec};

pub fn start_cleaning_sessions(sessions: Sessions, client_to_sessions: ClientToSessions) {
let sessions = sessions.clone();
Expand All @@ -17,37 +16,56 @@ pub fn start_cleaning_sessions(sessions: Sessions, client_to_sessions: ClientToS
loop {
// Wait for tick
interval.tick().await;
let mut sessions_to_remove = vec![];

// Remove all sessions that expired
let mut sessions_to_remove = vec![];
let now = get_timestamp_in_milliseconds();
let mut sessions = sessions.write().await;
for (session_id, session) in sessions.iter() {
// Check if the session expired

info!("[{:?}]: Cleaning sessions", now);

// Lock sessions
let mut sessions_write = sessions.write().await;

// Iterate over all sessions and check if they expired
for (session_id, session) in sessions_write.iter() {
// Default session time is two weeks
if session.creation_timestamp + 1000 * 60 * 60 * 24 * 14 < now {
if session.read().await.creation_timestamp + 1000 * 60 * 60 * 24 * 14 < now {
sessions_to_remove.push(session_id.clone());
}
}

info!(
"[{:?}]: {} sessions to remove",
now,
sessions_to_remove.len()
);

// Remove all sessions that expired
for session_id in sessions_to_remove {
let session = sessions.get_mut(&session_id).unwrap();
// safe unwrap because we just checked if the session exists
let session = sessions_write.get_mut(&session_id).unwrap();
let mut session_write = session.write().await;

// Remove session from client_to_sessions
match &session.client_state.client_id {
Some(client_id) => {
client_to_sessions
.remove_session(client_id.clone(), session_id.clone())
.await;
}
None => {}
if let Some(client_id) = &session_write.client_state.client_id {
client_to_sessions
.remove_session(client_id.clone(), session_id.clone())
.await;
}

// Disconnect app
// Send to all apps
for (_, socket) in &mut session.app_state.app_socket {
let _ = socket.close();
for (_, socket) in &mut session_write.app_state.app_socket {
socket.close().await.unwrap_or_default();
}

sessions.remove(&session_id);
// Release write lock on session
drop(session_write);

sessions_write.remove(&session_id);
}

info!("[{:?}]: Sessions cleaning finished", now);
}
});
}
10 changes: 5 additions & 5 deletions server/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use tokio::sync::RwLock;

pub type SessionId = String;
pub type ClientId = String;
pub type Sessions = Arc<RwLock<HashMap<SessionId, Session>>>;
pub type Sessions = Arc<RwLock<HashMap<SessionId, RwLock<Session>>>>;
pub type ClientSockets = Arc<RwLock<HashMap<ClientId, RwLock<SplitSink<WebSocket, Message>>>>>;
pub type ClientToSessions = Arc<RwLock<HashMap<ClientId, RwLock<HashSet<SessionId>>>>>;

Expand All @@ -37,14 +37,14 @@ pub trait DisconnectUser {
#[async_trait]
impl DisconnectUser for Sessions {
async fn disconnect_user(&self, session_id: SessionId) -> Result<()> {
let mut sessions = self.write().await;
let session = match sessions.get_mut(&session_id) {
Some(session) => session,
let sessions_read = self.read().await;
let mut session_write = match sessions_read.get(&session_id) {
Some(session) => session.write().await,
None => return Err(anyhow::anyhow!("Session does not exist")), // Session does not exist
};

// Update session user state
session.disconnect_user().await;
session_write.disconnect_user().await;

Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions server/src/structs/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ impl Session {
}
pub async fn close_app_socket(&mut self, id: &Uuid) -> Result<()> {
info!("Drop app connection for session {}", self.session_id);
match &mut self.app_state.app_socket.remove(id) {
Some(app_socket) => {
match self.app_state.app_socket.remove(id) {
Some(mut app_socket) => {
app_socket.close().await?;
warn!("Drop app connection for session {}", self.session_id);
return Ok(());
Expand Down
18 changes: 9 additions & 9 deletions server/src/ws/app_handler/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,17 @@ pub async fn app_handler(

match app_msg {
AppToServer::RequestPayload(sing_transactions_request) => {
let mut sessions = sessions.write().await;
let session = match sessions.get_mut(&session_id) {
Some(session) => session,
let sessions_read = sessions.read().await;
let mut session_write = match sessions_read.get(&session_id) {
Some(session) => session.write().await,
None => {
// Should never happen
return;
}
};
let response_id: String = sing_transactions_request.response_id.clone();

session.pending_requests.insert(
session_write.pending_requests.insert(
response_id.clone(),
PendingRequest {
content: sing_transactions_request.content.clone(),
Expand All @@ -154,7 +154,7 @@ pub async fn app_handler(
session_id: session_id.clone(),
});

let client_id = match &session.client_state.client_id {
let client_id = match &session_write.client_state.client_id {
Some(id) => id,
None => {
// Should never happen
Expand All @@ -169,11 +169,11 @@ pub async fn app_handler(
.await
{
// Fall back to notification
if let Some(notification) = &session.notification {
if let Some(notification) = &session_write.notification {
let notification_payload = NotificationPayload {
network: session.network.clone(),
app_metadata: session.app_state.metadata.clone(),
device: session
network: session_write.network.clone(),
app_metadata: session_write.app_state.metadata.clone(),
device: session_write
.client_state
.device
.clone()
Expand Down
22 changes: 13 additions & 9 deletions server/src/ws/app_handler/methods/disconnect_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ pub async fn disconnect_session(
client_to_sessions: &ClientToSessions,
) -> Result<()> {
// Lock the whole sessions map as we might need to remove a session
let mut sessions = sessions.write().await;
let session = match sessions.get_mut(&session_id) {
Some(session) => session,
let mut sessions_write = sessions.write().await;
let mut session_write = match sessions_write.get_mut(&session_id) {
Some(session) => session.write().await,
None => {
// Should never happen
bail!("Session not found, session_id: {}", session_id);
}
};

// Close user socket
if let Some(client_id) = &session.client_state.client_id {
if let Some(client_id) = &session_write.client_state.client_id {
let app_disconnected_event = ServerToClient::AppDisconnectedEvent(AppDisconnectedEvent {
session_id: session_id.clone(),
reason: "App disconnected".to_string(),
Expand All @@ -47,24 +47,28 @@ pub async fn disconnect_session(
}

// Close app socket
if let Err(err) = session.close_app_socket(&connection_id).await {
if let Err(err) = session_write.close_app_socket(&connection_id).await {
warn!(
"Error sending app disconnected event to connection_id: {}, session_id: {}, err: {}",
connection_id, session_id, err
);
}

// Update session status based on session type
if session.persistent {
session.update_status(SessionStatus::AppDisconnected);
if session_write.persistent {
session_write.update_status(SessionStatus::AppDisconnected);
} else {
// Remove session
if let Some(client_id) = session.client_state.client_id.clone() {
if let Some(client_id) = session_write.client_state.client_id.clone() {
client_to_sessions
.remove_session(client_id, session_id.clone())
.await;
}
sessions.remove(&session_id);

// Drop session lock
drop(session_write);

sessions_write.remove(&session_id);
}

Ok(())
Expand Down
Loading

0 comments on commit eee8b4b

Please sign in to comment.