diff --git a/bindings/matrix-sdk-ffi/src/authentication.rs b/bindings/matrix-sdk-ffi/src/authentication.rs index d7fd97cf932..6ac3c59197f 100644 --- a/bindings/matrix-sdk-ffi/src/authentication.rs +++ b/bindings/matrix-sdk-ffi/src/authentication.rs @@ -1,4 +1,8 @@ -use std::collections::HashMap; +use std::{ + collections::HashMap, + fmt::{self, Debug}, + sync::Arc, +}; use matrix_sdk::{ oidc::{ @@ -15,6 +19,8 @@ use matrix_sdk::{ }; use url::Url; +use crate::client::Client; + #[derive(uniffi::Object)] pub struct HomeserverLoginDetails { pub(crate) url: String, @@ -47,6 +53,42 @@ impl HomeserverLoginDetails { } } +/// An object encapsulating the SSO login flow +#[derive(uniffi::Object)] +pub struct SsoHandler { + /// The wrapped Client. + pub(crate) client: Arc, + + /// The underlying URL for authentication. + pub(crate) url: String, +} + +#[uniffi::export(async_runtime = "tokio")] +impl SsoHandler { + /// Returns the URL for starting SSO authentication. The URL should be + /// opened in a web view. Once the web view succeeds, call `finish` with + /// the callback URL. + pub fn url(&self) -> String { + self.url.clone() + } + + /// Completes the SSO login process. + pub async fn finish(&self, callback_url: String) -> Result<(), SsoError> { + let auth = self.client.inner.matrix_auth(); + let url = Url::parse(&callback_url).map_err(|_| SsoError::CallbackUrlInvalid)?; + let builder = + auth.login_with_sso_callback(url).map_err(|_| SsoError::CallbackUrlInvalid)?; + builder.await.map_err(|_| SsoError::LoginWithTokenFailed)?; + Ok(()) + } +} + +impl Debug for SsoHandler { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(fmt, "SsoHandler") + } +} + #[derive(Debug, thiserror::Error, uniffi::Error)] #[uniffi(flat_error)] pub enum SsoError { diff --git a/bindings/matrix-sdk-ffi/src/client.rs b/bindings/matrix-sdk-ffi/src/client.rs index a502e1924b1..c43f8207477 100644 --- a/bindings/matrix-sdk-ffi/src/client.rs +++ b/bindings/matrix-sdk-ffi/src/client.rs @@ -1,6 +1,6 @@ use std::{ collections::HashMap, - fmt::{self, Debug}, + fmt::Debug, mem::ManuallyDrop, path::Path, sync::{Arc, RwLock}, @@ -61,7 +61,7 @@ use url::Url; use super::{room::Room, session_verification::SessionVerificationController, RUNTIME}; use crate::{ - authentication::{HomeserverLoginDetails, OidcConfiguration, OidcError, SsoError}, + authentication::{HomeserverLoginDetails, OidcConfiguration, OidcError, SsoError, SsoHandler}, client, encryption::Encryption, notification::NotificationClient, @@ -173,54 +173,6 @@ impl From for TransmissionProgress { } } -/// An object encapsulating the SSO login flow -#[derive(uniffi::Object)] -pub struct SsoHandler { - /// The wrapped Client. - client: Arc, - - /// The underlying URL for authentication. - url: String, -} - -#[uniffi::export(async_runtime = "tokio")] -impl SsoHandler { - /// Returns the URL for starting SSO authentication. The URL should be - /// opened in a web view. Once the web view succeeds, call `finish` with - /// the callback URL. - pub fn url(&self) -> String { - self.url.clone() - } - - /// Completes the SSO login process. - pub async fn finish(&self, callback_url: String) -> Result<(), SsoError> { - let auth = self.client.inner.matrix_auth(); - - let url = Url::parse(&callback_url).map_err(|_| SsoError::CallbackUrlInvalid)?; - - #[derive(Deserialize)] - struct QueryParameters { - #[serde(rename = "loginToken")] - login_token: Option, - } - - let query_string = url.query().unwrap_or(""); - let query: QueryParameters = - serde_html_form::from_str(query_string).map_err(|_| SsoError::CallbackUrlInvalid)?; - let token = query.login_token.ok_or(SsoError::CallbackUrlInvalid)?; - - auth.login_token(token.as_str()).await.map_err(|_| SsoError::LoginWithTokenFailed)?; - - Ok(()) - } -} - -impl Debug for SsoHandler { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - write!(fmt, "SsoHandler") - } -} - #[derive(uniffi::Object)] pub struct Client { pub(crate) inner: ManuallyDrop, @@ -1590,100 +1542,3 @@ impl MediaFileHandle { ) } } - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use assert_matches::assert_matches; - use matrix_sdk_test::{async_test, test_json}; - use serde::Deserialize; - use url::Url; - use wiremock::{ - matchers::{method, path}, - Mock, MockServer, ResponseTemplate, - }; - - use crate::{authentication::SsoError, client_builder::ClientBuilder}; - - #[async_test] - async fn test_start_sso_login_adds_redirect_url_to_login_url() { - let homeserver = make_mock_homeserver().await; - let builder = ClientBuilder::new().server_name_or_homeserver_url(homeserver.uri()); - let client = Arc::new(builder.build_inner().await.expect("Should build client")); - - let handler = client - .start_sso_login("app://redirect".to_owned(), None) - .await - .expect("Should create SSO handler"); - - let url = Url::parse(&handler.url).expect("Should generate a valid SSO login URL"); - - #[derive(Deserialize)] - struct QueryParameters { - #[serde(rename = "redirectUrl")] - redirect_url: Option, - } - - let query_string = url.query().unwrap_or(""); - let query: QueryParameters = serde_html_form::from_str(query_string) - .expect("Should deserialize query parameters from SSO login URL"); - - assert_eq!(query.redirect_url, Some("app://redirect".to_owned())); - } - - #[async_test] - async fn test_finish_sso_login_with_login_token_succeeds() { - let homeserver = make_mock_homeserver().await; - let builder = ClientBuilder::new().server_name_or_homeserver_url(homeserver.uri()); - let client = Arc::new(builder.build_inner().await.expect("Should build client")); - - let handler = client - .start_sso_login("app://redirect".to_owned(), None) - .await - .expect("Should create SSO handler"); - - handler - .finish("app://redirect?loginToken=foo".to_owned()) - .await - .expect("Should log in with token"); - } - - #[tokio::test] - async fn test_finish_sso_login_without_login_token_fails() { - let homeserver = make_mock_homeserver().await; - let builder = ClientBuilder::new().server_name_or_homeserver_url(homeserver.uri()); - let client = Arc::new(builder.build_inner().await.expect("Should build client")); - - let handler = client - .start_sso_login("app://redirect".to_owned(), None) - .await - .expect("Should create SSO handler"); - - let result = handler.finish("app://redirect?foo=bar".to_owned()).await; - - assert_matches!(result, Err(SsoError::CallbackUrlInvalid)); - } - - /* Helper functions */ - - async fn make_mock_homeserver() -> MockServer { - let homeserver = MockServer::start().await; - Mock::given(method("GET")) - .and(path("/_matrix/client/versions")) - .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS)) - .mount(&homeserver) - .await; - Mock::given(method("GET")) - .and(path("/_matrix/client/r0/login")) - .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::LOGIN_TYPES)) - .mount(&homeserver) - .await; - Mock::given(method("POST")) - .and(path("/_matrix/client/r0/login")) - .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::LOGIN)) - .mount(&homeserver) - .await; - homeserver - } -} diff --git a/crates/matrix-sdk/src/matrix_auth/mod.rs b/crates/matrix-sdk/src/matrix_auth/mod.rs index 55699192d76..a51edd2458e 100644 --- a/crates/matrix-sdk/src/matrix_auth/mod.rs +++ b/crates/matrix-sdk/src/matrix_auth/mod.rs @@ -38,7 +38,9 @@ use ruma::{ serde::JsonObject, }; use serde::{Deserialize, Serialize}; +use thiserror::Error; use tracing::{debug, error, info, instrument}; +use url::Url; use crate::{ authentication::AuthData, @@ -73,6 +75,14 @@ pub struct MatrixAuth { client: Client, } +/// Errors that can occur when using the SSO API. +#[derive(Debug, Error)] +pub enum SsoError { + /// The supplied callback URL used to complete SSO is invalid. + #[error("callback URL invalid")] + CallbackUrlInvalid, +} + impl MatrixAuth { pub(crate) fn new(client: Client) -> Self { Self { client } @@ -292,6 +302,24 @@ impl MatrixAuth { LoginBuilder::new_token(self.clone(), token.to_owned()) } + /// A higher level wrapper around the methods to complete an SSO login after + /// the user has logged in through a webview. This method should be used + /// in tandem with [`MatrixAuth::get_sso_login_url`]. + pub fn login_with_sso_callback(&self, callback_url: Url) -> Result { + #[derive(Deserialize)] + struct QueryParameters { + #[serde(rename = "loginToken")] + login_token: Option, + } + + let query_string = callback_url.query().unwrap_or(""); + let query: QueryParameters = + serde_html_form::from_str(query_string).map_err(|_| SsoError::CallbackUrlInvalid)?; + let token = query.login_token.ok_or(SsoError::CallbackUrlInvalid)?; + + Ok(self.login_token(token.as_str())) + } + /// Log into the server via Single Sign-On. /// /// This takes care of the whole SSO flow: diff --git a/crates/matrix-sdk/tests/integration/matrix_auth.rs b/crates/matrix-sdk/tests/integration/matrix_auth.rs index ce34a8f4ac0..ef732b37f06 100644 --- a/crates/matrix-sdk/tests/integration/matrix_auth.rs +++ b/crates/matrix-sdk/tests/integration/matrix_auth.rs @@ -191,6 +191,42 @@ async fn test_login_with_sso_token() { assert!(logged_in, "Client should be logged in"); } +#[async_test] +async fn test_login_with_sso_callback() { + let (client, server) = no_retry_test_client_with_server().await; + + Mock::given(method("GET")) + .and(path("/_matrix/client/r0/login")) + .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::LOGIN_TYPES)) + .mount(&server) + .await; + + let auth = client.matrix_auth(); + let can_sso = auth + .get_login_types() + .await + .unwrap() + .flows + .iter() + .any(|flow| matches!(flow, LoginType::Sso(_))); + assert!(can_sso); + + let sso_url = auth.get_sso_login_url("http://127.0.0.1:3030", None).await; + sso_url.unwrap(); + + Mock::given(method("POST")) + .and(path("/_matrix/client/r0/login")) + .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::LOGIN)) + .mount(&server) + .await; + + let callback_url = Url::parse("http://127.0.0.1:3030?loginToken=averysmalltoken").unwrap(); + auth.login_with_sso_callback(callback_url).unwrap().await.unwrap(); + + let logged_in = client.logged_in(); + assert!(logged_in, "Client should be logged in"); +} + #[async_test] async fn test_login_error() { let (client, server) = no_retry_test_client_with_server().await;