From 6472417ba79567d4972f725e1332faae625b9f3c Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Wed, 20 Nov 2024 18:19:08 -0800 Subject: [PATCH] Switch to VDAF-13. I currently pin to a specific commit in libprio-rs tracking the VDAF-13 implementation, as the implementation is currently ongoing. I am fairly confident that the only breaking interface changes have already been implemented: there is a new "application context string", which is computed as the concatenation of hte current DAP draft in use & the task ID. Also, bump the version tag while I'm at it, since I touched it anyway. --- Cargo.lock | 58 +++++++++++++++---- Cargo.toml | 5 +- aggregator/src/aggregator.rs | 8 ++- .../src/aggregator/aggregate_init_tests.rs | 1 + .../aggregator/aggregation_job_continue.rs | 11 ++-- .../src/aggregator/aggregation_job_creator.rs | 13 +++++ .../src/aggregator/aggregation_job_driver.rs | 39 +++++++------ .../aggregation_job_driver/tests.rs | 12 ++++ .../tests/aggregation_job_continue.rs | 14 +++++ .../tests/aggregation_job_init.rs | 3 + aggregator/src/aggregator/test_util.rs | 6 +- aggregator_core/src/datastore/tests.rs | 31 +++++++++- client/src/lib.rs | 7 ++- collector/src/lib.rs | 13 +++-- core/Cargo.toml | 1 + core/src/hpke.rs | 6 +- core/src/lib.rs | 4 ++ core/src/test_util/mod.rs | 16 +++-- core/src/vdaf.rs | 15 ++++- .../integration/simulation/bad_client.rs | 57 +++++++++++------- 20 files changed, 242 insertions(+), 78 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a08ba2500..f42862e40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1064,6 +1064,12 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "constcat" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4938185353434999ef52c81753c8cca8955ed38042fc29913db3751916f3b7ab" + [[package]] name = "content_inspector" version = "0.2.4" @@ -1437,7 +1443,7 @@ dependencies = [ "num-bigint", "num-rational", "pad-adapter", - "prio", + "prio 0.16.7", "serde", "serde_json", "thiserror 2.0.3", @@ -2569,7 +2575,7 @@ dependencies = [ "pem", "postgres-protocol", "postgres-types", - "prio", + "prio 0.17.0-alpha.0", "prometheus", "quickcheck", "quickcheck_macros", @@ -2678,7 +2684,7 @@ dependencies = [ "opentelemetry", "postgres-protocol", "postgres-types", - "prio", + "prio 0.17.0-alpha.0", "rand", "regex", "reqwest", @@ -2719,7 +2725,7 @@ dependencies = [ "janus_messages 0.8.0-prerelease-1", "mockito", "ohttp", - "prio", + "prio 0.17.0-alpha.0", "rand", "reqwest", "thiserror 1.0.69", @@ -2745,7 +2751,7 @@ dependencies = [ "janus_core", "janus_messages 0.8.0-prerelease-1", "mockito", - "prio", + "prio 0.17.0-alpha.0", "rand", "reqwest", "retry-after", @@ -2768,6 +2774,7 @@ dependencies = [ "bytes", "chrono", "clap", + "constcat", "derivative", "fixed", "futures", @@ -2780,7 +2787,7 @@ dependencies = [ "k8s-openapi", "kube", "mockito", - "prio", + "prio 0.17.0-alpha.0", "quickcheck", "rand", "regex", @@ -2832,7 +2839,7 @@ dependencies = [ "k8s-openapi", "kube", "opentelemetry", - "prio", + "prio 0.17.0-alpha.0", "quickcheck", "rand", "regex", @@ -2871,7 +2878,7 @@ dependencies = [ "janus_core", "janus_interop_binaries", "janus_messages 0.8.0-prerelease-1", - "prio", + "prio 0.17.0-alpha.0", "rand", "regex", "reqwest", @@ -2902,7 +2909,7 @@ dependencies = [ "derivative", "hex", "num_enum", - "prio", + "prio 0.16.7", "rand", "serde", "thiserror 1.0.69", @@ -2920,7 +2927,7 @@ dependencies = [ "hex", "num_enum", "pretty_assertions", - "prio", + "prio 0.17.0-alpha.0", "rand", "serde", "serde_test", @@ -2942,7 +2949,7 @@ dependencies = [ "janus_collector", "janus_core", "janus_messages 0.8.0-prerelease-1", - "prio", + "prio 0.17.0-alpha.0", "rand", "reqwest", "serde_json", @@ -4048,6 +4055,33 @@ name = "prio" version = "0.16.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e139374dfbada620828b9736ecef3959c8fea97caa8d9ba39d3e84966380978" +dependencies = [ + "aes 0.8.4", + "bitvec", + "byteorder", + "ctr 0.9.2", + "fiat-crypto", + "fixed", + "getrandom", + "hmac 0.12.1", + "num-bigint", + "num-integer", + "num-iter", + "num-rational", + "num-traits", + "rand", + "rand_core", + "serde", + "sha2 0.10.8", + "sha3", + "subtle", + "thiserror 1.0.69", +] + +[[package]] +name = "prio" +version = "0.17.0-alpha.0" +source = "git+https://github.com/divviup/libprio-rs?rev=c85e537682a7932edc0c44c80049df0429d6fa4c#c85e537682a7932edc0c44c80049df0429d6fa4c" dependencies = [ "aes 0.8.4", "bitvec", @@ -4071,7 +4105,7 @@ dependencies = [ "sha2 0.10.8", "sha3", "subtle", - "thiserror 1.0.69", + "thiserror 2.0.3", "zipf", ] diff --git a/Cargo.toml b/Cargo.toml index 1d46ff5cd..5efa6fe8d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ cfg-if = "1.0.0" chrono = { version = "0.4.38", default-features = false } clap = { version = "4.5.21", features = ["cargo", "derive", "env"] } console-subscriber = "0.4.1" +constcat = "0.5" deadpool = "0.12.1" deadpool-postgres = "0.14.0" derivative = "2.2.0" @@ -73,7 +74,9 @@ postgres-types = "0.2.8" pretty_assertions = "1.4.1" # Disable default features so that individual workspace crates can choose to # re-enable them -prio = { version = "0.16.7", default-features = false, features = ["experimental"] } +# TODO(#3436): switch to a released version of libprio, once there is a released version implementing VDAF-13 +# prio = { version = "0.16.7", default-features = false, features = ["experimental"] } +prio = { git = "https://github.com/divviup/libprio-rs", rev = "c85e537682a7932edc0c44c80049df0429d6fa4c", default-features = false, features = ["experimental"] } prometheus = "0.13.4" querystring = "1.1.0" quickcheck = { version = "1.0.3", default-features = false } diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 901221e71..83aa58e65 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -58,7 +58,7 @@ use janus_core::{ retries::retry_http_request_notify, time::{Clock, DurationExt, IntervalExt, TimeExt}, vdaf::{ - new_prio3_sum_vec_field64_multiproof_hmacsha256_aes128, + new_prio3_sum_vec_field64_multiproof_hmacsha256_aes128, vdaf_application_context, Prio3SumVecField64MultiproofHmacSha256Aes128, VdafInstance, VERIFY_KEY_LENGTH, }, Runtime, @@ -1978,12 +1978,13 @@ impl VdafOps { move || { let span = info_span!(parent: parent_span, "handle_aggregate_init_generic threadpool task"); + let ctx = vdaf_application_context(task.id()); req .prepare_inits() .par_iter() .enumerate() - .try_for_each_with((sender, span), |(sender, span), (ord, prepare_init)| { + .try_for_each(|(ord, prepare_init)| { let _entered = span.enter(); // If decryption fails, then the aggregator MUST fail with error `hpke-decrypt-error`. (ยง4.4.2.2) @@ -2182,6 +2183,7 @@ impl VdafOps { trace_span!("VDAF preparation (helper initialization)").in_scope(|| { vdaf.helper_initialized( verify_key.as_bytes(), + &ctx, &agg_param, /* report ID is used as VDAF nonce */ prepare_init.report_share().metadata().id().as_ref(), @@ -2189,7 +2191,7 @@ impl VdafOps { &input_share, prepare_init.message(), ) - .and_then(|transition| transition.evaluate(&vdaf)) + .and_then(|transition| transition.evaluate(&ctx, &vdaf)) .map_err(|error| { handle_ping_pong_error( task.id(), diff --git a/aggregator/src/aggregator/aggregate_init_tests.rs b/aggregator/src/aggregator/aggregate_init_tests.rs index c9a36d53f..3f6c02313 100644 --- a/aggregator/src/aggregator/aggregate_init_tests.rs +++ b/aggregator/src/aggregator/aggregate_init_tests.rs @@ -132,6 +132,7 @@ where ) -> (ReportShare, VdafTranscript) { let transcript = run_vdaf( &self.vdaf, + self.task.id(), self.task.vdaf_verify_key().unwrap().as_bytes(), &self.aggregation_param, report_metadata.id(), diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index 1e2136904..eff2fa4ee 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -16,7 +16,7 @@ use janus_aggregator_core::{ }, task::AggregatorTask, }; -use janus_core::time::Clock; +use janus_core::{time::Clock, vdaf::vdaf_application_context}; use janus_messages::{ AggregationJobContinueReq, AggregationJobResp, PrepareResp, PrepareStepResult, ReportError, Role, @@ -159,10 +159,10 @@ impl VdafOps { move || { let span = info_span!(parent: parent_span, "step_aggregation_job threadpool task"); + let ctx = vdaf_application_context(task.id()); - prep_steps_and_ras.into_par_iter().try_for_each_with( - (sender, span), - |(sender, span), (prep_step, report_aggregation, prep_state)| { + prep_steps_and_ras.into_par_iter().try_for_each( + |(prep_step, report_aggregation, prep_state)| { let _entered = span.enter(); let (report_aggregation_state, prepare_step_result, output_share) = @@ -170,6 +170,7 @@ impl VdafOps { .in_scope(|| { // Continue with the incoming message. vdaf.helper_continued( + &ctx, PingPongState::Continued(prep_state.clone()), aggregation_job.aggregation_parameter(), prep_step.message(), @@ -181,7 +182,7 @@ impl VdafOps { transition, } => { let (new_state, message) = - transition.evaluate(vdaf.as_ref())?; + transition.evaluate(&ctx, vdaf.as_ref())?; let (report_aggregation_state, output_share) = match new_state { // Helper did not finish. Store the new diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index 438853253..f87c5bcd8 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -976,6 +976,7 @@ mod tests { let leader_report_metadata = ReportMetadata::new(random(), report_time); let leader_transcript = run_vdaf( vdaf.as_ref(), + leader_task.id(), leader_task.vdaf_verify_key().unwrap().as_bytes(), &(), leader_report_metadata.id(), @@ -1166,6 +1167,7 @@ mod tests { let report_metadata = ReportMetadata::new(random(), report_time); let transcript = run_vdaf( vdaf.as_ref(), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &(), report_metadata.id(), @@ -1344,6 +1346,7 @@ mod tests { let first_report_metadata = ReportMetadata::new(random(), report_time); let first_transcript = run_vdaf( vdaf.as_ref(), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &(), first_report_metadata.id(), @@ -1360,6 +1363,7 @@ mod tests { let second_report_metadata = ReportMetadata::new(random(), report_time); let second_transcript = run_vdaf( vdaf.as_ref(), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &(), second_report_metadata.id(), @@ -1556,6 +1560,7 @@ mod tests { let report_metadata = ReportMetadata::new(random(), report_time); let transcript = run_vdaf( vdaf.as_ref(), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &(), report_metadata.id(), @@ -1740,6 +1745,7 @@ mod tests { let report_metadata = ReportMetadata::new(random(), report_time); let transcript = run_vdaf( vdaf.as_ref(), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &(), report_metadata.id(), @@ -1937,6 +1943,7 @@ mod tests { let report_metadata = ReportMetadata::new(random(), report_time); let transcript = run_vdaf( vdaf.as_ref(), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &(), report_metadata.id(), @@ -2098,6 +2105,7 @@ mod tests { let report_metadata = ReportMetadata::new(random(), report_time); let transcript = run_vdaf( vdaf.as_ref(), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &(), report_metadata.id(), @@ -2211,6 +2219,7 @@ mod tests { let last_report_metadata = ReportMetadata::new(random(), report_time); let last_transcript = run_vdaf( vdaf.as_ref(), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &(), last_report_metadata.id(), @@ -2358,6 +2367,7 @@ mod tests { let report_metadata = ReportMetadata::new(random(), report_time); let transcript = run_vdaf( vdaf.as_ref(), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &(), report_metadata.id(), @@ -2474,6 +2484,7 @@ mod tests { let report_metadata = ReportMetadata::new(random(), report_time); let transcript = run_vdaf( vdaf.as_ref(), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &(), report_metadata.id(), @@ -2629,6 +2640,7 @@ mod tests { let report_metadata = ReportMetadata::new(random(), report_time_1); let transcript = run_vdaf( vdaf.as_ref(), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &(), report_metadata.id(), @@ -2649,6 +2661,7 @@ mod tests { let report_metadata = ReportMetadata::new(random(), report_time_2); let transcript = run_vdaf( vdaf.as_ref(), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &(), report_metadata.id(), diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index b1d6c92af..04b41bc4d 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -32,6 +32,7 @@ use janus_aggregator_core::{ use janus_core::{ retries::{is_retryable_http_client_error, is_retryable_http_status}, time::Clock, + vdaf::vdaf_application_context, vdaf_dispatch, }; use janus_messages::{ @@ -326,8 +327,8 @@ where // // Shutdown on cancellation: if this request is cancelled, the `receiver` will be dropped. // This will cause any attempts to send on `sender` to return a `SendError`, which will be - // returned from the function passed to `try_for_each_with`; `try_for_each_with` will - // terminate early on receiving an error. + // returned from the function passed to `try_for_each`; `try_for_each` will terminate early + // on receiving an error. let (ra_sender, mut ra_receiver) = mpsc::unbounded_channel(); let (pi_and_sa_sender, mut pi_and_sa_receiver) = mpsc::unbounded_channel(); let producer_task = tokio::task::spawn_blocking({ @@ -342,12 +343,12 @@ where parent: parent_span, "step_aggregation_job_aggregate_init threadpool task" ); + let ctx = vdaf_application_context(&task_id); // Compute report shares to send to helper, and decrypt our input shares & // initialize preparation state. - report_aggregations.into_par_iter().try_for_each_with( - (span, ra_sender, pi_and_sa_sender), - |(span, ra_sender, pi_and_sa_sender), report_aggregation| { + report_aggregations.into_par_iter().try_for_each( + |report_aggregation| { let _entered = span.enter(); // Extract report data from the report aggregation state. @@ -416,6 +417,7 @@ where match trace_span!("VDAF preparation (leader initialization)").in_scope(|| { vdaf.leader_initialized( verify_key.as_bytes(), + &ctx, aggregation_job.aggregation_parameter(), // DAP report ID is used as VDAF nonce report_aggregation.report_id().as_ref(), @@ -591,8 +593,8 @@ where // // Shutdown on cancellation: if this request is cancelled, the `receiver` will be dropped. // This will cause any attempts to send on `sender` to return a `SendError`, which will be - // returned from the function passed to `try_for_each_with`; `try_for_each_with` will - // terminate early on receiving an error. + // returned from the function passed to `try_for_each`; `try_for_each` will terminate early + // on receiving an error. let (ra_sender, mut ra_receiver) = mpsc::unbounded_channel(); let (pc_and_sa_sender, mut pc_and_sa_receiver) = mpsc::unbounded_channel(); let producer_task = tokio::task::spawn_blocking({ @@ -606,10 +608,11 @@ where parent: parent_span, "step_aggregation_job_aggregate_continue threadpool task" ); + let ctx = vdaf_application_context(&task_id); - report_aggregations.into_par_iter().try_for_each_with( - (span, ra_sender, pc_and_sa_sender), - |(span, ra_sender, pc_and_sa_sender), report_aggregation| { + report_aggregations + .into_par_iter() + .try_for_each(|report_aggregation| { let _entered = span.enter(); let transition = match report_aggregation.state() { @@ -623,7 +626,7 @@ where }; let result = trace_span!("VDAF preparation (leader transition evaluation)") - .in_scope(|| transition.evaluate(vdaf.as_ref())); + .in_scope(|| transition.evaluate(&ctx, vdaf.as_ref())); let (prep_state, message) = match result { Ok((state, message)) => (state, message), Err(error) => { @@ -655,8 +658,7 @@ where }, )) .map_err(|_| ()) - }, - ) + }) } }); @@ -782,8 +784,8 @@ where // Shutdown on cancellation: if this request is cancelled, the `receiver` will be dropped. // This will cause any attempts to send on `sender` to return a `SendError`, which will be - // returned from the function passed to `try_for_each_with`; `try_for_each_with` will - // terminate early on receiving an error. + // returned from the function passed to `try_for_each`; `try_for_each` will terminate early + // on receiving an error. let (ra_sender, mut ra_receiver) = mpsc::unbounded_channel(); let aggregation_job = Arc::new(aggregation_job); let producer_task = tokio::task::spawn_blocking({ @@ -798,10 +800,10 @@ where parent: parent_span, "process_Response_from_helper threadpool task" ); + let ctx = vdaf_application_context(&task_id); - stepped_aggregations.into_par_iter().zip(helper_resp.prepare_resps()).try_for_each_with( - (span, ra_sender), - |(span, ra_sender), (stepped_aggregation, helper_prep_resp)| { + stepped_aggregations.into_par_iter().zip(helper_resp.prepare_resps()).try_for_each( + |(stepped_aggregation, helper_prep_resp)| { let _entered = span.enter(); let (new_state, output_share) = match helper_prep_resp.result() { @@ -811,6 +813,7 @@ where let state_and_message = trace_span!("VDAF preparation (leader continuation)") .in_scope(|| { vdaf.leader_continued( + &ctx, stepped_aggregation.leader_state.clone(), aggregation_job.aggregation_parameter(), helper_prep_msg, diff --git a/aggregator/src/aggregator/aggregation_job_driver/tests.rs b/aggregator/src/aggregator/aggregation_job_driver/tests.rs index 5d2617b0e..4191d7600 100644 --- a/aggregator/src/aggregator/aggregation_job_driver/tests.rs +++ b/aggregator/src/aggregator/aggregation_job_driver/tests.rs @@ -89,6 +89,7 @@ async fn aggregation_job_driver() { let transcript = run_vdaf( vdaf.as_ref(), + task.id(), verify_key.as_bytes(), &aggregation_param, report_metadata.id(), @@ -366,6 +367,7 @@ async fn step_time_interval_aggregation_job_init_single_step() { let transcript = run_vdaf( vdaf.as_ref(), + task.id(), verify_key.as_bytes(), &(), report_metadata.id(), @@ -697,6 +699,7 @@ async fn step_time_interval_aggregation_job_init_two_steps() { let transcript = run_vdaf( vdaf.as_ref(), + task.id(), verify_key.as_bytes(), &aggregation_param, report_metadata.id(), @@ -962,6 +965,7 @@ async fn step_time_interval_aggregation_job_init_partially_garbage_collected() { let gc_eligible_transcript = run_vdaf( vdaf.as_ref(), + task.id(), verify_key.as_bytes(), &(), gc_eligible_report_metadata.id(), @@ -969,6 +973,7 @@ async fn step_time_interval_aggregation_job_init_partially_garbage_collected() { ); let gc_ineligible_transcript = run_vdaf( vdaf.as_ref(), + task.id(), verify_key.as_bytes(), &(), gc_ineligible_report_metadata.id(), @@ -1314,6 +1319,7 @@ async fn step_leader_selected_aggregation_job_init_single_step() { let transcript = run_vdaf( vdaf.as_ref(), + task.id(), verify_key.as_bytes(), &(), report_metadata.id(), @@ -1601,6 +1607,7 @@ async fn step_leader_selected_aggregation_job_init_two_steps() { let transcript = run_vdaf( vdaf.as_ref(), + task.id(), verify_key.as_bytes(), &aggregation_param, report_metadata.id(), @@ -1856,6 +1863,7 @@ async fn step_time_interval_aggregation_job_continue() { let aggregation_param = dummy::AggregationParam(7); let transcript = run_vdaf( vdaf.as_ref(), + task.id(), verify_key.as_bytes(), &aggregation_param, report_metadata.id(), @@ -2172,6 +2180,7 @@ async fn step_leader_selected_aggregation_job_continue() { let aggregation_param = dummy::AggregationParam(7); let transcript = run_vdaf( vdaf.as_ref(), + task.id(), verify_key.as_bytes(), &aggregation_param, report_metadata.id(), @@ -2458,6 +2467,7 @@ async fn setup_cancel_aggregation_job_test() -> CancelAggregationJobTestCase { let transcript = run_vdaf( vdaf.as_ref(), + task.id(), verify_key.as_bytes(), &(), report_metadata.id(), @@ -2726,6 +2736,7 @@ async fn abandon_failing_aggregation_job_with_retryable_error() { let report_metadata = ReportMetadata::new(random(), time); let transcript = run_vdaf( &vdaf, + task.id(), verify_key.as_bytes(), &(), report_metadata.id(), @@ -2968,6 +2979,7 @@ async fn abandon_failing_aggregation_job_with_fatal_error() { let report_metadata = ReportMetadata::new(random(), time); let transcript = run_vdaf( &vdaf, + task.id(), verify_key.as_bytes(), &(), report_metadata.id(), diff --git a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs index d772e5991..207a8a646 100644 --- a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs @@ -68,6 +68,7 @@ async fn aggregate_continue() { ); let transcript_0 = run_vdaf( vdaf.as_ref(), + task.id(), verify_key.as_bytes(), &aggregation_param, report_metadata_0.id(), @@ -94,6 +95,7 @@ async fn aggregate_continue() { ); let transcript_1 = run_vdaf( vdaf.as_ref(), + task.id(), verify_key.as_bytes(), &aggregation_param, report_metadata_1.id(), @@ -123,6 +125,7 @@ async fn aggregate_continue() { ); let transcript_2 = run_vdaf( vdaf.as_ref(), + task.id(), verify_key.as_bytes(), &aggregation_param, report_metadata_2.id(), @@ -398,6 +401,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { let report_metadata_0 = ReportMetadata::new(random(), report_time_0); let transcript_0 = run_vdaf( &vdaf, + task.id(), verify_key.as_bytes(), &aggregation_param, report_metadata_0.id(), @@ -423,6 +427,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { let report_metadata_1 = ReportMetadata::new(random(), report_time_1); let transcript_1 = run_vdaf( &vdaf, + task.id(), verify_key.as_bytes(), &aggregation_param, report_metadata_1.id(), @@ -450,6 +455,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { ); let transcript_2 = run_vdaf( &vdaf, + task.id(), verify_key.as_bytes(), &aggregation_param, report_metadata_2.id(), @@ -745,6 +751,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { let report_metadata_3 = ReportMetadata::new(random(), report_time_3); let transcript_3 = run_vdaf( &vdaf, + task.id(), verify_key.as_bytes(), &aggregation_param, report_metadata_3.id(), @@ -772,6 +779,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { ); let transcript_4 = run_vdaf( &vdaf, + task.id(), verify_key.as_bytes(), &aggregation_param, report_metadata_4.id(), @@ -799,6 +807,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { ); let transcript_5 = run_vdaf( &vdaf, + task.id(), verify_key.as_bytes(), &aggregation_param, report_metadata_5.id(), @@ -1036,6 +1045,7 @@ async fn aggregate_continue_leader_sends_non_continue_or_finish_transition() { let aggregation_param = dummy::AggregationParam(7); let transcript = run_vdaf( &dummy::Vdaf::new(2), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &aggregation_param, &report_id, @@ -1151,6 +1161,7 @@ async fn aggregate_continue_prep_step_fails() { let aggregation_param = dummy::AggregationParam(7); let transcript = run_vdaf( &vdaf, + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &aggregation_param, &report_id, @@ -1320,6 +1331,7 @@ async fn aggregate_continue_unexpected_transition() { let aggregation_param = dummy::AggregationParam(7); let transcript = run_vdaf( &dummy::Vdaf::new(2), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &aggregation_param, &report_id, @@ -1433,6 +1445,7 @@ async fn aggregate_continue_out_of_order_transition() { let aggregation_param = dummy::AggregationParam(7); let transcript_0 = run_vdaf( &dummy::Vdaf::new(2), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &aggregation_param, &report_id_0, @@ -1442,6 +1455,7 @@ async fn aggregate_continue_out_of_order_transition() { let report_id_1 = random(); let transcript_1 = run_vdaf( &dummy::Vdaf::new(2), + task.id(), task.vdaf_verify_key().unwrap().as_bytes(), &aggregation_param, &report_id_1, diff --git a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs index 058127e95..9158968fa 100644 --- a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs +++ b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs @@ -287,6 +287,7 @@ async fn aggregate_init() { ); let transcript_5 = run_vdaf( &vdaf, + task.id(), verify_key.as_bytes(), &dummy::AggregationParam(0), report_metadata_5.id(), @@ -317,6 +318,7 @@ async fn aggregate_init() { ); let transcript_6 = run_vdaf( &vdaf, + task.id(), verify_key.as_bytes(), &dummy::AggregationParam(0), report_metadata_6.id(), @@ -347,6 +349,7 @@ async fn aggregate_init() { ); let transcript_7 = run_vdaf( &vdaf, + task.id(), verify_key.as_bytes(), &dummy::AggregationParam(0), report_metadata_7.id(), diff --git a/aggregator/src/aggregator/test_util.rs b/aggregator/src/aggregator/test_util.rs index 4bb32a55b..e2b7cbe38 100644 --- a/aggregator/src/aggregator/test_util.rs +++ b/aggregator/src/aggregator/test_util.rs @@ -6,7 +6,7 @@ use janus_aggregator_core::{ use janus_core::{ hpke::{self, HpkeApplicationInfo, HpkeKeypair, Label}, time::MockClock, - vdaf::VdafInstance, + vdaf::{vdaf_application_context, VdafInstance}, }; use janus_messages::{ Extension, HpkeConfig, InputShareAad, PlaintextInputShare, Report, ReportId, ReportMetadata, @@ -74,7 +74,9 @@ pub fn create_report_custom( let vdaf = Prio3Count::new_count(2).unwrap(); let report_metadata = ReportMetadata::new(id, report_timestamp); - let (public_share, measurements) = vdaf.shard(&true, id.as_ref()).unwrap(); + let (public_share, measurements) = vdaf + .shard(&vdaf_application_context(task.id()), &true, id.as_ref()) + .unwrap(); let associated_data = InputShareAad::new( *task.id(), diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index 7aa055872..6bec8e244 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -2251,10 +2251,18 @@ async fn get_aggregation_jobs_for_task(ephemeral_datastore: EphemeralDatastore) async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { install_test_trace_subscriber(); + let task_id = random(); let report_id = random(); let vdaf = Arc::new(dummy::Vdaf::new(2)); let aggregation_param = dummy::AggregationParam(5); - let vdaf_transcript = run_vdaf(vdaf.as_ref(), &[], &aggregation_param, &report_id, &13); + let vdaf_transcript = run_vdaf( + vdaf.as_ref(), + &task_id, + &[], + &aggregation_param, + &report_id, + &13, + ); for (ord, (role, state)) in [ ( @@ -2567,11 +2575,19 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme let clock = MockClock::new(OLDEST_ALLOWED_REPORT_TIMESTAMP); let ds = ephemeral_datastore.datastore(clock.clone()).await; + let task_id = random(); let report_id = random(); let vdaf = Arc::new(dummy::Vdaf::new(2)); let aggregation_param = dummy::AggregationParam(7); - let vdaf_transcript = run_vdaf(vdaf.as_ref(), &[], &aggregation_param, &report_id, &13); + let vdaf_transcript = run_vdaf( + vdaf.as_ref(), + &task_id, + &[], + &aggregation_param, + &report_id, + &13, + ); let task = TaskBuilder::new( task::BatchMode::TimeInterval, @@ -2714,11 +2730,19 @@ async fn create_report_aggregation_from_client_reports_table( let clock = MockClock::new(OLDEST_ALLOWED_REPORT_TIMESTAMP); let ds = ephemeral_datastore.datastore(clock.clone()).await; + let task_id = random(); let report_id = random(); let vdaf = Arc::new(dummy::Vdaf::new(2)); let aggregation_param = dummy::AggregationParam(7); - let vdaf_transcript = run_vdaf(vdaf.as_ref(), &[], &aggregation_param, &report_id, &13); + let vdaf_transcript = run_vdaf( + vdaf.as_ref(), + &task_id, + &[], + &aggregation_param, + &report_id, + &13, + ); let task = TaskBuilder::new( task::BatchMode::TimeInterval, @@ -5141,6 +5165,7 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { let report_id_0_1 = random(); let transcript = run_vdaf( &dummy::Vdaf::default(), + task_1.id(), task_1.vdaf_verify_key().unwrap().as_bytes(), &dummy::AggregationParam(0), &report_id_0_1, diff --git a/client/src/lib.rs b/client/src/lib.rs index 2277b641a..9a4e7bbe0 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -53,6 +53,7 @@ use janus_core::{ retries::{http_request_exponential_backoff, retry_http_request}, time::{Clock, RealClock, TimeExt}, url_ensure_trailing_slash, + vdaf::vdaf_application_context, }; use janus_messages::{ Duration, HpkeConfig, HpkeConfigList, InputShareAad, PlaintextInputShare, Report, ReportId, @@ -576,7 +577,11 @@ impl> Client { /// uploaded. fn prepare_report(&self, measurement: &V::Measurement, time: &Time) -> Result { let report_id: ReportId = random(); - let (public_share, input_shares) = self.vdaf.shard(measurement, report_id.as_ref())?; + let (public_share, input_shares) = self.vdaf.shard( + &vdaf_application_context(&self.parameters.task_id), + measurement, + report_id.as_ref(), + )?; assert_eq!(input_shares.len(), 2); // DAP only supports VDAFs using two aggregators. let time = time diff --git a/collector/src/lib.rs b/collector/src/lib.rs index 9fa8a7c26..53d2a7cae 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -915,7 +915,7 @@ mod tests { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let vdaf = Prio3::new_count(2).unwrap(); - let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &true); + let transcript = run_vdaf(&vdaf, &random(), &random(), &(), &random(), &true); let collector = setup_collector(&mut server, vdaf); let (auth_header, auth_value) = collector.authentication.request_authentication(); @@ -1016,7 +1016,7 @@ mod tests { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let vdaf = Prio3::new_sum(2, 8).unwrap(); - let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &144); + let transcript = run_vdaf(&vdaf, &random(), &random(), &(), &random(), &144); let collector = setup_collector(&mut server, vdaf); let batch_interval = Interval::new( @@ -1084,7 +1084,7 @@ mod tests { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let vdaf = Prio3::new_histogram(2, 4, 2).unwrap(); - let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &3); + let transcript = run_vdaf(&vdaf, &random(), &random(), &(), &random(), &3); let collector = setup_collector(&mut server, vdaf); let batch_interval = Interval::new( @@ -1159,6 +1159,7 @@ mod tests { let transcript = run_vdaf( &vdaf, &random(), + &random(), &(), &random(), &vec![FP32_16_INV, FP32_8_INV, FP32_4_INV], @@ -1231,7 +1232,7 @@ mod tests { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let vdaf = Prio3::new_count(2).unwrap(); - let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &true); + let transcript = run_vdaf(&vdaf, &random(), &random(), &(), &random(), &true); let collector = setup_collector(&mut server, vdaf); let batch_id = random(); @@ -1294,7 +1295,7 @@ mod tests { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let vdaf = Prio3::new_count(2).unwrap(); - let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &true); + let transcript = run_vdaf(&vdaf, &random(), &random(), &(), &random(), &true); let server_url = Url::parse(&server.url()).unwrap(); let hpke_keypair = HpkeKeypair::test(); let collector = Collector::builder( @@ -1955,7 +1956,7 @@ mod tests { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let vdaf = Prio3::new_count(2).unwrap(); - let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &true); + let transcript = run_vdaf(&vdaf, &random(), &random(), &(), &random(), &true); let collector = setup_collector(&mut server, vdaf); let (auth_header, auth_value) = collector.authentication.request_authentication(); diff --git a/core/Cargo.toml b/core/Cargo.toml index bac14c6c4..0b7ef7c55 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -41,6 +41,7 @@ base64.workspace = true bytes.workspace = true chrono = { workspace = true, features = ["clock"] } clap.workspace = true +constcat.workspace = true derivative.workspace = true fixed = { workspace = true, optional = true } futures = { workspace = true } diff --git a/core/src/hpke.rs b/core/src/hpke.rs index 01a8b2a89..d5bef83f8 100644 --- a/core/src/hpke.rs +++ b/core/src/hpke.rs @@ -1,5 +1,7 @@ //! Encryption and decryption of messages using HPKE (RFC 9180). +use crate::DAP_VERSION_IDENTIFIER; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; +use constcat::concat; use derivative::Derivative; use hpke_dispatch::{HpkeError, Kem, Keypair}; use janus_messages::{ @@ -63,8 +65,8 @@ impl Label { /// Get the message-specific portion of the application info string for this label. pub fn as_bytes(&self) -> &'static [u8] { match self { - Self::InputShare => b"dap-09 input share", - Self::AggregateShare => b"dap-09 aggregate share", + Self::InputShare => concat!(DAP_VERSION_IDENTIFIER, " input share").as_bytes(), + Self::AggregateShare => concat!(DAP_VERSION_IDENTIFIER, " aggregate share").as_bytes(), } } } diff --git a/core/src/lib.rs b/core/src/lib.rs index cf6b97182..3f3398071 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -44,6 +44,10 @@ pub mod taskprov { pub const TASKPROV_HEADER: &str = "dap-taskprov"; } +/// This value is used in a few places throughout the protocol to identify the draft of DAP being +/// implemented. +const DAP_VERSION_IDENTIFIER: &str = "dap-13"; + /// Returns the given [`Url`], possibly modified to end with a slash. /// /// Aggregator endpoint URLs should end with a slash if they will be used with [`Url::join`], diff --git a/core/src/test_util/mod.rs b/core/src/test_util/mod.rs index 8427d316e..03e135b72 100644 --- a/core/src/test_util/mod.rs +++ b/core/src/test_util/mod.rs @@ -1,5 +1,6 @@ +use crate::vdaf::vdaf_application_context; use assert_matches::assert_matches; -use janus_messages::{ReportId, Role}; +use janus_messages::{ReportId, Role, TaskId}; use prio::{ topology::ping_pong::{ PingPongContinuedValue, PingPongMessage, PingPongState, PingPongTopology, @@ -88,20 +89,24 @@ pub fn run_vdaf< V: vdaf::Aggregator + vdaf::Client<16>, >( vdaf: &V, + task_id: &TaskId, verify_key: &[u8; VERIFY_KEY_LENGTH], aggregation_param: &V::AggregationParam, report_id: &ReportId, measurement: &V::Measurement, ) -> VdafTranscript { + let ctx = vdaf_application_context(task_id); + let mut leader_prepare_transitions = Vec::new(); let mut helper_prepare_transitions = Vec::new(); // Shard inputs into input shares, and initialize the initial PrepareTransitions. - let (public_share, input_shares) = vdaf.shard(measurement, report_id.as_ref()).unwrap(); + let (public_share, input_shares) = vdaf.shard(&ctx, measurement, report_id.as_ref()).unwrap(); let (leader_state, leader_message) = vdaf .leader_initialized( verify_key, + &ctx, aggregation_param, report_id.as_ref(), &public_share, @@ -118,6 +123,7 @@ pub fn run_vdaf< let helper_transition = vdaf .helper_initialized( verify_key, + &ctx, aggregation_param, report_id.as_ref(), &public_share, @@ -125,7 +131,7 @@ pub fn run_vdaf< &leader_message, ) .unwrap(); - let (helper_state, helper_message) = helper_transition.evaluate(vdaf).unwrap(); + let (helper_state, helper_message) = helper_transition.evaluate(&ctx, vdaf).unwrap(); helper_prepare_transitions.push(HelperPrepareTransition { transition: helper_transition, @@ -155,6 +161,7 @@ pub fn run_vdaf< let state_and_message = match role { Role::Leader => vdaf .leader_continued( + &ctx, curr_state.clone(), aggregation_param, last_peer_message, @@ -162,6 +169,7 @@ pub fn run_vdaf< .unwrap(), Role::Helper => vdaf .helper_continued( + &ctx, curr_state.clone(), aggregation_param, last_peer_message, @@ -172,7 +180,7 @@ pub fn run_vdaf< match state_and_message { PingPongContinuedValue::WithMessage { transition } => { - let (state, message) = transition.clone().evaluate(vdaf).unwrap(); + let (state, message) = transition.clone().evaluate(&ctx, vdaf).unwrap(); match role { Role::Leader => { leader_prepare_transitions.push(LeaderPrepareTransition { diff --git a/core/src/vdaf.rs b/core/src/vdaf.rs index 7fdae75d2..aa6e76f3e 100644 --- a/core/src/vdaf.rs +++ b/core/src/vdaf.rs @@ -1,5 +1,6 @@ +use crate::DAP_VERSION_IDENTIFIER; use derivative::Derivative; -use janus_messages::taskprov; +use janus_messages::{taskprov, TaskId}; use prio::{ field::Field64, flp::{ @@ -23,6 +24,18 @@ const ALGORITHM_ID_PRIO3_SUM_VEC_FIELD64_MULTIPROOF_HMACSHA256_AES128: u32 = 0xF /// of the VDAF specification. pub const VERIFY_KEY_LENGTH_HMACSHA256_AES128: usize = 32; +/// Computes the VDAF "application context", which is an opaque byte string used by many VDAF +/// operations for domain-separation purposes. In DAP, this is the DAP draft number concatenated +/// with the task ID. +pub fn vdaf_application_context( + task_id: &TaskId, +) -> [u8; DAP_VERSION_IDENTIFIER.len() + TaskId::LEN] { + let mut ctx = [0u8; DAP_VERSION_IDENTIFIER.len() + TaskId::LEN]; + ctx[..DAP_VERSION_IDENTIFIER.len()].copy_from_slice(DAP_VERSION_IDENTIFIER.as_bytes()); + ctx[DAP_VERSION_IDENTIFIER.len()..].copy_from_slice(task_id.as_ref()); + ctx +} + /// Bitsize parameter for the `Prio3FixedPointBoundedL2VecSum` VDAF. #[cfg(feature = "fpvec_bounded_l2")] #[derive(Debug, Derivative, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] diff --git a/integration_tests/tests/integration/simulation/bad_client.rs b/integration_tests/tests/integration/simulation/bad_client.rs index 65f26f6fe..d59faf7e5 100644 --- a/integration_tests/tests/integration/simulation/bad_client.rs +++ b/integration_tests/tests/integration/simulation/bad_client.rs @@ -1,5 +1,3 @@ -use std::{net::Ipv4Addr, sync::Arc}; - use assert_matches::assert_matches; use http::header::CONTENT_TYPE; use janus_aggregator::aggregator::http_handlers::AggregatorHandlerBuilder; @@ -14,11 +12,11 @@ use janus_core::{ retries::retry_http_request, test_util::{install_test_trace_subscriber, runtime::TestRuntime}, time::{Clock, RealClock, TimeExt}, - vdaf::{vdaf_dp_strategies, VdafInstance}, + vdaf::{vdaf_application_context, vdaf_dp_strategies, VdafInstance}, }; use janus_messages::{ HpkeConfig, HpkeConfigList, InputShareAad, PlaintextInputShare, Report, ReportId, - ReportMetadata, Role, Time, + ReportMetadata, Role, TaskId, Time, }; use prio::{ codec::{Decode, Encode, ParameterizedDecode}, @@ -31,6 +29,7 @@ use prio::{ }, }; use rand::{distributions::Standard, random, Rng}; +use std::{net::Ipv4Addr, sync::Arc}; use tokio::net::TcpListener; use trillium_tokio::Stopper; use url::Url; @@ -49,7 +48,11 @@ pub(super) async fn upload_replay_report( let report_id = ReportId::from([ 173, 234, 101, 107, 42, 222, 166, 86, 178, 173, 234, 101, 107, 42, 222, 166, ]); - let (public_share, input_shares) = vdaf.shard(&measurement, report_id.as_ref())?; + let (public_share, input_shares) = vdaf.shard( + &vdaf_application_context(task.id()), + &measurement, + report_id.as_ref(), + )?; let rounded_time = report_time .to_batch_interval_start(task.time_precision()) .unwrap(); @@ -75,7 +78,11 @@ pub(super) async fn upload_report_not_rounded( http_client: &reqwest::Client, ) -> Result<(), janus_client::Error> { let report_id: ReportId = random(); - let (public_share, input_shares) = vdaf.shard(&measurement, report_id.as_ref())?; + let (public_share, input_shares) = vdaf.shard( + &vdaf_application_context(task.id()), + &measurement, + report_id.as_ref(), + )?; let report = prepare_report( http_client, @@ -99,9 +106,10 @@ pub(super) async fn upload_report_invalid_measurement( let mut encoded_measurement = Vec::from([Field128::zero(); MAX_REPORTS]); encoded_measurement[0] = Field128::one(); encoded_measurement[1] = Field128::one(); - let report_id: ReportId = random(); + let task_id = random(); + let report_id = random(); let (public_share, input_shares) = - shard_encoded_measurement(vdaf, encoded_measurement, report_id); + shard_encoded_measurement(vdaf, &task_id, encoded_measurement, report_id); let report = prepare_report( http_client, @@ -119,6 +127,7 @@ pub(super) async fn upload_report_invalid_measurement( /// algorithm on it to produce a public share and a set of input shares. fn shard_encoded_measurement( vdaf: &Prio3Histogram, + task_id: &TaskId, encoded_measurement: Vec, report_id: ReportId, ) -> (Prio3PublicShare<16>, Vec>) { @@ -133,6 +142,8 @@ fn shard_encoded_measurement( const NUM_PROOFS: u8 = 1; + let ctx = vdaf_application_context(task_id); + assert_eq!(encoded_measurement.len(), MAX_REPORTS); let chunk_length = optimal_chunk_length(MAX_REPORTS); let circuit: Histogram> = @@ -142,7 +153,7 @@ fn shard_encoded_measurement( let helper_measurement_share_seed = Seed::<16>::generate().unwrap(); let mut helper_measurement_share_rng = XofTurboShake128::seed_stream( &helper_measurement_share_seed, - &vdaf.domain_separation_tag(DST_MEASUREMENT_SHARE), + &vdaf.domain_separation_tag(DST_MEASUREMENT_SHARE, &ctx), &[HELPER_AGGREGATOR_ID], ); let mut expanded_helper_measurement_share: Vec = @@ -158,7 +169,7 @@ fn shard_encoded_measurement( let helper_joint_rand_blind = Seed::<16>::generate().unwrap(); let mut helper_joint_rand_part_xof = XofTurboShake128::init( helper_joint_rand_blind.as_ref(), - &vdaf.domain_separation_tag(DST_JOINT_RAND_PART), + &vdaf.domain_separation_tag(DST_JOINT_RAND_PART, &ctx), ); helper_joint_rand_part_xof.update(&[HELPER_AGGREGATOR_ID]); helper_joint_rand_part_xof.update(report_id.as_ref()); @@ -170,7 +181,7 @@ fn shard_encoded_measurement( let leader_joint_rand_blind = Seed::<16>::generate().unwrap(); let mut leader_joint_rand_part_xof = XofTurboShake128::init( leader_joint_rand_blind.as_ref(), - &vdaf.domain_separation_tag(DST_JOINT_RAND_PART), + &vdaf.domain_separation_tag(DST_JOINT_RAND_PART, &ctx), ); leader_joint_rand_part_xof.update(&[LEADER_AGGREGATOR_ID]); leader_joint_rand_part_xof.update(report_id.as_ref()); @@ -179,15 +190,17 @@ fn shard_encoded_measurement( } let leader_joint_rand_seed_part = leader_joint_rand_part_xof.into_seed(); - let mut joint_rand_seed_xof = - XofTurboShake128::init(&[0; 16], &vdaf.domain_separation_tag(DST_JOINT_RAND_SEED)); + let mut joint_rand_seed_xof = XofTurboShake128::init( + &[0; 16], + &vdaf.domain_separation_tag(DST_JOINT_RAND_SEED, &ctx), + ); joint_rand_seed_xof.update(leader_joint_rand_seed_part.as_ref()); joint_rand_seed_xof.update(helper_joint_rand_seed_part.as_ref()); let joint_rand_seed = joint_rand_seed_xof.into_seed(); let mut joint_rand: Vec = Vec::with_capacity(circuit.joint_rand_len()); let mut joint_rand_xof = XofTurboShake128::seed_stream( &joint_rand_seed, - &vdaf.domain_separation_tag(DST_JOINT_RANDOMNESS), + &vdaf.domain_separation_tag(DST_JOINT_RANDOMNESS, &ctx), &[NUM_PROOFS], ); for _ in 0..circuit.joint_rand_len() { @@ -205,7 +218,7 @@ fn shard_encoded_measurement( let helper_proof_share_seed = Seed::<16>::generate().unwrap(); let mut helper_proof_share_xof = XofTurboShake128::seed_stream( &helper_proof_share_seed, - &vdaf.domain_separation_tag(DST_PROOF_SHARE), + &vdaf.domain_separation_tag(DST_PROOF_SHARE, &ctx), &[NUM_PROOFS, HELPER_AGGREGATOR_ID], ); for leader_elem in leader_proof_share.iter_mut() { @@ -439,14 +452,17 @@ fn shard_encoded_measurement_correct() { let mut encoded_measurement = Vec::from([Field128::zero(); MAX_REPORTS]); encoded_measurement[0] = Field128::one(); - let report_id: ReportId = random(); + let task_id = random(); + let report_id = random(); let (public_share, input_shares) = - shard_encoded_measurement(&vdaf, encoded_measurement, report_id); + shard_encoded_measurement(&vdaf, &task_id, encoded_measurement, report_id); let verify_key: [u8; 16] = random(); + let ctx = vdaf_application_context(&task_id); let (leader_prepare_state, leader_prepare_share) = vdaf .prepare_init( &verify_key, + &ctx, 0, &(), report_id.as_ref(), @@ -457,6 +473,7 @@ fn shard_encoded_measurement_correct() { let (helper_prepare_state, helper_prepare_share) = vdaf .prepare_init( &verify_key, + &ctx, 1, &(), report_id.as_ref(), @@ -465,13 +482,13 @@ fn shard_encoded_measurement_correct() { ) .unwrap(); let prepare_message = vdaf - .prepare_shares_to_prepare_message(&(), [leader_prepare_share, helper_prepare_share]) + .prepare_shares_to_prepare_message(&ctx, &(), [leader_prepare_share, helper_prepare_share]) .unwrap(); let leader_transition = vdaf - .prepare_next(leader_prepare_state, prepare_message.clone()) + .prepare_next(&ctx, leader_prepare_state, prepare_message.clone()) .unwrap(); let helper_transition = vdaf - .prepare_next(helper_prepare_state, prepare_message) + .prepare_next(&ctx, helper_prepare_state, prepare_message) .unwrap(); let leader_output_share = assert_matches!(leader_transition, PrepareTransition::Finish(output_share) => output_share);