Skip to content

Commit

Permalink
fix(prover): Parallelize circuit metadata uploading for BWG (#2520)
Browse files Browse the repository at this point in the history
Basic Witness Generation (BWG) creates circuits for base layer to be
proven. BWG runs out-of-circuit VM and provides circuits for the
instructions in the execution (note that those are ordered, messing with
the order would break proving).

Additionally, harness will do heavy computation -> provide circuits ->
heavy computation -> provide circuits. As such, providing circuits is
rather fast. Once a circuit is produced, it's sent to BWG, which in turn
sends it to GCS and saves it to DB. The queue between BWG runner and
harness is of size 1. It was picked in this way to ensure that harness
won't overwhelm bwg runner with circuits (if the queue was big, the
amount of RAM used would be size of queue * size of circuit). The
problem is that the harness will be throttled on circuit submission. It
needs to blockingly wait for bwg runner to upload data (GCS + PG).

There are 2 alternatives here:
- make the buffer bigger -- this would work, but it would start
consuming more memory. In theory, if you can match the size of the
buffer to fit in between CPU cycles, harness wouldn't need to wait for
upload and bwg runner would catch up by the time harness has a new set
of circuits. The problem with this approach is that it relies on
deterministic upload time (nothing can be deterministic in distributed
systems) and fixed circuit size (this is not true, because batches can
have varying number of circuits, of varying types).
- upload everything async -- this work wonders, but there will be
multiple circuits in flight being uploaded at any time. This will also
increase RAM. In my testing, I've seen increases smaller than 100MB
(acceptable tradeoff). The edge cases are when GCS is down or not
working as intended (for instance, refusing connection). In such
scenarios the async tasks will grow, adding to a lot more RAM usage (up
to 50GB). Whilst this is a real problem, I consider it more of an edge
case. Furthermore, there are a few ways to make GCS more reliable.

The second option was picked here. MPSC channel (the buffer) is still of
size 1 to have a single point of failure when RAM goes up (we don't want
to figure out -- do we have too many things in the buffer, or do we have
problems with GCS? -- in the current implementation, it will always be
GCS). Furthermore, MPSC(1) doesn't cause much overhead.

The implementation has been tested locally, using mainnet and testnet
batches. Whilst the number are super encouraging (00:03:58 time, with
~40GB RAM), the CPUs (and their speed) is slightly lower in GCP.
  • Loading branch information
EmilLuta authored Jul 30, 2024
1 parent 0dda805 commit f49720f
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 37 deletions.
18 changes: 18 additions & 0 deletions core/lib/config/src/configs/fri_witness_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ pub struct FriWitnessGeneratorConfig {
pub shall_save_to_public_bucket: bool,

pub prometheus_listener_port: Option<u16>,

/// This value corresponds to the maximum number of circuits kept in memory at any given time for a BWG.
/// Acts as a throttling mechanism for circuits; the trade-off here is speed vs memory usage.
/// With more circuits in flight, harness does not need to wait for BWG runner to process them.
/// But every single circuit in flight eats memory (up to 50MB).
/// WARNING: Do NOT change this value unless you're absolutely sure you know what you're doing.
/// It affects the performance and resource usage of BWGs.
#[serde(default = "FriWitnessGeneratorConfig::default_max_circuits_in_flight")]
pub max_circuits_in_flight: usize,
}

#[derive(Debug)]
Expand Down Expand Up @@ -87,4 +96,13 @@ impl FriWitnessGeneratorConfig {
pub fn last_l1_batch_to_process(&self) -> u32 {
self.last_l1_batch_to_process.unwrap_or(u32::MAX)
}

/// 500 was picked as a mid-ground between allowing enough circuits in flight to speed up circuit generation,
/// whilst keeping memory as low as possible. At the moment, max size of a circuit is ~50MB.
/// This number is important when there are issues with saving circuits (network issues, service unavailability, etc.)
/// Maximum theoretic extra memory consumed is up to 25GB (50MB * 500 circuits), but in reality, worse case scenarios are closer to 5GB (the average space distribution).
/// During normal operations (> P95), this will incur an overhead of ~100MB.
const fn default_max_circuits_in_flight() -> usize {
500
}
}
1 change: 1 addition & 0 deletions core/lib/config/src/testonly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ impl Distribution<configs::FriWitnessGeneratorConfig> for EncodeDist {
last_l1_batch_to_process: self.sample(rng),
shall_save_to_public_bucket: self.sample(rng),
prometheus_listener_port: self.sample(rng),
max_circuits_in_flight: self.sample(rng),
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions core/lib/env_config/src/fri_witness_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ mod tests {
last_l1_batch_to_process: None,
shall_save_to_public_bucket: true,
prometheus_listener_port: Some(3333u16),
max_circuits_in_flight: 500,
}
}

Expand All @@ -43,6 +44,7 @@ mod tests {
FRI_WITNESS_MAX_ATTEMPTS=4
FRI_WITNESS_SHALL_SAVE_TO_PUBLIC_BUCKET=true
FRI_WITNESS_PROMETHEUS_LISTENER_PORT=3333
FRI_WITNESS_MAX_CIRCUITS_IN_FLIGHT=500
"#;
lock.set_env(config);

Expand Down
1 change: 1 addition & 0 deletions core/lib/protobuf_config/src/proto/config/prover.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ message WitnessGenerator {
optional uint32 scheduler_generation_timeout_in_secs = 11; // optional;
optional uint32 recursion_tip_timeout_in_secs = 12; // optional;
optional uint32 prometheus_listener_port = 13; // optional;
optional uint64 max_circuits_in_flight = 14; // optional;
reserved 3, 4, 6;
reserved "dump_arguments_for_blocks", "force_process_block", "blocks_proving_percentage";
}
Expand Down
4 changes: 4 additions & 0 deletions core/lib/protobuf_config/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ impl ProtoRepr for proto::WitnessGenerator {
.map(|x| x.try_into())
.transpose()
.context("prometheus_listener_port")?,
max_circuits_in_flight: required(&self.max_circuits_in_flight)
.and_then(|x| Ok((*x).try_into()?))
.context("max_circuits_in_flight")?,
})
}

Expand All @@ -219,6 +222,7 @@ impl ProtoRepr for proto::WitnessGenerator {
.scheduler_generation_timeout_in_secs
.map(|x| x.into()),
prometheus_listener_port: this.prometheus_listener_port.map(|x| x.into()),
max_circuits_in_flight: Some(this.max_circuits_in_flight as u64),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions etc/env/file_based/general.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ witness_generator:
max_attempts: 10
shall_save_to_public_bucket: true
prometheus_listener_port: 3116
max_circuits_in_flight: 500
witness_vector_generator:
prover_instance_wait_timeout_in_secs: 200
prover_instance_poll_time_in_milli_secs: 250
Expand Down
2 changes: 1 addition & 1 deletion prover/crates/bin/witness_generator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ zksync_prover_fri_utils.workspace = true
zksync_core_leftovers.workspace = true

zkevm_test_harness = { workspace = true }
circuit_definitions = { workspace = true, features = [ "log_tracing" ] }
circuit_definitions = { workspace = true, features = ["log_tracing"] }

anyhow.workspace = true
tracing.workspace = true
Expand Down
153 changes: 119 additions & 34 deletions prover/crates/bin/witness_generator/src/basic_circuits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use circuit_definitions::{
encodings::recursion_request::RecursionQueueSimulator,
zkevm_circuits::fsm_input_output::ClosedFormInputCompactFormWitness,
};
use tokio::sync::Semaphore;
use tracing::Instrument;
use zkevm_test_harness::geometry_config::get_geometry_config;
use zksync_config::configs::FriWitnessGeneratorConfig;
Expand Down Expand Up @@ -107,6 +108,7 @@ impl BasicWitnessGenerator {
object_store: Arc<dyn ObjectStore>,
basic_job: BasicWitnessGeneratorJob,
started_at: Instant,
max_circuits_in_flight: usize,
) -> Option<BasicCircuitArtifacts> {
let BasicWitnessGeneratorJob { block_number, job } = basic_job;

Expand All @@ -116,7 +118,16 @@ impl BasicWitnessGenerator {
block_number.0
);

Some(process_basic_circuits_job(&*object_store, started_at, block_number, job).await)
Some(
process_basic_circuits_job(
object_store,
started_at,
block_number,
job,
max_circuits_in_flight,
)
.await,
)
}
}

Expand Down Expand Up @@ -177,11 +188,14 @@ impl JobProcessor for BasicWitnessGenerator {
started_at: Instant,
) -> tokio::task::JoinHandle<anyhow::Result<Option<BasicCircuitArtifacts>>> {
let object_store = Arc::clone(&self.object_store);
let max_circuits_in_flight = self.config.max_circuits_in_flight;
tokio::spawn(async move {
let block_number = job.block_number;
Ok(Self::process_job_impl(object_store, job, started_at)
.instrument(tracing::info_span!("basic_circuit", %block_number))
.await)
Ok(
Self::process_job_impl(object_store, job, started_at, max_circuits_in_flight)
.instrument(tracing::info_span!("basic_circuit", %block_number))
.await,
)
})
}

Expand Down Expand Up @@ -246,13 +260,14 @@ impl JobProcessor for BasicWitnessGenerator {

#[tracing::instrument(skip_all, fields(l1_batch = %block_number))]
async fn process_basic_circuits_job(
object_store: &dyn ObjectStore,
object_store: Arc<dyn ObjectStore>,
started_at: Instant,
block_number: L1BatchNumber,
job: WitnessInputData,
max_circuits_in_flight: usize,
) -> BasicCircuitArtifacts {
let (circuit_urls, queue_urls, scheduler_witness, aux_output_witness) =
generate_witness(block_number, object_store, job).await;
generate_witness(block_number, object_store, job, max_circuits_in_flight).await;
WITNESS_GENERATOR_METRICS.witness_generation_time[&AggregationRound::BasicCircuits.into()]
.observe(started_at.elapsed());
tracing::info!(
Expand Down Expand Up @@ -350,8 +365,8 @@ async fn save_recursion_queue(
block_number: L1BatchNumber,
circuit_id: u8,
recursion_queue_simulator: RecursionQueueSimulator<GoldilocksField>,
closed_form_inputs: &[ClosedFormInputCompactFormWitness<GoldilocksField>],
object_store: &dyn ObjectStore,
closed_form_inputs: Vec<ClosedFormInputCompactFormWitness<GoldilocksField>>,
object_store: Arc<dyn ObjectStore>,
) -> (u8, String, usize) {
let key = ClosedFormInputKey {
block_number,
Expand Down Expand Up @@ -381,8 +396,9 @@ type Witness = (
#[tracing::instrument(skip_all, fields(l1_batch = %block_number))]
async fn generate_witness(
block_number: L1BatchNumber,
object_store: &dyn ObjectStore,
object_store: Arc<dyn ObjectStore>,
input: WitnessInputData,
max_circuits_in_flight: usize,
) -> Witness {
let bootloader_contents = expand_bootloader_contents(
&input.vm_run_data.initial_heap_content,
Expand All @@ -407,7 +423,10 @@ async fn generate_witness(

let make_circuits_span = tracing::info_span!("make_circuits");
let make_circuits_span_copy = make_circuits_span.clone();
let make_circuits = tokio::task::spawn_blocking(move || {
// Blocking call from harness that does the CPU heavy lifting.
// Provides circuits and recursion queue via callback functions and returns scheduler witnesses.
// Circuits are "streamed" one by one as they're being generated.
let make_circuits_handle = tokio::task::spawn_blocking(move || {
let span = tracing::info_span!(parent: make_circuits_span_copy, "make_circuits_blocking");

let witness_storage = WitnessStorage::new(input.vm_run_data.witness_block_state);
Expand Down Expand Up @@ -446,43 +465,109 @@ async fn generate_witness(
|circuit| {
let parent_span = span.clone();
tracing::info_span!(parent: parent_span, "send_circuit").in_scope(|| {
circuit_sender.blocking_send(circuit).unwrap();
circuit_sender
.blocking_send(circuit)
.expect("failed to send circuit from harness");
});
},
|a, b, c| queue_sender.blocking_send((a as u8, b, c)).unwrap(),
|a, b, c| {
queue_sender
.blocking_send((a as u8, b, c))
.expect("failed to send recursion queue from harness")
},
);
(scheduler_witness, block_witness)
})
.instrument(make_circuits_span);

let mut circuit_urls = vec![];
let mut recursion_urls = vec![];

let mut circuits_present = HashSet::<u8>::new();
let mut save_circuit_handles = vec![];

let save_circuits_span = tracing::info_span!("save_circuits");
let save_circuits = async {
loop {
tokio::select! {
Some(circuit) = circuit_receiver.recv().instrument(tracing::info_span!("wait_for_circuit")) => {
circuits_present.insert(circuit.numeric_circuit_type());
circuit_urls.push(
save_circuit(block_number, circuit, circuit_urls.len(), object_store).await,
);
}
Some((circuit_id, queue, inputs)) = queue_receiver.recv().instrument(tracing::info_span!("wait_for_queue")) => {
let urls = save_recursion_queue(block_number, circuit_id, queue, &inputs, object_store).await;
recursion_urls.push(urls);
}
else => break,
};

// Future which receives circuits and saves them async.
let circuit_receiver_handle = async {
// Ordering determines how we compose the circuit proofs in Leaf Aggregation Round.
// Sequence is used to determine circuit ordering (the sequencing of instructions) .
// If the order is tampered with, proving will fail (as the proof would be computed for a different sequence of instruction).
let mut circuit_sequence = 0;

let semaphore = Arc::new(Semaphore::new(max_circuits_in_flight));

while let Some(circuit) = circuit_receiver
.recv()
.instrument(tracing::info_span!("wait_for_circuit"))
.await
{
let sequence = circuit_sequence;
circuit_sequence += 1;
let object_store = object_store.clone();
let semaphore = semaphore.clone();
let permit = semaphore
.acquire_owned()
.await
.expect("failed to get permit for running save circuit task");
save_circuit_handles.push(tokio::task::spawn(async move {
let (circuit_id, circuit_url) =
save_circuit(block_number, circuit, sequence, object_store).await;
drop(permit);
(circuit_id, circuit_url)
}));
}
}.instrument(save_circuits_span);
}
.instrument(save_circuits_span);

let mut save_queue_handles = vec![];

let save_queues_span = tracing::info_span!("save_queues");

let (witnesses, ()) = tokio::join!(make_circuits, save_circuits,);
// Future which receives recursion queues and saves them async.
// Note that this section needs no semaphore as there's # of circuit ids (16) queues at most.
// All queues combined are < 10MB.
let queue_receiver_handle = async {
while let Some((circuit_id, queue, inputs)) = queue_receiver
.recv()
.instrument(tracing::info_span!("wait_for_queue"))
.await
{
let object_store = object_store.clone();
save_queue_handles.push(tokio::task::spawn(save_recursion_queue(
block_number,
circuit_id,
queue,
inputs,
object_store,
)));
}
}
.instrument(save_queues_span);

let (witnesses, _, _) = tokio::join!(
make_circuits_handle,
circuit_receiver_handle,
queue_receiver_handle
);
let (mut scheduler_witness, block_aux_witness) = witnesses.unwrap();

recursion_urls.retain(|(circuit_id, _, _)| circuits_present.contains(circuit_id));
// Harness returns recursion queues for all circuits, but for proving only the queues that have circuits matter.
// `circuits_present` stores which circuits exist and is used to filter queues in `recursion_urls` later.
let mut circuits_present = HashSet::<u8>::new();

let circuit_urls = futures::future::join_all(save_circuit_handles)
.await
.into_iter()
.map(|result| {
let (circuit_id, circuit_url) = result.expect("failed to save circuit");
circuits_present.insert(circuit_id);
(circuit_id, circuit_url)
})
.collect();

let recursion_urls = futures::future::join_all(save_queue_handles)
.await
.into_iter()
.map(|result| result.expect("failed to save queue"))
.filter(|(circuit_id, _, _)| circuits_present.contains(circuit_id))
.collect();

scheduler_witness.previous_block_meta_hash = input.previous_batch_metadata.meta_hash.0;
scheduler_witness.previous_block_aux_hash = input.previous_batch_metadata.aux_hash.0;
Expand Down
2 changes: 1 addition & 1 deletion prover/crates/bin/witness_generator/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use zksync_core_leftovers::temp_config_store::{load_database_secrets, load_gener
use zksync_env_config::object_store::ProverObjectStoreConfig;
use zksync_object_store::ObjectStoreFactory;
use zksync_prover_dal::{ConnectionPool, Prover, ProverDal};
use zksync_prover_fri_types::PROVER_PROTOCOL_SEMANTIC_VERSION;
use zksync_queued_job_processor::JobProcessor;
use zksync_types::basic_fri_types::AggregationRound;
use zksync_utils::wait_for_tasks::ManagedTasks;
Expand All @@ -35,7 +36,6 @@ mod utils;

#[cfg(not(target_env = "msvc"))]
use jemallocator::Jemalloc;
use zksync_prover_fri_types::PROVER_PROTOCOL_SEMANTIC_VERSION;

#[cfg(not(target_env = "msvc"))]
#[global_allocator]
Expand Down
3 changes: 2 additions & 1 deletion prover/crates/bin/witness_generator/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
collections::HashMap,
io::{BufWriter, Write as _},
sync::Arc,
};

use circuit_definitions::circuit_definitions::{
Expand Down Expand Up @@ -130,7 +131,7 @@ pub async fn save_circuit(
block_number: L1BatchNumber,
circuit: ZkSyncBaseLayerCircuit,
sequence_number: usize,
object_store: &dyn ObjectStore,
object_store: Arc<dyn ObjectStore>,
) -> (u8, String) {
let circuit_id = circuit.numeric_circuit_type();
let circuit_key = FriCircuitKey {
Expand Down

0 comments on commit f49720f

Please sign in to comment.