Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various cleanups after merging ACI/PNI #298

Merged
merged 5 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions libsignal-service/examples/storage.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use libsignal_service::pre_keys::{PreKeysStore, ServiceKyberPreKeyStore};
use libsignal_service::pre_keys::{KyberPreKeyStoreExt, PreKeysStore};
use libsignal_service::protocol::{
Direction, IdentityKey, IdentityKeyPair, IdentityKeyStore, KyberPreKeyId,
KyberPreKeyRecord, KyberPreKeyStore, PreKeyId, PreKeyRecord, PreKeyStore,
Expand Down Expand Up @@ -92,7 +92,7 @@ impl SignedPreKeyStore for ExampleStore {
}

#[async_trait::async_trait(?Send)]
impl ServiceKyberPreKeyStore for ExampleStore {
impl KyberPreKeyStoreExt for ExampleStore {
async fn store_last_resort_kyber_pre_key(
&mut self,
_kyber_prekey_id: KyberPreKeyId,
Expand Down Expand Up @@ -227,6 +227,19 @@ impl PreKeysStore for ExampleStore {
) -> Result<(), SignalProtocolError> {
todo!()
}

async fn signed_pre_keys_count(
&self,
) -> Result<usize, SignalProtocolError> {
todo!()
}

async fn kyber_pre_keys_count(
&self,
_last_resort: bool,
) -> Result<usize, SignalProtocolError> {
todo!()
}
}

#[allow(dead_code)]
Expand Down
44 changes: 17 additions & 27 deletions libsignal-service/src/account_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
use libsignal_protocol::{
kem, GenericSignedPreKey, IdentityKey, IdentityKeyStore, KeyPair,
KyberPreKeyRecord, PrivateKey, ProtocolStore, PublicKey, SenderKeyStore,
SignalProtocolError, SignedPreKeyRecord,
SignedPreKeyRecord,
};
use prost::Message;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -64,20 +64,6 @@
ProfileCipherError(#[from] ProfileCipherError),
}

#[derive(thiserror::Error, Debug)]
pub enum LinkError {
#[error(transparent)]
ServiceError(#[from] ServiceError),
#[error("TsUrl has an invalid UUID field")]
InvalidUuid,
#[error("TsUrl has an invalid pub_key field")]
InvalidPublicKey,
#[error("Protocol error {0}")]
ProtocolError(#[from] SignalProtocolError),
#[error(transparent)]
ProvisioningError(#[from] ProvisioningError),
}

#[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct Profile {
pub name: Option<ProfileName<String>>,
Expand Down Expand Up @@ -111,7 +97,6 @@
service_id_type: ServiceIdType,
csprng: &mut R,
use_last_resort_key: bool,
force: bool,
) -> Result<(), ServiceError> {
let prekey_status = match self
.service
Expand Down Expand Up @@ -143,7 +128,9 @@
if prekey_status.count >= PRE_KEY_MINIMUM
&& prekey_status.pq_count >= PRE_KEY_MINIMUM
{
if !force {
if protocol_store.signed_pre_keys_count().await? > 0
&& protocol_store.kyber_pre_keys_count(true).await? > 0
{
tracing::debug!("Available keys sufficient");
return Ok(());
}
Expand All @@ -159,6 +146,7 @@
.load_last_resort_kyber_pre_keys()
.instrument(tracing::trace_span!("fetch last resort key"))
.await?;

// XXX: Maybe this check should be done in the generate_pre_keys function?
let has_last_resort_key = !last_resort_keys.is_empty();

Expand Down Expand Up @@ -186,7 +174,7 @@
.transpose()?
};

let identity_key = *identity_key_pair.identity_key().public_key();
let identity_key = *identity_key_pair.identity_key();

let pre_keys: Vec<_> = pre_keys
.into_iter()
Expand Down Expand Up @@ -289,16 +277,18 @@
aci_identity_store: &dyn IdentityKeyStore,
pni_identity_store: &dyn IdentityKeyStore,
credentials: ServiceCredentials,
) -> Result<(), LinkError> {
) -> Result<(), ProvisioningError> {
let query: HashMap<_, _> = url.query_pairs().collect();
let ephemeral_id = query.get("uuid").ok_or(LinkError::InvalidUuid)?;
let pub_key =
query.get("pub_key").ok_or(LinkError::InvalidPublicKey)?;
let ephemeral_id =
query.get("uuid").ok_or(ProvisioningError::MissingUuid)?;
let pub_key = query
.get("pub_key")
.ok_or(ProvisioningError::MissingPublicKey)?;
let pub_key = BASE64_RELAXED
.decode(&**pub_key)
.map_err(|_e| LinkError::InvalidPublicKey)?;
.map_err(|e| ProvisioningError::InvalidPublicKey(e.into()))?;
let pub_key = PublicKey::deserialize(&pub_key)
.map_err(|_e| LinkError::InvalidPublicKey)?;
.map_err(|e| ProvisioningError::InvalidPublicKey(e.into()))?;

let aci_identity_key_pair =
aci_identity_store.get_identity_key_pair().await?;
Expand Down Expand Up @@ -361,7 +351,7 @@
aci_protocol_store: &mut Aci,
pni_protocol_store: &mut Pni,
skip_device_transfer: bool,
) -> Result<VerifyAccountResponse, LinkError> {
) -> Result<VerifyAccountResponse, ProvisioningError> {
let aci_identity_key_pair = aci_protocol_store
.get_identity_key_pair()
.instrument(tracing::trace_span!("get ACI identity key pair"))
Expand Down Expand Up @@ -421,8 +411,8 @@
registration_method,
account_attributes,
skip_device_transfer,
*aci_identity_key,
*pni_identity_key,
aci_identity_key,
pni_identity_key,
dar,
)
.await?;
Expand Down Expand Up @@ -724,10 +714,10 @@
} else {
loop {
let regid = generate_registration_id(csprng);
if pni_registration_ids
.iter()
.find(|(_k, v)| **v == regid)
.is_none()

Check warning on line 720 in libsignal-service/src/account_manager.rs

View workflow job for this annotation

GitHub Actions / clippy

called `is_none()` after searching an `Iterator` with `find`

warning: called `is_none()` after searching an `Iterator` with `find` --> libsignal-service/src/account_manager.rs:717:24 | 717 | if pni_registration_ids | ________________________^ 718 | | .iter() 719 | | .find(|(_k, v)| **v == regid) 720 | | .is_none() | |__________________________________^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#search_is_some = note: `#[warn(clippy::search_is_some)]` on by default help: consider using | 717 ~ if !pni_registration_ids 718 + .iter().any(|(_k, v)| **v == regid) |
{
break regid;
}
Expand Down
3 changes: 2 additions & 1 deletion libsignal-service/src/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
envelope::Envelope,
push_service::ServiceError,
sender::OutgoingPushMessage,
session_store::SessionStoreExt,
utils::BASE64_RELAXED,
ServiceAddress,
};
Expand Down Expand Up @@ -75,7 +76,7 @@

impl<S, R> ServiceCipher<S, R>
where
S: ProtocolStore + KyberPreKeyStore + SenderKeyStore + Clone,
S: ProtocolStore + SenderKeyStore + SessionStoreExt + Clone,
R: Rng + CryptoRng,
{
pub fn new(
Expand Down Expand Up @@ -181,7 +182,7 @@
&mut self.protocol_store.clone(),
&mut self.protocol_store.clone(),
&mut self.protocol_store.clone(),
&mut self.protocol_store.clone(),

Check warning on line 185 in libsignal-service/src/cipher.rs

View workflow job for this annotation

GitHub Actions / clippy

the function `message_decrypt_prekey` doesn't need a mutable reference

warning: the function `message_decrypt_prekey` doesn't need a mutable reference --> libsignal-service/src/cipher.rs:185:21 | 185 | &mut self.protocol_store.clone(), | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#unnecessary_mut_passed = note: `#[warn(clippy::unnecessary_mut_passed)]` on by default
&mut self.protocol_store.clone(),
&mut self.csprng,
)
Expand Down
1 change: 1 addition & 0 deletions libsignal-service/src/groups_v2/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ pub fn decrypt_group(
fn current_days_seconds() -> (u64, u64) {
let days_seconds = |date: NaiveDate| {
date.and_time(NaiveTime::from_hms_opt(0, 0, 0).unwrap())
.and_utc()
.timestamp() as u64
};

Expand Down
5 changes: 2 additions & 3 deletions libsignal-service/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ pub mod utils;
pub mod websocket;

pub use crate::account_manager::{
decrypt_device_name, AccountManager, LinkError, Profile,
ProfileManagerError,
decrypt_device_name, AccountManager, Profile, ProfileManagerError,
};
pub use crate::service_address::*;

Expand Down Expand Up @@ -96,7 +95,7 @@ pub mod prelude {
profiles::ProfileKey,
};

pub use libsignal_protocol::DeviceId;
pub use libsignal_protocol::{DeviceId, IdentityKeyStore};
}

pub use libsignal_protocol as protocol;
Expand Down
28 changes: 19 additions & 9 deletions libsignal-service/src/pre_keys.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::{convert::TryFrom, time::SystemTime};

use crate::utils::{serde_base64, serde_public_key};
use crate::utils::{serde_base64, serde_identity_key};
use async_trait::async_trait;
use libsignal_protocol::{
error::SignalProtocolError, kem, GenericSignedPreKey, IdentityKeyPair,
IdentityKeyStore, KeyPair, KyberPreKeyId, KyberPreKeyRecord,
KyberPreKeyStore, PreKeyRecord, PreKeyStore, PublicKey, SignedPreKeyRecord,
SignedPreKeyStore,
error::SignalProtocolError, kem, GenericSignedPreKey, IdentityKey,
IdentityKeyPair, IdentityKeyStore, KeyPair, KyberPreKeyId,
KyberPreKeyRecord, KyberPreKeyStore, PreKeyRecord, PreKeyStore,
SignedPreKeyRecord, SignedPreKeyStore,
};

use serde::{Deserialize, Serialize};
Expand All @@ -16,7 +16,7 @@ use tracing::Instrument;
/// Additional methods for the Kyber pre key store
///
/// Analogue of Android's ServiceKyberPreKeyStore
pub trait ServiceKyberPreKeyStore: KyberPreKeyStore {
pub trait KyberPreKeyStoreExt: KyberPreKeyStore {
async fn store_last_resort_kyber_pre_key(
&mut self,
kyber_prekey_id: KyberPreKeyId,
Expand Down Expand Up @@ -55,7 +55,7 @@ pub trait PreKeysStore:
+ IdentityKeyStore
+ SignedPreKeyStore
+ KyberPreKeyStore
+ ServiceKyberPreKeyStore
+ KyberPreKeyStoreExt
{
/// ID of the next pre key
async fn next_pre_key_id(&self) -> Result<u32, SignalProtocolError>;
Expand Down Expand Up @@ -83,6 +83,16 @@ pub trait PreKeysStore:
&mut self,
id: u32,
) -> Result<(), SignalProtocolError>;

/// number of signed pre-keys we currently have in store
async fn signed_pre_keys_count(&self)
-> Result<usize, SignalProtocolError>;

/// number of kyber pre-keys we currently have in store
async fn kyber_pre_keys_count(
&self,
last_resort: bool,
) -> Result<usize, SignalProtocolError>;
}

#[derive(Debug, Deserialize, Serialize)]
Expand Down Expand Up @@ -169,8 +179,8 @@ impl TryFrom<KyberPreKeyRecord> for KyberPreKeyEntity {
pub struct PreKeyState {
pub pre_keys: Vec<PreKeyEntity>,
pub signed_pre_key: SignedPreKeyEntity,
#[serde(with = "serde_public_key")]
pub identity_key: PublicKey,
#[serde(with = "serde_identity_key")]
pub identity_key: IdentityKey,
#[serde(skip_serializing_if = "Option::is_none")]
pub pq_last_resort_key: Option<KyberPreKeyEntity>,
pub pq_pre_keys: Vec<KyberPreKeyEntity>,
Expand Down
12 changes: 3 additions & 9 deletions libsignal-service/src/provisioning/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,7 @@ impl ProvisioningCipher {
.body
.expect("no body in ProvisionMessage");
if body[0] != VERSION {
return Err(ProvisioningError::InvalidData {
reason: "Bad version number".into(),
});
return Err(ProvisioningError::BadVersionNumber);
}

let iv = &body[IV_OFFSET..(IV_LENGTH + IV_OFFSET)];
Expand All @@ -166,9 +164,7 @@ impl ProvisioningCipher {
let our_mac = verifier.finalize().into_bytes();
debug_assert_eq!(our_mac.len(), mac.len());
if &our_mac[..32] != mac {
return Err(ProvisioningError::InvalidData {
reason: "wrong MAC".into(),
});
return Err(ProvisioningError::MismatchedMac);
}

// libsignal-service-java uses Pkcs5,
Expand All @@ -177,9 +173,7 @@ impl ProvisioningCipher {
let cipher = cbc::Decryptor::<Aes256>::new(parts1.into(), iv.into());
let input = cipher
.decrypt_padded_vec_mut::<Pkcs7>(cipher_text)
.map_err(|e| ProvisioningError::InvalidData {
reason: format!("CBC/Padding error: {:?}", e).into(),
})?;
.map_err(ProvisioningError::AesPaddingError)?;

Ok(prost::Message::decode(Bytes::from(input))?)
}
Expand Down
Loading
Loading