diff --git a/Cargo.toml b/Cargo.toml index 8aceea1a..22a9e1e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ rust-version = "1.60" version = "2.0.0" [features] -curve25519 = ["curve25519-dalek"] +curve25519 = ["curve25519-dalek/precomputed-tables"] default = ["ristretto255-voprf", "serde"] ristretto255 = ["curve25519-dalek", "voprf/ristretto255"] ristretto255-voprf = ["ristretto255", "voprf/ristretto255-ciphersuite"] @@ -23,8 +23,9 @@ std = ["getrandom"] argon2 = { version = "0.4", default-features = false, features = [ "alloc", ], optional = true } -curve25519-dalek = { version = "=4.0.0-pre.5", default-features = false, features = [ +curve25519-dalek = { version = "=4.0.0-rc.1", default-features = false, features = [ "rand_core", + "zeroize", ], optional = true } derive-where = { version = "1", features = ["zeroize-on-drop"] } digest = "0.10" @@ -78,4 +79,4 @@ name = "simple_login" required-features = ["argon2"] [patch.crates-io] -voprf = { git = "https://github.com/facebook/voprf" } +voprf = { git = "https://github.com/khonsulabs/voprf", branch = "curve25519-dalek-4.0.0-rc.1"} diff --git a/src/key_exchange/group/curve25519.rs b/src/key_exchange/group/curve25519.rs index 17de8da8..d562f898 100644 --- a/src/key_exchange/group/curve25519.rs +++ b/src/key_exchange/group/curve25519.rs @@ -81,7 +81,7 @@ impl KeGroup for Curve25519 { } fn public_key(sk: Self::Sk) -> Self::Pk { - (&ED25519_BASEPOINT_TABLE * &sk).to_montgomery() + (ED25519_BASEPOINT_TABLE * &sk).to_montgomery() } fn diffie_hellman(pk: Self::Pk, sk: Self::Sk) -> GenericArray { diff --git a/src/key_exchange/group/ristretto255.rs b/src/key_exchange/group/ristretto255.rs index f1aa98ab..f90bcb84 100644 --- a/src/key_exchange/group/ristretto255.rs +++ b/src/key_exchange/group/ristretto255.rs @@ -38,11 +38,8 @@ impl KeGroup for Ristretto255 { } fn deserialize_pk(bytes: &[u8]) -> Result { - if bytes.len() != 32 { - return Err(InternalError::PointError); - } - CompressedRistretto::from_slice(bytes) + .map_err(|_| InternalError::PointError)? .decompress() .filter(|point| point != &RistrettoPoint::identity()) .ok_or(InternalError::PointError)