Skip to content

Commit

Permalink
Add the Olm Session cache back in the CryptoStoreWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
poljar committed Aug 6, 2024
1 parent e747d4c commit 07250db
Show file tree
Hide file tree
Showing 15 changed files with 175 additions and 35 deletions.
3 changes: 2 additions & 1 deletion crates/matrix-sdk-crypto/src/dehydrated_devices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ impl DehydratedDevices {
let user_identity = self.inner.store().private_identity();

let account = Account::new_dehydrated(user_id);
let store = Arc::new(CryptoStoreWrapper::new(user_id, MemoryStore::new()));
let store =
Arc::new(CryptoStoreWrapper::new(user_id, account.device_id(), MemoryStore::new()));

let verification_machine = VerificationMachine::new(
account.static_data().clone(),
Expand Down
5 changes: 3 additions & 2 deletions crates/matrix-sdk-crypto/src/gossiping/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,7 @@ mod tests {
let device_id = DeviceId::new();

let account = Account::with_device_id(&user_id, &device_id);
let store = Arc::new(CryptoStoreWrapper::new(&user_id, MemoryStore::new()));
let store = Arc::new(CryptoStoreWrapper::new(&user_id, &device_id, MemoryStore::new()));
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id())));
let verification =
VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone());
Expand All @@ -1197,7 +1197,8 @@ mod tests {
let another_device =
DeviceData::from_account(&Account::with_device_id(&user_id, alice2_device_id()));

let store = Arc::new(CryptoStoreWrapper::new(&user_id, MemoryStore::new()));
let store =
Arc::new(CryptoStoreWrapper::new(&user_id, account.device_id(), MemoryStore::new()));
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id())));
let verification =
VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone());
Expand Down
3 changes: 2 additions & 1 deletion crates/matrix-sdk-crypto/src/identities/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,8 @@ impl DeviceData {
store: &CryptoStoreWrapper,
) -> OlmResult<Option<Session>> {
if let Some(sender_key) = self.curve25519_key() {
if let Some(mut sessions) = store.get_sessions(&sender_key.to_base64()).await? {
if let Some(sessions) = store.get_sessions(&sender_key.to_base64()).await? {
let mut sessions = sessions.lock().await;
sessions.sort_by_key(|s| s.creation_time);

Ok(sessions.last().cloned())
Expand Down
2 changes: 1 addition & 1 deletion crates/matrix-sdk-crypto/src/identities/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ pub(crate) mod testing {
let user_id = user_id.to_owned();
let account = Account::with_device_id(&user_id, device_id);
let static_account = account.static_data().clone();
let store = Arc::new(CryptoStoreWrapper::new(&user_id, MemoryStore::new()));
let store = Arc::new(CryptoStoreWrapper::new(&user_id, device_id, MemoryStore::new()));
let verification =
VerificationMachine::new(static_account.clone(), identity.clone(), store.clone());
let store = Store::new(static_account, identity, store, verification);
Expand Down
12 changes: 10 additions & 2 deletions crates/matrix-sdk-crypto/src/identities/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,11 @@ pub(crate) mod tests {
let verification_machine = VerificationMachine::new(
Account::with_device_id(second.user_id(), second.device_id()).static_data,
private_identity,
Arc::new(CryptoStoreWrapper::new(second.user_id(), MemoryStore::new())),
Arc::new(CryptoStoreWrapper::new(
second.user_id(),
second.device_id(),
MemoryStore::new(),
)),
);

let first = Device {
Expand Down Expand Up @@ -1139,7 +1143,11 @@ pub(crate) mod tests {
let verification_machine = VerificationMachine::new(
Account::with_device_id(device.user_id(), device.device_id()).static_data,
id.clone(),
Arc::new(CryptoStoreWrapper::new(device.user_id(), MemoryStore::new())),
Arc::new(CryptoStoreWrapper::new(
device.user_id(),
device.device_id(),
MemoryStore::new(),
)),
);

let public_identity = identity.to_public_identity().await.unwrap();
Expand Down
11 changes: 7 additions & 4 deletions crates/matrix-sdk-crypto/src/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ impl OlmMachine {
let account = Account::rehydrate(pickle_key, self.user_id(), device_id, device_data)?;
let static_account = account.static_data().clone();

let store = Arc::new(CryptoStoreWrapper::new(self.user_id(), MemoryStore::new()));
let store =
Arc::new(CryptoStoreWrapper::new(self.user_id(), device_id, MemoryStore::new()));
let device = DeviceData::from_account(&account);
store.save_pending_changes(PendingChanges { account: Some(account) }).await?;
store
Expand Down Expand Up @@ -356,7 +357,7 @@ impl OlmMachine {
});

let identity = Arc::new(Mutex::new(identity));
let store = Arc::new(CryptoStoreWrapper::new(user_id, store));
let store = Arc::new(CryptoStoreWrapper::new(user_id, device_id, store));
Ok(OlmMachine::new_helper(device_id, store, static_account, identity, maybe_backup_key))
}

Expand Down Expand Up @@ -2930,7 +2931,7 @@ pub(crate) mod tests {
.unwrap()
.unwrap();

assert!(!session.is_empty())
assert!(!session.lock().await.is_empty())
}

#[async_test]
Expand Down Expand Up @@ -2975,13 +2976,15 @@ pub(crate) mod tests {
// a resolution in seconds, it's very likely that we're going to end up
// with the same timestamps, so we manually masage them to be 10s apart.
let session_id = {
let mut sessions = alice_machine
let sessions = alice_machine
.store()
.get_sessions(&bob_machine.identity_keys().curve25519.to_base64())
.await
.unwrap()
.unwrap();

let mut sessions = sessions.lock().await;

let mut use_time = SystemTime::now();

let mut session_id = None;
Expand Down
4 changes: 2 additions & 2 deletions crates/matrix-sdk-crypto/src/olm/account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1252,7 +1252,7 @@ impl Account {
if let Some(sessions) = existing_sessions {
// Try to decrypt the message using each Session we share with the
// given curve25519 sender key.
for mut session in sessions {
for session in sessions.lock().await.iter_mut() {
match session.decrypt(message).await {
Ok(p) => {
// success!
Expand Down Expand Up @@ -1280,7 +1280,7 @@ impl Account {
OlmMessage::PreKey(prekey_message) => {
// First try to decrypt using an existing session.
if let Some(sessions) = existing_sessions {
for mut session in sessions {
for session in sessions.lock().await.iter_mut() {
if prekey_message.session_id() != session.session_id() {
// wrong session
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,11 @@ mod tests {
}

fn create_store(me: &TestUser) -> Store {
let store_wrapper = Arc::new(CryptoStoreWrapper::new(&me.user_id, MemoryStore::new()));
let store_wrapper = Arc::new(CryptoStoreWrapper::new(
&me.user_id,
me.account.device_id(),
MemoryStore::new(),
));

let verification_machine = VerificationMachine::new(
me.account.deref().clone(),
Expand Down Expand Up @@ -1078,7 +1082,11 @@ mod tests {
Arc::new(Mutex::new(PrivateCrossSigningIdentity::new(
account.user_id().to_owned(),
))),
Arc::new(CryptoStoreWrapper::new(account.user_id(), MemoryStore::new())),
Arc::new(CryptoStoreWrapper::new(
account.user_id(),
account.device_id(),
MemoryStore::new(),
)),
),
own_identity: None,
device_owner_identity: None,
Expand Down
9 changes: 6 additions & 3 deletions crates/matrix-sdk-crypto/src/session_manager/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,11 @@ impl SessionManager {
} else if let Some(sender_key) = device.curve25519_key() {
let sessions = self.store.get_sessions(&sender_key.to_base64()).await?;

let is_missing =
if let Some(sessions) = sessions { sessions.is_empty() } else { true };
let is_missing = if let Some(sessions) = sessions {
sessions.lock().await.is_empty()
} else {
true
};

let is_timed_out = self.is_user_timed_out(&user_id, &device_id);

Expand Down Expand Up @@ -668,7 +671,7 @@ mod tests {
let device_id = device_id();

let account = Account::with_device_id(user_id, device_id);
let store = Arc::new(CryptoStoreWrapper::new(user_id, MemoryStore::new()));
let store = Arc::new(CryptoStoreWrapper::new(user_id, device_id, MemoryStore::new()));
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id)));
let verification = VerificationMachine::new(
account.static_data().clone(),
Expand Down
114 changes: 107 additions & 7 deletions crates/matrix-sdk-crypto/src/store/crypto_store_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@ use std::{future, ops::Deref, sync::Arc};
use futures_core::Stream;
use futures_util::StreamExt;
use matrix_sdk_common::store_locks::CrossProcessStoreLock;
use ruma::{OwnedUserId, UserId};
use tokio::sync::broadcast;
use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId};
use tokio::sync::{broadcast, Mutex};
use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
use tracing::warn;

use super::{DeviceChanges, IdentityChanges, LockableCryptoStore};
use super::{caches::SessionStore, DeviceChanges, IdentityChanges, LockableCryptoStore};
use crate::{
olm::InboundGroupSession,
store,
store::{Changes, DynCryptoStore, IntoCryptoStore, RoomKeyInfo, RoomKeyWithheldInfo},
GossippedSecret, OwnUserIdentityData,
store::{self, Changes, DynCryptoStore, IntoCryptoStore, RoomKeyInfo, RoomKeyWithheldInfo},
GossippedSecret, OwnUserIdentityData, Session,
};

/// A wrapper for crypto store implementations that adds update notifiers.
Expand All @@ -23,8 +22,13 @@ use crate::{
#[derive(Debug)]
pub(crate) struct CryptoStoreWrapper {
user_id: OwnedUserId,
device_id: OwnedDeviceId,

store: Arc<DynCryptoStore>,

/// A cache for the Olm Sessions.
sessions: SessionStore,

/// The sender side of a broadcast stream that is notified whenever we get
/// an update to an inbound group session.
room_keys_received_sender: broadcast::Sender<Vec<RoomKeyInfo>>,
Expand All @@ -44,7 +48,7 @@ pub(crate) struct CryptoStoreWrapper {
}

impl CryptoStoreWrapper {
pub(crate) fn new(user_id: &UserId, store: impl IntoCryptoStore) -> Self {
pub(crate) fn new(user_id: &UserId, device_id: &DeviceId, store: impl IntoCryptoStore) -> Self {
let room_keys_received_sender = broadcast::Sender::new(10);
let room_keys_withheld_received_sender = broadcast::Sender::new(10);
let secrets_broadcaster = broadcast::Sender::new(10);
Expand All @@ -54,7 +58,9 @@ impl CryptoStoreWrapper {

Self {
user_id: user_id.to_owned(),
device_id: device_id.to_owned(),
store: store.into_crypto_store(),
sessions: SessionStore::new(),
room_keys_received_sender,
room_keys_withheld_received_sender,
secrets_broadcaster,
Expand Down Expand Up @@ -90,6 +96,22 @@ impl CryptoStoreWrapper {
let devices = changes.devices.to_owned();
let identities = changes.identities.to_owned();

if devices
.changed
.iter()
.any(|d| d.user_id() == self.user_id && d.device_id() == self.device_id)
{
// If our own device key changes, we need to clear the
// session cache because the sessions contain a copy of our
// device key.
self.sessions.clear().await;
} else {
// Otherwise add the sessions to the cache.
for session in &changes.sessions {
self.sessions.add(session.clone()).await;
}
}

self.store.save_changes(changes).await?;

if !room_key_updates.is_empty() {
Expand Down Expand Up @@ -118,6 +140,34 @@ impl CryptoStoreWrapper {
Ok(())
}

pub async fn get_sessions(
&self,
sender_key: &str,
) -> store::Result<Option<Arc<Mutex<Vec<Session>>>>> {
let sessions = self.sessions.get(sender_key).await;

let sessions = if sessions.is_none() {
let mut entries = self.sessions.entries.write().await;

let sessions = entries.get(sender_key);

if sessions.is_some() {
sessions.cloned()
} else {
let sessions = self.store.get_sessions(sender_key).await?;
let sessions = Arc::new(Mutex::new(sessions.unwrap_or_default()));

entries.insert(sender_key.to_owned(), sessions.clone());

Some(sessions)
}
} else {
sessions
};

Ok(sessions)
}

/// Save a list of inbound group sessions to the store.
///
/// # Arguments
Expand Down Expand Up @@ -232,3 +282,53 @@ impl Deref for CryptoStoreWrapper {
self.store.deref()
}
}

#[cfg(test)]
mod test {
use matrix_sdk_test::async_test;
use ruma::user_id;

use super::*;
use crate::machine::tests::get_machine_pair_with_setup_sessions_test_helper;

#[async_test]
async fn cache_cleared_after_device_update() {
let user_id = user_id!("@alice:example.com");
let (first, second) =
get_machine_pair_with_setup_sessions_test_helper(user_id, user_id, false).await;

let sender_key = second.identity_keys().curve25519.to_base64();

first
.store()
.inner
.store
.sessions
.get(&sender_key)
.await
.expect("We should have a session in the cache.");

let device_data = first
.get_device(user_id, first.device_id(), None)
.await
.unwrap()
.expect("We should have access to our own device.")
.inner;

// When we save a new version of our device keys
first
.store()
.save_changes(Changes {
devices: DeviceChanges { changed: vec![device_data], ..Default::default() },
..Default::default()
})
.await
.unwrap();

// Then the session is no longer in the cache
assert!(
first.store().inner.store.sessions.get(&sender_key).await.is_none(),
"The session should no longer be in the cache after our own device keys changed"
);
}
}
7 changes: 7 additions & 0 deletions crates/matrix-sdk-crypto/src/store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,13 @@ impl Store {
self.save_changes(changes).await
}

pub(crate) async fn get_sessions(
&self,
sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
self.inner.store.get_sessions(sender_key).await
}

pub(crate) async fn save_changes(&self, changes: Changes) -> Result<()> {
self.inner.store.save_changes(changes).await
}
Expand Down
2 changes: 1 addition & 1 deletion crates/matrix-sdk-crypto/src/verification/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ mod tests {
let _ = VerificationMachine::new(
alice.static_data,
identity,
Arc::new(CryptoStoreWrapper::new(alice_id(), MemoryStore::new())),
Arc::new(CryptoStoreWrapper::new(alice_id(), alice_device_id(), MemoryStore::new())),
);
}

Expand Down
8 changes: 6 additions & 2 deletions crates/matrix-sdk-crypto/src/verification/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -861,14 +861,18 @@ pub(crate) mod tests {
bob_store.save_devices(vec![alice_device]);

let alice_store = VerificationStore {
inner: Arc::new(CryptoStoreWrapper::new(alice.user_id(), alice_store)),
inner: Arc::new(CryptoStoreWrapper::new(
alice.user_id(),
alice.device_id(),
alice_store,
)),
account: alice.static_data.clone(),
private_identity: alice_private_identity.into(),
};

let bob_store = VerificationStore {
account: bob.static_data.clone(),
inner: Arc::new(CryptoStoreWrapper::new(bob.user_id(), bob_store)),
inner: Arc::new(CryptoStoreWrapper::new(bob.user_id(), bob.device_id(), bob_store)),
private_identity: bob_private_identity.into(),
};

Expand Down
Loading

0 comments on commit 07250db

Please sign in to comment.