diff --git a/presage-cli/src/main.rs b/presage-cli/src/main.rs index 3ff2b2af7..646e7e183 100644 --- a/presage-cli/src/main.rs +++ b/presage-cli/src/main.rs @@ -1,4 +1,3 @@ -use core::fmt; use std::convert::TryInto; use std::path::Path; use std::path::PathBuf; @@ -492,7 +491,12 @@ async fn run<S: Store>(subcommand: Cmd, config_store: S) -> anyhow::Result<()> { let stdin = io::stdin(); let reader = BufReader::new(stdin); if let Some(confirmation_code) = reader.lines().next_line().await? { - manager.confirm_verification_code(confirmation_code).await?; + let registered_manager = + manager.confirm_verification_code(confirmation_code).await?; + println!( + "Account identifier: {}", + registered_manager.registration_data().aci() + ); } } Cmd::LinkDevice { @@ -709,17 +713,3 @@ fn parse_base64_profile_key(s: &str) -> anyhow::Result<ProfileKey> { .map_err(|_| anyhow!("profile key of invalid length"))?; Ok(ProfileKey::create(bytes)) } - -struct DebugGroup<'a>(&'a Group); - -impl fmt::Debug for DebugGroup<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let group = &self.0; - f.debug_struct("Group") - .field("title", &group.title) - .field("avatar", &group.avatar) - .field("revision", &group.revision) - .field("description", &group.description) - .finish() - } -} diff --git a/presage-store-sled/Cargo.toml b/presage-store-sled/Cargo.toml index 23983eacd..bd4102d93 100644 --- a/presage-store-sled/Cargo.toml +++ b/presage-store-sled/Cargo.toml @@ -23,6 +23,7 @@ thiserror = "1.0" prost = "> 0.10, <= 0.12" sha2 = "0.10" quickcheck_macros = "1.0.0" +chrono = "0.4.35" [dev-dependencies] anyhow = "1.0" diff --git a/presage-store-sled/src/content.rs b/presage-store-sled/src/content.rs new file mode 100644 index 000000000..bfc713ff7 --- /dev/null +++ b/presage-store-sled/src/content.rs @@ -0,0 +1,470 @@ +use std::{ + ops::{Bound, RangeBounds, RangeFull}, + sync::Arc, +}; + +use log::debug; +use presage::{ + libsignal_service::{ + content::Content, + groups_v2::Group, + models::Contact, + prelude::Uuid, + zkgroup::{profiles::ProfileKey, GroupMasterKeyBytes}, + Profile, + }, + store::{ContentExt, ContentsStore, StickerPack, Thread}, + AvatarBytes, +}; +use prost::Message; +use serde::de::DeserializeOwned; +use sha2::{Digest, Sha256}; +use sled::IVec; + +use crate::{protobuf::ContentProto, SledStore, SledStoreError}; + +const SLED_TREE_PROFILE_AVATARS: &str = "profile_avatars"; +const SLED_TREE_PROFILE_KEYS: &str = "profile_keys"; +const SLED_TREE_STICKER_PACKS: &str = "sticker_packs"; +const SLED_TREE_CONTACTS: &str = "contacts"; +const SLED_TREE_GROUP_AVATARS: &str = "group_avatars"; +const SLED_TREE_GROUPS: &str = "groups"; +const SLED_TREE_PROFILES: &str = "profiles"; +const SLED_TREE_THREADS_PREFIX: &str = "threads"; + +impl ContentsStore for SledStore { + type ContentsStoreError = SledStoreError; + + type ContactsIter = SledContactsIter; + type GroupsIter = SledGroupsIter; + type MessagesIter = SledMessagesIter; + type StickerPacksIter = SledStickerPacksIter; + + fn clear_profiles(&mut self) -> Result<(), Self::ContentsStoreError> { + let db = self.write(); + db.drop_tree(SLED_TREE_PROFILES)?; + db.drop_tree(SLED_TREE_PROFILE_KEYS)?; + db.drop_tree(SLED_TREE_PROFILE_AVATARS)?; + db.flush()?; + Ok(()) + } + + fn clear_contents(&mut self) -> Result<(), Self::ContentsStoreError> { + let db = self.write(); + db.drop_tree(SLED_TREE_CONTACTS)?; + db.drop_tree(SLED_TREE_GROUPS)?; + + for tree in db + .tree_names() + .into_iter() + .filter(|n| n.starts_with(SLED_TREE_THREADS_PREFIX.as_bytes())) + { + db.drop_tree(tree)?; + } + + db.flush()?; + Ok(()) + } + + fn clear_contacts(&mut self) -> Result<(), SledStoreError> { + self.write().drop_tree(SLED_TREE_CONTACTS)?; + Ok(()) + } + + fn save_contact(&mut self, contact: &Contact) -> Result<(), SledStoreError> { + self.insert(SLED_TREE_CONTACTS, contact.uuid, contact)?; + debug!("saved contact"); + Ok(()) + } + + fn contacts(&self) -> Result<Self::ContactsIter, SledStoreError> { + Ok(SledContactsIter { + iter: self.read().open_tree(SLED_TREE_CONTACTS)?.iter(), + #[cfg(feature = "encryption")] + cipher: self.cipher.clone(), + }) + } + + fn contact_by_id(&self, id: &Uuid) -> Result<Option<Contact>, SledStoreError> { + self.get(SLED_TREE_CONTACTS, id) + } + + /// Groups + + fn clear_groups(&mut self) -> Result<(), SledStoreError> { + let db = self.write(); + db.drop_tree(SLED_TREE_GROUPS)?; + db.flush()?; + Ok(()) + } + + fn groups(&self) -> Result<Self::GroupsIter, SledStoreError> { + Ok(SledGroupsIter { + iter: self.read().open_tree(SLED_TREE_GROUPS)?.iter(), + #[cfg(feature = "encryption")] + cipher: self.cipher.clone(), + }) + } + + fn group( + &self, + master_key_bytes: GroupMasterKeyBytes, + ) -> Result<Option<Group>, SledStoreError> { + self.get(SLED_TREE_GROUPS, master_key_bytes) + } + + fn save_group( + &self, + master_key: GroupMasterKeyBytes, + group: &Group, + ) -> Result<(), SledStoreError> { + self.insert(SLED_TREE_GROUPS, master_key, group)?; + Ok(()) + } + + fn group_avatar( + &self, + master_key_bytes: GroupMasterKeyBytes, + ) -> Result<Option<AvatarBytes>, SledStoreError> { + self.get(SLED_TREE_GROUP_AVATARS, master_key_bytes) + } + + fn save_group_avatar( + &self, + master_key: GroupMasterKeyBytes, + avatar: &AvatarBytes, + ) -> Result<(), SledStoreError> { + self.insert(SLED_TREE_GROUP_AVATARS, master_key, avatar)?; + Ok(()) + } + + /// Messages + + fn clear_messages(&mut self) -> Result<(), SledStoreError> { + let db = self.write(); + for name in db.tree_names() { + if name + .as_ref() + .starts_with(SLED_TREE_THREADS_PREFIX.as_bytes()) + { + db.drop_tree(&name)?; + } + } + db.flush()?; + Ok(()) + } + + fn clear_thread(&mut self, thread: &Thread) -> Result<(), SledStoreError> { + log::trace!("clearing thread {thread}"); + + let db = self.write(); + db.drop_tree(messages_thread_tree_name(thread))?; + db.flush()?; + + Ok(()) + } + + fn save_message(&self, thread: &Thread, message: Content) -> Result<(), SledStoreError> { + let ts = message.timestamp(); + log::trace!("storing a message with thread: {thread}, timestamp: {ts}",); + + let tree = messages_thread_tree_name(thread); + let key = ts.to_be_bytes(); + + let proto: ContentProto = message.into(); + let value = proto.encode_to_vec(); + + self.insert(&tree, key, value)?; + + Ok(()) + } + + fn delete_message(&mut self, thread: &Thread, timestamp: u64) -> Result<bool, SledStoreError> { + let tree = messages_thread_tree_name(thread); + self.remove(&tree, timestamp.to_be_bytes()) + } + + fn message(&self, thread: &Thread, timestamp: u64) -> Result<Option<Content>, SledStoreError> { + // Big-Endian needed, otherwise wrong ordering in sled. + let val: Option<Vec<u8>> = + self.get(&messages_thread_tree_name(thread), timestamp.to_be_bytes())?; + match val { + Some(ref v) => { + let proto = ContentProto::decode(v.as_slice())?; + let content = proto.try_into()?; + Ok(Some(content)) + } + None => Ok(None), + } + } + + fn messages( + &self, + thread: &Thread, + range: impl RangeBounds<u64>, + ) -> Result<Self::MessagesIter, SledStoreError> { + let tree_thread = self.read().open_tree(messages_thread_tree_name(thread))?; + debug!("{} messages in this tree", tree_thread.len()); + + let iter = match (range.start_bound(), range.end_bound()) { + (Bound::Included(start), Bound::Unbounded) => tree_thread.range(start.to_be_bytes()..), + (Bound::Included(start), Bound::Excluded(end)) => { + tree_thread.range(start.to_be_bytes()..end.to_be_bytes()) + } + (Bound::Included(start), Bound::Included(end)) => { + tree_thread.range(start.to_be_bytes()..=end.to_be_bytes()) + } + (Bound::Unbounded, Bound::Included(end)) => tree_thread.range(..=end.to_be_bytes()), + (Bound::Unbounded, Bound::Excluded(end)) => tree_thread.range(..end.to_be_bytes()), + (Bound::Unbounded, Bound::Unbounded) => tree_thread.range::<[u8; 8], RangeFull>(..), + (Bound::Excluded(_), _) => { + unreachable!("range that excludes the initial value") + } + }; + + Ok(SledMessagesIter { + #[cfg(feature = "encryption")] + cipher: self.cipher.clone(), + iter, + }) + } + + fn upsert_profile_key(&mut self, uuid: &Uuid, key: ProfileKey) -> Result<bool, SledStoreError> { + self.insert(SLED_TREE_PROFILE_KEYS, uuid.as_bytes(), key) + } + + fn profile_key(&self, uuid: &Uuid) -> Result<Option<ProfileKey>, SledStoreError> { + self.get(SLED_TREE_PROFILE_KEYS, uuid.as_bytes()) + } + + fn save_profile( + &mut self, + uuid: Uuid, + key: ProfileKey, + profile: Profile, + ) -> Result<(), SledStoreError> { + let key = self.profile_key_for_uuid(uuid, key); + self.insert(SLED_TREE_PROFILES, key, profile)?; + Ok(()) + } + + fn profile(&self, uuid: Uuid, key: ProfileKey) -> Result<Option<Profile>, SledStoreError> { + let key = self.profile_key_for_uuid(uuid, key); + self.get(SLED_TREE_PROFILES, key) + } + + fn save_profile_avatar( + &mut self, + uuid: Uuid, + key: ProfileKey, + avatar: &AvatarBytes, + ) -> Result<(), SledStoreError> { + let key = self.profile_key_for_uuid(uuid, key); + self.insert(SLED_TREE_PROFILE_AVATARS, key, avatar)?; + Ok(()) + } + + fn profile_avatar( + &self, + uuid: Uuid, + key: ProfileKey, + ) -> Result<Option<AvatarBytes>, SledStoreError> { + let key = self.profile_key_for_uuid(uuid, key); + self.get(SLED_TREE_PROFILE_AVATARS, key) + } + + fn add_sticker_pack(&mut self, pack: &StickerPack) -> Result<(), SledStoreError> { + self.insert(SLED_TREE_STICKER_PACKS, pack.id.clone(), pack)?; + Ok(()) + } + + fn remove_sticker_pack(&mut self, id: &[u8]) -> Result<bool, SledStoreError> { + self.remove(SLED_TREE_STICKER_PACKS, id) + } + + fn sticker_pack(&self, id: &[u8]) -> Result<Option<StickerPack>, SledStoreError> { + self.get(SLED_TREE_STICKER_PACKS, id) + } + + fn sticker_packs(&self) -> Result<Self::StickerPacksIter, SledStoreError> { + Ok(SledStickerPacksIter { + cipher: self.cipher.clone(), + iter: self.read().open_tree(SLED_TREE_STICKER_PACKS)?.iter(), + }) + } +} + +pub struct SledContactsIter { + #[cfg(feature = "encryption")] + cipher: Option<Arc<presage_store_cipher::StoreCipher>>, + iter: sled::Iter, +} + +impl SledContactsIter { + #[cfg(feature = "encryption")] + fn decrypt_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T, SledStoreError> { + if let Some(cipher) = self.cipher.as_ref() { + Ok(cipher.decrypt_value(value)?) + } else { + Ok(serde_json::from_slice(value)?) + } + } + + #[cfg(not(feature = "encryption"))] + fn decrypt_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T, SledStoreError> { + Ok(serde_json::from_slice(value)?) + } +} + +impl Iterator for SledContactsIter { + type Item = Result<Contact, SledStoreError>; + + fn next(&mut self) -> Option<Self::Item> { + self.iter + .next()? + .map_err(SledStoreError::from) + .and_then(|(_key, value)| self.decrypt_value(&value)) + .into() + } +} + +pub struct SledGroupsIter { + #[cfg(feature = "encryption")] + cipher: Option<Arc<presage_store_cipher::StoreCipher>>, + iter: sled::Iter, +} + +impl SledGroupsIter { + #[cfg(feature = "encryption")] + fn decrypt_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T, SledStoreError> { + if let Some(cipher) = self.cipher.as_ref() { + Ok(cipher.decrypt_value(value)?) + } else { + Ok(serde_json::from_slice(value)?) + } + } + + #[cfg(not(feature = "encryption"))] + fn decrypt_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T, SledStoreError> { + Ok(serde_json::from_slice(value)?) + } +} + +impl Iterator for SledGroupsIter { + type Item = Result<(GroupMasterKeyBytes, Group), SledStoreError>; + + fn next(&mut self) -> Option<Self::Item> { + Some(self.iter.next()?.map_err(SledStoreError::from).and_then( + |(group_master_key_bytes, value)| { + let group = self.decrypt_value(&value)?; + Ok(( + group_master_key_bytes + .as_ref() + .try_into() + .map_err(|_| SledStoreError::GroupDecryption)?, + group, + )) + }, + )) + } +} + +pub struct SledStickerPacksIter { + #[cfg(feature = "encryption")] + cipher: Option<Arc<presage_store_cipher::StoreCipher>>, + iter: sled::Iter, +} + +impl Iterator for SledStickerPacksIter { + type Item = Result<StickerPack, SledStoreError>; + + #[cfg(feature = "encryption")] + fn next(&mut self) -> Option<Self::Item> { + self.iter + .next()? + .map_err(SledStoreError::from) + .and_then(|(_key, value)| { + if let Some(cipher) = self.cipher.as_ref() { + cipher.decrypt_value(&value).map_err(SledStoreError::from) + } else { + serde_json::from_slice(&value).map_err(SledStoreError::from) + } + }) + .into() + } + + #[cfg(not(feature = "encryption"))] + fn next(&mut self) -> Option<Self::Item> { + self.iter + .next()? + .map_err(SledStoreError::from) + .and_then(|(_key, value)| serde_json::from_slice(&value).map_err(SledStoreError::from)) + .into() + } +} + +pub struct SledMessagesIter { + #[cfg(feature = "encryption")] + cipher: Option<Arc<presage_store_cipher::StoreCipher>>, + iter: sled::Iter, +} + +impl SledMessagesIter { + #[cfg(feature = "encryption")] + fn decrypt_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T, SledStoreError> { + if let Some(cipher) = self.cipher.as_ref() { + Ok(cipher.decrypt_value(value)?) + } else { + Ok(serde_json::from_slice(value)?) + } + } + + #[cfg(not(feature = "encryption"))] + fn decrypt_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T, SledStoreError> { + Ok(serde_json::from_slice(value)?) + } +} + +impl SledMessagesIter { + fn decode( + &self, + elem: Result<(IVec, IVec), sled::Error>, + ) -> Option<Result<Content, SledStoreError>> { + elem.map_err(SledStoreError::from) + .and_then(|(_, value)| self.decrypt_value(&value).map_err(SledStoreError::from)) + .and_then(|data: Vec<u8>| ContentProto::decode(&data[..]).map_err(SledStoreError::from)) + .map_or_else(|e| Some(Err(e)), |p| Some(p.try_into())) + } +} + +impl Iterator for SledMessagesIter { + type Item = Result<Content, SledStoreError>; + + fn next(&mut self) -> Option<Self::Item> { + let elem = self.iter.next()?; + self.decode(elem) + } +} + +impl DoubleEndedIterator for SledMessagesIter { + fn next_back(&mut self) -> Option<Self::Item> { + let elem = self.iter.next_back()?; + self.decode(elem) + } +} + +fn messages_thread_tree_name(t: &Thread) -> String { + use base64::prelude::*; + let key = match t { + Thread::Contact(uuid) => { + format!("{SLED_TREE_THREADS_PREFIX}:contact:{uuid}") + } + Thread::Group(group_id) => format!( + "{SLED_TREE_THREADS_PREFIX}:group:{}", + BASE64_STANDARD.encode(group_id) + ), + }; + let mut hasher = Sha256::new(); + hasher.update(key.as_bytes()); + format!("{SLED_TREE_THREADS_PREFIX}:{:x}", hasher.finalize()) +} diff --git a/presage-store-sled/src/error.rs b/presage-store-sled/src/error.rs index 498bcb780..56b21bf5f 100644 --- a/presage-store-sled/src/error.rs +++ b/presage-store-sled/src/error.rs @@ -13,6 +13,8 @@ pub enum SledStoreError { StoreCipher(#[from] presage_store_cipher::StoreCipherError), #[error("JSON error: {0}")] Json(#[from] serde_json::Error), + #[error("base64 decode error: {0}")] + Base64Decode(#[from] base64::DecodeError), #[error("Prost error: {0}")] ProtobufDecode(#[from] prost::DecodeError), #[error("I/O error: {0}")] @@ -27,8 +29,9 @@ pub enum SledStoreError { impl StoreError for SledStoreError {} -impl SledStoreError { - pub(crate) fn into_signal_error(self) -> SignalProtocolError { - SignalProtocolError::InvalidState("presage error", self.to_string()) +impl From<SledStoreError> for SignalProtocolError { + fn from(error: SledStoreError) -> Self { + log::error!("presage store error: {error}"); + Self::InvalidState("presage store error", error.to_string()) } } diff --git a/presage-store-sled/src/lib.rs b/presage-store-sled/src/lib.rs index 0ac4de35b..eabf9e972 100644 --- a/presage-store-sled/src/lib.rs +++ b/presage-store-sled/src/lib.rs @@ -1,60 +1,37 @@ use std::{ - ops::{Bound, Range, RangeBounds, RangeFull}, + ops::Range, path::Path, sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}, time::{SystemTime, UNIX_EPOCH}, }; -use async_trait::async_trait; use base64::prelude::*; -use log::{debug, error, trace, warn}; -use presage::libsignal_service::zkgroup::GroupMasterKeyBytes; -use presage::libsignal_service::{ - self, - content::Content, - groups_v2::Group, - models::Contact, - pre_keys::PreKeysStore, - prelude::{ProfileKey, Uuid}, - protocol::{ - Direction, GenericSignedPreKey, IdentityKey, IdentityKeyPair, IdentityKeyStore, - KyberPreKeyId, KyberPreKeyRecord, KyberPreKeyStore, PreKeyId, PreKeyRecord, PreKeyStore, - ProtocolAddress, ProtocolStore, SenderKeyRecord, SenderKeyStore, SessionRecord, - SessionStore, SignalProtocolError, SignedPreKeyId, SignedPreKeyRecord, SignedPreKeyStore, +use log::debug; +use presage::{ + libsignal_service::{ + prelude::{ProfileKey, Uuid}, + protocol::{IdentityKey, IdentityKeyPair, PrivateKey}, + utils::{ + serde_identity_key, serde_optional_identity_key, serde_optional_private_key, + serde_private_key, + }, }, - push_service::DEFAULT_DEVICE_ID, - session_store::SessionStoreExt, - Profile, ServiceAddress, + manager::RegistrationData, + store::{ContentsStore, StateStore, Store}, }; -use presage::store::{ContentExt, ContentsStore, StateStore, StickerPack, Store, Thread}; -use presage::{manager::RegistrationData, proto::verified, AvatarBytes}; -use prost::Message; +use protocol::{AciSledStore, PniSledStore, SledProtocolStore, SledTrees}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use sha2::{Digest, Sha256}; -use sled::{Batch, IVec}; - -use crate::protobuf::ContentProto; +mod content; mod error; mod protobuf; +mod protocol; pub use error::SledStoreError; +use sled::IVec; -const SLED_TREE_CONTACTS: &str = "contacts"; -const SLED_TREE_GROUPS: &str = "groups"; -const SLED_TREE_GROUP_AVATARS: &str = "group_avatars"; -const SLED_TREE_IDENTITIES: &str = "identities"; -const SLED_TREE_PRE_KEYS: &str = "pre_keys"; -const SLED_TREE_SENDER_KEYS: &str = "sender_keys"; -const SLED_TREE_SESSIONS: &str = "sessions"; -const SLED_TREE_SIGNED_PRE_KEYS: &str = "signed_pre_keys"; -const SLED_TREE_KYBER_PRE_KEYS: &str = "kyber_pre_keys"; const SLED_TREE_STATE: &str = "state"; -const SLED_TREE_THREADS_PREFIX: &str = "threads"; -const SLED_TREE_PROFILES: &str = "profiles"; -const SLED_TREE_PROFILE_AVATARS: &str = "profile_avatars"; -const SLED_TREE_PROFILE_KEYS: &str = "profile_keys"; -const SLED_TREE_STICKER_PACKS: &str = "sticker_packs"; const SLED_KEY_NEXT_SIGNED_PRE_KEY_ID: &str = "next_signed_pre_key_id"; const SLED_KEY_NEXT_PQ_PRE_KEY_ID: &str = "next_pq_pre_key_id"; @@ -98,11 +75,13 @@ pub enum SchemaVersion { V3 = 3, // Introduction of avatars, requires dropping all profiles from the cache V4 = 4, + /// ACI and PNI identity key pairs are moved into dedicated storage keys from registration data + V5 = 5, } impl SchemaVersion { fn current() -> SchemaVersion { - Self::V4 + Self::V5 } /// return an iterator on all the necessary migration steps from another version @@ -116,6 +95,7 @@ impl SchemaVersion { 2 => SchemaVersion::V2, 3 => SchemaVersion::V3, 4 => SchemaVersion::V4, + 5 => SchemaVersion::V5, _ => unreachable!("oops, this not supposed to happen!"), }) } @@ -227,16 +207,16 @@ impl SledStore { } #[cfg(feature = "encryption")] - fn decrypt_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T, SledStoreError> { + fn decrypt_value<T: DeserializeOwned>(&self, value: IVec) -> Result<T, SledStoreError> { if let Some(cipher) = self.cipher.as_ref() { - Ok(cipher.decrypt_value(value)?) + Ok(cipher.decrypt_value(&value)?) } else { - Ok(serde_json::from_slice(value)?) + Ok(serde_json::from_slice(&value)?) } } #[cfg(not(feature = "encryption"))] - fn decrypt_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T, SledStoreError> { + fn decrypt_value<T: DeserializeOwned>(&self, value: IVec) -> Result<T, SledStoreError> { Ok(serde_json::from_slice(value)?) } @@ -262,11 +242,22 @@ impl SledStore { self.read() .open_tree(tree)? .get(key)? - .map(|p| self.decrypt_value(&p)) + .map(|p| self.decrypt_value(p)) .transpose() .map_err(SledStoreError::from) } + pub fn iter<'a, V: DeserializeOwned + 'a>( + &'a self, + tree: &str, + ) -> Result<impl Iterator<Item = Result<V, SledStoreError>> + 'a, SledStoreError> { + Ok(self + .read() + .open_tree(tree)? + .iter() + .flat_map(|res| res.map(|(_, value)| self.decrypt_value::<V>(value)))) + } + fn insert<K, V>(&self, tree: &str, key: K, value: V) -> Result<bool, SledStoreError> where K: AsRef<[u8]>, @@ -289,22 +280,6 @@ impl SledStore { Ok(removed.is_some()) } - /// build a hashed messages thread key - fn messages_thread_tree_name(&self, t: &Thread) -> String { - let key = match t { - Thread::Contact(uuid) => { - format!("{SLED_TREE_THREADS_PREFIX}:contact:{uuid}") - } - Thread::Group(group_id) => format!( - "{SLED_TREE_THREADS_PREFIX}:group:{}", - BASE64_STANDARD.encode(group_id) - ), - }; - let mut hasher = Sha256::new(); - hasher.update(key.as_bytes()); - format!("{SLED_TREE_THREADS_PREFIX}:{:x}", hasher.finalize()) - } - fn profile_key_for_uuid(&self, uuid: Uuid, key: ProfileKey) -> String { let key = uuid.into_bytes().into_iter().chain(key.get_bytes()); @@ -312,6 +287,29 @@ impl SledStore { hasher.update(key.collect::<Vec<_>>()); format!("{:x}", hasher.finalize()) } + + fn get_identity_key_pair<T: SledTrees>( + &self, + ) -> Result<Option<IdentityKeyPair>, SledStoreError> { + let key_base64: Option<String> = self.get(SLED_TREE_STATE, T::identity_keypair())?; + let Some(key_base64) = key_base64 else { + return Ok(None); + }; + let key_bytes = BASE64_STANDARD.decode(key_base64)?; + IdentityKeyPair::try_from(&*key_bytes) + .map(Some) + .map_err(|e| SledStoreError::ProtobufDecode(prost::DecodeError::new(e.to_string()))) + } + + fn set_identity_key_pair<T: SledTrees>( + &self, + key_pair: IdentityKeyPair, + ) -> Result<(), SledStoreError> { + let key_bytes = key_pair.serialize(); + let key_base64 = BASE64_STANDARD.encode(key_bytes); + self.insert(SLED_TREE_STATE, T::identity_keypair(), key_base64)?; + Ok(()) + } } fn migrate( @@ -349,15 +347,43 @@ fn migrate( } SchemaVersion::V3 => { debug!("migrating from schema v2 to v3: dropping encrypted group cache"); - let db = store.write(); - db.drop_tree(SLED_TREE_GROUPS)?; - db.flush()?; + store.clear_groups()?; } SchemaVersion::V4 => { debug!("migrating from schema v3 to v4: dropping profile cache"); - let db = store.write(); - db.drop_tree(SLED_TREE_PROFILES)?; - db.flush()?; + store.clear_profiles()?; + } + SchemaVersion::V5 => { + debug!("migrating from schema v4 to v5: moving identity key pairs"); + + #[derive(Deserialize)] + struct RegistrationDataV4Keys { + #[serde(with = "serde_private_key", rename = "private_key")] + pub(crate) aci_private_key: PrivateKey, + #[serde(with = "serde_identity_key", rename = "public_key")] + pub(crate) aci_public_key: IdentityKey, + #[serde(with = "serde_optional_private_key", default)] + pub(crate) pni_private_key: Option<PrivateKey>, + #[serde(with = "serde_optional_identity_key", default)] + pub(crate) pni_public_key: Option<IdentityKey>, + } + + let registration_data: Option<RegistrationDataV4Keys> = + store.get(SLED_TREE_STATE, SLED_KEY_REGISTRATION)?; + if let Some(data) = registration_data { + store.set_aci_identity_key_pair(IdentityKeyPair::new( + data.aci_public_key, + data.aci_private_key, + ))?; + if let Some((public_key, private_key)) = + data.pni_public_key.zip(data.pni_private_key) + { + store.set_pni_identity_key_pair(IdentityKeyPair::new( + public_key, + private_key, + ))?; + } + } } _ => return Err(SledStoreError::MigrationConflict), } @@ -393,8 +419,6 @@ fn migrate( Ok(()) } -impl ProtocolStore for SledStore {} - impl StateStore for SledStore { type StateStoreError = SledStoreError; @@ -402,864 +426,76 @@ impl StateStore for SledStore { self.get(SLED_TREE_STATE, SLED_KEY_REGISTRATION) } - fn save_registration_data(&mut self, state: &RegistrationData) -> Result<(), SledStoreError> { - self.insert(SLED_TREE_STATE, SLED_KEY_REGISTRATION, state)?; - Ok(()) - } - - fn is_registered(&self) -> bool { - self.load_registration_data().unwrap_or_default().is_some() - } - - fn clear_registration(&mut self) -> Result<(), SledStoreError> { - let db = self.write(); - db.remove(SLED_KEY_REGISTRATION)?; - - db.drop_tree(SLED_TREE_IDENTITIES)?; - db.drop_tree(SLED_TREE_PRE_KEYS)?; - db.drop_tree(SLED_TREE_SENDER_KEYS)?; - db.drop_tree(SLED_TREE_SESSIONS)?; - db.drop_tree(SLED_TREE_SIGNED_PRE_KEYS)?; - db.drop_tree(SLED_TREE_KYBER_PRE_KEYS)?; - db.drop_tree(SLED_TREE_STATE)?; - db.drop_tree(SLED_TREE_PROFILES)?; - db.drop_tree(SLED_TREE_PROFILE_KEYS)?; - - db.flush()?; - - Ok(()) - } -} - -impl ContentsStore for SledStore { - type ContentsStoreError = SledStoreError; - - type ContactsIter = SledContactsIter; - type GroupsIter = SledGroupsIter; - type MessagesIter = SledMessagesIter; - type StickerPacksIter = SledStickerPacksIter; - - fn clear_contacts(&mut self) -> Result<(), SledStoreError> { - self.write().drop_tree(SLED_TREE_CONTACTS)?; - Ok(()) - } - - fn save_contact(&mut self, contact: &Contact) -> Result<(), SledStoreError> { - self.insert(SLED_TREE_CONTACTS, contact.uuid, contact)?; - debug!("saved contact"); - Ok(()) - } - - fn contacts(&self) -> Result<Self::ContactsIter, SledStoreError> { - Ok(SledContactsIter { - iter: self.read().open_tree(SLED_TREE_CONTACTS)?.iter(), - #[cfg(feature = "encryption")] - cipher: self.cipher.clone(), - }) - } - - fn contact_by_id(&self, id: &Uuid) -> Result<Option<Contact>, SledStoreError> { - self.get(SLED_TREE_CONTACTS, id) - } - - /// Groups - - fn clear_groups(&mut self) -> Result<(), SledStoreError> { - let db = self.write(); - db.drop_tree(SLED_TREE_GROUPS)?; - db.flush()?; - Ok(()) - } - - fn groups(&self) -> Result<Self::GroupsIter, SledStoreError> { - Ok(SledGroupsIter { - iter: self.read().open_tree(SLED_TREE_GROUPS)?.iter(), - #[cfg(feature = "encryption")] - cipher: self.cipher.clone(), - }) - } - - fn group( + fn set_aci_identity_key_pair( &self, - master_key_bytes: GroupMasterKeyBytes, - ) -> Result<Option<Group>, SledStoreError> { - self.get(SLED_TREE_GROUPS, master_key_bytes) + key_pair: IdentityKeyPair, + ) -> Result<(), Self::StateStoreError> { + self.set_identity_key_pair::<AciSledStore>(key_pair) } - fn save_group( + fn set_pni_identity_key_pair( &self, - master_key: GroupMasterKeyBytes, - group: &Group, - ) -> Result<(), SledStoreError> { - self.insert(SLED_TREE_GROUPS, master_key, group)?; - Ok(()) - } - - fn group_avatar( - &self, - master_key_bytes: GroupMasterKeyBytes, - ) -> Result<Option<AvatarBytes>, SledStoreError> { - self.get(SLED_TREE_GROUP_AVATARS, master_key_bytes) - } - - fn save_group_avatar( - &self, - master_key: GroupMasterKeyBytes, - avatar: &AvatarBytes, - ) -> Result<(), SledStoreError> { - self.insert(SLED_TREE_GROUP_AVATARS, master_key, avatar)?; - Ok(()) - } - - /// Messages - - fn clear_messages(&mut self) -> Result<(), SledStoreError> { - let db = self.write(); - for name in db.tree_names() { - if name - .as_ref() - .starts_with(SLED_TREE_THREADS_PREFIX.as_bytes()) - { - db.drop_tree(&name)?; - } - } - db.flush()?; - Ok(()) - } - - fn clear_thread(&mut self, thread: &Thread) -> Result<(), SledStoreError> { - log::trace!("clearing thread {thread}"); - - let db = self.write(); - db.drop_tree(self.messages_thread_tree_name(thread))?; - db.flush()?; - - Ok(()) + key_pair: IdentityKeyPair, + ) -> Result<(), Self::StateStoreError> { + self.set_identity_key_pair::<PniSledStore>(key_pair) } - fn save_message(&self, thread: &Thread, message: Content) -> Result<(), SledStoreError> { - let ts = message.timestamp(); - log::trace!("storing a message with thread: {thread}, timestamp: {ts}",); - - let tree = self.messages_thread_tree_name(thread); - let key = ts.to_be_bytes(); - - let proto: ContentProto = message.into(); - let value = proto.encode_to_vec(); - - self.insert(&tree, key, value)?; - + fn save_registration_data(&mut self, state: &RegistrationData) -> Result<(), SledStoreError> { + self.insert(SLED_TREE_STATE, SLED_KEY_REGISTRATION, state)?; Ok(()) } - fn delete_message(&mut self, thread: &Thread, timestamp: u64) -> Result<bool, SledStoreError> { - let tree = self.messages_thread_tree_name(thread); - self.remove(&tree, timestamp.to_be_bytes()) + fn is_registered(&self) -> bool { + self.load_registration_data().unwrap_or_default().is_some() } - fn message( - &self, - thread: &Thread, - timestamp: u64, - ) -> Result<Option<libsignal_service::prelude::Content>, SledStoreError> { - // Big-Endian needed, otherwise wrong ordering in sled. - let val: Option<Vec<u8>> = self.get( - &self.messages_thread_tree_name(thread), - timestamp.to_be_bytes(), - )?; - match val { - Some(ref v) => { - let proto = ContentProto::decode(v.as_slice())?; - let content = proto.try_into()?; - Ok(Some(content)) - } - None => Ok(None), + fn clear_registration(&mut self) -> Result<(), SledStoreError> { + // drop registration data (includes identity keys) + { + let db = self.write(); + db.remove(SLED_KEY_REGISTRATION)?; + db.drop_tree(SLED_TREE_STATE)?; + db.flush()?; } - } - - fn messages( - &self, - thread: &Thread, - range: impl RangeBounds<u64>, - ) -> Result<Self::MessagesIter, SledStoreError> { - let tree_thread = self - .read() - .open_tree(self.messages_thread_tree_name(thread))?; - debug!("{} messages in this tree", tree_thread.len()); - - let iter = match (range.start_bound(), range.end_bound()) { - (Bound::Included(start), Bound::Unbounded) => tree_thread.range(start.to_be_bytes()..), - (Bound::Included(start), Bound::Excluded(end)) => { - tree_thread.range(start.to_be_bytes()..end.to_be_bytes()) - } - (Bound::Included(start), Bound::Included(end)) => { - tree_thread.range(start.to_be_bytes()..=end.to_be_bytes()) - } - (Bound::Unbounded, Bound::Included(end)) => tree_thread.range(..=end.to_be_bytes()), - (Bound::Unbounded, Bound::Excluded(end)) => tree_thread.range(..end.to_be_bytes()), - (Bound::Unbounded, Bound::Unbounded) => tree_thread.range::<[u8; 8], RangeFull>(..), - (Bound::Excluded(_), _) => { - unreachable!("range that excludes the initial value") - } - }; - Ok(SledMessagesIter { - #[cfg(feature = "encryption")] - cipher: self.cipher.clone(), - iter, - }) - } + // drop all saved profile (+avatards) and profile keys + self.clear_profiles()?; - fn upsert_profile_key(&mut self, uuid: &Uuid, key: ProfileKey) -> Result<bool, SledStoreError> { - self.insert(SLED_TREE_PROFILE_KEYS, uuid.as_bytes(), key) - } - - fn profile_key(&self, uuid: &Uuid) -> Result<Option<ProfileKey>, SledStoreError> { - self.get(SLED_TREE_PROFILE_KEYS, uuid.as_bytes()) - } + // drop all keys + self.aci_protocol_store().clear()?; + self.pni_protocol_store().clear()?; - fn save_profile( - &mut self, - uuid: Uuid, - key: ProfileKey, - profile: Profile, - ) -> Result<(), SledStoreError> { - let key = self.profile_key_for_uuid(uuid, key); - self.insert(SLED_TREE_PROFILES, key, profile)?; - Ok(()) - } - - fn profile(&self, uuid: Uuid, key: ProfileKey) -> Result<Option<Profile>, SledStoreError> { - let key = self.profile_key_for_uuid(uuid, key); - self.get(SLED_TREE_PROFILES, key) - } - - fn save_profile_avatar( - &mut self, - uuid: Uuid, - key: ProfileKey, - avatar: &AvatarBytes, - ) -> Result<(), SledStoreError> { - let key = self.profile_key_for_uuid(uuid, key); - self.insert(SLED_TREE_PROFILE_AVATARS, key, avatar)?; - Ok(()) - } - - fn profile_avatar( - &self, - uuid: Uuid, - key: ProfileKey, - ) -> Result<Option<AvatarBytes>, SledStoreError> { - let key = self.profile_key_for_uuid(uuid, key); - self.get(SLED_TREE_PROFILE_AVATARS, key) - } - - fn add_sticker_pack(&mut self, pack: &StickerPack) -> Result<(), SledStoreError> { - self.insert(SLED_TREE_STICKER_PACKS, pack.id.clone(), pack)?; - Ok(()) - } - - fn remove_sticker_pack(&mut self, id: &[u8]) -> Result<bool, SledStoreError> { - self.remove(SLED_TREE_STICKER_PACKS, id) - } - - fn sticker_pack(&self, id: &[u8]) -> Result<Option<StickerPack>, SledStoreError> { - self.get(SLED_TREE_STICKER_PACKS, id) - } - - fn sticker_packs(&self) -> Result<Self::StickerPacksIter, SledStoreError> { - Ok(SledStickerPacksIter { - cipher: self.cipher.clone(), - iter: self.read().open_tree(SLED_TREE_STICKER_PACKS)?.iter(), - }) - } -} - -#[async_trait(?Send)] -impl PreKeysStore for SledStore { - async fn next_pre_key_id(&self) -> Result<u32, SignalProtocolError> { - Ok(self - .get(SLED_TREE_STATE, SLED_KEY_PRE_KEYS_OFFSET_ID) - .map_err(|_| SignalProtocolError::InvalidPreKeyId)? - .unwrap_or(0)) - } - - async fn set_next_pre_key_id(&mut self, id: u32) -> Result<(), SignalProtocolError> { - self.insert(SLED_TREE_STATE, SLED_KEY_PRE_KEYS_OFFSET_ID, id) - .map_err(|_| SignalProtocolError::InvalidPreKeyId)?; - Ok(()) - } - - async fn next_signed_pre_key_id(&self) -> Result<u32, SignalProtocolError> { - Ok(self - .get(SLED_TREE_STATE, SLED_KEY_NEXT_SIGNED_PRE_KEY_ID) - .map_err(|_| SignalProtocolError::InvalidSignedPreKeyId)? - .unwrap_or(0)) - } - - async fn set_next_signed_pre_key_id(&mut self, id: u32) -> Result<(), SignalProtocolError> { - self.insert(SLED_TREE_STATE, SLED_KEY_NEXT_SIGNED_PRE_KEY_ID, id) - .map_err(|_| SignalProtocolError::InvalidSignedPreKeyId)?; - Ok(()) - } - - async fn next_pq_pre_key_id(&self) -> Result<u32, SignalProtocolError> { - Ok(self - .get(SLED_TREE_STATE, SLED_KEY_NEXT_PQ_PRE_KEY_ID) - .map_err(|_| SignalProtocolError::InvalidKyberPreKeyId)? - .unwrap_or(0)) - } - - async fn set_next_pq_pre_key_id(&mut self, id: u32) -> Result<(), SignalProtocolError> { - self.insert(SLED_TREE_STATE, SLED_KEY_NEXT_PQ_PRE_KEY_ID, id) - .map_err(|_| SignalProtocolError::InvalidKyberPreKeyId)?; Ok(()) } } impl Store for SledStore { type Error = SledStoreError; + type AciStore = SledProtocolStore<AciSledStore>; + type PniStore = SledProtocolStore<PniSledStore>; fn clear(&mut self) -> Result<(), SledStoreError> { self.clear_registration()?; - - let db = self.write(); - db.drop_tree(SLED_TREE_CONTACTS)?; - db.drop_tree(SLED_TREE_GROUPS)?; - db.drop_tree(SLED_TREE_PROFILES)?; - db.drop_tree(SLED_TREE_PROFILE_AVATARS)?; - - for tree in db - .tree_names() - .into_iter() - .filter(|n| n.starts_with(SLED_TREE_THREADS_PREFIX.as_bytes())) - { - db.drop_tree(tree)?; - } - - db.flush()?; + self.clear_contents()?; Ok(()) } -} - -pub struct SledContactsIter { - #[cfg(feature = "encryption")] - cipher: Option<Arc<presage_store_cipher::StoreCipher>>, - iter: sled::Iter, -} - -impl SledContactsIter { - #[cfg(feature = "encryption")] - fn decrypt_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T, SledStoreError> { - if let Some(cipher) = self.cipher.as_ref() { - Ok(cipher.decrypt_value(value)?) - } else { - Ok(serde_json::from_slice(value)?) - } - } - #[cfg(not(feature = "encryption"))] - fn decrypt_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T, SledStoreError> { - Ok(serde_json::from_slice(value)?) + fn aci_protocol_store(&self) -> Self::AciStore { + SledProtocolStore::aci_protocol_store(self.clone()) } -} - -impl Iterator for SledContactsIter { - type Item = Result<Contact, SledStoreError>; - fn next(&mut self) -> Option<Self::Item> { - self.iter - .next()? - .map_err(SledStoreError::from) - .and_then(|(_key, value)| self.decrypt_value(&value)) - .into() - } -} - -pub struct SledGroupsIter { - #[cfg(feature = "encryption")] - cipher: Option<Arc<presage_store_cipher::StoreCipher>>, - iter: sled::Iter, -} - -impl SledGroupsIter { - #[cfg(feature = "encryption")] - fn decrypt_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T, SledStoreError> { - if let Some(cipher) = self.cipher.as_ref() { - Ok(cipher.decrypt_value(value)?) - } else { - Ok(serde_json::from_slice(value)?) - } - } - - #[cfg(not(feature = "encryption"))] - fn decrypt_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T, SledStoreError> { - Ok(serde_json::from_slice(value)?) - } -} - -impl Iterator for SledGroupsIter { - type Item = Result<(GroupMasterKeyBytes, Group), SledStoreError>; - - fn next(&mut self) -> Option<Self::Item> { - Some(self.iter.next()?.map_err(SledStoreError::from).and_then( - |(group_master_key_bytes, value)| { - let group = self.decrypt_value(&value)?; - Ok(( - group_master_key_bytes - .as_ref() - .try_into() - .map_err(|_| SledStoreError::GroupDecryption)?, - group, - )) - }, - )) - } -} - -pub struct SledStickerPacksIter { - #[cfg(feature = "encryption")] - cipher: Option<Arc<presage_store_cipher::StoreCipher>>, - iter: sled::Iter, -} - -impl Iterator for SledStickerPacksIter { - type Item = Result<StickerPack, SledStoreError>; - - #[cfg(feature = "encryption")] - fn next(&mut self) -> Option<Self::Item> { - self.iter - .next()? - .map_err(SledStoreError::from) - .and_then(|(_key, value)| { - if let Some(cipher) = self.cipher.as_ref() { - cipher.decrypt_value(&value).map_err(SledStoreError::from) - } else { - serde_json::from_slice(&value).map_err(SledStoreError::from) - } - }) - .into() - } - - #[cfg(not(feature = "encryption"))] - fn next(&mut self) -> Option<Self::Item> { - self.iter - .next()? - .map_err(SledStoreError::from) - .and_then(|(_key, value)| serde_json::from_slice(&value).map_err(SledStoreError::from)) - .into() - } -} - -#[async_trait(?Send)] -impl PreKeyStore for SledStore { - async fn get_pre_key(&self, prekey_id: PreKeyId) -> Result<PreKeyRecord, SignalProtocolError> { - let buf: Vec<u8> = self - .get(SLED_TREE_PRE_KEYS, prekey_id.to_string()) - .ok() - .flatten() - .ok_or(SignalProtocolError::InvalidPreKeyId)?; - - PreKeyRecord::deserialize(&buf) - } - - async fn save_pre_key( - &mut self, - prekey_id: PreKeyId, - record: &PreKeyRecord, - ) -> Result<(), SignalProtocolError> { - self.insert( - SLED_TREE_PRE_KEYS, - prekey_id.to_string(), - record.serialize()?, - ) - .expect("failed to store pre-key"); - Ok(()) - } - - async fn remove_pre_key(&mut self, prekey_id: PreKeyId) -> Result<(), SignalProtocolError> { - self.remove(SLED_TREE_PRE_KEYS, prekey_id.to_string()) - .expect("failed to remove pre-key"); - Ok(()) - } -} - -#[async_trait(?Send)] -impl SignedPreKeyStore for SledStore { - async fn get_signed_pre_key( - &self, - signed_prekey_id: SignedPreKeyId, - ) -> Result<SignedPreKeyRecord, SignalProtocolError> { - let buf: Vec<u8> = self - .get(SLED_TREE_SIGNED_PRE_KEYS, signed_prekey_id.to_string()) - .ok() - .flatten() - .ok_or(SignalProtocolError::InvalidSignedPreKeyId)?; - SignedPreKeyRecord::deserialize(&buf) - } - - async fn save_signed_pre_key( - &mut self, - signed_prekey_id: SignedPreKeyId, - record: &SignedPreKeyRecord, - ) -> Result<(), SignalProtocolError> { - self.insert( - SLED_TREE_SIGNED_PRE_KEYS, - signed_prekey_id.to_string(), - record.serialize()?, - ) - .map_err(|e| { - log::error!("sled error: {}", e); - SignalProtocolError::InvalidState("save_signed_pre_key", "sled error".into()) - })?; - Ok(()) - } -} - -#[async_trait(?Send)] -impl KyberPreKeyStore for SledStore { - async fn get_kyber_pre_key( - &self, - kyber_prekey_id: KyberPreKeyId, - ) -> Result<KyberPreKeyRecord, SignalProtocolError> { - let buf: Vec<u8> = self - .get(SLED_TREE_KYBER_PRE_KEYS, kyber_prekey_id.to_string()) - .ok() - .flatten() - .ok_or(SignalProtocolError::InvalidKyberPreKeyId)?; - KyberPreKeyRecord::deserialize(&buf) - } - - async fn save_kyber_pre_key( - &mut self, - kyber_prekey_id: KyberPreKeyId, - record: &KyberPreKeyRecord, - ) -> Result<(), SignalProtocolError> { - self.insert( - SLED_TREE_KYBER_PRE_KEYS, - kyber_prekey_id.to_string(), - record.serialize()?, - ) - .map_err(|e| { - log::error!("sled error: {}", e); - SignalProtocolError::InvalidState("save_kyber_pre_key", "sled error".into()) - })?; - Ok(()) - } - - async fn mark_kyber_pre_key_used( - &mut self, - kyber_prekey_id: KyberPreKeyId, - ) -> Result<(), SignalProtocolError> { - let removed = self - .remove(SLED_TREE_KYBER_PRE_KEYS, kyber_prekey_id.to_string()) - .map_err(|e| { - log::error!("sled error: {}", e); - SignalProtocolError::InvalidState("mark_kyber_pre_key_used", "sled error".into()) - })?; - if removed { - log::trace!("removed kyber pre-key {kyber_prekey_id}"); - } - Ok(()) - } -} - -#[async_trait(?Send)] -impl SessionStore for SledStore { - async fn load_session( - &self, - address: &ProtocolAddress, - ) -> Result<Option<SessionRecord>, SignalProtocolError> { - let session = self - .get(SLED_TREE_SESSIONS, address.to_string()) - .map_err(SledStoreError::into_signal_error)?; - trace!("loading session {} / exists={}", address, session.is_some()); - session - .map(|b: Vec<u8>| SessionRecord::deserialize(&b)) - .transpose() - } - - async fn store_session( - &mut self, - address: &ProtocolAddress, - record: &SessionRecord, - ) -> Result<(), SignalProtocolError> { - trace!("storing session {}", address); - self.insert(SLED_TREE_SESSIONS, address.to_string(), record.serialize()?) - .map_err(SledStoreError::into_signal_error)?; - Ok(()) - } -} - -#[async_trait] -impl SessionStoreExt for SledStore { - async fn get_sub_device_sessions( - &self, - address: &ServiceAddress, - ) -> Result<Vec<u32>, SignalProtocolError> { - let session_prefix = format!("{}.", address.uuid); - trace!("get_sub_device_sessions {}", session_prefix); - let session_ids: Vec<u32> = self - .read() - .open_tree(SLED_TREE_SESSIONS) - .map_err(Into::into) - .map_err(SledStoreError::into_signal_error)? - .scan_prefix(&session_prefix) - .filter_map(|r| { - let (key, _) = r.ok()?; - let key_str = String::from_utf8_lossy(&key); - let device_id = key_str.strip_prefix(&session_prefix)?; - device_id.parse().ok() - }) - .filter(|d| *d != DEFAULT_DEVICE_ID) - .collect(); - Ok(session_ids) - } - - async fn delete_session(&self, address: &ProtocolAddress) -> Result<(), SignalProtocolError> { - trace!("deleting session {}", address); - self.write() - .open_tree(SLED_TREE_SESSIONS) - .map_err(Into::into) - .map_err(SledStoreError::into_signal_error)? - .remove(address.to_string()) - .map_err(|_e| SignalProtocolError::SessionNotFound(address.clone()))?; - Ok(()) - } - - async fn delete_all_sessions( - &self, - address: &ServiceAddress, - ) -> Result<usize, SignalProtocolError> { - let db = self.write(); - let sessions_tree = db - .open_tree(SLED_TREE_SESSIONS) - .map_err(Into::into) - .map_err(SledStoreError::into_signal_error)?; - - let mut batch = Batch::default(); - sessions_tree - .scan_prefix(address.uuid.to_string()) - .filter_map(|r| { - let (key, _) = r.ok()?; - Some(key) - }) - .for_each(|k| batch.remove(k)); - - db.apply_batch(batch) - .map_err(SledStoreError::Db) - .map_err(SledStoreError::into_signal_error)?; - - let len = sessions_tree.len(); - sessions_tree.clear().map_err(|_e| { - SignalProtocolError::InvalidSessionStructure("failed to delete all sessions") - })?; - Ok(len) - } -} - -#[async_trait(?Send)] -impl IdentityKeyStore for SledStore { - async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair, SignalProtocolError> { - trace!("getting identity_key_pair"); - let data = self - .load_registration_data() - .map_err(SledStoreError::into_signal_error)? - .ok_or(SignalProtocolError::InvalidState( - "failed to load identity key pair", - "no registration data".into(), - ))?; - - Ok(IdentityKeyPair::new( - IdentityKey::new(data.aci_public_key()), - data.aci_private_key(), - )) - } - - async fn get_local_registration_id(&self) -> Result<u32, SignalProtocolError> { - let data = self - .load_registration_data() - .map_err(SledStoreError::into_signal_error)? - .ok_or(SignalProtocolError::InvalidState( - "failed to load registration ID", - "no registration data".into(), - ))?; - Ok(data.registration_id) - } - - async fn save_identity( - &mut self, - address: &ProtocolAddress, - identity_key: &IdentityKey, - ) -> Result<bool, SignalProtocolError> { - trace!("saving identity"); - let existed_before = self - .insert( - SLED_TREE_IDENTITIES, - address.to_string(), - identity_key.serialize(), - ) - .map_err(|e| { - error!("error saving identity for {:?}: {}", address, e); - e.into_signal_error() - })?; - - self.save_trusted_identity_message( - address, - *identity_key, - if existed_before { - verified::State::Unverified - } else { - verified::State::Default - }, - ); - - Ok(true) - } - - async fn is_trusted_identity( - &self, - address: &ProtocolAddress, - right_identity_key: &IdentityKey, - _direction: Direction, - ) -> Result<bool, SignalProtocolError> { - match self - .get(SLED_TREE_IDENTITIES, address.to_string()) - .map_err(SledStoreError::into_signal_error)? - .map(|b: Vec<u8>| IdentityKey::decode(&b)) - .transpose()? - { - None => { - // when we encounter a new identity, we trust it by default - warn!("trusting new identity {:?}", address); - Ok(true) - } - // when we encounter some identity we know, we need to decide whether we trust it or not - Some(left_identity_key) => { - if left_identity_key == *right_identity_key { - Ok(true) - } else { - match self.trust_new_identities { - OnNewIdentity::Trust => Ok(true), - OnNewIdentity::Reject => Ok(false), - } - } - } - } - } - - async fn get_identity( - &self, - address: &ProtocolAddress, - ) -> Result<Option<IdentityKey>, SignalProtocolError> { - self.get(SLED_TREE_IDENTITIES, address.to_string()) - .map_err(SledStoreError::into_signal_error)? - .map(|b: Vec<u8>| IdentityKey::decode(&b)) - .transpose() - } -} - -#[async_trait(?Send)] -impl SenderKeyStore for SledStore { - async fn store_sender_key( - &mut self, - sender: &ProtocolAddress, - distribution_id: Uuid, - record: &SenderKeyRecord, - ) -> Result<(), SignalProtocolError> { - let key = format!( - "{}.{}/{}", - sender.name(), - sender.device_id(), - distribution_id - ); - self.insert(SLED_TREE_SENDER_KEYS, key, record.serialize()?) - .map_err(SledStoreError::into_signal_error)?; - Ok(()) - } - - async fn load_sender_key( - &mut self, - sender: &ProtocolAddress, - distribution_id: Uuid, - ) -> Result<Option<SenderKeyRecord>, SignalProtocolError> { - let key = format!( - "{}.{}/{}", - sender.name(), - sender.device_id(), - distribution_id - ); - self.get(SLED_TREE_SENDER_KEYS, key) - .map_err(SledStoreError::into_signal_error)? - .map(|b: Vec<u8>| SenderKeyRecord::deserialize(&b)) - .transpose() - } -} - -pub struct SledMessagesIter { - #[cfg(feature = "encryption")] - cipher: Option<Arc<presage_store_cipher::StoreCipher>>, - iter: sled::Iter, -} - -impl SledMessagesIter { - #[cfg(feature = "encryption")] - fn decrypt_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T, SledStoreError> { - if let Some(cipher) = self.cipher.as_ref() { - Ok(cipher.decrypt_value(value)?) - } else { - Ok(serde_json::from_slice(value)?) - } - } - - #[cfg(not(feature = "encryption"))] - fn decrypt_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T, SledStoreError> { - Ok(serde_json::from_slice(value)?) - } -} - -impl SledMessagesIter { - fn decode( - &self, - elem: Result<(IVec, IVec), sled::Error>, - ) -> Option<Result<Content, SledStoreError>> { - elem.map_err(SledStoreError::from) - .and_then(|(_, value)| self.decrypt_value(&value).map_err(SledStoreError::from)) - .and_then(|data: Vec<u8>| ContentProto::decode(&data[..]).map_err(SledStoreError::from)) - .map_or_else(|e| Some(Err(e)), |p| Some(p.try_into())) - } -} - -impl Iterator for SledMessagesIter { - type Item = Result<Content, SledStoreError>; - - fn next(&mut self) -> Option<Self::Item> { - let elem = self.iter.next()?; - self.decode(elem) - } -} - -impl DoubleEndedIterator for SledMessagesIter { - fn next_back(&mut self) -> Option<Self::Item> { - let elem = self.iter.next_back()?; - self.decode(elem) + fn pni_protocol_store(&self) -> Self::PniStore { + SledProtocolStore::pni_protocol_store(self.clone()) } } #[cfg(test)] mod tests { - use core::fmt; - - use base64::prelude::*; use presage::libsignal_service::{ content::{ContentBody, Metadata}, prelude::Uuid, proto::DataMessage, - protocol::{ - self, Direction, GenericSignedPreKey, IdentityKeyStore, PreKeyRecord, PreKeyStore, - SessionRecord, SessionStore, SignedPreKeyRecord, SignedPreKeyStore, - }, ServiceAddress, }; use presage::store::ContentsStore; @@ -1267,43 +503,12 @@ mod tests { use super::SledStore; - #[derive(Debug, Clone)] - struct ProtocolAddress(protocol::ProtocolAddress); - - #[derive(Clone)] - struct KeyPair(protocol::KeyPair); - #[derive(Debug, Clone)] struct Thread(presage::store::Thread); #[derive(Debug, Clone)] struct Content(presage::libsignal_service::content::Content); - impl fmt::Debug for KeyPair { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "{}", - BASE64_STANDARD.encode(self.0.public_key.serialize()) - ) - } - } - - impl Arbitrary for ProtocolAddress { - fn arbitrary(g: &mut Gen) -> Self { - let name: String = Arbitrary::arbitrary(g); - let device_id: u32 = Arbitrary::arbitrary(g); - ProtocolAddress(protocol::ProtocolAddress::new(name, device_id.into())) - } - } - - impl Arbitrary for KeyPair { - fn arbitrary(_g: &mut Gen) -> Self { - // Gen is not rand::CryptoRng here, see https://github.com/BurntSushi/quickcheck/issues/241 - KeyPair(protocol::KeyPair::generate(&mut rand::thread_rng())) - } - } - impl Arbitrary for Content { fn arbitrary(g: &mut Gen) -> Self { let timestamp: u64 = Arbitrary::arbitrary(g); @@ -1342,71 +547,6 @@ mod tests { } } - #[quickcheck_async::tokio] - async fn test_save_get_trust_identity(addr: ProtocolAddress, key_pair: KeyPair) -> bool { - let mut db = SledStore::temporary().unwrap(); - let identity_key = protocol::IdentityKey::new(key_pair.0.public_key); - db.save_identity(&addr.0, &identity_key).await.unwrap(); - let id = db.get_identity(&addr.0).await.unwrap().unwrap(); - if id != identity_key { - return false; - } - db.is_trusted_identity(&addr.0, &id, Direction::Receiving) - .await - .unwrap() - } - - #[quickcheck_async::tokio] - async fn test_store_load_session(addr: ProtocolAddress) -> bool { - let session = SessionRecord::new_fresh(); - - let mut db = SledStore::temporary().unwrap(); - db.store_session(&addr.0, &session).await.unwrap(); - if db.load_session(&addr.0).await.unwrap().is_none() { - return false; - } - let loaded_session = db.load_session(&addr.0).await.unwrap().unwrap(); - session.serialize().unwrap() == loaded_session.serialize().unwrap() - } - - #[quickcheck_async::tokio] - async fn test_prekey_store(id: u32, key_pair: KeyPair) -> bool { - let id = id.into(); - let mut db = SledStore::temporary().unwrap(); - let pre_key_record = PreKeyRecord::new(id, &key_pair.0); - db.save_pre_key(id, &pre_key_record).await.unwrap(); - if db.get_pre_key(id).await.unwrap().serialize().unwrap() - != pre_key_record.serialize().unwrap() - { - return false; - } - - db.remove_pre_key(id).await.unwrap(); - db.get_pre_key(id).await.is_err() - } - - #[quickcheck_async::tokio] - async fn test_signed_prekey_store( - id: u32, - timestamp: u64, - key_pair: KeyPair, - signature: Vec<u8>, - ) -> bool { - let mut db = SledStore::temporary().unwrap(); - let id = id.into(); - let signed_pre_key_record = SignedPreKeyRecord::new(id, timestamp, &key_pair.0, &signature); - db.save_signed_pre_key(id, &signed_pre_key_record) - .await - .unwrap(); - - db.get_signed_pre_key(id) - .await - .unwrap() - .serialize() - .unwrap() - == signed_pre_key_record.serialize().unwrap() - } - fn content_with_timestamp( content: &Content, ts: u64, diff --git a/presage-store-sled/src/protocol.rs b/presage-store-sled/src/protocol.rs new file mode 100644 index 000000000..61530feca --- /dev/null +++ b/presage-store-sled/src/protocol.rs @@ -0,0 +1,761 @@ +use std::marker::PhantomData; + +use async_trait::async_trait; +use log::{error, trace, warn}; +use presage::{ + libsignal_service::{ + pre_keys::{KyberPreKeyStoreExt, PreKeysStore}, + prelude::Uuid, + protocol::{ + Direction, GenericSignedPreKey, IdentityKey, IdentityKeyPair, IdentityKeyStore, + KyberPreKeyId, KyberPreKeyRecord, KyberPreKeyStore, PreKeyId, PreKeyRecord, + PreKeyStore, ProtocolAddress, ProtocolStore, SenderKeyRecord, SenderKeyStore, + SessionRecord, SessionStore, SignalProtocolError, SignedPreKeyId, SignedPreKeyRecord, + SignedPreKeyStore, + }, + push_service::DEFAULT_DEVICE_ID, + session_store::SessionStoreExt, + ServiceAddress, + }, + proto::verified, + store::{ContentsStore, StateStore}, +}; +use sled::Batch; + +use crate::{ + OnNewIdentity, SledStore, SledStoreError, SLED_KEY_NEXT_PQ_PRE_KEY_ID, + SLED_KEY_NEXT_SIGNED_PRE_KEY_ID, SLED_KEY_PRE_KEYS_OFFSET_ID, +}; + +#[derive(Clone)] +pub struct SledProtocolStore<T: SledTrees> { + pub(crate) store: SledStore, + _trees: PhantomData<T>, +} + +impl SledProtocolStore<AciSledStore> { + pub(crate) fn aci_protocol_store(store: SledStore) -> Self { + Self { + store, + _trees: Default::default(), + } + } +} + +impl SledProtocolStore<PniSledStore> { + pub(crate) fn pni_protocol_store(store: SledStore) -> Self { + Self { + store, + _trees: Default::default(), + } + } +} + +pub trait SledTrees: Clone { + fn identities() -> &'static str; + fn state() -> &'static str; + fn pre_keys() -> &'static str; + fn signed_pre_keys() -> &'static str; + fn kyber_pre_keys() -> &'static str; + fn kyber_pre_keys_last_resort() -> &'static str; + fn sender_keys() -> &'static str; + fn sessions() -> &'static str; + fn identity_keypair() -> &'static str; +} + +#[derive(Clone)] +pub struct AciSledStore; + +impl SledTrees for AciSledStore { + fn identities() -> &'static str { + "identities" + } + + fn state() -> &'static str { + "state" + } + + fn pre_keys() -> &'static str { + "pre_keys" + } + + fn signed_pre_keys() -> &'static str { + "sender_keys" + } + + fn kyber_pre_keys() -> &'static str { + "signed_pre_keys" + } + + fn kyber_pre_keys_last_resort() -> &'static str { + "kyber_pre_keys_last_resort" + } + + fn sender_keys() -> &'static str { + "kyber_pre_keys" + } + + fn sessions() -> &'static str { + "sessions" + } + + fn identity_keypair() -> &'static str { + "aci_identity_key_pair" + } +} + +#[derive(Clone)] +pub struct PniSledStore; + +impl SledTrees for PniSledStore { + fn identities() -> &'static str { + "identities" + } + + fn state() -> &'static str { + "pni_state" + } + + fn pre_keys() -> &'static str { + "pni_pre_keys" + } + + fn signed_pre_keys() -> &'static str { + "pni_sender_keys" + } + + fn kyber_pre_keys() -> &'static str { + "pni_signed_pre_keys" + } + + fn kyber_pre_keys_last_resort() -> &'static str { + "pni_kyber_pre_keys_last_resort" + } + + fn sender_keys() -> &'static str { + "pni_kyber_pre_keys" + } + + fn sessions() -> &'static str { + "pni_sessions" + } + + fn identity_keypair() -> &'static str { + "pni_identity_key_pair" + } +} + +impl<T: SledTrees> SledProtocolStore<T> { + pub(crate) fn clear(&self) -> Result<(), SledStoreError> { + let db = self.store.db.write().expect("poisoned mutex"); + db.drop_tree(T::pre_keys())?; + db.drop_tree(T::sender_keys())?; + db.drop_tree(T::sessions())?; + db.drop_tree(T::signed_pre_keys())?; + db.drop_tree(T::kyber_pre_keys())?; + Ok(()) + } +} + +impl<T: SledTrees> ProtocolStore for SledProtocolStore<T> {} + +#[async_trait(?Send)] +impl<T: SledTrees> PreKeyStore for SledProtocolStore<T> { + async fn get_pre_key(&self, prekey_id: PreKeyId) -> Result<PreKeyRecord, SignalProtocolError> { + let buf: Vec<u8> = self + .store + .get(T::pre_keys(), prekey_id.to_string()) + .ok() + .flatten() + .ok_or(SignalProtocolError::InvalidPreKeyId)?; + + PreKeyRecord::deserialize(&buf) + } + + async fn save_pre_key( + &mut self, + prekey_id: PreKeyId, + record: &PreKeyRecord, + ) -> Result<(), SignalProtocolError> { + self.store + .insert(T::pre_keys(), prekey_id.to_string(), record.serialize()?) + .expect("failed to store pre-key"); + Ok(()) + } + + async fn remove_pre_key(&mut self, prekey_id: PreKeyId) -> Result<(), SignalProtocolError> { + self.store + .remove(T::pre_keys(), prekey_id.to_string()) + .expect("failed to remove pre-key"); + Ok(()) + } +} + +#[async_trait(?Send)] +impl<T: SledTrees> PreKeysStore for SledProtocolStore<T> { + async fn next_pre_key_id(&self) -> Result<u32, SignalProtocolError> { + Ok(self + .store + .get(T::state(), SLED_KEY_PRE_KEYS_OFFSET_ID) + .map_err(|_| SignalProtocolError::InvalidPreKeyId)? + .unwrap_or(0)) + } + + async fn set_next_pre_key_id(&mut self, id: u32) -> Result<(), SignalProtocolError> { + self.store + .insert(T::state(), SLED_KEY_PRE_KEYS_OFFSET_ID, id) + .map_err(|_| SignalProtocolError::InvalidPreKeyId)?; + Ok(()) + } + + async fn next_signed_pre_key_id(&self) -> Result<u32, SignalProtocolError> { + Ok(self + .store + .get(T::state(), SLED_KEY_NEXT_SIGNED_PRE_KEY_ID) + .map_err(|_| SignalProtocolError::InvalidSignedPreKeyId)? + .unwrap_or(0)) + } + + async fn set_next_signed_pre_key_id(&mut self, id: u32) -> Result<(), SignalProtocolError> { + self.store + .insert(T::state(), SLED_KEY_NEXT_SIGNED_PRE_KEY_ID, id) + .map_err(|_| SignalProtocolError::InvalidSignedPreKeyId)?; + Ok(()) + } + + async fn next_pq_pre_key_id(&self) -> Result<u32, SignalProtocolError> { + Ok(self + .store + .get(T::state(), SLED_KEY_NEXT_PQ_PRE_KEY_ID) + .map_err(|_| SignalProtocolError::InvalidKyberPreKeyId)? + .unwrap_or(0)) + } + + async fn set_next_pq_pre_key_id(&mut self, id: u32) -> Result<(), SignalProtocolError> { + self.store + .insert(T::state(), SLED_KEY_NEXT_PQ_PRE_KEY_ID, id) + .map_err(|_| SignalProtocolError::InvalidKyberPreKeyId)?; + Ok(()) + } + + async fn signed_pre_keys_count(&self) -> Result<usize, SignalProtocolError> { + Ok(self + .store + .db + .read() + .expect("poisoned mutex") + .open_tree(T::signed_pre_keys()) + .map_err(|e| { + log::error!("sled error: {}", e); + SignalProtocolError::InvalidState("signed_pre_keys_count", "sled error".into()) + })? + .into_iter() + .keys() + .filter_map(Result::ok) + .count()) + } + + /// number of kyber pre-keys we currently have in store + async fn kyber_pre_keys_count(&self, last_resort: bool) -> Result<usize, SignalProtocolError> { + Ok(self + .store + .db + .read() + .expect("poisoned mutex") + .open_tree(if last_resort { + T::kyber_pre_keys_last_resort() + } else { + T::kyber_pre_keys() + }) + .map_err(|e| { + log::error!("sled error: {}", e); + SignalProtocolError::InvalidState("save_signed_pre_key", "sled error".into()) + })? + .into_iter() + .keys() + .filter_map(Result::ok) + .count()) + } +} + +#[async_trait(?Send)] +impl<T: SledTrees> SignedPreKeyStore for SledProtocolStore<T> { + async fn get_signed_pre_key( + &self, + signed_prekey_id: SignedPreKeyId, + ) -> Result<SignedPreKeyRecord, SignalProtocolError> { + let buf: Vec<u8> = self + .store + .get(T::signed_pre_keys(), signed_prekey_id.to_string()) + .ok() + .flatten() + .ok_or(SignalProtocolError::InvalidSignedPreKeyId)?; + SignedPreKeyRecord::deserialize(&buf) + } + + async fn save_signed_pre_key( + &mut self, + signed_prekey_id: SignedPreKeyId, + record: &SignedPreKeyRecord, + ) -> Result<(), SignalProtocolError> { + self.store + .insert( + T::signed_pre_keys(), + signed_prekey_id.to_string(), + record.serialize()?, + ) + .map_err(|e| { + log::error!("sled error: {}", e); + SignalProtocolError::InvalidState("save_signed_pre_key", "sled error".into()) + })?; + Ok(()) + } +} + +#[async_trait(?Send)] +impl<T: SledTrees> KyberPreKeyStore for SledProtocolStore<T> { + async fn get_kyber_pre_key( + &self, + kyber_prekey_id: KyberPreKeyId, + ) -> Result<KyberPreKeyRecord, SignalProtocolError> { + let buf: Vec<u8> = self + .store + .get(T::kyber_pre_keys(), kyber_prekey_id.to_string()) + .ok() + .flatten() + .ok_or(SignalProtocolError::InvalidKyberPreKeyId)?; + KyberPreKeyRecord::deserialize(&buf) + } + + async fn save_kyber_pre_key( + &mut self, + kyber_prekey_id: KyberPreKeyId, + record: &KyberPreKeyRecord, + ) -> Result<(), SignalProtocolError> { + self.store + .insert( + T::kyber_pre_keys(), + kyber_prekey_id.to_string(), + record.serialize()?, + ) + .map_err(|e| { + log::error!("sled error: {}", e); + SignalProtocolError::InvalidState("save_kyber_pre_key", "sled error".into()) + })?; + Ok(()) + } + + async fn mark_kyber_pre_key_used( + &mut self, + kyber_prekey_id: KyberPreKeyId, + ) -> Result<(), SignalProtocolError> { + let removed = self + .store + .remove(T::kyber_pre_keys(), kyber_prekey_id.to_string()) + .map_err(|e| { + log::error!("sled error: {}", e); + SignalProtocolError::InvalidState("mark_kyber_pre_key_used", "sled error".into()) + })?; + if removed { + log::trace!("removed kyber pre-key {kyber_prekey_id}"); + } + Ok(()) + } +} + +#[async_trait(?Send)] +impl<T: SledTrees> KyberPreKeyStoreExt for SledProtocolStore<T> { + async fn store_last_resort_kyber_pre_key( + &mut self, + kyber_prekey_id: KyberPreKeyId, + record: &KyberPreKeyRecord, + ) -> Result<(), SignalProtocolError> { + trace!("store_last_resort_kyber_pre_key"); + self.store + .insert( + T::kyber_pre_keys_last_resort(), + kyber_prekey_id.to_string(), + record.serialize()?, + ) + .map_err(|e| { + log::error!("sled error: {}", e); + SignalProtocolError::InvalidState( + "store_last_resort_kyber_pre_key", + "sled error".into(), + ) + })?; + Ok(()) + } + + async fn load_last_resort_kyber_pre_keys( + &self, + ) -> Result<Vec<KyberPreKeyRecord>, SignalProtocolError> { + trace!("load_last_resort_kyber_pre_keys"); + self.store + .iter(T::kyber_pre_keys_last_resort())? + .filter_map(|data: Result<Vec<u8>, SledStoreError>| data.ok()) + .map(|data| KyberPreKeyRecord::deserialize(&data)) + .collect() + } + + async fn remove_kyber_pre_key( + &mut self, + kyber_prekey_id: KyberPreKeyId, + ) -> Result<(), SignalProtocolError> { + self.store + .remove(T::kyber_pre_keys_last_resort(), kyber_prekey_id.to_string())?; + self.store + .remove(T::kyber_pre_keys_last_resort(), kyber_prekey_id.to_string())?; + Ok(()) + } + + /// Analogous to markAllOneTimeKyberPreKeysStaleIfNecessary + async fn mark_all_one_time_kyber_pre_keys_stale_if_necessary( + &mut self, + _stale_time: chrono::DateTime<chrono::Utc>, + ) -> Result<(), SignalProtocolError> { + unimplemented!("should not be used yet") + } + + /// Analogue of deleteAllStaleOneTimeKyberPreKeys + async fn delete_all_stale_one_time_kyber_pre_keys( + &mut self, + _threshold: chrono::DateTime<chrono::Utc>, + _min_count: usize, + ) -> Result<(), SignalProtocolError> { + unimplemented!("should not be used yet") + } +} + +#[async_trait(?Send)] +impl<T: SledTrees> SessionStore for SledProtocolStore<T> { + async fn load_session( + &self, + address: &ProtocolAddress, + ) -> Result<Option<SessionRecord>, SignalProtocolError> { + let session = self.store.get(T::sessions(), address.to_string())?; + trace!("loading session {} / exists={}", address, session.is_some()); + session + .map(|b: Vec<u8>| SessionRecord::deserialize(&b)) + .transpose() + } + + async fn store_session( + &mut self, + address: &ProtocolAddress, + record: &SessionRecord, + ) -> Result<(), SignalProtocolError> { + trace!("storing session {}", address); + self.store + .insert(T::sessions(), address.to_string(), record.serialize()?)?; + Ok(()) + } +} + +#[async_trait(?Send)] +impl<T: SledTrees> SessionStoreExt for SledProtocolStore<T> { + async fn get_sub_device_sessions( + &self, + address: &ServiceAddress, + ) -> Result<Vec<u32>, SignalProtocolError> { + let session_prefix = format!("{}.", address.uuid); + trace!("get_sub_device_sessions {}", session_prefix); + let session_ids: Vec<u32> = self + .store + .read() + .open_tree(T::sessions()) + .map_err(SledStoreError::Db)? + .scan_prefix(&session_prefix) + .filter_map(|r| { + let (key, _) = r.ok()?; + let key_str = String::from_utf8_lossy(&key); + let device_id = key_str.strip_prefix(&session_prefix)?; + device_id.parse().ok() + }) + .filter(|d| *d != DEFAULT_DEVICE_ID) + .collect(); + Ok(session_ids) + } + + async fn delete_session(&self, address: &ProtocolAddress) -> Result<(), SignalProtocolError> { + trace!("deleting session {}", address); + self.store + .write() + .open_tree(T::sessions()) + .map_err(SledStoreError::Db)? + .remove(address.to_string()) + .map_err(|_e| SignalProtocolError::SessionNotFound(address.clone()))?; + Ok(()) + } + + async fn delete_all_sessions( + &self, + address: &ServiceAddress, + ) -> Result<usize, SignalProtocolError> { + let db = self.store.write(); + let sessions_tree = db.open_tree(T::sessions()).map_err(SledStoreError::Db)?; + + let mut batch = Batch::default(); + sessions_tree + .scan_prefix(address.uuid.to_string()) + .filter_map(|r| { + let (key, _) = r.ok()?; + Some(key) + }) + .for_each(|k| batch.remove(k)); + + db.apply_batch(batch).map_err(SledStoreError::Db)?; + + let len = sessions_tree.len(); + sessions_tree.clear().map_err(|_e| { + SignalProtocolError::InvalidSessionStructure("failed to delete all sessions") + })?; + Ok(len) + } +} + +#[async_trait(?Send)] +impl<T: SledTrees> IdentityKeyStore for SledProtocolStore<T> { + async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair, SignalProtocolError> { + trace!("getting identity_key_pair"); + self.store.get_identity_key_pair::<T>()?.ok_or_else(|| { + SignalProtocolError::InvalidState( + "get_identity_key_pair", + "no identity key pair found".to_owned(), + ) + }) + } + + async fn get_local_registration_id(&self) -> Result<u32, SignalProtocolError> { + let data = + self.store + .load_registration_data()? + .ok_or(SignalProtocolError::InvalidState( + "failed to load registration ID", + "no registration data".into(), + ))?; + Ok(data.registration_id) + } + + async fn save_identity( + &mut self, + address: &ProtocolAddress, + identity_key: &IdentityKey, + ) -> Result<bool, SignalProtocolError> { + trace!("saving identity"); + let existed_before = self + .store + .insert( + T::identities(), + address.to_string(), + identity_key.serialize(), + ) + .map_err(|e| { + error!("error saving identity for {:?}: {}", address, e); + e + })?; + + self.store.save_trusted_identity_message( + address, + *identity_key, + if existed_before { + verified::State::Unverified + } else { + verified::State::Default + }, + ); + + Ok(true) + } + + async fn is_trusted_identity( + &self, + address: &ProtocolAddress, + right_identity_key: &IdentityKey, + _direction: Direction, + ) -> Result<bool, SignalProtocolError> { + match self + .store + .get(T::identities(), address.to_string())? + .map(|b: Vec<u8>| IdentityKey::decode(&b)) + .transpose()? + { + None => { + // when we encounter a new identity, we trust it by default + warn!("trusting new identity {:?}", address); + Ok(true) + } + // when we encounter some identity we know, we need to decide whether we trust it or not + Some(left_identity_key) => { + if left_identity_key == *right_identity_key { + Ok(true) + } else { + match self.store.trust_new_identities { + OnNewIdentity::Trust => Ok(true), + OnNewIdentity::Reject => Ok(false), + } + } + } + } + } + + async fn get_identity( + &self, + address: &ProtocolAddress, + ) -> Result<Option<IdentityKey>, SignalProtocolError> { + self.store + .get(T::identities(), address.to_string())? + .map(|b: Vec<u8>| IdentityKey::decode(&b)) + .transpose() + } +} + +#[async_trait(?Send)] +impl<T: SledTrees> SenderKeyStore for SledProtocolStore<T> { + async fn store_sender_key( + &mut self, + sender: &ProtocolAddress, + distribution_id: Uuid, + record: &SenderKeyRecord, + ) -> Result<(), SignalProtocolError> { + let key = format!( + "{}.{}/{}", + sender.name(), + sender.device_id(), + distribution_id + ); + self.store + .insert(T::sender_keys(), key, record.serialize()?)?; + Ok(()) + } + + async fn load_sender_key( + &mut self, + sender: &ProtocolAddress, + distribution_id: Uuid, + ) -> Result<Option<SenderKeyRecord>, SignalProtocolError> { + let key = format!( + "{}.{}/{}", + sender.name(), + sender.device_id(), + distribution_id + ); + self.store + .get(T::sender_keys(), key)? + .map(|b: Vec<u8>| SenderKeyRecord::deserialize(&b)) + .transpose() + } +} + +#[cfg(test)] +mod tests { + use core::fmt; + + use base64::prelude::*; + use presage::{ + libsignal_service::protocol::{ + self, Direction, GenericSignedPreKey, IdentityKeyStore, PreKeyRecord, PreKeyStore, + SessionRecord, SessionStore, SignedPreKeyRecord, SignedPreKeyStore, + }, + store::Store, + }; + use quickcheck::{Arbitrary, Gen}; + + use super::SledStore; + + #[derive(Debug, Clone)] + struct ProtocolAddress(protocol::ProtocolAddress); + + #[derive(Clone)] + struct KeyPair(protocol::KeyPair); + + impl fmt::Debug for KeyPair { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "{}", + BASE64_STANDARD.encode(self.0.public_key.serialize()) + ) + } + } + + impl Arbitrary for ProtocolAddress { + fn arbitrary(g: &mut Gen) -> Self { + let name: String = Arbitrary::arbitrary(g); + let device_id: u32 = Arbitrary::arbitrary(g); + ProtocolAddress(protocol::ProtocolAddress::new(name, device_id.into())) + } + } + + impl Arbitrary for KeyPair { + fn arbitrary(_g: &mut Gen) -> Self { + // Gen is not rand::CryptoRng here, see https://github.com/BurntSushi/quickcheck/issues/241 + KeyPair(protocol::KeyPair::generate(&mut rand::thread_rng())) + } + } + + #[quickcheck_async::tokio] + async fn test_save_get_trust_identity(addr: ProtocolAddress, key_pair: KeyPair) -> bool { + let mut db = SledStore::temporary().unwrap().aci_protocol_store(); + let identity_key = protocol::IdentityKey::new(key_pair.0.public_key); + db.save_identity(&addr.0, &identity_key).await.unwrap(); + let id = db.get_identity(&addr.0).await.unwrap().unwrap(); + if id != identity_key { + return false; + } + db.is_trusted_identity(&addr.0, &id, Direction::Receiving) + .await + .unwrap() + } + + #[quickcheck_async::tokio] + async fn test_store_load_session(addr: ProtocolAddress) -> bool { + let session = SessionRecord::new_fresh(); + + let mut db = SledStore::temporary().unwrap().aci_protocol_store(); + db.store_session(&addr.0, &session).await.unwrap(); + if db.load_session(&addr.0).await.unwrap().is_none() { + return false; + } + let loaded_session = db.load_session(&addr.0).await.unwrap().unwrap(); + session.serialize().unwrap() == loaded_session.serialize().unwrap() + } + + #[quickcheck_async::tokio] + async fn test_prekey_store(id: u32, key_pair: KeyPair) -> bool { + let id = id.into(); + let mut db = SledStore::temporary().unwrap().aci_protocol_store(); + let pre_key_record = PreKeyRecord::new(id, &key_pair.0); + db.save_pre_key(id, &pre_key_record).await.unwrap(); + if db.get_pre_key(id).await.unwrap().serialize().unwrap() + != pre_key_record.serialize().unwrap() + { + return false; + } + + db.remove_pre_key(id).await.unwrap(); + db.get_pre_key(id).await.is_err() + } + + #[quickcheck_async::tokio] + async fn test_signed_prekey_store( + id: u32, + timestamp: u64, + key_pair: KeyPair, + signature: Vec<u8>, + ) -> bool { + let mut db = SledStore::temporary().unwrap().aci_protocol_store(); + let id = id.into(); + let signed_pre_key_record = SignedPreKeyRecord::new(id, timestamp, &key_pair.0, &signature); + db.save_signed_pre_key(id, &signed_pre_key_record) + .await + .unwrap(); + + db.get_signed_pre_key(id) + .await + .unwrap() + .serialize() + .unwrap() + == signed_pre_key_record.serialize().unwrap() + } +} diff --git a/presage/Cargo.toml b/presage/Cargo.toml index b9b70d357..4db0ffa7b 100644 --- a/presage/Cargo.toml +++ b/presage/Cargo.toml @@ -7,8 +7,8 @@ edition = "2021" license = "AGPL-3.0-only" [dependencies] -libsignal-service = { git = "https://github.com/whisperfish/libsignal-service-rs", rev = "c072491aa3e2b604b45b9f2b764552b7d382898c" } -libsignal-service-hyper = { git = "https://github.com/whisperfish/libsignal-service-rs", rev = "c072491aa3e2b604b45b9f2b764552b7d382898c" } +libsignal-service = { git = "https://github.com/whisperfish/libsignal-service-rs", rev = "26c036e" } +libsignal-service-hyper = { git = "https://github.com/whisperfish/libsignal-service-rs", rev = "26c036e" } base64 = "0.21" futures = "0.3" diff --git a/presage/src/errors.rs b/presage/src/errors.rs index f4c0629bf..dc9a583a7 100644 --- a/presage/src/errors.rs +++ b/presage/src/errors.rs @@ -41,7 +41,7 @@ pub enum Error<S: std::error::Error> { #[error("no provisioning message received")] NoProvisioningMessageReceived, #[error("qr code error")] - LinkError, + LinkingError, #[error("missing key {0} in config DB")] MissingKeyError(Cow<'static, str>), #[error("message pipe not started, you need to start receiving messages before you can send anything back")] diff --git a/presage/src/manager/confirmation.rs b/presage/src/manager/confirmation.rs index badde9bee..30ba89a9e 100644 --- a/presage/src/manager/confirmation.rs +++ b/presage/src/manager/confirmation.rs @@ -1,16 +1,17 @@ use libsignal_service::configuration::{ServiceConfiguration, SignalServers}; use libsignal_service::messagepipe::ServiceCredentials; use libsignal_service::prelude::phonenumber::PhoneNumber; -use libsignal_service::protocol::KeyPair; +use libsignal_service::protocol::IdentityKeyPair; use libsignal_service::provisioning::generate_registration_id; use libsignal_service::push_service::{ AccountAttributes, DeviceCapabilities, PushService, RegistrationMethod, ServiceIds, + VerifyAccountResponse, }; use libsignal_service::zkgroup::profiles::ProfileKey; +use libsignal_service::AccountManager; use libsignal_service_hyper::push_service::HyperPushService; use log::trace; -use rand::rngs::StdRng; -use rand::{RngCore, SeedableRng}; +use rand::RngCore; use crate::manager::registered::RegistrationData; use crate::store::Store; @@ -35,13 +36,13 @@ impl<S: Store> Manager<S, Confirmation> { /// Returns a [registered manager](Manager::load_registered) that you can use /// to send and receive messages. pub async fn confirm_verification_code( - self, + mut self, confirmation_code: impl AsRef<str>, ) -> Result<Manager<S, Registered>, Error<S::Error>> { trace!("confirming verification code"); - let registration_id = generate_registration_id(&mut StdRng::from_entropy()); - let pni_registration_id = generate_registration_id(&mut StdRng::from_entropy()); + let registration_id = generate_registration_id(&mut self.rng); + let pni_registration_id = generate_registration_id(&mut self.rng); let Confirmation { signal_servers, @@ -60,13 +61,13 @@ impl<S: Store> Manager<S, Confirmation> { }; let service_configuration: ServiceConfiguration = signal_servers.into(); - let mut push_service = HyperPushService::new( + let mut identified_push_service = HyperPushService::new( service_configuration, Some(credentials), crate::USER_AGENT.to_string(), ); - let session = push_service + let session = identified_push_service .submit_verification_code(&session_id, confirmation_code.as_ref()) .await?; @@ -76,23 +77,34 @@ impl<S: Store> Manager<S, Confirmation> { return Err(Error::UnverifiedRegistrationSession); } - let mut rng = StdRng::from_entropy(); - // generate a 52 bytes signaling key let mut signaling_key = [0u8; 52]; - rng.fill_bytes(&mut signaling_key); + self.rng.fill_bytes(&mut signaling_key); + // generate a 32 bytes profile key let mut profile_key = [0u8; 32]; - rng.fill_bytes(&mut profile_key); - + self.rng.fill_bytes(&mut profile_key); let profile_key = ProfileKey::generate(profile_key); - let skip_device_transfer = false; - let registered = push_service - .submit_registration_request( - RegistrationMethod::SessionId(&session_id), + // generate new identity keys used in `register_account` and below + self.store + .set_aci_identity_key_pair(IdentityKeyPair::generate(&mut self.rng))?; + self.store + .set_pni_identity_key_pair(IdentityKeyPair::generate(&mut self.rng))?; + + let skip_device_transfer = true; + let mut account_manager = AccountManager::new(identified_push_service, Some(profile_key)); + + let VerifyAccountResponse { + aci, + pni, + storage_capable: _, + number: _, + } = account_manager + .register_account( + &mut self.rng, + RegistrationMethod::SessionId(&session.id), AccountAttributes { - name: None, signaling_key: Some(signaling_key.to_vec()), registration_id, pni_registration_id, @@ -104,41 +116,30 @@ impl<S: Store> Manager<S, Confirmation> { unidentified_access_key: Some(profile_key.derive_access_key().to_vec()), unrestricted_unidentified_access: false, // TODO: make this configurable? discoverable_by_phone_number: true, - capabilities: DeviceCapabilities { - gv2: true, - gv1_migration: true, - ..Default::default() - }, + name: None, + capabilities: DeviceCapabilities::default(), }, + &mut self.store.aci_protocol_store(), + &mut self.store.pni_protocol_store(), skip_device_transfer, ) .await?; - let aci_identity_key_pair = KeyPair::generate(&mut rng); - let pni_identity_key_pair = KeyPair::generate(&mut rng); - trace!("confirmed! (and registered)"); let mut manager = Manager { - rng, + rng: self.rng, store: self.store, state: Registered::with_data(RegistrationData { signal_servers: self.state.signal_servers, device_name: None, phone_number, - service_ids: ServiceIds { - aci: registered.uuid, - pni: registered.pni, - }, + service_ids: ServiceIds { aci, pni }, password, signaling_key, device_id: None, registration_id, pni_registration_id: Some(pni_registration_id), - aci_private_key: aci_identity_key_pair.private_key, - aci_public_key: aci_identity_key_pair.public_key, - pni_private_key: Some(pni_identity_key_pair.private_key), - pni_public_key: Some(pni_identity_key_pair.public_key), profile_key, }), }; diff --git a/presage/src/manager/linking.rs b/presage/src/manager/linking.rs index eebfe9fcd..bbda35e7a 100644 --- a/presage/src/manager/linking.rs +++ b/presage/src/manager/linking.rs @@ -1,7 +1,10 @@ use futures::channel::{mpsc, oneshot}; use futures::{future, StreamExt}; use libsignal_service::configuration::{ServiceConfiguration, SignalServers}; -use libsignal_service::provisioning::{link_device, SecondaryDeviceProvisioning}; +use libsignal_service::protocol::IdentityKeyPair; +use libsignal_service::provisioning::{ + link_device, NewDeviceRegistration, SecondaryDeviceProvisioning, +}; use libsignal_service_hyper::push_service::HyperPushService; use log::info; use rand::distributions::{Alphanumeric, DistString}; @@ -77,13 +80,10 @@ impl<S: Store> Manager<S, Linking> { let (tx, mut rx) = mpsc::channel(1); - // XXX: this is obviously wrong - let mut pni_store = store.clone(); - let (wait_for_qrcode_scan, registration_data) = future::join( link_device( - &mut store, - &mut pni_store, + &mut store.aci_protocol_store(), + &mut store.pni_protocol_store(), &mut rng, push_service, &password, @@ -94,10 +94,10 @@ impl<S: Store> Manager<S, Linking> { if let Some(SecondaryDeviceProvisioning::Url(url)) = rx.next().await { info!("generating qrcode from provisioning link: {}", &url); if provisioning_link_channel.send(url).is_err() { - return Err(Error::LinkError); + return Err(Error::LinkingError); } } else { - return Err(Error::LinkError); + return Err(Error::LinkingError); } if let Some(SecondaryDeviceProvisioning::NewDeviceRegistration(data)) = rx.next().await @@ -113,24 +113,40 @@ impl<S: Store> Manager<S, Linking> { wait_for_qrcode_scan?; match registration_data { - Ok(d) => { + Ok(NewDeviceRegistration { + phone_number, + device_id, + registration_id, + pni_registration_id, + service_ids, + aci_private_key, + aci_public_key, + pni_private_key, + pni_public_key, + profile_key, + }) => { let registration_data = RegistrationData { signal_servers, device_name: Some(device_name), - phone_number: d.phone_number, - service_ids: d.service_ids, + phone_number, + service_ids, password, signaling_key, - device_id: Some(d.device_id.into()), - registration_id: d.registration_id, - pni_registration_id: Some(d.pni_registration_id), - aci_public_key: d.aci_public_key, - aci_private_key: d.aci_private_key, - pni_public_key: Some(d.pni_public_key), - pni_private_key: Some(d.pni_private_key), - profile_key: d.profile_key, + device_id: Some(device_id.into()), + registration_id, + pni_registration_id: Some(pni_registration_id), + profile_key, }; + store.set_aci_identity_key_pair(IdentityKeyPair::new( + aci_public_key, + aci_private_key, + ))?; + store.set_pni_identity_key_pair(IdentityKeyPair::new( + pni_public_key, + pni_private_key, + ))?; + store.save_registration_data(®istration_data)?; info!( "successfully registered device {}", @@ -143,7 +159,8 @@ impl<S: Store> Manager<S, Linking> { state: Registered::with_data(registration_data), }; - // Register pre-keys with the server. If this fails, this can lead to issues receiving, in that case clear the registration and propagate the error. + // Register pre-keys with the server. If this fails, this can lead to issues + // receiving, in that case clear the registration and propagate the error. if let Err(e) = manager.register_pre_keys().await { store.clear_registration()?; Err(e) diff --git a/presage/src/manager/mod.rs b/presage/src/manager/mod.rs index a17558672..f0b59d17d 100644 --- a/presage/src/manager/mod.rs +++ b/presage/src/manager/mod.rs @@ -38,59 +38,3 @@ impl<Store, State: fmt::Debug> fmt::Debug for Manager<Store, State> { .finish_non_exhaustive() } } - -#[cfg(test)] -mod tests { - use base64::engine::general_purpose; - use base64::Engine; - use libsignal_service::prelude::ProfileKey; - use libsignal_service::protocol::KeyPair; - use rand::RngCore; - use serde_json::json; - - use crate::manager::RegistrationData; - - #[test] - fn test_state_before_pni() { - let mut rng = rand::thread_rng(); - let key_pair = KeyPair::generate(&mut rng); - let mut profile_key = [0u8; 32]; - rng.fill_bytes(&mut profile_key); - let profile_key = ProfileKey::generate(profile_key); - let mut signaling_key = [0u8; 52]; - rng.fill_bytes(&mut signaling_key); - - // this is before public_key and private_key were renamed to aci_public_key and aci_private_key - // and pni_public_key + pni_private_key were added - let previous_state = json!({ - "signal_servers": "Production", - "device_name": "Test", - "phone_number": { - "code": { - "value": 1, - "source": "plus" - }, - "national": { - "value": 5550199, - "zeros": 0 - }, - "extension": null, - "carrier": null - }, - "uuid": "ff9a89d9-8052-4af0-a91d-2a0dfa0c6b95", - "password": "HelloWorldOfPasswords", - "signaling_key": general_purpose::STANDARD.encode(signaling_key), - "device_id": 42, - "registration_id": 64, - "private_key": general_purpose::STANDARD.encode(key_pair.private_key.serialize()), - "public_key": general_purpose::STANDARD.encode(key_pair.public_key.serialize()), - "profile_key": general_purpose::STANDARD.encode(profile_key.get_bytes()), - }); - - let data: RegistrationData = - serde_json::from_value(previous_state).expect("should deserialize"); - assert_eq!(data.aci_public_key, key_pair.public_key); - assert!(data.aci_private_key == key_pair.private_key); - assert!(data.pni_public_key.is_none()); - } -} diff --git a/presage/src/manager/registered.rs b/presage/src/manager/registered.rs index e2301a9c4..c6d402c7d 100644 --- a/presage/src/manager/registered.rs +++ b/presage/src/manager/registered.rs @@ -20,8 +20,7 @@ use libsignal_service::proto::{ AttachmentPointer, DataMessage, EditMessage, GroupContextV2, NullMessage, SyncMessage, Verified, }; -use libsignal_service::protocol::SenderCertificate; -use libsignal_service::protocol::{PrivateKey, PublicKey}; +use libsignal_service::protocol::{IdentityKeyStore, SenderCertificate}; use libsignal_service::provisioning::{generate_registration_id, ProvisioningError}; use libsignal_service::push_service::{ AccountAttributes, DeviceCapabilities, PushService, ServiceError, ServiceIdType, ServiceIds, @@ -31,10 +30,7 @@ use libsignal_service::receiver::MessageReceiver; use libsignal_service::sender::{AttachmentSpec, AttachmentUploadError}; use libsignal_service::sticker_cipher::derive_key; use libsignal_service::unidentified_access::UnidentifiedAccess; -use libsignal_service::utils::{ - serde_optional_private_key, serde_optional_public_key, serde_private_key, serde_public_key, - serde_signaling_key, -}; +use libsignal_service::utils::serde_signaling_key; use libsignal_service::websocket::SignalWebSocket; use libsignal_service::zkgroup::groups::{GroupMasterKey, GroupSecretParams}; use libsignal_service::zkgroup::profiles::ProfileKey; @@ -109,14 +105,6 @@ pub struct RegistrationData { pub registration_id: u32, #[serde(default)] pub pni_registration_id: Option<u32>, - #[serde(with = "serde_private_key", rename = "private_key")] - pub(crate) aci_private_key: PrivateKey, - #[serde(with = "serde_public_key", rename = "public_key")] - pub(crate) aci_public_key: PublicKey, - #[serde(with = "serde_optional_private_key", default)] - pub(crate) pni_private_key: Option<PrivateKey>, - #[serde(with = "serde_optional_public_key", default)] - pub(crate) pni_public_key: Option<PublicKey>, #[serde(with = "serde_profile_key")] pub(crate) profile_key: ProfileKey, } @@ -141,16 +129,6 @@ impl RegistrationData { pub fn device_name(&self) -> Option<&str> { self.device_name.as_deref() } - - /// Account identity public key - pub fn aci_public_key(&self) -> PublicKey { - self.aci_public_key - } - - /// Account identity private key - pub fn aci_private_key(&self) -> PrivateKey { - self.aci_private_key - } } impl<S: Store> Manager<S, Registered> { @@ -271,22 +249,22 @@ impl<S: Store> Manager<S, Registered> { Some(self.state.data.profile_key), ); - // TODO: Do the same for PNI once implemented upstream. - let (pre_keys_offset_id, next_signed_pre_key_id, next_pq_pre_key_id) = account_manager + account_manager .update_pre_key_bundle( - &mut self.store.clone(), + &mut self.store.aci_protocol_store(), ServiceIdType::AccountIdentity, &mut self.rng, true, ) .await?; - self.store.set_next_pre_key_id(pre_keys_offset_id).await?; - self.store - .set_next_signed_pre_key_id(next_signed_pre_key_id) - .await?; - self.store - .set_next_pq_pre_key_id(next_pq_pre_key_id) + account_manager + .update_pre_key_bundle( + &mut self.store.pni_protocol_store(), + ServiceIdType::PhoneNumberIdentity, + &mut self.rng, + true, + ) .await?; trace!("registered pre keys"); @@ -327,8 +305,11 @@ impl<S: Store> Manager<S, Registered> { unrestricted_unidentified_access: false, discoverable_by_phone_number: true, capabilities: DeviceCapabilities { - gv2: true, - gv1_migration: true, + gift_badges: true, + payment_activation: false, + pni: true, + sender_key: true, + stories: false, ..Default::default() }, }) @@ -606,12 +587,12 @@ impl<S: Store> Manager<S, Registered> { &mut self, mode: ReceivingMode, ) -> Result<impl Stream<Item = Content>, Error<S::Error>> { - struct StreamState<S, C: ContentsStore + Send + Sync> { - encrypted_messages: S, + struct StreamState<Receiver, Store, AciStore> { + encrypted_messages: Receiver, message_receiver: MessageReceiver<HyperPushService>, - service_cipher: ServiceCipher<C>, + service_cipher: ServiceCipher<AciStore>, push_service: HyperPushService, - store: C, + store: Store, groups_manager: GroupsManager<HyperPushService, InMemoryCredentialsCache>, mode: ReceivingMode, } @@ -811,6 +792,10 @@ impl<S: Store> Manager<S, Registered> { let mut sender = self.new_message_sender().await?; let online_only = false; + // TODO: Populate this flag based on the recipient information + // + // Issue <https://github.com/whisperfish/presage/issues/252> + let include_pni_signature = false; let recipient = recipient_addr.into(); let mut content_body: ContentBody = message.into(); @@ -855,6 +840,7 @@ impl<S: Store> Manager<S, Registered> { content_body.clone(), timestamp, online_only, + include_pni_signature, ) .await?; @@ -953,7 +939,12 @@ impl<S: Store> Manager<S, Registered> { key: profile_key.derive_access_key().to_vec(), certificate: sender_certificate.clone(), }); - recipients.push((member.uuid.into(), unidentified_access)); + let include_pni_signature = true; + recipients.push(( + member.uuid.into(), + unidentified_access, + include_pni_signature, + )); } let online_only = false; @@ -1002,7 +993,15 @@ impl<S: Store> Manager<S, Registered> { /// Clears all sessions established wiht [recipient](ServiceAddress). pub async fn clear_sessions(&self, recipient: &ServiceAddress) -> Result<(), Error<S::Error>> { - self.store.delete_all_sessions(recipient).await?; + use libsignal_service::session_store::SessionStoreExt; + self.store + .aci_protocol_store() + .delete_all_sessions(recipient) + .await?; + self.store + .pni_protocol_store() + .delete_all_sessions(recipient) + .await?; Ok(()) } @@ -1152,30 +1151,37 @@ impl<S: Store> Manager<S, Registered> { } /// Creates a new message sender. - async fn new_message_sender(&self) -> Result<MessageSender<S>, Error<S::Error>> { - let local_addr = ServiceAddress { - uuid: self.state.data.service_ids.aci, - }; - + async fn new_message_sender(&self) -> Result<MessageSender<S::AciStore>, Error<S::Error>> { let identified_websocket = self.identified_websocket(false).await?; let unidentified_websocket = self.unidentified_websocket().await?; + let aci_protocol_store = self.store.aci_protocol_store(); + let aci_identity_keypair = aci_protocol_store.get_identity_key_pair().await?; + let pni_identity_keypair = self + .store + .pni_protocol_store() + .get_identity_key_pair() + .await?; + Ok(MessageSender::new( identified_websocket, unidentified_websocket, self.identified_push_service(), self.new_service_cipher()?, self.rng.clone(), - self.store.clone(), - local_addr, + aci_protocol_store, + self.state.data.service_ids.aci, + self.state.data.service_ids.pni, + aci_identity_keypair, + Some(pni_identity_keypair), self.state.device_id().into(), )) } /// Creates a new service cipher. - fn new_service_cipher(&self) -> Result<ServiceCipher<S>, Error<S::Error>> { + fn new_service_cipher(&self) -> Result<ServiceCipher<S::AciStore>, Error<S::Error>> { let service_cipher = ServiceCipher::new( - self.store.clone(), + self.store.aci_protocol_store(), self.rng.clone(), self.state .service_configuration() diff --git a/presage/src/store.rs b/presage/src/store.rs index c40b7ec08..cbe58ada9 100644 --- a/presage/src/store.rs +++ b/presage/src/store.rs @@ -12,7 +12,7 @@ use libsignal_service::{ sync_message::{self, Sent}, verified, DataMessage, EditMessage, GroupContextV2, SyncMessage, Verified, }, - protocol::{IdentityKey, ProtocolAddress, ProtocolStore, SenderKeyStore}, + protocol::{IdentityKey, IdentityKeyPair, ProtocolAddress, ProtocolStore, SenderKeyStore}, session_store::SessionStoreExt, zkgroup::GroupMasterKeyBytes, Profile, @@ -32,6 +32,16 @@ pub trait StateStore { /// Load registered (or linked) state fn load_registration_data(&self) -> Result<Option<RegistrationData>, Self::StateStoreError>; + fn set_aci_identity_key_pair( + &self, + key_pair: IdentityKeyPair, + ) -> Result<(), Self::StateStoreError>; + + fn set_pni_identity_key_pair( + &self, + key_pair: IdentityKeyPair, + ) -> Result<(), Self::StateStoreError>; + /// Save registered (or linked) state fn save_registration_data( &mut self, @@ -63,6 +73,12 @@ pub trait ContentsStore: Send + Sync { /// Iterator over all stored sticker packs type StickerPacksIter: Iterator<Item = Result<StickerPack, Self::ContentsStoreError>>; + // Clear all profiles + fn clear_profiles(&mut self) -> Result<(), Self::ContentsStoreError>; + + // Clear all stored messages + fn clear_contents(&mut self) -> Result<(), Self::ContentsStoreError>; + // Messages /// Clear all stored messages. @@ -290,22 +306,24 @@ pub trait ContentsStore: Send + Sync { /// The manager store trait combining all other stores into a single one pub trait Store: StateStore<StateStoreError = Self::Error> - + PreKeysStore + ContentsStore<ContentsStoreError = Self::Error> - + ProtocolStore - + SenderKeyStore - + SessionStoreExt + Send + Sync + Clone + 'static { type Error: StoreError; + type AciStore: ProtocolStore + PreKeysStore + SenderKeyStore + SessionStoreExt + Sync + Clone; + type PniStore: ProtocolStore + PreKeysStore + SenderKeyStore + SessionStoreExt + Sync + Clone; /// Clear the entire store /// /// This can be useful when resetting an existing client. fn clear(&mut self) -> Result<(), <Self as StateStore>::StateStoreError>; + + fn aci_protocol_store(&self) -> Self::AciStore; + + fn pni_protocol_store(&self) -> Self::PniStore; } /// A thread specifies where a message was sent, either to or from a contact or in a group.