Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

crypto: Allow querying InboundGroupSessions using curve key #3806

Merged
merged 9 commits into from
Sep 2, 2024
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ use vodozemac::{
};

use super::{
BackedUpRoomKey, ExportedRoomKey, OutboundGroupSession, SenderData, SessionCreationError,
SessionKey,
BackedUpRoomKey, ExportedRoomKey, OutboundGroupSession, SenderData, SenderDataType,
SessionCreationError, SessionKey,
};
use crate::{
error::{EventError, MegolmResult},
Expand Down Expand Up @@ -477,6 +477,13 @@ impl InboundGroupSession {
pub(crate) fn mark_as_imported(&mut self) {
self.imported = true;
}

/// Return the [`SenderDataType`] of our [`SenderData`]. This is used during
/// serialization, to allow us to store the type in a separate queryable
/// column/property.
pub fn sender_data_type(&self) -> SenderDataType {
self.sender_data.to_type()
}
}

#[cfg(not(tarpaulin_include))]
Expand Down
2 changes: 1 addition & 1 deletion crates/matrix-sdk-crypto/src/olm/group_sessions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub(crate) use outbound::ShareState;
pub use outbound::{
EncryptionSettings, OutboundGroupSession, PickledOutboundGroupSession, ShareInfo,
};
pub use sender_data::{KnownSenderData, SenderData};
pub use sender_data::{KnownSenderData, SenderData, SenderDataType};
pub(crate) use sender_data_finder::SenderDataFinder;
use thiserror::Error;
pub use vodozemac::megolm::{ExportedSessionKey, SessionKey};
Expand Down
30 changes: 30 additions & 0 deletions crates/matrix-sdk-crypto/src/olm/group_sessions/sender_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,19 @@ impl SenderData {
SenderData::SenderVerified(..) => 4,
}
}

/// Return our type as a [`SenderDataType`].
pub fn to_type(&self) -> SenderDataType {
match self {
Self::UnknownDevice { .. } => SenderDataType::UnknownDevice,
Self::DeviceInfo { .. } => SenderDataType::DeviceInfo,
Self::SenderUnverifiedButPreviouslyVerified { .. } => {
SenderDataType::SenderUnverifiedButPreviouslyVerified
}
Self::SenderUnverified { .. } => SenderDataType::SenderUnverified,
Self::SenderVerified { .. } => SenderDataType::SenderVerified,
}
}
}

/// Used when deserialising and the sender_data property is missing.
Expand Down Expand Up @@ -266,6 +279,23 @@ impl From<SenderDataReader> for SenderData {
}
}

/// Used when serializing [`crate::olm::group_sessions::InboundGroupSession`]s.
/// We want just the type of the session's [`SenderData`] to be queryable, so we
/// store the type as a separate column/property in the database.
#[derive(Clone, Copy, Debug, PartialEq, Deserialize, Serialize)]
pub enum SenderDataType {
/// The [`SenderData`] is of type `UnknownDevice`.
UnknownDevice = 1,
/// The [`SenderData`] is of type `DeviceInfo`.
DeviceInfo = 2,
/// The [`SenderData`] is of type `SenderUnverifiedButPreviouslyVerified`.
SenderUnverifiedButPreviouslyVerified = 3,
/// The [`SenderData`] is of type `SenderUnverified`.
SenderUnverified = 4,
/// The [`SenderData`] is of type `SenderVerified`.
SenderVerified = 5,
}

#[cfg(test)]
mod tests {
use std::{cmp::Ordering, collections::BTreeMap};
Expand Down
2 changes: 1 addition & 1 deletion crates/matrix-sdk-crypto/src/olm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub(crate) use account::{OlmDecryptionInfo, SessionType};
pub use group_sessions::{
BackedUpRoomKey, EncryptionSettings, ExportedRoomKey, InboundGroupSession, KnownSenderData,
OutboundGroupSession, PickledInboundGroupSession, PickledOutboundGroupSession, SenderData,
SessionCreationError, SessionExportError, SessionKey, ShareInfo,
SenderDataType, SessionCreationError, SessionExportError, SessionKey, ShareInfo,
};
pub(crate) use group_sessions::{SenderDataFinder, ShareState};
pub use session::{PickledSession, Session};
Expand Down
150 changes: 149 additions & 1 deletion crates/matrix-sdk-crypto/src/store/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ macro_rules! cryptostore_integration_tests {
use $crate::{
olm::{
Account, Curve25519PublicKey, InboundGroupSession, OlmMessageHash,
PrivateCrossSigningIdentity, Session,
PrivateCrossSigningIdentity, SenderData, SenderDataType, Session
},
store::{
BackupDecryptionKey, Changes, CryptoStore, DeviceChanges, GossipRequest,
Expand All @@ -71,6 +71,9 @@ macro_rules! cryptostore_integration_tests {
EventEncryptionAlgorithm,
},
GossippedSecret, LocalTrust, DeviceData, SecretInfo, ToDeviceRequest, TrackedUser,
vodozemac::{
megolm::{GroupSession, SessionConfig},
},
};

use super::get_store;
Expand Down Expand Up @@ -561,6 +564,118 @@ macro_rules! cryptostore_integration_tests {
assert_eq!(store.inbound_group_session_counts(None).await.unwrap().total, 1);
}

#[async_test]
async fn test_fetch_inbound_group_sessions_for_device() {
// Given a store exists, containing inbound group sessions from different devices
let (account, store) =
get_loaded_store("fetch_inbound_group_sessions_for_device").await;

let dev1 = Curve25519PublicKey::from_base64(
"wjLpTLRqbqBzLs63aYaEv2Boi6cFEbbM/sSRQ2oAKk4"
).unwrap();
let dev2 = Curve25519PublicKey::from_base64(
"LTpv2DGMhggPAXO02+7f68CNEp6A40F0Yl8B094Y8gc"
).unwrap();

let dev_1_unknown_a = create_session(&account, &dev1, SenderDataType::UnknownDevice).await;
let dev_1_unknown_b = create_session(&account, &dev1, SenderDataType::UnknownDevice).await;

let dev_1_keys_a = create_session(&account, &dev1, SenderDataType::DeviceInfo).await;
let dev_1_keys_b = create_session(&account, &dev1, SenderDataType::DeviceInfo).await;
let dev_1_keys_c = create_session(&account, &dev1, SenderDataType::DeviceInfo).await;
let dev_1_keys_d = create_session(&account, &dev1, SenderDataType::DeviceInfo).await;

let dev_2_unknown = create_session(
&account, &dev2, SenderDataType::UnknownDevice).await;

let dev_2_keys = create_session(
&account, &dev2, SenderDataType::DeviceInfo).await;

let sessions = vec![
dev_1_unknown_a.clone(),
dev_1_unknown_b.clone(),
dev_1_keys_a.clone(),
dev_1_keys_b.clone(),
dev_1_keys_c.clone(),
dev_1_keys_d.clone(),
dev_2_unknown.clone(),
dev_2_keys.clone(),
];

let changes = Changes {
inbound_group_sessions: sessions,
..Default::default()
};
store.save_changes(changes).await.expect("Can't save group session");

// When we fetch the list of sessions for device 1, unknown
let sessions_1_u = store.get_inbound_group_sessions_for_device_batch(
dev1,
SenderDataType::UnknownDevice,
None,
10
).await.expect("Failed to get sessions for dev1");

// Then the expected sessions are returned
assert_session_lists_eq(sessions_1_u, [dev_1_unknown_a, dev_1_unknown_b], "device 1 sessions");

// And when we ask for the list of sessions for device 2, with device keys
let sessions_2_d = store
.get_inbound_group_sessions_for_device_batch(dev2, SenderDataType::DeviceInfo, None, 10)
.await
.expect("Failed to get sessions for dev2");

// Then the matching session is returned
assert_eq!(sessions_2_d, vec![dev_2_keys], "device 2 sessions");

// And we can fetch device 1, keys in batches.
// We call the batch function repeatedly, to ensure it terminates correctly.
let mut sessions_1_k = Vec::new();
let mut previous_last_session_id: Option<String> = None;
loop {
let mut sessions_1_k_batch = store.get_inbound_group_sessions_for_device_batch(
dev1,
SenderDataType::DeviceInfo,
previous_last_session_id,
2
).await.expect("Failed to get batch 1");

// If there are no results in the batch, we have reached the end of the results.
let Some(last_session) = sessions_1_k_batch.last() else {
break;
};

// Check that there are exactly two results in the batch
assert_eq!(sessions_1_k_batch.len(), 2);

previous_last_session_id = Some(last_session.session_id().to_owned());
sessions_1_k.append(&mut sessions_1_k_batch);
}

assert_session_lists_eq(
sessions_1_k,
[dev_1_keys_a, dev_1_keys_b, dev_1_keys_c, dev_1_keys_d],
"device 1 batched results"
);
}

/// Assert that two lists of sessions are the same, modulo ordering.
///
/// There is no requirement for `get_inbound_group_sessions_for_device_batch` to
/// return the results in a specific order. This helper ensures that the two lists
/// of inbound group sessions are equivalent, without worrying about the ordering.
fn assert_session_lists_eq<I, J>(actual: I, expected: J, message: &str)
where I: IntoIterator<Item = InboundGroupSession>, J: IntoIterator<Item = InboundGroupSession>
{
let sorter = |a: &InboundGroupSession, b: &InboundGroupSession| Ord::cmp(a.session_id(), b.session_id());

let mut actual = Vec::from_iter(actual);
actual.sort_unstable_by(sorter);
let mut expected = Vec::from_iter(expected);
expected.sort_unstable_by(sorter);
assert_eq!(actual, expected, "{}", message);
}

#[async_test]
async fn test_tracked_users() {
let dir = "test_tracked_users";
Expand Down Expand Up @@ -1112,6 +1227,39 @@ macro_rules! cryptostore_integration_tests {
fn session_info(session: &InboundGroupSession) -> (&RoomId, &str) {
(&session.room_id(), &session.session_id())
}

async fn create_session(
account: &Account,
device_curve_key: &Curve25519PublicKey,
sender_data_type: SenderDataType,
) -> InboundGroupSession {
let sender_data = match sender_data_type {
SenderDataType::UnknownDevice => {
SenderData::UnknownDevice { legacy_session: false, owner_check_failed: false }
}
SenderDataType::DeviceInfo => SenderData::DeviceInfo {
device_keys: account.device_keys().clone(),
legacy_session: false,
},
SenderDataType::SenderUnverifiedButPreviouslyVerified =>
panic!("SenderUnverifiedButPreviouslyVerified not supported"),
SenderDataType::SenderUnverified=> panic!("SenderUnverified not supported"),
SenderDataType::SenderVerified => panic!("SenderVerified not supported"),
};

let session_key = GroupSession::new(SessionConfig::default()).session_key();

InboundGroupSession::new(
device_curve_key.clone(),
account.device_keys().ed25519_key().unwrap(),
room_id!("!r:s.co"),
&session_key,
sender_data,
EventEncryptionAlgorithm::MegolmV1AesSha2,
None,
)
.unwrap()
}
}
};
}
Expand Down
65 changes: 63 additions & 2 deletions crates/matrix-sdk-crypto/src/store/memorystore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use ruma::{
};
use tokio::sync::RwLock;
use tracing::warn;
use vodozemac::Curve25519PublicKey;

use super::{
caches::{DeviceStore, GroupSessionStore},
Expand All @@ -35,7 +36,7 @@ use super::{
use crate::{
gossiping::{GossipRequest, GossippedSecret, SecretInfo},
identities::{DeviceData, UserIdentityData},
olm::{OutboundGroupSession, PrivateCrossSigningIdentity},
olm::{OutboundGroupSession, PrivateCrossSigningIdentity, SenderDataType},
types::events::room_key_withheld::RoomKeyWithheldEvent,
TrackedUser,
};
Expand Down Expand Up @@ -380,6 +381,48 @@ impl CryptoStore for MemoryStore {
Ok(RoomKeyCounts { total: self.inbound_group_sessions.count(), backed_up })
}

async fn get_inbound_group_sessions_for_device_batch(
&self,
sender_key: Curve25519PublicKey,
sender_data_type: SenderDataType,
after_session_id: Option<String>,
limit: usize,
) -> Result<Vec<InboundGroupSession>> {
// First, find all InboundGroupSessions, filtering for those that match the
// device and sender_data type.
let mut sessions: Vec<_> = self
.get_inbound_group_sessions()
.await?
.into_iter()
.filter(|session: &InboundGroupSession| {
session.creator_info.curve25519_key == sender_key
&& session.sender_data.to_type() == sender_data_type
})
.collect();

// Then, sort the sessions in order of ascending session ID...
sessions.sort_by_key(|s| s.session_id().to_owned());

// Figure out where in the array to start returning results from
let start_index = {
match after_session_id {
None => 0,
Some(id) => {
let idx = sessions
.iter()
.position(|session| session.session_id() == id)
.map(|idx| idx + 1);

// If `after_session_id` was not found in the array, go to the end of the array
idx.unwrap_or(sessions.len())
}
}
};

// Return up to `limit` items from the array, starting from `start_index`
Ok(sessions.drain(start_index..).take(limit).collect())
}

async fn inbound_group_sessions_for_backup(
&self,
backup_version: &str,
Expand Down Expand Up @@ -1102,13 +1145,14 @@ mod integration_tests {
use ruma::{
events::secret::request::SecretName, DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId,
};
use vodozemac::Curve25519PublicKey;

use super::MemoryStore;
use crate::{
cryptostore_integration_tests, cryptostore_integration_tests_time,
olm::{
InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity,
StaticAccountData,
SenderDataType, StaticAccountData,
},
store::{BackupKeys, Changes, CryptoStore, PendingChanges, RoomKeyCounts, RoomSettings},
types::events::room_key_withheld::RoomKeyWithheldEvent,
Expand Down Expand Up @@ -1230,6 +1274,23 @@ mod integration_tests {
self.0.inbound_group_session_counts(backup_version).await
}

async fn get_inbound_group_sessions_for_device_batch(
&self,
sender_key: Curve25519PublicKey,
sender_data_type: SenderDataType,
after_session_id: Option<String>,
limit: usize,
) -> Result<Vec<InboundGroupSession>, Self::Error> {
self.0
.get_inbound_group_sessions_for_device_batch(
sender_key,
sender_data_type,
after_session_id,
limit,
)
.await
}

async fn inbound_group_sessions_for_backup(
&self,
backup_version: &str,
Expand Down
Loading
Loading