Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional lock on the session itself to spread the access #88

Merged
merged 1 commit into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions server/src/http/connect_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ pub async fn connect_session(
State(client_to_sessions): State<ClientToSessions>,
Json(request): Json<HttpConnectSessionRequest>,
) -> Result<Json<HttpConnectSessionResponse>, (StatusCode, String)> {
let mut sessions = sessions.write().await;
let session = match sessions.get_mut(&request.session_id) {
Some(session) => session,
let sessions_read = sessions.read().await;
let mut session_write = match sessions_read.get(&request.session_id) {
Some(session) => session.write().await,
None => {
return Err((
StatusCode::BAD_REQUEST,
Expand All @@ -43,7 +43,7 @@ pub async fn connect_session(
};

// Insert user socket
session
session_write
.connect_user(
&request.device,
&request.public_keys,
Expand Down
10 changes: 5 additions & 5 deletions server/src/http/get_pending_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ pub async fn get_pending_request(
State(sessions): State<Sessions>,
Json(request): Json<HttpGetPendingRequestRequest>,
) -> Result<Json<HttpGetPendingRequestResponse>, (StatusCode, String)> {
let sessions = sessions.read().await;
let session = match sessions.get(&request.session_id) {
Some(session) => session,
let sessions_read = sessions.read().await;
let session_read = match sessions_read.get(&request.session_id) {
Some(session) => session.read().await,
None => {
return Err((
StatusCode::BAD_REQUEST,
Expand All @@ -38,14 +38,14 @@ pub async fn get_pending_request(
}
};

if session.client_state.client_id != Some(request.client_id.clone()) {
if session_read.client_state.client_id != Some(request.client_id.clone()) {
return Err((
StatusCode::BAD_REQUEST,
NightlyError::UserNotConnected.to_string(),
));
}

match session.pending_requests.get(&request.request_id) {
match session_read.pending_requests.get(&request.request_id) {
Some(pending_request) => {
return Ok(Json(HttpGetPendingRequestResponse {
request: pending_request.clone(),
Expand Down
10 changes: 5 additions & 5 deletions server/src/http/get_pending_requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ pub async fn get_pending_requests(
State(sessions): State<Sessions>,
Json(request): Json<HttpGetPendingRequestsRequest>,
) -> Result<Json<HttpGetPendingRequestsResponse>, (StatusCode, String)> {
let sessions = sessions.read().await;
let session = match sessions.get(&request.session_id) {
Some(session) => session,
let sessions_read = sessions.read().await;
let session_read = match sessions_read.get(&request.session_id) {
Some(session) => session.read().await,
None => {
return Err((
StatusCode::BAD_REQUEST,
Expand All @@ -38,14 +38,14 @@ pub async fn get_pending_requests(
}
};

if session.client_state.client_id != Some(request.client_id.clone()) {
if session_read.client_state.client_id != Some(request.client_id.clone()) {
return Err((
StatusCode::BAD_REQUEST,
NightlyError::UserNotConnected.to_string(),
));
}

let pending_requests = session
let pending_requests = session_read
.pending_requests
.values()
.cloned()
Expand Down
16 changes: 8 additions & 8 deletions server/src/http/get_session_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ pub async fn get_session_info(
State(sessions): State<Sessions>,
Json(request): Json<HttpGetSessionInfoRequest>,
) -> Result<Json<HttpGetSessionInfoResponse>, (StatusCode, String)> {
let sessions = sessions.read().await;
let session = match sessions.get(&request.session_id) {
Some(session) => session,
let sessions_read = sessions.read().await;
let session_read = match sessions_read.get(&request.session_id) {
Some(session) => session.read().await,
None => {
return Err((
StatusCode::BAD_REQUEST,
Expand All @@ -40,11 +40,11 @@ pub async fn get_session_info(
};

let response = HttpGetSessionInfoResponse {
status: session.status.clone(),
persistent: session.persistent,
version: session.version.clone(),
network: session.network.clone(),
app_metadata: session.app_state.metadata.clone(),
status: session_read.status.clone(),
persistent: session_read.persistent,
version: session_read.version.clone(),
network: session_read.network.clone(),
app_metadata: session_read.app_state.metadata.clone(),
};
return Ok(Json(response));
}
12 changes: 6 additions & 6 deletions server/src/http/resolve_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ pub async fn resolve_request(
Json(request): Json<HttpResolveRequestRequest>,
) -> Result<Json<HttpResolveRequestResponse>, (StatusCode, String)> {
// Get session
let mut sessions = sessions.write().await;
let session = match sessions.get_mut(&request.session_id) {
Some(session) => session,
let sessions_read = sessions.read().await;
let mut session_write = match sessions_read.get(&request.session_id) {
Some(session) => session.write().await,
None => {
return Err((
StatusCode::BAD_REQUEST,
Expand All @@ -39,14 +39,14 @@ pub async fn resolve_request(
};

// Check if client_id matches
if session.client_state.client_id != Some(request.client_id.clone()) {
if session_write.client_state.client_id != Some(request.client_id.clone()) {
return Err((
StatusCode::BAD_REQUEST,
NightlyError::UserNotConnected.to_string(),
));
}
// Remove request from pending requests
if let None = session.pending_requests.remove(&request.request_id) {
if let None = session_write.pending_requests.remove(&request.request_id) {
return Err((
StatusCode::BAD_REQUEST,
NightlyError::RequestDoesNotExist.to_string(),
Expand All @@ -58,7 +58,7 @@ pub async fn resolve_request(
response_id: request.request_id.clone(),
content: request.content.clone(),
});
if let Err(_) = session.send_to_app(app_msg).await {
if let Err(_) = session_write.send_to_app(app_msg).await {
return Err((
StatusCode::BAD_REQUEST,
NightlyError::AppDisconnected.to_string(),
Expand Down
58 changes: 38 additions & 20 deletions server/src/sesssion_cleaner.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::{time::Duration, vec};

use futures::SinkExt;

use crate::{
state::{ClientToSessions, ModifySession, Sessions},
utils::get_timestamp_in_milliseconds,
};
use futures::SinkExt;
use log::info;
use std::{time::Duration, vec};

pub fn start_cleaning_sessions(sessions: Sessions, client_to_sessions: ClientToSessions) {
let sessions = sessions.clone();
Expand All @@ -17,37 +16,56 @@ pub fn start_cleaning_sessions(sessions: Sessions, client_to_sessions: ClientToS
loop {
// Wait for tick
interval.tick().await;
let mut sessions_to_remove = vec![];

// Remove all sessions that expired
let mut sessions_to_remove = vec![];
let now = get_timestamp_in_milliseconds();
let mut sessions = sessions.write().await;
for (session_id, session) in sessions.iter() {
// Check if the session expired

info!("[{:?}]: Cleaning sessions", now);

// Lock sessions
let mut sessions_write = sessions.write().await;

// Iterate over all sessions and check if they expired
for (session_id, session) in sessions_write.iter() {
// Default session time is two weeks
if session.creation_timestamp + 1000 * 60 * 60 * 24 * 14 < now {
if session.read().await.creation_timestamp + 1000 * 60 * 60 * 24 * 14 < now {
sessions_to_remove.push(session_id.clone());
}
}

info!(
"[{:?}]: {} sessions to remove",
now,
sessions_to_remove.len()
);

// Remove all sessions that expired
for session_id in sessions_to_remove {
let session = sessions.get_mut(&session_id).unwrap();
// safe unwrap because we just checked if the session exists
let session = sessions_write.get_mut(&session_id).unwrap();
let mut session_write = session.write().await;

// Remove session from client_to_sessions
match &session.client_state.client_id {
Some(client_id) => {
client_to_sessions
.remove_session(client_id.clone(), session_id.clone())
.await;
}
None => {}
if let Some(client_id) = &session_write.client_state.client_id {
client_to_sessions
.remove_session(client_id.clone(), session_id.clone())
.await;
}

// Disconnect app
// Send to all apps
for (_, socket) in &mut session.app_state.app_socket {
let _ = socket.close();
for (_, socket) in &mut session_write.app_state.app_socket {
socket.close().await.unwrap_or_default();
}

sessions.remove(&session_id);
// Release write lock on session
drop(session_write);

sessions_write.remove(&session_id);
}

info!("[{:?}]: Sessions cleaning finished", now);
}
});
}
10 changes: 5 additions & 5 deletions server/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use tokio::sync::RwLock;

pub type SessionId = String;
pub type ClientId = String;
pub type Sessions = Arc<RwLock<HashMap<SessionId, Session>>>;
pub type Sessions = Arc<RwLock<HashMap<SessionId, RwLock<Session>>>>;
pub type ClientSockets = Arc<RwLock<HashMap<ClientId, RwLock<SplitSink<WebSocket, Message>>>>>;
pub type ClientToSessions = Arc<RwLock<HashMap<ClientId, RwLock<HashSet<SessionId>>>>>;

Expand All @@ -37,14 +37,14 @@ pub trait DisconnectUser {
#[async_trait]
impl DisconnectUser for Sessions {
async fn disconnect_user(&self, session_id: SessionId) -> Result<()> {
let mut sessions = self.write().await;
let session = match sessions.get_mut(&session_id) {
Some(session) => session,
let sessions_read = self.read().await;
let mut session_write = match sessions_read.get(&session_id) {
Some(session) => session.write().await,
None => return Err(anyhow::anyhow!("Session does not exist")), // Session does not exist
};

// Update session user state
session.disconnect_user().await;
session_write.disconnect_user().await;

Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions server/src/structs/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ impl Session {
}
pub async fn close_app_socket(&mut self, id: &Uuid) -> Result<()> {
info!("Drop app connection for session {}", self.session_id);
match &mut self.app_state.app_socket.remove(id) {
Some(app_socket) => {
match self.app_state.app_socket.remove(id) {
Some(mut app_socket) => {
app_socket.close().await?;
warn!("Drop app connection for session {}", self.session_id);
return Ok(());
Expand Down
18 changes: 9 additions & 9 deletions server/src/ws/app_handler/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,17 @@ pub async fn app_handler(

match app_msg {
AppToServer::RequestPayload(sing_transactions_request) => {
let mut sessions = sessions.write().await;
let session = match sessions.get_mut(&session_id) {
Some(session) => session,
let sessions_read = sessions.read().await;
let mut session_write = match sessions_read.get(&session_id) {
Some(session) => session.write().await,
None => {
// Should never happen
return;
}
};
let response_id: String = sing_transactions_request.response_id.clone();

session.pending_requests.insert(
session_write.pending_requests.insert(
response_id.clone(),
PendingRequest {
content: sing_transactions_request.content.clone(),
Expand All @@ -154,7 +154,7 @@ pub async fn app_handler(
session_id: session_id.clone(),
});

let client_id = match &session.client_state.client_id {
let client_id = match &session_write.client_state.client_id {
Some(id) => id,
None => {
// Should never happen
Expand All @@ -169,11 +169,11 @@ pub async fn app_handler(
.await
{
// Fall back to notification
if let Some(notification) = &session.notification {
if let Some(notification) = &session_write.notification {
let notification_payload = NotificationPayload {
network: session.network.clone(),
app_metadata: session.app_state.metadata.clone(),
device: session
network: session_write.network.clone(),
app_metadata: session_write.app_state.metadata.clone(),
device: session_write
.client_state
.device
.clone()
Expand Down
22 changes: 13 additions & 9 deletions server/src/ws/app_handler/methods/disconnect_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ pub async fn disconnect_session(
client_to_sessions: &ClientToSessions,
) -> Result<()> {
// Lock the whole sessions map as we might need to remove a session
let mut sessions = sessions.write().await;
let session = match sessions.get_mut(&session_id) {
Some(session) => session,
let mut sessions_write = sessions.write().await;
let mut session_write = match sessions_write.get_mut(&session_id) {
Some(session) => session.write().await,
None => {
// Should never happen
bail!("Session not found, session_id: {}", session_id);
}
};

// Close user socket
if let Some(client_id) = &session.client_state.client_id {
if let Some(client_id) = &session_write.client_state.client_id {
let app_disconnected_event = ServerToClient::AppDisconnectedEvent(AppDisconnectedEvent {
session_id: session_id.clone(),
reason: "App disconnected".to_string(),
Expand All @@ -47,24 +47,28 @@ pub async fn disconnect_session(
}

// Close app socket
if let Err(err) = session.close_app_socket(&connection_id).await {
if let Err(err) = session_write.close_app_socket(&connection_id).await {
warn!(
"Error sending app disconnected event to connection_id: {}, session_id: {}, err: {}",
connection_id, session_id, err
);
}

// Update session status based on session type
if session.persistent {
session.update_status(SessionStatus::AppDisconnected);
if session_write.persistent {
session_write.update_status(SessionStatus::AppDisconnected);
} else {
// Remove session
if let Some(client_id) = session.client_state.client_id.clone() {
if let Some(client_id) = session_write.client_state.client_id.clone() {
client_to_sessions
.remove_session(client_id, session_id.clone())
.await;
}
sessions.remove(&session_id);

// Drop session lock
drop(session_write);

sessions_write.remove(&session_id);
}

Ok(())
Expand Down
Loading
Loading