diff --git a/crypto-primitives/src/sponge/poseidon/constraints.rs b/crypto-primitives/src/sponge/poseidon/constraints.rs index 746eae6..be53260 100644 --- a/crypto-primitives/src/sponge/poseidon/constraints.rs +++ b/crypto-primitives/src/sponge/poseidon/constraints.rs @@ -169,12 +169,13 @@ impl PoseidonSpongeVar { ..(self.parameters.capacity + num_elements_squeezed + rate_start_index)], ); + // Repeat with updated output slices and rate start index + remaining_output = &mut remaining_output[num_elements_squeezed..]; + // Unless we are done with squeezing in this call, permute. - if remaining_output.len() != self.parameters.rate { + if !remaining_output.is_empty() { self.permute()?; } - // Repeat with updated output slices and rate start index - remaining_output = &mut remaining_output[num_elements_squeezed..]; rate_start_index = 0; } } diff --git a/crypto-primitives/src/sponge/poseidon/mod.rs b/crypto-primitives/src/sponge/poseidon/mod.rs index eee18b9..7740001 100644 --- a/crypto-primitives/src/sponge/poseidon/mod.rs +++ b/crypto-primitives/src/sponge/poseidon/mod.rs @@ -174,12 +174,13 @@ impl PoseidonSponge { ..(self.parameters.capacity + num_elements_squeezed + rate_start_index)], ); + // Repeat with updated output slices + output_remaining = &mut output_remaining[num_elements_squeezed..]; // Unless we are done with squeezing in this call, permute. - if output_remaining.len() != self.parameters.rate { + if !output_remaining.is_empty() { self.permute(); } - // Repeat with updated output slices - output_remaining = &mut output_remaining[num_elements_squeezed..]; + rate_start_index = 0; } } diff --git a/crypto-primitives/src/sponge/poseidon/tests.rs b/crypto-primitives/src/sponge/poseidon/tests.rs index cc44fd7..73955d6 100644 --- a/crypto-primitives/src/sponge/poseidon/tests.rs +++ b/crypto-primitives/src/sponge/poseidon/tests.rs @@ -1,10 +1,243 @@ -use crate::sponge::poseidon::{PoseidonConfig, PoseidonSponge}; +use crate::sponge::poseidon::{PoseidonConfig, PoseidonDefaultConfigField, PoseidonSponge}; use crate::sponge::test::Fr; use crate::sponge::{Absorb, AbsorbWithLength, CryptographicSponge, FieldBasedCryptographicSponge}; use crate::{absorb, collect_sponge_bytes, collect_sponge_field_elements}; use ark_ff::{One, PrimeField, UniformRand}; use ark_std::test_rng; +#[test] +// Remove once this PR matures +fn demo_bug() { + let sponge_params = Fr::get_default_poseidon_parameters(2, false).unwrap(); + + let rng = &mut test_rng(); + let input = (0..3).map(|_| Fr::rand(rng)).collect::>(); + + // works good + let e0 = { + let mut sponge = PoseidonSponge::::new(&sponge_params); + sponge.absorb(&input); + sponge.squeeze_native_field_elements(3) + }; + + // works good + let e1 = { + let mut sponge = PoseidonSponge::::new(&sponge_params); + sponge.absorb(&input); + let e0 = sponge.squeeze_native_field_elements(1); + let e1 = sponge.squeeze_native_field_elements(1); + let e2 = sponge.squeeze_native_field_elements(1); + e0.iter() + .chain(e1.iter()) + .chain(e2.iter()) + .cloned() + .collect::>() + }; + + // also works good + let e2 = { + let mut sponge = PoseidonSponge::::new(&sponge_params); + sponge.absorb(&input); + + let e0 = sponge.squeeze_native_field_elements(2); + let e1 = sponge.squeeze_native_field_elements(1); + e0.iter().chain(e1.iter()).cloned().collect::>() + }; + + // skips a permutation if sponge + // * in squeezing mode + // * number of elements are equal to rate + let e3 = { + let mut sponge = PoseidonSponge::::new(&sponge_params); + sponge.absorb(&input); + let e0 = sponge.squeeze_native_field_elements(1); + let e1 = sponge.squeeze_native_field_elements(2); + e0.iter().chain(e1.iter()).cloned().collect::>() + }; + + assert_eq!(e0, e1); + assert_eq!(e0, e2); + assert_eq!(e0, e3); // this will fail +} + +// Remove once this PR matures +fn run_cross_test(cfg: &PoseidonConfig) { + #[derive(Debug, PartialEq, Eq)] + enum SpongeMode { + Absorbing, + Squeezing, + } + + #[derive(Clone, Debug)] + struct Reference { + cfg: PoseidonConfig, + state: Vec, + absorbing: Vec, + squeeze_count: Option, + } + + // workaround to permute a state + fn permute(cfg: &PoseidonConfig, state: &mut [F]) { + let mut sponge = PoseidonSponge::new(&cfg); + sponge.state.copy_from_slice(state); + sponge.permute(); + state.copy_from_slice(&sponge.state) + } + + impl Reference { + fn new(cfg: &PoseidonConfig) -> Self { + let t = cfg.rate + cfg.capacity; + let state = vec![F::zero(); t]; + Self { + cfg: cfg.clone(), + state, + absorbing: Vec::new(), + squeeze_count: None, + } + } + + fn mode(&self) -> SpongeMode { + match self.squeeze_count { + Some(_) => { + assert!(self.absorbing.is_empty()); + SpongeMode::Squeezing + } + None => SpongeMode::Absorbing, + } + } + + fn absorb(&mut self, input: &[F]) { + if !input.is_empty() { + match self.mode() { + SpongeMode::Absorbing => self.absorbing.extend_from_slice(input), + SpongeMode::Squeezing => { + // Wash the state as mode changes + // This is not appied in SAFE sponge + permute(&self.cfg, &mut self.state); + // Append inputs to the absorbing line + self.absorbing.extend_from_slice(input); + // Change mode to absorbing + self.squeeze_count = None; + } + } + } + } + + fn _absorb(&mut self) { + let rate = self.cfg.rate; + self.absorbing.chunks(rate).for_each(|chunk| { + self.state + .iter_mut() + .skip(self.cfg.capacity) + .zip(chunk.iter()) + .for_each(|(s, c)| *s += *c); + permute(&self.cfg, &mut self.state); + }); + + // This case can only happen in the begining when the absorbing line is empty + // and user wants to squeeze elements. Notice that after moving to squueze mode + // if user calls absorb again with empty input it will be ignored + self.absorbing + .is_empty() + .then(|| permute(&self.cfg, &mut self.state)); + + // flush the absorbing line + self.absorbing.clear(); + + // Change to the squeezing mode + assert_eq!(self.mode(), SpongeMode::Absorbing); + self.squeeze_count = Some(0); + } + + pub fn squeeze(&mut self, n: usize) -> Vec { + match self.mode() { + SpongeMode::Absorbing => self._absorb(), + SpongeMode::Squeezing => { + assert!(self.absorbing.is_empty()); + assert!(self.squeeze_count.is_some()); + + // ??? + // **This seems nonsense to me** + // If, + // * number of squeeze is zero AND + // * in squeezing mode AND + // * output index is is at `rate` + // it applies a useless permutation. + // This is also not appied in SAFE sponge + + if n == 0 { + let squeeze_count = self.squeeze_count.unwrap(); + let out_index = self.squeeze_count.unwrap() % self.cfg.rate; + (out_index == 0 && squeeze_count != 0).then(|| { + permute(&self.cfg, &mut self.state); + self.squeeze_count = Some(0); + }); + } + } + } + + let rate = self.cfg.rate; + let mut output = Vec::new(); + for _ in 0..n { + let squeeze_count = self.squeeze_count.unwrap(); + let out_index = squeeze_count % rate; + + // proceed with a permutation if + // * the rate is full + // * and it is not the first output + (out_index == 0 && squeeze_count != 0).then(|| permute(&self.cfg, &mut self.state)); + + // skip the capacity elements + let out_index = out_index + self.cfg.capacity; + output.push(self.state[out_index]); + self.squeeze_count.as_mut().map(|c| *c += 1); + } + + output + } + } + + let mut sponge = PoseidonSponge::new(cfg); + let mut sponge_ref = Reference::new(cfg); + let mut rng = test_rng(); + + for _ in 0..1000 { + let test = (0..100) + .map(|_| { + use crate::ark_std::rand::Rng; + let do_absorb = rng.gen_bool(0.5); + let do_squeeze = rng.gen_bool(0.5); + + ( + (do_absorb, rng.gen_range(0..=cfg.rate * 2 + 1)), + (do_squeeze, rng.gen_range(0..=cfg.rate * 2 + 1)), + ) + }) + .collect::>(); + + // fuzz fuzz + for (_i, ((do_absorb, n_absorb), (do_squeeze, n_squeeze))) in test.into_iter().enumerate() { + do_absorb.then(|| { + let inputs = (0..n_absorb).map(|_| F::rand(&mut rng)).collect::>(); + sponge_ref.absorb(&inputs); + sponge.absorb(&inputs); + }); + do_squeeze.then(|| { + let out0 = sponge_ref.squeeze(n_squeeze); + let out1 = sponge.squeeze_field_elements(n_squeeze); + assert_eq!(out0, out1); + }); + } + } +} + +#[test] +// Remove once this PR matures +fn test_cross() { + let cfg = Fr::get_default_poseidon_parameters(2, false).unwrap(); + run_cross_test::(&cfg); +} + fn assert_different_encodings(a: &A, b: &A) { let bytes1 = a.to_sponge_bytes_as_vec(); let bytes2 = b.to_sponge_bytes_as_vec();