From ceeb19dbd5dc6469c958f08c495f8706f15885cb Mon Sep 17 00:00:00 2001 From: Michael Rosenberg Date: Mon, 9 Dec 2024 16:09:40 +0100 Subject: [PATCH] Added `max_measurement` field to `Prio3Sum` type (#1150) --- benches/cycle_counts.rs | 5 +- benches/speed_tests.rs | 11 +- binaries/src/bin/vdaf_message_sizes.rs | 6 +- src/field.rs | 15 +++ src/field/field255.rs | 6 + src/flp/szk.rs | 12 +- src/flp/types.rs | 164 ++++++++++++++++++------- src/vdaf/mastic.rs | 14 ++- src/vdaf/prio3.rs | 64 +++++----- src/vdaf/prio3_test.rs | 5 +- 10 files changed, 207 insertions(+), 95 deletions(-) diff --git a/benches/cycle_counts.rs b/benches/cycle_counts.rs index 8b1b3c184..a4598bed1 100644 --- a/benches/cycle_counts.rs +++ b/benches/cycle_counts.rs @@ -125,8 +125,9 @@ fn prio3_client_histogram_10() -> Vec> { .1 } -fn prio3_client_sum_32() -> Vec> { - let prio3 = Prio3::new_sum(2, 16).unwrap(); +fn prio3_client_sum_32() -> Vec> { + let bits = 16; + let prio3 = Prio3::new_sum(2, (1 << bits) - 1).unwrap(); let measurement = 1337; let nonce = [0; 16]; prio3 diff --git a/benches/speed_tests.rs b/benches/speed_tests.rs index 053cabb2e..94dd5d183 100644 --- a/benches/speed_tests.rs +++ b/benches/speed_tests.rs @@ -198,8 +198,10 @@ fn prio3(c: &mut Criterion) { let mut group = c.benchmark_group("prio3sum_shard"); for bits in [8, 32] { group.bench_with_input(BenchmarkId::from_parameter(bits), &bits, |b, bits| { - let vdaf = Prio3::new_sum(num_shares, *bits).unwrap(); - let measurement = (1 << bits) - 1; + // Doesn't matter for speed what we use for max measurement, or measurement + let max_measurement = (1 << bits) - 1; + let vdaf = Prio3::new_sum(num_shares, max_measurement).unwrap(); + let measurement = max_measurement; let nonce = black_box([0u8; 16]); b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }); @@ -209,8 +211,9 @@ fn prio3(c: &mut Criterion) { let mut group = c.benchmark_group("prio3sum_prepare_init"); for bits in [8, 32] { group.bench_with_input(BenchmarkId::from_parameter(bits), &bits, |b, bits| { - let vdaf = Prio3::new_sum(num_shares, *bits).unwrap(); - let measurement = (1 << bits) - 1; + let max_measurement = (1 << bits) - 1; + let vdaf = Prio3::new_sum(num_shares, max_measurement).unwrap(); + let measurement = max_measurement; let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); let (public_share, input_shares) = vdaf.shard(b"", &measurement, &nonce).unwrap(); diff --git a/binaries/src/bin/vdaf_message_sizes.rs b/binaries/src/bin/vdaf_message_sizes.rs index 998f15722..940be79a4 100644 --- a/binaries/src/bin/vdaf_message_sizes.rs +++ b/binaries/src/bin/vdaf_message_sizes.rs @@ -42,12 +42,12 @@ fn main() { ) ); - let bits = 32; - let prio3 = Prio3::new_sum(num_shares, bits).unwrap(); + let max_measurement = 0xffff_ffff; + let prio3 = Prio3::new_sum(num_shares, max_measurement).unwrap(); let measurement = 1337; println!( "prio3 sum ({} bits) share size = {}", - bits, + max_measurement.ilog2() + 1, vdaf_input_share_size::( prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap() ) diff --git a/src/field.rs b/src/field.rs index 7b8460a4b..88bf40ff1 100644 --- a/src/field.rs +++ b/src/field.rs @@ -201,6 +201,9 @@ pub trait Integer: /// Returns one. fn one() -> Self; + + /// Returns ⌊log₂(self)⌋, or `None` if `self == 0` + fn checked_ilog2(&self) -> Option; } /// Extension trait for field elements that can be converted back and forth to an integer type. @@ -785,6 +788,10 @@ impl Integer for u32 { fn one() -> Self { 1 } + + fn checked_ilog2(&self) -> Option { + u32::checked_ilog2(*self) + } } impl Integer for u64 { @@ -798,6 +805,10 @@ impl Integer for u64 { fn one() -> Self { 1 } + + fn checked_ilog2(&self) -> Option { + u64::checked_ilog2(*self) + } } impl Integer for u128 { @@ -811,6 +822,10 @@ impl Integer for u128 { fn one() -> Self { 1 } + + fn checked_ilog2(&self) -> Option { + u128::checked_ilog2(*self) + } } make_field!( diff --git a/src/field/field255.rs b/src/field/field255.rs index 65fb443e2..8a3f74bda 100644 --- a/src/field/field255.rs +++ b/src/field/field255.rs @@ -388,6 +388,12 @@ mod tests { fn one() -> Self { Self::new(Vec::from([1])) } + + fn checked_ilog2(&self) -> Option { + // This is a test module, and this code is never used. If we need this in the future, + // use BigUint::bits() + unimplemented!() + } } impl TestFieldElementWithInteger for Field255 { diff --git a/src/flp/szk.rs b/src/flp/szk.rs index e25598d0f..4531d3bf9 100644 --- a/src/flp/szk.rs +++ b/src/flp/szk.rs @@ -794,8 +794,9 @@ mod tests { #[test] fn test_sum_proof_share_encode() { let mut nonce = [0u8; 16]; + let max_measurement = 13; thread_rng().fill(&mut nonce[..]); - let sum = Sum::::new(5).unwrap(); + let sum = Sum::::new(max_measurement).unwrap(); let encoded_measurement = sum.encode_measurement(&9).unwrap(); let algorithm_id = 5; let szk_typ = Szk::new_turboshake128(sum, algorithm_id); @@ -896,9 +897,10 @@ mod tests { #[test] fn test_sum_leader_proof_share_roundtrip() { + let max_measurement = 13; let mut nonce = [0u8; 16]; thread_rng().fill(&mut nonce[..]); - let sum = Sum::::new(5).unwrap(); + let sum = Sum::::new(max_measurement).unwrap(); let encoded_measurement = sum.encode_measurement(&9).unwrap(); let algorithm_id = 5; let szk_typ = Szk::new_turboshake128(sum, algorithm_id); @@ -936,9 +938,10 @@ mod tests { #[test] fn test_sum_helper_proof_share_roundtrip() { + let max_measurement = 13; let mut nonce = [0u8; 16]; thread_rng().fill(&mut nonce[..]); - let sum = Sum::::new(5).unwrap(); + let sum = Sum::::new(max_measurement).unwrap(); let encoded_measurement = sum.encode_measurement(&9).unwrap(); let algorithm_id = 5; let szk_typ = Szk::new_turboshake128(sum, algorithm_id); @@ -1138,7 +1141,8 @@ mod tests { #[test] fn test_sum() { - let sum = Sum::::new(5).unwrap(); + let max_measurement = 13; + let sum = Sum::::new(max_measurement).unwrap(); let five = Field128::from(5); let nine = sum.encode_measurement(&9).unwrap(); diff --git a/src/flp/types.rs b/src/flp/types.rs index 9403039ef..2431af986 100644 --- a/src/flp/types.rs +++ b/src/flp/types.rs @@ -2,7 +2,7 @@ //! A collection of [`Type`] implementations. -use crate::field::{FftFriendlyFieldElement, FieldElementWithIntegerExt}; +use crate::field::{FftFriendlyFieldElement, FieldElementWithIntegerExt, Integer}; use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval}; use crate::flp::{FlpError, Gadget, Type}; use crate::polynomial::poly_range_check; @@ -113,37 +113,57 @@ impl Type for Count { } } -/// This sum type. Each measurement is a integer in `[0, 2^bits)` and the aggregate is the sum of -/// the measurements. +/// The sum type. Each measurement is a integer in `[0, max_measurement]` and the aggregate is the +/// sum of the measurements. /// /// The validity circuit is based on the SIMD circuit construction of [[BBCG+19], Theorem 5.3]. /// /// [BBCG+19]: https://ia.cr/2019/188 #[derive(Clone, PartialEq, Eq)] pub struct Sum { + max_measurement: F::Integer, + + // Computed from max_measurement + offset: F::Integer, bits: usize, - range_checker: Vec, + // Constant + bit_range_checker: Vec, } impl Debug for Sum { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Sum").field("bits", &self.bits).finish() + f.debug_struct("Sum") + .field("max_measurement", &self.max_measurement) + .field("bits", &self.bits) + .finish() } } impl Sum { /// Return a new [`Sum`] type parameter. Each value of this type is an integer in range `[0, - /// 2^bits)`. - pub fn new(bits: usize) -> Result { - if !F::valid_integer_bitlength(bits) { - return Err(FlpError::Encode( - "invalid bits: number of bits exceeds maximum number of bits in this field" - .to_string(), + /// max_measurement]` where `max_measurement > 0`. Errors if `max_measurement == 0`. + pub fn new(max_measurement: F::Integer) -> Result { + if max_measurement == F::Integer::zero() { + return Err(FlpError::InvalidParameter( + "max measurement cannot be zero".to_string(), )); } + + // Number of bits needed to represent x is ⌊log₂(x)⌋ + 1 + let bits = max_measurement.checked_ilog2().unwrap() as usize + 1; + + // The offset we add to the summand for range-checking purposes + let one = F::Integer::try_from(1).unwrap(); + let offset = (one << bits) - one - max_measurement; + + // Construct a range checker to ensure encoded bits are in the range [0, 2) + let bit_range_checker = poly_range_check(0, 2); + Ok(Self { bits, - range_checker: poly_range_check(0, 2), + max_measurement, + offset, + bit_range_checker, }) } } @@ -154,8 +174,17 @@ impl Type for Sum { type Field = F; fn encode_measurement(&self, summand: &F::Integer) -> Result, FlpError> { - let v = F::encode_as_bitvector(*summand, self.bits)?.collect(); - Ok(v) + if summand > &self.max_measurement { + return Err(FlpError::Encode(format!( + "unexpected measurement: got {:?}; want ≤{:?}", + summand, self.max_measurement + ))); + } + + let enc_summand = F::encode_as_bitvector(*summand, self.bits)?; + let enc_summand_plus_offset = F::encode_as_bitvector(self.offset + *summand, self.bits)?; + + Ok(enc_summand.chain(enc_summand_plus_offset).collect()) } fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result { @@ -164,8 +193,8 @@ impl Type for Sum { fn gadget(&self) -> Vec>> { vec![Box::new(PolyEval::new( - self.range_checker.clone(), - self.bits, + self.bit_range_checker.clone(), + 2 * self.bits, ))] } @@ -178,25 +207,38 @@ impl Type for Sum { g: &mut Vec>>, input: &[F], joint_rand: &[F], - _num_shares: usize, + num_shares: usize, ) -> Result, FlpError> { self.valid_call_check(input, joint_rand)?; let gadget = &mut g[0]; - input.iter().map(|&b| gadget.call(&[b])).collect() + let bit_checks = input + .iter() + .map(|&b| gadget.call(&[b])) + .collect::, _>>()?; + + let range_check = { + let offset = F::from(self.offset); + let shares_inv = F::from(F::valid_integer_try_from(num_shares)?).inv(); + let sum = F::decode_bitvector(&input[..self.bits])?; + let sum_plus_offset = F::decode_bitvector(&input[self.bits..])?; + offset * shares_inv + sum - sum_plus_offset + }; + + Ok([bit_checks.as_slice(), &[range_check]].concat()) } fn truncate(&self, input: Vec) -> Result, FlpError> { self.truncate_call_check(&input)?; - let res = F::decode_bitvector(&input)?; + let res = F::decode_bitvector(&input[..self.bits])?; Ok(vec![res]) } fn input_len(&self) -> usize { - self.bits + 2 * self.bits } fn proof_len(&self) -> usize { - 2 * ((1 + self.bits).next_power_of_two() - 1) + 2 + 2 * ((1 + 2 * self.bits).next_power_of_two() - 1) + 2 } fn verifier_len(&self) -> usize { @@ -212,7 +254,7 @@ impl Type for Sum { } fn eval_output_len(&self) -> usize { - self.bits + 2 * self.bits + 1 } fn prove_rand_len(&self) -> usize { @@ -220,8 +262,8 @@ impl Type for Sum { } } -/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the -/// aggregate is the arithmetic average. +/// The average type. Each measurement is an integer in `[0, max_measurement]` and the aggregate is +/// the arithmetic average of the measurements. // This is just a `Sum` object under the hood. The only difference is that the aggregate result is // an f64, which we get by dividing by `num_measurements` #[derive(Clone, PartialEq, Eq)] @@ -232,6 +274,7 @@ pub struct Average { impl Debug for Average { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Average") + .field("max_measurement", &self.summer.max_measurement) .field("bits", &self.summer.bits) .finish() } @@ -239,9 +282,9 @@ impl Debug for Average { impl Average { /// Return a new [`Average`] type parameter. Each value of this type is an integer in range `[0, - /// 2^bits)`. - pub fn new(bits: usize) -> Result { - let summer = Sum::new(bits)?; + /// max_measurement]` where `max_measurement > 0`. Errors if `max_measurement == 0`. + pub fn new(max_measurement: F::Integer) -> Result { + let summer = Sum::new(max_measurement)?; Ok(Average { summer }) } } @@ -288,7 +331,7 @@ impl Type for Average { } fn input_len(&self) -> usize { - self.summer.bits + self.summer.input_len() } fn proof_len(&self) -> usize { @@ -592,20 +635,19 @@ where } // Convert bool vector to field elems - let multihot_vec: Vec = measurement + let multihot_vec = measurement .iter() // We can unwrap because any Integer type can cast from bool - .map(|bit| F::from(F::valid_integer_try_from(*bit as usize).unwrap())) - .collect(); + .map(|bit| F::from(F::valid_integer_try_from(*bit as usize).unwrap())); // Encode the measurement weight in binary (actually, the weight plus some offset) let offset_weight_bits = { let offset_weight_reported = F::valid_integer_try_from(self.offset + weight_reported)?; - F::encode_as_bitvector(offset_weight_reported, self.bits_for_weight)?.collect() + F::encode_as_bitvector(offset_weight_reported, self.bits_for_weight)? }; // Report the concat of the two - Ok([multihot_vec, offset_weight_bits].concat()) + Ok(multihot_vec.chain(offset_weight_bits).collect()) } fn decode_result( @@ -1024,7 +1066,9 @@ mod tests { #[test] fn test_sum() { - let sum = Sum::new(11).unwrap(); + let max_measurement = 1458; + + let sum = Sum::new(max_measurement).unwrap(); let zero = TestField::zero(); let one = TestField::one(); let nine = TestField::from(9); @@ -1045,22 +1089,52 @@ mod tests { &sum.encode_measurement(&1337).unwrap(), &[TestField::from(1337)], ); - FlpTest::expect_valid::<3>(&Sum::new(0).unwrap(), &[], &[zero]); - FlpTest::expect_valid::<3>(&Sum::new(2).unwrap(), &[one, zero], &[one]); - FlpTest::expect_valid::<3>( - &Sum::new(9).unwrap(), - &[one, zero, one, one, zero, one, one, one, zero], - &[TestField::from(237)], - ); - // Test FLP on invalid input. - FlpTest::expect_invalid::<3>(&Sum::new(3).unwrap(), &[one, nine, zero]); - FlpTest::expect_invalid::<3>(&Sum::new(5).unwrap(), &[zero, zero, zero, zero, nine]); + { + let sum = Sum::new(3).unwrap(); + let meas = 1; + FlpTest::expect_valid::<3>( + &sum, + &sum.encode_measurement(&meas).unwrap(), + &[TestField::from(meas)], + ); + } + + { + let sum = Sum::new(400).unwrap(); + let meas = 237; + FlpTest::expect_valid::<3>( + &sum, + &sum.encode_measurement(&meas).unwrap(), + &[TestField::from(meas)], + ); + } + + // Test FLP on invalid input, specifically on field elements outside of {0,1} + { + let sum = Sum::new((1 << 3) - 1).unwrap(); + // The sum+offset value can be whatever. The binariness test should fail first + let sum_plus_offset = vec![zero; 3]; + FlpTest::expect_invalid::<3>( + &sum, + &[&[one, nine, zero], sum_plus_offset.as_slice()].concat(), + ); + } + { + let sum = Sum::new((1 << 5) - 1).unwrap(); + let sum_plus_offset = vec![zero; 5]; + FlpTest::expect_invalid::<3>( + &sum, + &[&[zero, zero, zero, zero, nine], sum_plus_offset.as_slice()].concat(), + ); + } } #[test] fn test_average() { - let average = Average::new(11).unwrap(); + let max_measurement = (1 << 11) - 13; + + let average = Average::new(max_measurement).unwrap(); let zero = TestField::zero(); let one = TestField::one(); let ten = TestField::from(10); diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs index afbac9331..6e3426b5f 100644 --- a/src/vdaf/mastic.rs +++ b/src/vdaf/mastic.rs @@ -394,9 +394,12 @@ mod tests { #[test] fn test_mastic_shard_sum() { let algorithm_id = 6; - let sum_typ = Sum::::new(5).unwrap(); + let max_measurement = 29; + let sum_typ = Sum::::new(max_measurement).unwrap(); + let encoded_meas_len = sum_typ.input_len(); + let sum_szk = Szk::new_turboshake128(sum_typ, algorithm_id); - let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(5); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(encoded_meas_len); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; @@ -414,9 +417,12 @@ mod tests { #[test] fn test_input_share_encode_sum() { let algorithm_id = 6; - let sum_typ = Sum::::new(5).unwrap(); + let max_measurement = 29; + let sum_typ = Sum::::new(max_measurement).unwrap(); + let encoded_meas_len = sum_typ.input_len(); + let sum_szk = Szk::new_turboshake128(sum_typ, algorithm_id); - let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(5); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(encoded_meas_len); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 3936730ec..fdefb0933 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -33,7 +33,9 @@ use super::AggregatorWithNoise; use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode}; #[cfg(feature = "experimental")] use crate::dp::DifferentialPrivacyStrategy; -use crate::field::{decode_fieldvec, FftFriendlyFieldElement, FieldElement}; +use crate::field::{ + decode_fieldvec, FftFriendlyFieldElement, FieldElement, FieldElementWithInteger, +}; use crate::field::{Field128, Field64}; #[cfg(feature = "multithreaded")] use crate::flp::gadgets::ParallelSumMultithreaded; @@ -138,19 +140,16 @@ impl Prio3SumVecMultithreaded { /// The sum type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the /// aggregate is the sum. -pub type Prio3Sum = Prio3, XofTurboShake128, 16>; +pub type Prio3Sum = Prio3, XofTurboShake128, 16>; impl Prio3Sum { - /// Construct an instance of Prio3Sum with the given number of aggregators and required bit - /// length. The bit length must not exceed 64. - pub fn new_sum(num_aggregators: u8, bits: usize) -> Result { - if bits > 64 { - return Err(VdafError::Uncategorized(format!( - "bit length ({bits}) exceeds limit for aggregate type (64)" - ))); - } - - Prio3::new(num_aggregators, 1, 0x00000001, Sum::new(bits)?) + /// Construct an instance of `Prio3Sum` with the given number of aggregators, where each summand + /// must be in the range `[0, max_measurement]`. Errors if `max_measurement == 0`. + pub fn new_sum( + num_aggregators: u8, + max_measurement: ::Integer, + ) -> Result { + Prio3::new(num_aggregators, 1, 0x00000001, Sum::new(max_measurement)?) } } @@ -340,22 +339,19 @@ impl Prio3MultihotCountVecMultithreaded { pub type Prio3Average = Prio3, XofTurboShake128, 16>; impl Prio3Average { - /// Construct an instance of Prio3Average with the given number of aggregators and required bit - /// length. The bit length must not exceed 64. - pub fn new_average(num_aggregators: u8, bits: usize) -> Result { + /// Construct an instance of `Prio3Average` with the given number of aggregators, where each + /// summand must be in the range `[0, max_measurement]`. Errors if `max_measurement == 0`. + pub fn new_average( + num_aggregators: u8, + max_measurement: ::Integer, + ) -> Result { check_num_aggregators(num_aggregators)?; - if bits > 64 { - return Err(VdafError::Uncategorized(format!( - "bit length ({bits}) exceeds limit for aggregate type (64)" - ))); - } - Ok(Prio3 { num_aggregators, num_proofs: 1, algorithm_id: 0xFFFF0000, - typ: Average::new(bits)?, + typ: Average::new(max_measurement)?, phantom: PhantomData, }) } @@ -1700,11 +1696,13 @@ mod tests { #[test] fn test_prio3_sum() { - let prio3 = Prio3::new_sum(3, 16).unwrap(); + let max_measurement = 35_891; + + let prio3 = Prio3::new_sum(3, max_measurement).unwrap(); assert_eq!( - run_vdaf(CTX_STR, &prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(), - (1 << 16) + 1 + run_vdaf(CTX_STR, &prio3, &(), [0, max_measurement, 0, 1, 1]).unwrap(), + max_measurement + 2, ); let mut verify_key = [0; 16]; @@ -1713,7 +1711,7 @@ mod tests { let (public_share, mut input_shares) = prio3.shard(CTX_STR, &1, &nonce).unwrap(); assert_matches!(input_shares[0].measurement_share, Share::Leader(ref mut data) => { - data[0] += Field128::one(); + data[0] += Field64::one(); }); let result = run_vdaf_prepare( &prio3, @@ -1728,7 +1726,7 @@ mod tests { let (public_share, mut input_shares) = prio3.shard(CTX_STR, &1, &nonce).unwrap(); assert_matches!(input_shares[0].proofs_share, Share::Leader(ref mut data) => { - data[0] += Field128::one(); + data[0] += Field64::one(); }); let result = run_vdaf_prepare( &prio3, @@ -2082,7 +2080,8 @@ mod tests { #[test] fn test_prio3_average() { - let prio3 = Prio3::new_average(2, 64).unwrap(); + let max_measurement = 43_208; + let prio3 = Prio3::new_average(2, max_measurement).unwrap(); assert_eq!(run_vdaf(CTX_STR, &prio3, &(), [17, 8]).unwrap(), 12.5f64); assert_eq!(run_vdaf(CTX_STR, &prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64); @@ -2098,7 +2097,8 @@ mod tests { #[test] fn test_prio3_input_share() { - let prio3 = Prio3::new_sum(5, 16).unwrap(); + let max_measurement = 1; + let prio3 = Prio3::new_sum(5, max_measurement).unwrap(); let (_public_share, input_shares) = prio3.shard(CTX_STR, &1, &[0; 16]).unwrap(); // Check that seed shares are distinct. @@ -2217,7 +2217,8 @@ mod tests { let vdaf = Prio3::new_count(2).unwrap(); fieldvec_roundtrip_test::>(&vdaf, &(), 1); - let vdaf = Prio3::new_sum(2, 17).unwrap(); + let max_measurement = 13; + let vdaf = Prio3::new_sum(2, max_measurement).unwrap(); fieldvec_roundtrip_test::>(&vdaf, &(), 1); let vdaf = Prio3::new_histogram(2, 12, 3).unwrap(); @@ -2229,7 +2230,8 @@ mod tests { let vdaf = Prio3::new_count(2).unwrap(); fieldvec_roundtrip_test::>(&vdaf, &(), 1); - let vdaf = Prio3::new_sum(2, 17).unwrap(); + let max_measurement = 13; + let vdaf = Prio3::new_sum(2, max_measurement).unwrap(); fieldvec_roundtrip_test::>(&vdaf, &(), 1); let vdaf = Prio3::new_histogram(2, 12, 3).unwrap(); diff --git a/src/vdaf/prio3_test.rs b/src/vdaf/prio3_test.rs index 10b72c739..18c9817de 100644 --- a/src/vdaf/prio3_test.rs +++ b/src/vdaf/prio3_test.rs @@ -285,13 +285,14 @@ mod tests { #[ignore] #[test] fn test_vec_prio3_sum() { + const FAKE_MAX_MEASUREMENT_UPDATE_ME: u64 = 0; for test_vector_str in [ include_str!("test_vec/08/Prio3Sum_0.json"), include_str!("test_vec/08/Prio3Sum_1.json"), ] { check_test_vec(test_vector_str, |json_params, num_shares| { - let bits = json_params["bits"].as_u64().unwrap() as usize; - Prio3::new_sum(num_shares, bits).unwrap() + let _bits = json_params["bits"].as_u64().unwrap() as usize; + Prio3::new_sum(num_shares, FAKE_MAX_MEASUREMENT_UPDATE_ME).unwrap() }); } }