Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML-KEM] incremental API #757

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 281 additions & 3 deletions libcrux-ml-kem/src/ind_cca/incremental.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use super::{

pub mod types {
use core::array::from_fn;
use std::eprintln;
use std::{eprintln, vec::Vec};

use ind_cpa::unpacked::IndCpaPublicKeyUnpacked;

Expand Down Expand Up @@ -157,6 +157,58 @@ pub mod types {
pub(super) randomness: [u8; 32],
}

impl<const K: usize, Vector: Operations> EncapsState<K, Vector> {
/// Get the number of bytes, required for the state.
pub const fn num_bytes() -> usize {
SHARED_SECRET_SIZE
+ vec_len_bytes::<K, Vector>()
+ PolynomialRingElement::<Vector>::num_bytes()
+ 32
}

/// Get the state as bytes
pub fn to_bytes(self, out: &mut [u8]) {
debug_assert!(out.len() >= Self::num_bytes());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there's any chance that this serialization will change in the future, it might be prudent to add a "version byte" to the head of this serialization. These serializations may be stored persistently for long periods of time on user devices, and may be processed by different versions of the libcrux library. Should any changes occur here (including length changes, or changes in the layout of bytes or additional information), new library versions should be able to process old library serialization, and old library versions should at least be able to fail gracefully to process new serializations (return a Result?). Some initial byte saying "this is serialization version 1" could aid with that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a very good question. This is really just the flat byte representation of the state and don't see a reason this should ever change.
We could add a version byte. But I don't want to make things more complicated than necessary.
I think if something here changes, we should just create a new state struct.


out[..SHARED_SECRET_SIZE].copy_from_slice(&self.shared_secret);
let mut offset = SHARED_SECRET_SIZE;

vec_to_bytes(&self.r_as_ntt, &mut out[offset..]);
offset += vec_len_bytes::<K, Vector>();

self.error2.to_bytes(&mut out[offset..]);
offset += PolynomialRingElement::<Vector>::num_bytes();

out[offset..offset + 32].copy_from_slice(&self.randomness);
}

/// Build a state from bytes
pub fn from_bytes(bytes: &[u8]) -> Self {
debug_assert!(bytes.len() >= Self::num_bytes());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This particular assert is probably best as a Result, if there's any possibility of future serialization change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the functions here still panic all over the place. I'll start cleaning things up now.


let mut shared_secret = [0u8; SHARED_SECRET_SIZE];
shared_secret.copy_from_slice(&bytes[..SHARED_SECRET_SIZE]);
let mut offset = SHARED_SECRET_SIZE;

let mut r_as_ntt = from_fn(|_| PolynomialRingElement::<Vector>::ZERO());
vec_from_bytes(&bytes[offset..], &mut r_as_ntt);
offset += vec_len_bytes::<K, Vector>();

let error2 = PolynomialRingElement::<Vector>::from_bytes(&bytes[offset..]);
offset += PolynomialRingElement::<Vector>::num_bytes();

let mut randomness = [0u8; 32];
randomness.copy_from_slice(&bytes[offset..offset + 32]);

Self {
shared_secret,
r_as_ntt,
error2,
randomness,
}
}
}

/// Trait container for multiplexing over platform dependent [`EncapsState`].
pub trait State {
fn as_any(&self) -> &dyn Any;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This trait will also need to be stored/retrieved from persistent storage, so also needs some mechanism of being mapped to/from bytes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was one of my main questions. I'll add it then. It's computational overhead, that's why this is nicer. But I'll add something and try to keep it light and safe.

Expand Down Expand Up @@ -455,6 +507,36 @@ pub(crate) mod portable {
>(public_key_part, randomness)
}

pub(crate) fn encapsulate1_serialized<
const K: usize,
const CIPHERTEXT_SIZE: usize,
const C1_SIZE: usize,
const VECTOR_U_COMPRESSION_FACTOR: usize,
const VECTOR_U_BLOCK_LEN: usize,
const ETA1: usize,
const ETA1_RANDOMNESS_SIZE: usize,
const ETA2: usize,
const ETA2_RANDOMNESS_SIZE: usize,
>(
public_key_part: &PublicKey1,
randomness: [u8; SHARED_SECRET_SIZE],
state: &mut [u8],
) -> Ciphertext1<C1_SIZE> {
super::encapsulate1_serialized::<
K,
CIPHERTEXT_SIZE,
C1_SIZE,
VECTOR_U_COMPRESSION_FACTOR,
VECTOR_U_BLOCK_LEN,
ETA1,
ETA1_RANDOMNESS_SIZE,
ETA2,
ETA2_RANDOMNESS_SIZE,
Vector,
Hash<K>,
>(public_key_part, randomness, state)
}

pub(crate) fn encapsulate2<
const K: usize,
const C2_SIZE: usize,
Expand All @@ -469,6 +551,20 @@ pub(crate) mod portable {
)
}

pub(crate) fn encapsulate2_serialized<
const K: usize,
const C2_SIZE: usize,
const VECTOR_V_COMPRESSION_FACTOR: usize,
>(
state: &[u8],
public_key_part: &PublicKey2<K, Vector>,
) -> Ciphertext2<C2_SIZE> {
super::encapsulate2_serialized::<K, C2_SIZE, VECTOR_V_COMPRESSION_FACTOR, Vector>(
state,
public_key_part,
)
}

pub(crate) fn decapsulate<
const K: usize,
const SECRET_KEY_SIZE: usize,
Expand Down Expand Up @@ -656,6 +752,36 @@ pub(crate) mod neon {
>(public_key_part, randomness)
}

pub(crate) fn encapsulate1_serialized<
const K: usize,
const CIPHERTEXT_SIZE: usize,
const C1_SIZE: usize,
const VECTOR_U_COMPRESSION_FACTOR: usize,
const VECTOR_U_BLOCK_LEN: usize,
const ETA1: usize,
const ETA1_RANDOMNESS_SIZE: usize,
const ETA2: usize,
const ETA2_RANDOMNESS_SIZE: usize,
>(
public_key_part: &PublicKey1,
randomness: [u8; SHARED_SECRET_SIZE],
state: &mut [u8],
) -> Ciphertext1<C1_SIZE> {
super::encapsulate1_serialized::<
K,
CIPHERTEXT_SIZE,
C1_SIZE,
VECTOR_U_COMPRESSION_FACTOR,
VECTOR_U_BLOCK_LEN,
ETA1,
ETA1_RANDOMNESS_SIZE,
ETA2,
ETA2_RANDOMNESS_SIZE,
Vector,
Hash,
>(public_key_part, randomness, state)
}

pub(crate) fn encapsulate2<
const K: usize,
const C2_SIZE: usize,
Expand All @@ -670,6 +796,20 @@ pub(crate) mod neon {
)
}

pub(crate) fn encapsulate2_serialized<
const K: usize,
const C2_SIZE: usize,
const VECTOR_V_COMPRESSION_FACTOR: usize,
>(
state: &[u8],
public_key_part: &PublicKey2<K, Vector>,
) -> Ciphertext2<C2_SIZE> {
super::encapsulate2_serialized::<K, C2_SIZE, VECTOR_V_COMPRESSION_FACTOR, Vector>(
state,
public_key_part,
)
}

pub(crate) fn decapsulate<
const K: usize,
const SECRET_KEY_SIZE: usize,
Expand Down Expand Up @@ -912,7 +1052,8 @@ pub(crate) mod multiplexing {
use neon::{
as_neon_keypair, as_neon_state, decapsulate as decapsulate_neon,
decapsulate_incremental_key as decapsulate_incremental_key_neon,
encapsulate1 as encapsulate1_neon, encapsulate2 as encapsulate2_neon,
encapsulate1 as encapsulate1_neon, encapsulate1_serialized as encapsulate1_serialized_neon,
encapsulate2 as encapsulate2_neon, encapsulate2_serialized as encapsulate2_serialized_neon,
generate_keypair as generate_keypair_neon,
generate_keypair_serialized as generate_keypair_serialized_neon,
};
Expand All @@ -922,7 +1063,8 @@ pub(crate) mod multiplexing {
as_portable_keypair as as_avx2_keypair, as_portable_state as as_avx2_state,
decapsulate as decapsulate_avx2,
decapsulate_incremental_key as decapsulate_incremental_key_avx2,
encapsulate1 as encapsulate1_avx2, encapsulate2 as encapsulate2_avx2,
encapsulate1 as encapsulate1_avx2, encapsulate1_serialized as encapsulate1_serialized_avx2,
encapsulate2 as encapsulate2_avx2, encapsulate2_serialized as encapsulate2_serialized_avx2,
generate_keypair as generate_keypair_avx2,
generate_keypair_serialized as generate_keypair_serialized_avx2,
};
Expand Down Expand Up @@ -1079,6 +1221,60 @@ pub(crate) mod multiplexing {
}
}

pub(crate) fn encapsulate1_serialized<
const K: usize,
const CIPHERTEXT_SIZE: usize,
const C1_SIZE: usize,
const VECTOR_U_COMPRESSION_FACTOR: usize,
const VECTOR_U_BLOCK_LEN: usize,
const ETA1: usize,
const ETA1_RANDOMNESS_SIZE: usize,
const ETA2: usize,
const ETA2_RANDOMNESS_SIZE: usize,
>(
public_key_part: &PublicKey1,
randomness: [u8; SHARED_SECRET_SIZE],
state: &mut [u8],
) -> Ciphertext1<C1_SIZE> {
if libcrux_platform::simd256_support() {
encapsulate1_serialized_avx2::<
K,
CIPHERTEXT_SIZE,
C1_SIZE,
VECTOR_U_COMPRESSION_FACTOR,
VECTOR_U_BLOCK_LEN,
ETA1,
ETA1_RANDOMNESS_SIZE,
ETA2,
ETA2_RANDOMNESS_SIZE,
>(public_key_part, randomness, state)
} else if libcrux_platform::simd128_support() {
encapsulate1_serialized_neon::<
K,
CIPHERTEXT_SIZE,
C1_SIZE,
VECTOR_U_COMPRESSION_FACTOR,
VECTOR_U_BLOCK_LEN,
ETA1,
ETA1_RANDOMNESS_SIZE,
ETA2,
ETA2_RANDOMNESS_SIZE,
>(public_key_part, randomness, state)
} else {
portable::encapsulate1_serialized::<
K,
CIPHERTEXT_SIZE,
C1_SIZE,
VECTOR_U_COMPRESSION_FACTOR,
VECTOR_U_BLOCK_LEN,
ETA1,
ETA1_RANDOMNESS_SIZE,
ETA2,
ETA2_RANDOMNESS_SIZE,
>(public_key_part, randomness, state)
}
}

pub(crate) fn encapsulate2<
const K: usize,
const C2_SIZE: usize,
Expand All @@ -1105,6 +1301,37 @@ pub(crate) mod multiplexing {
>(state, &pk2))
}
}
pub(crate) fn encapsulate2_serialized<
const K: usize,
const C2_SIZE: usize,
const VECTOR_V_COMPRESSION_FACTOR: usize,
>(
state: &[u8],
public_key_part: &[u8],
) -> Result<Ciphertext2<C2_SIZE>, Error> {
if libcrux_platform::simd256_support() {
let pk2 = PublicKey2::try_from(public_key_part)?;
Ok(encapsulate2_serialized_avx2::<
K,
C2_SIZE,
VECTOR_V_COMPRESSION_FACTOR,
>(state, &pk2))
} else if libcrux_platform::simd128_support() {
let pk2 = PublicKey2::try_from(public_key_part)?;
Ok(encapsulate2_serialized_neon::<
K,
C2_SIZE,
VECTOR_V_COMPRESSION_FACTOR,
>(state, &pk2))
} else {
let pk2 = PublicKey2::try_from(public_key_part)?;
Ok(portable::encapsulate2_serialized::<
K,
C2_SIZE,
VECTOR_V_COMPRESSION_FACTOR,
>(state, &pk2))
}
}

pub(crate) fn decapsulate<
const K: usize,
Expand Down Expand Up @@ -1432,6 +1659,44 @@ pub(crate) fn encapsulate1<
(Ciphertext1 { value: ciphertext }, state)
}

pub(crate) fn encapsulate1_serialized<
const K: usize,
const CIPHERTEXT_SIZE: usize,
const C1_SIZE: usize,
const VECTOR_U_COMPRESSION_FACTOR: usize,
const VECTOR_U_BLOCK_LEN: usize,
const ETA1: usize,
const ETA1_RANDOMNESS_SIZE: usize,
const ETA2: usize,
const ETA2_RANDOMNESS_SIZE: usize,
Vector: Operations,
Hasher: Hash<K>,
>(
public_key_part: &PublicKey1,
randomness: [u8; SHARED_SECRET_SIZE],
state: &mut [u8],
) -> Ciphertext1<C1_SIZE> {
let (ct1, encaps_state) = encapsulate1::<
K,
CIPHERTEXT_SIZE,
C1_SIZE,
VECTOR_U_COMPRESSION_FACTOR,
VECTOR_U_BLOCK_LEN,
ETA1,
ETA1_RANDOMNESS_SIZE,
ETA2,
ETA2_RANDOMNESS_SIZE,
Vector,
Hasher,
>(public_key_part, randomness);

// Write out the state
encaps_state.to_bytes(state);

// Return the ciphertext
ct1
}

pub(crate) fn encapsulate2<
const K: usize,
const C2_SIZE: usize,
Expand All @@ -1453,6 +1718,19 @@ pub(crate) fn encapsulate2<
Ciphertext2 { value: ciphertext }
}

pub(crate) fn encapsulate2_serialized<
const K: usize,
const C2_SIZE: usize,
const VECTOR_V_COMPRESSION_FACTOR: usize,
Vector: Operations,
>(
state: &[u8],
public_key_part: &PublicKey2<K, Vector>,
) -> Ciphertext2<C2_SIZE> {
let state = EncapsState::from_bytes(state);
encapsulate2::<K, C2_SIZE, VECTOR_V_COMPRESSION_FACTOR, Vector>(&state, public_key_part)
}

pub(crate) fn decapsulate<
const K: usize,
const SECRET_KEY_SIZE: usize,
Expand Down
Loading
Loading