diff --git a/Cargo.lock b/Cargo.lock index 2259c4c51d0..9b5cb13ce16 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3429,6 +3429,7 @@ dependencies = [ "assert_matches2", "async-trait", "base64 0.22.1", + "futures-util", "getrandom", "gloo-utils", "growable-bloom-filter", @@ -3507,6 +3508,7 @@ dependencies = [ "assert_matches", "async-trait", "deadpool-sqlite", + "futures-util", "glob", "itertools 0.12.1", "matrix-sdk-base", diff --git a/crates/matrix-sdk-crypto/src/store/integration_tests.rs b/crates/matrix-sdk-crypto/src/store/integration_tests.rs index 0757ab34126..43115403626 100644 --- a/crates/matrix-sdk-crypto/src/store/integration_tests.rs +++ b/crates/matrix-sdk-crypto/src/store/integration_tests.rs @@ -36,6 +36,7 @@ macro_rules! cryptostore_integration_tests { mod cryptostore_integration_tests { use std::collections::{BTreeMap, HashMap}; use std::time::Duration; + use futures_util::stream::StreamExt; use assert_matches::assert_matches; use matrix_sdk_test::async_test; @@ -48,7 +49,7 @@ macro_rules! cryptostore_integration_tests { use $crate::{ olm::{ Account, Curve25519PublicKey, InboundGroupSession, OlmMessageHash, - PrivateCrossSigningIdentity, Session, + PrivateCrossSigningIdentity, SenderData, SenderDataType, Session }, store::{ BackupDecryptionKey, Changes, CryptoStore, DeviceChanges, GossipRequest, @@ -70,7 +71,8 @@ macro_rules! cryptostore_integration_tests { DeviceKeys, EventEncryptionAlgorithm, }, - GossippedSecret, LocalTrust, DeviceData, SecretInfo, ToDeviceRequest, TrackedUser, + EncryptionSettings, GossippedSecret, LocalTrust, DeviceData, SecretInfo, + ToDeviceRequest, TrackedUser, }; use super::get_store; @@ -559,6 +561,69 @@ macro_rules! cryptostore_integration_tests { assert_eq!(store.inbound_group_session_counts(None).await.unwrap().total, 1); } + #[async_test] + async fn fetch_inbound_group_sessions_for_device() { + // Given a store exists, containing inbound group sessions from different devices + let (account, store) = + get_loaded_store("fetch_inbound_group_sessions_for_device").await; + + let dev1 = Curve25519PublicKey::from_base64( + "wjLpTLRqbqBzLs63aYaEv2Boi6cFEbbM/sSRQ2oAKk4" + ).unwrap(); + let dev2 = Curve25519PublicKey::from_base64( + "LTpv2DGMhggPAXO02+7f68CNEp6A40F0Yl8B094Y8gc" + ).unwrap(); + + let dev_1_unknown_a = create_session( + &account, &dev1, SenderDataType::UnknownDevice).await; + + let dev_1_unknown_b = create_session( + &account, &dev1, SenderDataType::UnknownDevice).await; + + let dev_1_keys = create_session( + &account, &dev1, SenderDataType::DeviceInfo).await; + + let dev_2_unknown = create_session( + &account, &dev1, SenderDataType::UnknownDevice).await; + + let dev_2_keys = create_session( + &account, &dev1, SenderDataType::DeviceInfo).await; + + let sessions = vec![ + dev_1_unknown_a.clone(), + dev_1_unknown_b.clone(), + dev_1_keys.clone(), + dev_2_unknown.clone(), + dev_2_keys.clone(), + ]; + + let changes = Changes { + inbound_group_sessions: sessions, + ..Default::default() + }; + store.save_changes(changes).await.expect("Can't save group session"); + + // When we fetch the list of sessions for device 1, unknown + let sessions_1_u: Vec<_> = store + .get_inbound_group_sessions_for_device(&dev1, SenderDataType::UnknownDevice) + .await + .unwrap() + .collect().await; + + // Then the expected sessions are returned + assert_eq!(sessions_1_u, vec![dev_1_unknown_a, dev_1_unknown_b]); + + // And when we ask for the list of sessions for device 2, with device keys + let sessions_2_d: Vec<_> = store + .get_inbound_group_sessions_for_device(&dev2, SenderDataType::DeviceInfo) + .await + .unwrap() + .collect().await; + + // Then the matching session is returned + assert_eq!(sessions_2_d, vec![dev_2_keys]); + } + #[async_test] async fn test_tracked_users() { let dir = "test_tracked_users"; @@ -1110,6 +1175,14 @@ macro_rules! cryptostore_integration_tests { fn session_info(session: &InboundGroupSession) -> (&RoomId, &str) { (&session.room_id(), &session.session_id()) } + + async fn create_session( + account: &Account, + device_curve_key: &Curve25519PublicKey, + sender_data_type: SenderDataType + ) -> InboundGroupSession { + todo!() + } } }; } diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index dde91cda174..ab03b36398f 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -381,7 +381,7 @@ impl CryptoStore for MemoryStore { async fn get_inbound_group_sessions_for_device( &self, - _device_key: Curve25519PublicKey, + _device_key: &Curve25519PublicKey, _sender_data_type: SenderDataType, ) -> Result { todo!() @@ -1241,7 +1241,7 @@ mod integration_tests { async fn get_inbound_group_sessions_for_device( &self, - device_key: Curve25519PublicKey, + device_key: &Curve25519PublicKey, sender_data_type: SenderDataType, ) -> Result { self.0.get_inbound_group_sessions_for_device(device_key, sender_data_type).await diff --git a/crates/matrix-sdk-crypto/src/store/traits.rs b/crates/matrix-sdk-crypto/src/store/traits.rs index a8e062a32bb..3504f2b5a84 100644 --- a/crates/matrix-sdk-crypto/src/store/traits.rs +++ b/crates/matrix-sdk-crypto/src/store/traits.rs @@ -134,7 +134,7 @@ pub trait CryptoStore: AsyncTraitDeps { /// and we want to update the sender data based on the new information. async fn get_inbound_group_sessions_for_device( &self, - device_key: Curve25519PublicKey, + device_key: &Curve25519PublicKey, sender_data_type: SenderDataType, ) -> Result; @@ -422,7 +422,7 @@ impl CryptoStore for EraseCryptoStoreError { async fn get_inbound_group_sessions_for_device( &self, - device_key: Curve25519PublicKey, + device_key: &Curve25519PublicKey, sender_data_type: SenderDataType, ) -> Result { self.0 diff --git a/crates/matrix-sdk-indexeddb/Cargo.toml b/crates/matrix-sdk-indexeddb/Cargo.toml index a2249b490b4..2f939a821d4 100644 --- a/crates/matrix-sdk-indexeddb/Cargo.toml +++ b/crates/matrix-sdk-indexeddb/Cargo.toml @@ -50,6 +50,7 @@ getrandom = { version = "0.2.6", features = ["js"] } [dev-dependencies] assert_matches = { workspace = true } assert_matches2 = { workspace = true } +futures-util = { workspace = true } matrix-sdk-base = { workspace = true, features = ["testing"] } matrix-sdk-common = { workspace = true, features = ["js"] } matrix-sdk-crypto = { workspace = true, features = ["js", "testing"] } diff --git a/crates/matrix-sdk-sqlite/Cargo.toml b/crates/matrix-sdk-sqlite/Cargo.toml index da7de41c3b2..a4ccb848299 100644 --- a/crates/matrix-sdk-sqlite/Cargo.toml +++ b/crates/matrix-sdk-sqlite/Cargo.toml @@ -34,6 +34,7 @@ vodozemac = { workspace = true } [dev-dependencies] assert_matches = { workspace = true } +futures-util = { workspace = true } glob = "0.3.0" matrix-sdk-base = { workspace = true, features = ["testing"] } matrix-sdk-crypto = { workspace = true, features = ["testing"] } diff --git a/crates/matrix-sdk-sqlite/src/crypto_store.rs b/crates/matrix-sdk-sqlite/src/crypto_store.rs index f4b4b82a8e5..113a28fbb19 100644 --- a/crates/matrix-sdk-sqlite/src/crypto_store.rs +++ b/crates/matrix-sdk-sqlite/src/crypto_store.rs @@ -996,7 +996,7 @@ impl CryptoStore for SqliteCryptoStore { async fn get_inbound_group_sessions_for_device( &self, - _device_key: Curve25519PublicKey, + _device_key: &Curve25519PublicKey, _sender_data_type: SenderDataType, ) -> Result { todo!()