From f49720fbafdab8f102d908b2be3fa869482a92fa Mon Sep 17 00:00:00 2001 From: EmilLuta Date: Tue, 30 Jul 2024 14:41:21 +0200 Subject: [PATCH] fix(prover): Parallelize circuit metadata uploading for BWG (#2520) Basic Witness Generation (BWG) creates circuits for base layer to be proven. BWG runs out-of-circuit VM and provides circuits for the instructions in the execution (note that those are ordered, messing with the order would break proving). Additionally, harness will do heavy computation -> provide circuits -> heavy computation -> provide circuits. As such, providing circuits is rather fast. Once a circuit is produced, it's sent to BWG, which in turn sends it to GCS and saves it to DB. The queue between BWG runner and harness is of size 1. It was picked in this way to ensure that harness won't overwhelm bwg runner with circuits (if the queue was big, the amount of RAM used would be size of queue * size of circuit). The problem is that the harness will be throttled on circuit submission. It needs to blockingly wait for bwg runner to upload data (GCS + PG). There are 2 alternatives here: - make the buffer bigger -- this would work, but it would start consuming more memory. In theory, if you can match the size of the buffer to fit in between CPU cycles, harness wouldn't need to wait for upload and bwg runner would catch up by the time harness has a new set of circuits. The problem with this approach is that it relies on deterministic upload time (nothing can be deterministic in distributed systems) and fixed circuit size (this is not true, because batches can have varying number of circuits, of varying types). - upload everything async -- this work wonders, but there will be multiple circuits in flight being uploaded at any time. This will also increase RAM. In my testing, I've seen increases smaller than 100MB (acceptable tradeoff). The edge cases are when GCS is down or not working as intended (for instance, refusing connection). In such scenarios the async tasks will grow, adding to a lot more RAM usage (up to 50GB). Whilst this is a real problem, I consider it more of an edge case. Furthermore, there are a few ways to make GCS more reliable. The second option was picked here. MPSC channel (the buffer) is still of size 1 to have a single point of failure when RAM goes up (we don't want to figure out -- do we have too many things in the buffer, or do we have problems with GCS? -- in the current implementation, it will always be GCS). Furthermore, MPSC(1) doesn't cause much overhead. The implementation has been tested locally, using mainnet and testnet batches. Whilst the number are super encouraging (00:03:58 time, with ~40GB RAM), the CPUs (and their speed) is slightly lower in GCP. --- .../src/configs/fri_witness_generator.rs | 18 +++ core/lib/config/src/testonly.rs | 1 + .../env_config/src/fri_witness_generator.rs | 2 + .../src/proto/config/prover.proto | 1 + core/lib/protobuf_config/src/prover.rs | 4 + etc/env/file_based/general.yaml | 1 + .../crates/bin/witness_generator/Cargo.toml | 2 +- .../witness_generator/src/basic_circuits.rs | 153 ++++++++++++++---- .../crates/bin/witness_generator/src/main.rs | 2 +- .../crates/bin/witness_generator/src/utils.rs | 3 +- 10 files changed, 150 insertions(+), 37 deletions(-) diff --git a/core/lib/config/src/configs/fri_witness_generator.rs b/core/lib/config/src/configs/fri_witness_generator.rs index 281159271dd0..44eab27d3b5e 100644 --- a/core/lib/config/src/configs/fri_witness_generator.rs +++ b/core/lib/config/src/configs/fri_witness_generator.rs @@ -25,6 +25,15 @@ pub struct FriWitnessGeneratorConfig { pub shall_save_to_public_bucket: bool, pub prometheus_listener_port: Option, + + /// This value corresponds to the maximum number of circuits kept in memory at any given time for a BWG. + /// Acts as a throttling mechanism for circuits; the trade-off here is speed vs memory usage. + /// With more circuits in flight, harness does not need to wait for BWG runner to process them. + /// But every single circuit in flight eats memory (up to 50MB). + /// WARNING: Do NOT change this value unless you're absolutely sure you know what you're doing. + /// It affects the performance and resource usage of BWGs. + #[serde(default = "FriWitnessGeneratorConfig::default_max_circuits_in_flight")] + pub max_circuits_in_flight: usize, } #[derive(Debug)] @@ -87,4 +96,13 @@ impl FriWitnessGeneratorConfig { pub fn last_l1_batch_to_process(&self) -> u32 { self.last_l1_batch_to_process.unwrap_or(u32::MAX) } + + /// 500 was picked as a mid-ground between allowing enough circuits in flight to speed up circuit generation, + /// whilst keeping memory as low as possible. At the moment, max size of a circuit is ~50MB. + /// This number is important when there are issues with saving circuits (network issues, service unavailability, etc.) + /// Maximum theoretic extra memory consumed is up to 25GB (50MB * 500 circuits), but in reality, worse case scenarios are closer to 5GB (the average space distribution). + /// During normal operations (> P95), this will incur an overhead of ~100MB. + const fn default_max_circuits_in_flight() -> usize { + 500 + } } diff --git a/core/lib/config/src/testonly.rs b/core/lib/config/src/testonly.rs index c7864e629fb1..2c2934859fe5 100644 --- a/core/lib/config/src/testonly.rs +++ b/core/lib/config/src/testonly.rs @@ -571,6 +571,7 @@ impl Distribution for EncodeDist { last_l1_batch_to_process: self.sample(rng), shall_save_to_public_bucket: self.sample(rng), prometheus_listener_port: self.sample(rng), + max_circuits_in_flight: self.sample(rng), } } } diff --git a/core/lib/env_config/src/fri_witness_generator.rs b/core/lib/env_config/src/fri_witness_generator.rs index 5853a0178308..a79638624653 100644 --- a/core/lib/env_config/src/fri_witness_generator.rs +++ b/core/lib/env_config/src/fri_witness_generator.rs @@ -27,6 +27,7 @@ mod tests { last_l1_batch_to_process: None, shall_save_to_public_bucket: true, prometheus_listener_port: Some(3333u16), + max_circuits_in_flight: 500, } } @@ -43,6 +44,7 @@ mod tests { FRI_WITNESS_MAX_ATTEMPTS=4 FRI_WITNESS_SHALL_SAVE_TO_PUBLIC_BUCKET=true FRI_WITNESS_PROMETHEUS_LISTENER_PORT=3333 + FRI_WITNESS_MAX_CIRCUITS_IN_FLIGHT=500 "#; lock.set_env(config); diff --git a/core/lib/protobuf_config/src/proto/config/prover.proto b/core/lib/protobuf_config/src/proto/config/prover.proto index 80d45f40bbcb..4fe3861183bf 100644 --- a/core/lib/protobuf_config/src/proto/config/prover.proto +++ b/core/lib/protobuf_config/src/proto/config/prover.proto @@ -88,6 +88,7 @@ message WitnessGenerator { optional uint32 scheduler_generation_timeout_in_secs = 11; // optional; optional uint32 recursion_tip_timeout_in_secs = 12; // optional; optional uint32 prometheus_listener_port = 13; // optional; + optional uint64 max_circuits_in_flight = 14; // optional; reserved 3, 4, 6; reserved "dump_arguments_for_blocks", "force_process_block", "blocks_proving_percentage"; } diff --git a/core/lib/protobuf_config/src/prover.rs b/core/lib/protobuf_config/src/prover.rs index e1c31ee1fccd..e88338833053 100644 --- a/core/lib/protobuf_config/src/prover.rs +++ b/core/lib/protobuf_config/src/prover.rs @@ -198,6 +198,9 @@ impl ProtoRepr for proto::WitnessGenerator { .map(|x| x.try_into()) .transpose() .context("prometheus_listener_port")?, + max_circuits_in_flight: required(&self.max_circuits_in_flight) + .and_then(|x| Ok((*x).try_into()?)) + .context("max_circuits_in_flight")?, }) } @@ -219,6 +222,7 @@ impl ProtoRepr for proto::WitnessGenerator { .scheduler_generation_timeout_in_secs .map(|x| x.into()), prometheus_listener_port: this.prometheus_listener_port.map(|x| x.into()), + max_circuits_in_flight: Some(this.max_circuits_in_flight as u64), } } } diff --git a/etc/env/file_based/general.yaml b/etc/env/file_based/general.yaml index f91e6236a1ea..b6ba932d63c6 100644 --- a/etc/env/file_based/general.yaml +++ b/etc/env/file_based/general.yaml @@ -157,6 +157,7 @@ witness_generator: max_attempts: 10 shall_save_to_public_bucket: true prometheus_listener_port: 3116 + max_circuits_in_flight: 500 witness_vector_generator: prover_instance_wait_timeout_in_secs: 200 prover_instance_poll_time_in_milli_secs: 250 diff --git a/prover/crates/bin/witness_generator/Cargo.toml b/prover/crates/bin/witness_generator/Cargo.toml index fe73a02ba2af..7eb75bb3d82f 100644 --- a/prover/crates/bin/witness_generator/Cargo.toml +++ b/prover/crates/bin/witness_generator/Cargo.toml @@ -29,7 +29,7 @@ zksync_prover_fri_utils.workspace = true zksync_core_leftovers.workspace = true zkevm_test_harness = { workspace = true } -circuit_definitions = { workspace = true, features = [ "log_tracing" ] } +circuit_definitions = { workspace = true, features = ["log_tracing"] } anyhow.workspace = true tracing.workspace = true diff --git a/prover/crates/bin/witness_generator/src/basic_circuits.rs b/prover/crates/bin/witness_generator/src/basic_circuits.rs index 76d3dce5ac83..fe2d86716f4f 100644 --- a/prover/crates/bin/witness_generator/src/basic_circuits.rs +++ b/prover/crates/bin/witness_generator/src/basic_circuits.rs @@ -12,6 +12,7 @@ use circuit_definitions::{ encodings::recursion_request::RecursionQueueSimulator, zkevm_circuits::fsm_input_output::ClosedFormInputCompactFormWitness, }; +use tokio::sync::Semaphore; use tracing::Instrument; use zkevm_test_harness::geometry_config::get_geometry_config; use zksync_config::configs::FriWitnessGeneratorConfig; @@ -107,6 +108,7 @@ impl BasicWitnessGenerator { object_store: Arc, basic_job: BasicWitnessGeneratorJob, started_at: Instant, + max_circuits_in_flight: usize, ) -> Option { let BasicWitnessGeneratorJob { block_number, job } = basic_job; @@ -116,7 +118,16 @@ impl BasicWitnessGenerator { block_number.0 ); - Some(process_basic_circuits_job(&*object_store, started_at, block_number, job).await) + Some( + process_basic_circuits_job( + object_store, + started_at, + block_number, + job, + max_circuits_in_flight, + ) + .await, + ) } } @@ -177,11 +188,14 @@ impl JobProcessor for BasicWitnessGenerator { started_at: Instant, ) -> tokio::task::JoinHandle>> { let object_store = Arc::clone(&self.object_store); + let max_circuits_in_flight = self.config.max_circuits_in_flight; tokio::spawn(async move { let block_number = job.block_number; - Ok(Self::process_job_impl(object_store, job, started_at) - .instrument(tracing::info_span!("basic_circuit", %block_number)) - .await) + Ok( + Self::process_job_impl(object_store, job, started_at, max_circuits_in_flight) + .instrument(tracing::info_span!("basic_circuit", %block_number)) + .await, + ) }) } @@ -246,13 +260,14 @@ impl JobProcessor for BasicWitnessGenerator { #[tracing::instrument(skip_all, fields(l1_batch = %block_number))] async fn process_basic_circuits_job( - object_store: &dyn ObjectStore, + object_store: Arc, started_at: Instant, block_number: L1BatchNumber, job: WitnessInputData, + max_circuits_in_flight: usize, ) -> BasicCircuitArtifacts { let (circuit_urls, queue_urls, scheduler_witness, aux_output_witness) = - generate_witness(block_number, object_store, job).await; + generate_witness(block_number, object_store, job, max_circuits_in_flight).await; WITNESS_GENERATOR_METRICS.witness_generation_time[&AggregationRound::BasicCircuits.into()] .observe(started_at.elapsed()); tracing::info!( @@ -350,8 +365,8 @@ async fn save_recursion_queue( block_number: L1BatchNumber, circuit_id: u8, recursion_queue_simulator: RecursionQueueSimulator, - closed_form_inputs: &[ClosedFormInputCompactFormWitness], - object_store: &dyn ObjectStore, + closed_form_inputs: Vec>, + object_store: Arc, ) -> (u8, String, usize) { let key = ClosedFormInputKey { block_number, @@ -381,8 +396,9 @@ type Witness = ( #[tracing::instrument(skip_all, fields(l1_batch = %block_number))] async fn generate_witness( block_number: L1BatchNumber, - object_store: &dyn ObjectStore, + object_store: Arc, input: WitnessInputData, + max_circuits_in_flight: usize, ) -> Witness { let bootloader_contents = expand_bootloader_contents( &input.vm_run_data.initial_heap_content, @@ -407,7 +423,10 @@ async fn generate_witness( let make_circuits_span = tracing::info_span!("make_circuits"); let make_circuits_span_copy = make_circuits_span.clone(); - let make_circuits = tokio::task::spawn_blocking(move || { + // Blocking call from harness that does the CPU heavy lifting. + // Provides circuits and recursion queue via callback functions and returns scheduler witnesses. + // Circuits are "streamed" one by one as they're being generated. + let make_circuits_handle = tokio::task::spawn_blocking(move || { let span = tracing::info_span!(parent: make_circuits_span_copy, "make_circuits_blocking"); let witness_storage = WitnessStorage::new(input.vm_run_data.witness_block_state); @@ -446,43 +465,109 @@ async fn generate_witness( |circuit| { let parent_span = span.clone(); tracing::info_span!(parent: parent_span, "send_circuit").in_scope(|| { - circuit_sender.blocking_send(circuit).unwrap(); + circuit_sender + .blocking_send(circuit) + .expect("failed to send circuit from harness"); }); }, - |a, b, c| queue_sender.blocking_send((a as u8, b, c)).unwrap(), + |a, b, c| { + queue_sender + .blocking_send((a as u8, b, c)) + .expect("failed to send recursion queue from harness") + }, ); (scheduler_witness, block_witness) }) .instrument(make_circuits_span); - let mut circuit_urls = vec![]; - let mut recursion_urls = vec![]; - - let mut circuits_present = HashSet::::new(); + let mut save_circuit_handles = vec![]; let save_circuits_span = tracing::info_span!("save_circuits"); - let save_circuits = async { - loop { - tokio::select! { - Some(circuit) = circuit_receiver.recv().instrument(tracing::info_span!("wait_for_circuit")) => { - circuits_present.insert(circuit.numeric_circuit_type()); - circuit_urls.push( - save_circuit(block_number, circuit, circuit_urls.len(), object_store).await, - ); - } - Some((circuit_id, queue, inputs)) = queue_receiver.recv().instrument(tracing::info_span!("wait_for_queue")) => { - let urls = save_recursion_queue(block_number, circuit_id, queue, &inputs, object_store).await; - recursion_urls.push(urls); - } - else => break, - }; + + // Future which receives circuits and saves them async. + let circuit_receiver_handle = async { + // Ordering determines how we compose the circuit proofs in Leaf Aggregation Round. + // Sequence is used to determine circuit ordering (the sequencing of instructions) . + // If the order is tampered with, proving will fail (as the proof would be computed for a different sequence of instruction). + let mut circuit_sequence = 0; + + let semaphore = Arc::new(Semaphore::new(max_circuits_in_flight)); + + while let Some(circuit) = circuit_receiver + .recv() + .instrument(tracing::info_span!("wait_for_circuit")) + .await + { + let sequence = circuit_sequence; + circuit_sequence += 1; + let object_store = object_store.clone(); + let semaphore = semaphore.clone(); + let permit = semaphore + .acquire_owned() + .await + .expect("failed to get permit for running save circuit task"); + save_circuit_handles.push(tokio::task::spawn(async move { + let (circuit_id, circuit_url) = + save_circuit(block_number, circuit, sequence, object_store).await; + drop(permit); + (circuit_id, circuit_url) + })); } - }.instrument(save_circuits_span); + } + .instrument(save_circuits_span); + + let mut save_queue_handles = vec![]; + + let save_queues_span = tracing::info_span!("save_queues"); - let (witnesses, ()) = tokio::join!(make_circuits, save_circuits,); + // Future which receives recursion queues and saves them async. + // Note that this section needs no semaphore as there's # of circuit ids (16) queues at most. + // All queues combined are < 10MB. + let queue_receiver_handle = async { + while let Some((circuit_id, queue, inputs)) = queue_receiver + .recv() + .instrument(tracing::info_span!("wait_for_queue")) + .await + { + let object_store = object_store.clone(); + save_queue_handles.push(tokio::task::spawn(save_recursion_queue( + block_number, + circuit_id, + queue, + inputs, + object_store, + ))); + } + } + .instrument(save_queues_span); + + let (witnesses, _, _) = tokio::join!( + make_circuits_handle, + circuit_receiver_handle, + queue_receiver_handle + ); let (mut scheduler_witness, block_aux_witness) = witnesses.unwrap(); - recursion_urls.retain(|(circuit_id, _, _)| circuits_present.contains(circuit_id)); + // Harness returns recursion queues for all circuits, but for proving only the queues that have circuits matter. + // `circuits_present` stores which circuits exist and is used to filter queues in `recursion_urls` later. + let mut circuits_present = HashSet::::new(); + + let circuit_urls = futures::future::join_all(save_circuit_handles) + .await + .into_iter() + .map(|result| { + let (circuit_id, circuit_url) = result.expect("failed to save circuit"); + circuits_present.insert(circuit_id); + (circuit_id, circuit_url) + }) + .collect(); + + let recursion_urls = futures::future::join_all(save_queue_handles) + .await + .into_iter() + .map(|result| result.expect("failed to save queue")) + .filter(|(circuit_id, _, _)| circuits_present.contains(circuit_id)) + .collect(); scheduler_witness.previous_block_meta_hash = input.previous_batch_metadata.meta_hash.0; scheduler_witness.previous_block_aux_hash = input.previous_batch_metadata.aux_hash.0; diff --git a/prover/crates/bin/witness_generator/src/main.rs b/prover/crates/bin/witness_generator/src/main.rs index 38b2e46ef74b..d3b828b06558 100644 --- a/prover/crates/bin/witness_generator/src/main.rs +++ b/prover/crates/bin/witness_generator/src/main.rs @@ -11,6 +11,7 @@ use zksync_core_leftovers::temp_config_store::{load_database_secrets, load_gener use zksync_env_config::object_store::ProverObjectStoreConfig; use zksync_object_store::ObjectStoreFactory; use zksync_prover_dal::{ConnectionPool, Prover, ProverDal}; +use zksync_prover_fri_types::PROVER_PROTOCOL_SEMANTIC_VERSION; use zksync_queued_job_processor::JobProcessor; use zksync_types::basic_fri_types::AggregationRound; use zksync_utils::wait_for_tasks::ManagedTasks; @@ -35,7 +36,6 @@ mod utils; #[cfg(not(target_env = "msvc"))] use jemallocator::Jemalloc; -use zksync_prover_fri_types::PROVER_PROTOCOL_SEMANTIC_VERSION; #[cfg(not(target_env = "msvc"))] #[global_allocator] diff --git a/prover/crates/bin/witness_generator/src/utils.rs b/prover/crates/bin/witness_generator/src/utils.rs index 7671e2fd86db..97991fbd4d04 100644 --- a/prover/crates/bin/witness_generator/src/utils.rs +++ b/prover/crates/bin/witness_generator/src/utils.rs @@ -1,6 +1,7 @@ use std::{ collections::HashMap, io::{BufWriter, Write as _}, + sync::Arc, }; use circuit_definitions::circuit_definitions::{ @@ -130,7 +131,7 @@ pub async fn save_circuit( block_number: L1BatchNumber, circuit: ZkSyncBaseLayerCircuit, sequence_number: usize, - object_store: &dyn ObjectStore, + object_store: Arc, ) -> (u8, String) { let circuit_id = circuit.numeric_circuit_type(); let circuit_key = FriCircuitKey {