diff --git a/src/api/notifications.rs b/src/api/notifications.rs index 7e76021b35..3df91e91d9 100644 --- a/src/api/notifications.rs +++ b/src/api/notifications.rs @@ -20,7 +20,7 @@ use tokio_tungstenite::{ }; use crate::{ - auth::ClientIp, + auth::{ClientIp, WsAccessTokenHeader}, db::{ models::{Cipher, Folder, Send as DbSend, User}, DbConn, @@ -111,11 +111,19 @@ fn websockets_hub<'r>( ws: rocket_ws::WebSocket, data: WsAccessToken, ip: ClientIp, + header_token: WsAccessTokenHeader, ) -> Result { let addr = ip.ip; info!("Accepting Rocket WS connection from {addr}"); - let Some(token) = data.access_token else { err_code!("Invalid claim", 401) }; + let token = if let Some(token) = data.access_token { + token + } else if let Some(token) = header_token.access_token { + token + } else { + err_code!("Invalid claim", 401) + }; + let Ok(claims) = crate::auth::decode_login(&token) else { err_code!("Invalid token", 401) }; let (mut rx, guard) = { diff --git a/src/auth.rs b/src/auth.rs index 6879bb6e59..e23aa32dc9 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -825,3 +825,26 @@ impl<'r> FromRequest<'r> for ClientIp { }) } } + +pub struct WsAccessTokenHeader { + pub access_token: Option, +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for WsAccessTokenHeader { + type Error = (); + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let headers = request.headers(); + + // Get access_token + let access_token = match headers.get_one("Authorization") { + Some(a) => a.rsplit("Bearer ").next().map(String::from), + None => None, + }; + + Outcome::Success(Self { + access_token, + }) + } +}