From ed949d7c07bfe6857a64a6309dd4dcf22c016ef2 Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Thu, 16 Jan 2025 10:34:21 -0800 Subject: [PATCH] Mastic clean up (#1187) * mastic: Remove `verify_key` from tests that don't use it Some tests in the `mastic` module generate verification keys that aren't actually used by the test. Incidentally, the same tests were generating keys of the wrong length. * mastic: De-duplicate agg share length computation Add a function that computes the length of the aggregate share in field elements as a function of the aggregation parameter. * vidpf: Improve `VidpfPublicShare::encoded_len()` Avoid iterating over the weights to compute the length of the encoded public share; just take the length of the first weight and multiply by the number of correction words. This computation assumes the length of each weight is equal to the weight parameter at ever level of the VIDPF tree. This certainly is true, but add a test to validate this assumption anyway. * vidpf: Move `eval_prefix_tree_with_siblings()` to `impl` This method is currently implemented for `Vidpf>`, but it applies to the more general `Vidpf`. * vdaf: Remove `domain_separation_tag()` from `Vdaf` trait This method is used in Prio3 and Poplar1 for domain separation with the version of the document that specifies them. This version control is not applicable to future VDAFs defined by future documents. Remove the method from the trait and add it to implementations of `Prio3` and `Poplar1`. * vidpf: Rename `weight_parameter` to `weight_len` The associated type `ValueParameter` is likely always going to be a `usize` that expresses the length. In the future we might consider hardcoding this change in the API. --- src/vdaf.rs | 12 --- src/vdaf/mastic.rs | 53 +++---------- src/vdaf/poplar1.rs | 14 +++- src/vdaf/prio3.rs | 14 +++- src/vidpf.rs | 184 ++++++++++++++++++++++++-------------------- 5 files changed, 138 insertions(+), 139 deletions(-) diff --git a/src/vdaf.rs b/src/vdaf.rs index b53a6f58..c992869b 100644 --- a/src/vdaf.rs +++ b/src/vdaf.rs @@ -197,18 +197,6 @@ pub trait Vdaf: Clone + Debug { /// The number of Aggregators. The Client generates as many input shares as there are /// Aggregators. fn num_aggregators(&self) -> usize; - - /// Generate the domain separation tag for this VDAF. The output is used for domain separation - /// by the XOF. - fn domain_separation_tag(&self, usage: u16) -> [u8; 8] { - // Prefix is 8 bytes and defined by the spec. Copy these values in - let mut dst = [0; 8]; - dst[0] = VERSION; - dst[1] = 0; // algorithm class - dst[2..6].clone_from_slice(self.algorithm_id().to_be_bytes().as_slice()); - dst[6..8].clone_from_slice(usage.to_be_bytes().as_slice()); - dst - } } /// The Client's role in the execution of a VDAF. diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs index bfb4ddf1..df43bbb5 100644 --- a/src/vdaf/mastic.rs +++ b/src/vdaf/mastic.rs @@ -94,6 +94,11 @@ where bits, }) } + + fn agg_share_len(&self, agg_param: &MasticAggregationParam) -> usize { + // The aggregate share consists of the counter and truncated weight for each candidate prefix. + (1 + self.szk.typ.output_len()) * agg_param.level_and_prefixes.prefixes().len() + } } /// Mastic aggregation parameter. @@ -158,7 +163,7 @@ where mastic: &Mastic, bytes: &mut Cursor<&[u8]>, ) -> Result { - VidpfPublicShare::decode_with_param(&(mastic.bits, mastic.vidpf.weight_parameter), bytes) + VidpfPublicShare::decode_with_param(&(mastic.bits, mastic.vidpf.weight_len), bytes) } } @@ -252,8 +257,7 @@ where (mastic, agg_param): &(&Mastic, &MasticAggregationParam), bytes: &mut Cursor<&[u8]>, ) -> Result { - let len = (1 + mastic.szk.typ.output_len()) * agg_param.level_and_prefixes.prefixes().len(); - decode_fieldvec(len, bytes).map(AggregateShare) + decode_fieldvec(mastic.agg_share_len(agg_param), bytes).map(AggregateShare) } } @@ -268,8 +272,7 @@ where (mastic, agg_param): &(&Mastic, &MasticAggregationParam), bytes: &mut Cursor<&[u8]>, ) -> Result { - let len = (1 + mastic.szk.typ.output_len()) * agg_param.level_and_prefixes.prefixes().len(); - decode_fieldvec(len, bytes).map(OutputShare) + decode_fieldvec(mastic.agg_share_len(agg_param), bytes).map(OutputShare) } } @@ -425,10 +428,10 @@ impl<'a, T: Type, P: Xof, const SEED_SIZE: usize> for MasticPrepareState { fn decode_with_param( - (mastic, agg_param): &(&Mastic, &MasticAggregationParam), + decoder @ (mastic, agg_param): &(&Mastic, &MasticAggregationParam), bytes: &mut Cursor<&[u8]>, ) -> Result { - let output_shares = MasticOutputShare::decode_with_param(&(*mastic, *agg_param), bytes)?; + let output_shares = MasticOutputShare::decode_with_param(decoder, bytes)?; let szk_query_state = (mastic.szk.typ.joint_rand_len() > 0 && agg_param.require_weight_check) .then(|| Seed::decode(bytes)) @@ -774,14 +777,7 @@ where output_shares: M, ) -> Result, VdafError> { let mut agg_share = - MasticAggregateShare::::from(vec![ - T::Field::zero(); - (1 + self.szk.typ.output_len()) - * agg_param - .level_and_prefixes - .prefixes() - .len() - ]); + MasticAggregateShare::from(vec![T::Field::zero(); self.agg_share_len(agg_param)]); for output_share in output_shares.into_iter() { agg_share.accumulate(&output_share)?; } @@ -803,10 +799,7 @@ where let num_prefixes = agg_param.level_and_prefixes.prefixes().len(); let AggregateShare(agg) = agg_shares.into_iter().try_fold( - AggregateShare(vec![ - T::Field::zero(); - num_prefixes * (1 + self.szk.typ.output_len()) - ]), + AggregateShare(vec![T::Field::zero(); self.agg_share_len(agg_param)]), |mut agg, agg_share| { agg.merge(&agg_share)?; Result::<_, VdafError>::Ok(agg) @@ -866,8 +859,6 @@ mod tests { let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sum_typ, 32).unwrap(); let mut nonce = [0u8; 16]; - let mut verify_key = [0u8; 16]; - thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); let inputs = [ @@ -947,8 +938,6 @@ mod tests { let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sum_typ, 32).unwrap(); let mut nonce = [0u8; 16]; - let mut verify_key = [0u8; 16]; - thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); @@ -1000,8 +989,6 @@ mod tests { let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sum_typ, 32).unwrap(); let mut nonce = [0u8; 16]; - let mut verify_key = [0u8; 16]; - thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); @@ -1023,8 +1010,6 @@ mod tests { let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, count, 32).unwrap(); let mut nonce = [0u8; 16]; - let mut verify_key = [0u8; 16]; - thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); let inputs = [ @@ -1102,8 +1087,6 @@ mod tests { let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, count, 32).unwrap(); let mut nonce = [0u8; 16]; - let mut verify_key = [0u8; 16]; - thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); @@ -1122,8 +1105,6 @@ mod tests { let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, count, 32).unwrap(); let mut nonce = [0u8; 16]; - let mut verify_key = [0u8; 16]; - thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); @@ -1144,8 +1125,6 @@ mod tests { let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32).unwrap(); let mut nonce = [0u8; 16]; - let mut verify_key = [0u8; 16]; - thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); let inputs = [ @@ -1234,8 +1213,6 @@ mod tests { let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32).unwrap(); let mut nonce = [0u8; 16]; - let mut verify_key = [0u8; 16]; - thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); @@ -1265,8 +1242,6 @@ mod tests { let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32).unwrap(); let mut nonce = [0u8; 16]; - let mut verify_key = [0u8; 16]; - thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); @@ -1298,8 +1273,6 @@ mod tests { let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32).unwrap(); let mut nonce = [0u8; 16]; - let mut verify_key = [0u8; 16]; - thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); @@ -1323,8 +1296,6 @@ mod tests { let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32).unwrap(); let mut nonce = [0u8; 16]; - let mut verify_key = [0u8; 16]; - thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index 49d36654..f6283c09 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -11,7 +11,7 @@ use crate::{ prng::Prng, vdaf::{ xof::{Seed, Xof, XofTurboShake128}, - Aggregatable, Aggregator, Client, Collector, PrepareTransition, Vdaf, VdafError, + Aggregatable, Aggregator, Client, Collector, PrepareTransition, Vdaf, VdafError, VERSION, }, }; use rand_core::RngCore; @@ -862,6 +862,18 @@ impl, const SEED_SIZE: usize> Vdaf for Poplar1 { } impl, const SEED_SIZE: usize> Poplar1 { + /// Generate the domain separation tag for this VDAF. The output is used for domain separation + /// by the XOF. + fn domain_separation_tag(&self, usage: u16) -> [u8; 8] { + // Prefix is 8 bytes and defined by the spec. Copy these values in + let mut dst = [0; 8]; + dst[0] = VERSION; + dst[1] = 0; // algorithm class + dst[2..6].clone_from_slice(self.algorithm_id().to_be_bytes().as_slice()); + dst[6..8].clone_from_slice(usage.to_be_bytes().as_slice()); + dst + } + fn shard_with_random( &self, ctx: &[u8], diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index d602a0e7..4457e708 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -54,7 +54,7 @@ use crate::prng::Prng; use crate::vdaf::xof::{IntoFieldVec, Seed, Xof}; use crate::vdaf::{ Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition, - Share, ShareDecodingParameter, Vdaf, VdafError, + Share, ShareDecodingParameter, Vdaf, VdafError, VERSION, }; #[cfg(feature = "experimental")] use fixed::traits::Fixed; @@ -548,6 +548,18 @@ where .into_field_vec(self.typ.query_rand_len() * self.num_proofs()) } + /// Generate the domain separation tag for this VDAF. The output is used for domain separation + /// by the XOF. + fn domain_separation_tag(&self, usage: u16) -> [u8; 8] { + // Prefix is 8 bytes and defined by the spec. Copy these values in + let mut dst = [0; 8]; + dst[0] = VERSION; + dst[1] = 0; // algorithm class + dst[2..6].clone_from_slice(self.algorithm_id().to_be_bytes().as_slice()); + dst[6..8].clone_from_slice(usage.to_be_bytes().as_slice()); + dst + } + fn random_size(&self) -> usize { if self.typ.joint_rand_len() == 0 { // One seed per helper (share, proof) pair, plus one seed for proving randomness diff --git a/src/vidpf.rs b/src/vidpf.rs index 3e52d35e..6fd6837f 100644 --- a/src/vidpf.rs +++ b/src/vidpf.rs @@ -62,7 +62,7 @@ pub trait VidpfValue: IdpfValue + Clone + Debug + PartialEq + ConstantTimeEq {} /// An instance of the VIDPF. pub struct Vidpf { pub(crate) bits: u16, - pub(crate) weight_parameter: W::ValueParameter, + pub(crate) weight_len: W::ValueParameter, } impl Vidpf { @@ -71,13 +71,10 @@ impl Vidpf { /// # Arguments /// /// * `bits`, the length of the input in bits. - /// * `weight_parameter`, the length of the weight in number of field elements. - pub fn new(bits: usize, weight_parameter: W::ValueParameter) -> Result { + /// * `weight_len`, the length of the weight in number of field elements. + pub fn new(bits: usize, weight_len: W::ValueParameter) -> Result { let bits = u16::try_from(bits).map_err(|_| VidpfError::BitLengthTooLong)?; - Ok(Self { - bits, - weight_parameter, - }) + Ok(Self { bits, weight_len }) } /// Splits an incremental point function `F` into two private keys @@ -201,7 +198,7 @@ impl Vidpf { let mut r = VidpfEvalResult { state: VidpfEvalState::init_from_key(id, key), - share: W::zero(&self.weight_parameter), // not used + share: W::zero(&self.weight_len), // not used }; if input.len() > public.cw.len() { @@ -247,7 +244,7 @@ impl Vidpf { // Convert and correct the payload. let (next_seed, w) = self.convert(seed_keep, ctx, nonce); let mut weight = ::conditional_select( - &::zero(&self.weight_parameter), + &::zero(&self.weight_len), &cw.weight, next_ctrl, ); @@ -298,6 +295,70 @@ impl Vidpf { Ok(beta_share) } + /// Ensure `prefix_tree` contains the prefix tree for `prefixes`, as well as the sibling of + /// each node in the prefix tree. The return value is the weights for the prefixes + /// concatenated together. + #[allow(clippy::too_many_arguments)] + pub(crate) fn eval_prefix_tree_with_siblings( + &self, + ctx: &[u8], + id: VidpfServerId, + public: &VidpfPublicShare, + key: &VidpfKey, + nonce: &[u8], + prefixes: &[VidpfInput], + prefix_tree: &mut BinaryTree>, + ) -> Result, VidpfError> { + let mut out_shares = Vec::with_capacity(prefixes.len()); + + for prefix in prefixes { + if prefix.len() > public.cw.len() { + return Err(VidpfError::InvalidInputLength); + } + + let mut sub_tree = prefix_tree.root.get_or_insert_with(|| { + Box::new(Node::new(VidpfEvalResult { + state: VidpfEvalState::init_from_key(id, key), + share: W::zero(&self.weight_len), // not used + })) + }); + + for (idx, cw) in self.index_iter(prefix)?.zip(public.cw.iter()) { + let left = sub_tree.left.get_or_insert_with(|| { + Box::new(Node::new(self.eval_next( + ctx, + cw, + idx.left_sibling(), + &sub_tree.value.state, + nonce, + ))) + }); + let right = sub_tree.right.get_or_insert_with(|| { + Box::new(Node::new(self.eval_next( + ctx, + cw, + idx.right_sibling(), + &sub_tree.value.state, + nonce, + ))) + }); + + sub_tree = if idx.bit.unwrap_u8() == 0 { + left + } else { + right + }; + } + + out_shares.push(sub_tree.value.share.clone()); + } + + for out_share in out_shares.iter_mut() { + out_share.conditional_negate(Choice::from(id)); + } + Ok(out_shares) + } + fn extend(seed: VidpfSeed, ctx: &[u8], nonce: &[u8]) -> ExtendedSeed { let mut rng = XofFixedKeyAes128::seed_stream( &Seed(seed), @@ -334,7 +395,7 @@ impl Vidpf { let mut next_seed = VidpfSeed::default(); seed_stream.fill_bytes(&mut next_seed); - let weight = W::generate(&mut seed_stream, &self.weight_parameter); + let weight = W::generate(&mut seed_stream, &self.weight_len); (next_seed, weight) } @@ -371,72 +432,6 @@ impl Vidpf { } } -impl Vidpf> { - /// Ensure `prefix_tree` contains the prefix tree for `prefixes`, as well as the sibling of - /// each node in the prefix tree. The return value is the weights for the prefixes - /// concatenated together. - #[allow(clippy::too_many_arguments)] - pub(crate) fn eval_prefix_tree_with_siblings( - &self, - ctx: &[u8], - id: VidpfServerId, - public: &VidpfPublicShare>, - key: &VidpfKey, - nonce: &[u8], - prefixes: &[VidpfInput], - prefix_tree: &mut BinaryTree>>, - ) -> Result>, VidpfError> { - let mut out_shares = Vec::with_capacity(prefixes.len()); - - for prefix in prefixes { - if prefix.len() > public.cw.len() { - return Err(VidpfError::InvalidInputLength); - } - - let mut sub_tree = prefix_tree.root.get_or_insert_with(|| { - Box::new(Node::new(VidpfEvalResult { - state: VidpfEvalState::init_from_key(id, key), - share: VidpfWeight::zero(&self.weight_parameter), // not used - })) - }); - - for (idx, cw) in self.index_iter(prefix)?.zip(public.cw.iter()) { - let left = sub_tree.left.get_or_insert_with(|| { - Box::new(Node::new(self.eval_next( - ctx, - cw, - idx.left_sibling(), - &sub_tree.value.state, - nonce, - ))) - }); - let right = sub_tree.right.get_or_insert_with(|| { - Box::new(Node::new(self.eval_next( - ctx, - cw, - idx.right_sibling(), - &sub_tree.value.state, - nonce, - ))) - }); - - sub_tree = if idx.bit.unwrap_u8() == 0 { - left - } else { - right - }; - } - - out_shares.push(sub_tree.value.share.clone()); - } - - for out_share in out_shares.iter_mut() { - out_share.conditional_negate(Choice::from(id)); - } - Ok(out_shares) - } -} - /// VIDPF key. /// /// Private key of an aggregation server. @@ -528,18 +523,24 @@ impl Encode for VidpfPublicShare { } fn encoded_len(&self) -> Option { - let control_bits_count = self.cw.len() * 2; - let mut len = (control_bits_count + 7) / 8 + self.cw.len() * 48; - for cw in self.cw.iter() { - len += cw.weight.encoded_len()?; - } + // We assume the weight has the same length at each level of the tree. + let weight_len = self + .cw + .first() + .map_or(Some(0), |cw| cw.weight.encoded_len())?; + + let mut len = 0; + len += (2 * self.cw.len() + 7) / 8; // packed control bits + len += self.cw.len() * VIDPF_SEED_SIZE; // seeds + len += self.cw.len() * weight_len; // weights + len += self.cw.len() * VIDPF_PROOF_SIZE; // nod proofs Some(len) } } impl ParameterizedDecode<(usize, W::ValueParameter)> for VidpfPublicShare { fn decode_with_param( - (bits, weight_parameter): &(usize, W::ValueParameter), + (bits, weight_len): &(usize, W::ValueParameter), bytes: &mut Cursor<&[u8]>, ) -> Result { let packed_control_len = (bits + 3) / 4; @@ -564,7 +565,7 @@ impl ParameterizedDecode<(usize, W::ValueParameter)> for VidpfPub .collect::, _>>()?; // Weights - let weights = std::iter::repeat_with(|| W::decode_with_param(weight_parameter, bytes)) + let weights = std::iter::repeat_with(|| W::decode_with_param(weight_len, bytes)) .take(*bits) .collect::, _>>()?; @@ -864,7 +865,10 @@ mod tests { use crate::{ codec::{Encode, ParameterizedDecode}, idpf::IdpfValue, - vidpf::{Vidpf, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare, VidpfServerId}, + vidpf::{ + Vidpf, VidpfCorrectionWord, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare, + VidpfServerId, + }, }; use super::{TestWeight, TEST_NONCE, TEST_NONCE_SIZE, TEST_WEIGHT_LEN}; @@ -990,6 +994,18 @@ mod tests { state_1 = r1.state; } } + + // Assert that the length of the weight is the same at each level of the tree. This + // assumption is made in `PublicShare::encoded_len()`. + #[test] + fn public_share_weight_len() { + let input = VidpfInput::from_bools(&vec![false; 237]); + let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]); + let (vidpf, public, _, _) = vidpf_gen_setup(b"some application", &input, &weight); + for VidpfCorrectionWord { weight, .. } in public.cw { + assert_eq!(weight.0.len(), vidpf.weight_len); + } + } } mod weight {