Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(prover): Parallelize circuit metadata uploading for BWG #2520

Merged
merged 11 commits into from
Jul 30, 2024
18 changes: 18 additions & 0 deletions core/lib/config/src/configs/fri_witness_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ pub struct FriWitnessGeneratorConfig {
pub shall_save_to_public_bucket: bool,

pub prometheus_listener_port: Option<u16>,

/// This value corresponds to the maximum number of circuits kept in memory at any given time for a BWG.
/// Acts as a throttling mechanism for circuits; the trade-off here is speed vs memory usage.
/// With more circuits in flight, harness does not need to wait for BWG runner to process them.
/// But every single circuit in flight eats memory (up to 50MB).
/// WARNING: Do NOT change this value unless you're absolutely sure you know what you're doing.
/// It affects the performance and resource usage of BWGs.
#[serde(default = "FriWitnessGeneratorConfig::default_max_circuits_in_flight")]
pub max_circuits_in_flight: usize,
}

#[derive(Debug)]
Expand Down Expand Up @@ -87,4 +96,13 @@ impl FriWitnessGeneratorConfig {
pub fn last_l1_batch_to_process(&self) -> u32 {
self.last_l1_batch_to_process.unwrap_or(u32::MAX)
}

/// 500 was picked as a mid-ground between allowing enough circuits in flight to speed up circuit generation,
/// whilst keeping memory as low as possible. At the moment, max size of a circuit is ~50MB.
/// This number is important when there are issues with saving circuits (network issues, service unavailability, etc.)
/// Maximum theoretic extra memory consumed is up to 25GB (50MB * 500 circuits), but in reality, worse case scenarios are closer to 5GB (the average space distribution).
/// During normal operations (> P95), this will incur an overhead of ~100MB.
const fn default_max_circuits_in_flight() -> usize {
500
}
}
1 change: 1 addition & 0 deletions core/lib/config/src/testonly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ impl Distribution<configs::FriWitnessGeneratorConfig> for EncodeDist {
last_l1_batch_to_process: self.sample(rng),
shall_save_to_public_bucket: self.sample(rng),
prometheus_listener_port: self.sample(rng),
max_circuits_in_flight: self.sample(rng),
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions core/lib/env_config/src/fri_witness_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ mod tests {
last_l1_batch_to_process: None,
shall_save_to_public_bucket: true,
prometheus_listener_port: Some(3333u16),
max_circuits_in_flight: 500,
}
}

Expand All @@ -43,6 +44,7 @@ mod tests {
FRI_WITNESS_MAX_ATTEMPTS=4
FRI_WITNESS_SHALL_SAVE_TO_PUBLIC_BUCKET=true
FRI_WITNESS_PROMETHEUS_LISTENER_PORT=3333
FRI_WITNESS_MAX_CIRCUITS_IN_FLIGHT=500
"#;
lock.set_env(config);

Expand Down
1 change: 1 addition & 0 deletions core/lib/protobuf_config/src/proto/config/prover.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ message WitnessGenerator {
optional uint32 scheduler_generation_timeout_in_secs = 11; // optional;
optional uint32 recursion_tip_timeout_in_secs = 12; // optional;
optional uint32 prometheus_listener_port = 13; // optional;
optional uint64 max_circuits_in_flight = 14; // optional;
reserved 3, 4, 6;
reserved "dump_arguments_for_blocks", "force_process_block", "blocks_proving_percentage";
}
Expand Down
4 changes: 4 additions & 0 deletions core/lib/protobuf_config/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ impl ProtoRepr for proto::WitnessGenerator {
.map(|x| x.try_into())
.transpose()
.context("prometheus_listener_port")?,
max_circuits_in_flight: required(&self.max_circuits_in_flight)
.and_then(|x| Ok((*x).try_into()?))
.context("max_circuits_in_flight")?,
})
}

Expand All @@ -219,6 +222,7 @@ impl ProtoRepr for proto::WitnessGenerator {
.scheduler_generation_timeout_in_secs
.map(|x| x.into()),
prometheus_listener_port: this.prometheus_listener_port.map(|x| x.into()),
max_circuits_in_flight: Some(this.max_circuits_in_flight as u64),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions etc/env/file_based/general.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ witness_generator:
max_attempts: 10
shall_save_to_public_bucket: true
prometheus_listener_port: 3116
max_circuits_in_flight: 500
witness_vector_generator:
prover_instance_wait_timeout_in_secs: 200
prover_instance_poll_time_in_milli_secs: 250
Expand Down
2 changes: 1 addition & 1 deletion prover/crates/bin/witness_generator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ zksync_prover_fri_utils.workspace = true
zksync_core_leftovers.workspace = true

zkevm_test_harness = { workspace = true }
circuit_definitions = { workspace = true, features = [ "log_tracing" ] }
circuit_definitions = { workspace = true, features = ["log_tracing"] }

anyhow.workspace = true
tracing.workspace = true
Expand Down
153 changes: 119 additions & 34 deletions prover/crates/bin/witness_generator/src/basic_circuits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use circuit_definitions::{
encodings::recursion_request::RecursionQueueSimulator,
zkevm_circuits::fsm_input_output::ClosedFormInputCompactFormWitness,
};
use tokio::sync::Semaphore;
use tracing::Instrument;
use zkevm_test_harness::geometry_config::get_geometry_config;
use zksync_config::configs::FriWitnessGeneratorConfig;
Expand Down Expand Up @@ -107,6 +108,7 @@ impl BasicWitnessGenerator {
object_store: Arc<dyn ObjectStore>,
basic_job: BasicWitnessGeneratorJob,
started_at: Instant,
max_circuits_in_flight: usize,
) -> Option<BasicCircuitArtifacts> {
let BasicWitnessGeneratorJob { block_number, job } = basic_job;

Expand All @@ -116,7 +118,16 @@ impl BasicWitnessGenerator {
block_number.0
);

Some(process_basic_circuits_job(&*object_store, started_at, block_number, job).await)
Some(
process_basic_circuits_job(
object_store,
started_at,
block_number,
job,
max_circuits_in_flight,
)
.await,
)
}
}

Expand Down Expand Up @@ -177,11 +188,14 @@ impl JobProcessor for BasicWitnessGenerator {
started_at: Instant,
) -> tokio::task::JoinHandle<anyhow::Result<Option<BasicCircuitArtifacts>>> {
let object_store = Arc::clone(&self.object_store);
let max_circuits_in_flight = self.config.max_circuits_in_flight;
tokio::spawn(async move {
let block_number = job.block_number;
Ok(Self::process_job_impl(object_store, job, started_at)
.instrument(tracing::info_span!("basic_circuit", %block_number))
.await)
Ok(
Self::process_job_impl(object_store, job, started_at, max_circuits_in_flight)
.instrument(tracing::info_span!("basic_circuit", %block_number))
.await,
)
})
}

Expand Down Expand Up @@ -246,13 +260,14 @@ impl JobProcessor for BasicWitnessGenerator {

#[tracing::instrument(skip_all, fields(l1_batch = %block_number))]
async fn process_basic_circuits_job(
object_store: &dyn ObjectStore,
object_store: Arc<dyn ObjectStore>,
started_at: Instant,
block_number: L1BatchNumber,
job: WitnessInputData,
max_circuits_in_flight: usize,
) -> BasicCircuitArtifacts {
let (circuit_urls, queue_urls, scheduler_witness, aux_output_witness) =
generate_witness(block_number, object_store, job).await;
generate_witness(block_number, object_store, job, max_circuits_in_flight).await;
WITNESS_GENERATOR_METRICS.witness_generation_time[&AggregationRound::BasicCircuits.into()]
.observe(started_at.elapsed());
tracing::info!(
Expand Down Expand Up @@ -350,8 +365,8 @@ async fn save_recursion_queue(
block_number: L1BatchNumber,
circuit_id: u8,
recursion_queue_simulator: RecursionQueueSimulator<GoldilocksField>,
closed_form_inputs: &[ClosedFormInputCompactFormWitness<GoldilocksField>],
object_store: &dyn ObjectStore,
closed_form_inputs: Vec<ClosedFormInputCompactFormWitness<GoldilocksField>>,
object_store: Arc<dyn ObjectStore>,
) -> (u8, String, usize) {
let key = ClosedFormInputKey {
block_number,
Expand Down Expand Up @@ -381,8 +396,9 @@ type Witness = (
#[tracing::instrument(skip_all, fields(l1_batch = %block_number))]
async fn generate_witness(
block_number: L1BatchNumber,
object_store: &dyn ObjectStore,
object_store: Arc<dyn ObjectStore>,
input: WitnessInputData,
max_circuits_in_flight: usize,
) -> Witness {
let bootloader_contents = expand_bootloader_contents(
&input.vm_run_data.initial_heap_content,
Expand All @@ -407,7 +423,10 @@ async fn generate_witness(

let make_circuits_span = tracing::info_span!("make_circuits");
let make_circuits_span_copy = make_circuits_span.clone();
let make_circuits = tokio::task::spawn_blocking(move || {
// Blocking call from harness that does the CPU heavy lifting.
// Provides circuits and recursion queue via callback functions and returns scheduler witnesses.
// Circuits are "streamed" one by one as they're being generated.
let make_circuits_handle = tokio::task::spawn_blocking(move || {
let span = tracing::info_span!(parent: make_circuits_span_copy, "make_circuits_blocking");

let witness_storage = WitnessStorage::new(input.vm_run_data.witness_block_state);
Expand Down Expand Up @@ -446,43 +465,109 @@ async fn generate_witness(
|circuit| {
let parent_span = span.clone();
tracing::info_span!(parent: parent_span, "send_circuit").in_scope(|| {
circuit_sender.blocking_send(circuit).unwrap();
circuit_sender
.blocking_send(circuit)
.expect("failed to send circuit from harness");
});
},
|a, b, c| queue_sender.blocking_send((a as u8, b, c)).unwrap(),
|a, b, c| {
queue_sender
.blocking_send((a as u8, b, c))
.expect("failed to send recursion queue from harness")
},
);
(scheduler_witness, block_witness)
})
.instrument(make_circuits_span);

let mut circuit_urls = vec![];
let mut recursion_urls = vec![];

let mut circuits_present = HashSet::<u8>::new();
let mut save_circuit_handles = vec![];

let save_circuits_span = tracing::info_span!("save_circuits");
let save_circuits = async {
loop {
tokio::select! {
Some(circuit) = circuit_receiver.recv().instrument(tracing::info_span!("wait_for_circuit")) => {
circuits_present.insert(circuit.numeric_circuit_type());
circuit_urls.push(
save_circuit(block_number, circuit, circuit_urls.len(), object_store).await,
);
}
Some((circuit_id, queue, inputs)) = queue_receiver.recv().instrument(tracing::info_span!("wait_for_queue")) => {
let urls = save_recursion_queue(block_number, circuit_id, queue, &inputs, object_store).await;
recursion_urls.push(urls);
}
else => break,
};

// Future which receives circuits and saves them async.
let circuit_receiver_handle = async {
// Ordering determines how we compose the circuit proofs in Leaf Aggregation Round.
// Sequence is used to determine circuit ordering (the sequencing of instructions) .
// If the order is tampered with, proving will fail (as the proof would be computed for a different sequence of instruction).
let mut circuit_sequence = 0;

let semaphore = Arc::new(Semaphore::new(max_circuits_in_flight));

while let Some(circuit) = circuit_receiver
.recv()
.instrument(tracing::info_span!("wait_for_circuit"))
.await
{
let sequence = circuit_sequence;
circuit_sequence += 1;
let object_store = object_store.clone();
let semaphore = semaphore.clone();
let permit = semaphore
.acquire_owned()
.await
.expect("failed to get permit for running save circuit task");
save_circuit_handles.push(tokio::task::spawn(async move {
let (circuit_id, circuit_url) =
save_circuit(block_number, circuit, sequence, object_store).await;
drop(permit);
(circuit_id, circuit_url)
}));
}
}.instrument(save_circuits_span);
}
.instrument(save_circuits_span);

let mut save_queue_handles = vec![];

let save_queues_span = tracing::info_span!("save_queues");

let (witnesses, ()) = tokio::join!(make_circuits, save_circuits,);
// Future which receives recursion queues and saves them async.
// Note that this section needs no semaphore as there's # of circuit ids (16) queues at most.
// All queues combined are < 10MB.
let queue_receiver_handle = async {
while let Some((circuit_id, queue, inputs)) = queue_receiver
.recv()
.instrument(tracing::info_span!("wait_for_queue"))
.await
{
let object_store = object_store.clone();
EmilLuta marked this conversation as resolved.
Show resolved Hide resolved
save_queue_handles.push(tokio::task::spawn(save_recursion_queue(
block_number,
circuit_id,
queue,
inputs,
object_store,
)));
}
}
.instrument(save_queues_span);

let (witnesses, _, _) = tokio::join!(
make_circuits_handle,
circuit_receiver_handle,
queue_receiver_handle
);
let (mut scheduler_witness, block_aux_witness) = witnesses.unwrap();

recursion_urls.retain(|(circuit_id, _, _)| circuits_present.contains(circuit_id));
// Harness returns recursion queues for all circuits, but for proving only the queues that have circuits matter.
// `circuits_present` stores which circuits exist and is used to filter queues in `recursion_urls` later.
let mut circuits_present = HashSet::<u8>::new();

let circuit_urls = futures::future::join_all(save_circuit_handles)
.await
.into_iter()
.map(|result| {
let (circuit_id, circuit_url) = result.expect("failed to save circuit");
circuits_present.insert(circuit_id);
(circuit_id, circuit_url)
})
.collect();

let recursion_urls = futures::future::join_all(save_queue_handles)
.await
.into_iter()
.map(|result| result.expect("failed to save queue"))
.filter(|(circuit_id, _, _)| circuits_present.contains(circuit_id))
.collect();

scheduler_witness.previous_block_meta_hash = input.previous_batch_metadata.meta_hash.0;
scheduler_witness.previous_block_aux_hash = input.previous_batch_metadata.aux_hash.0;
Expand Down
2 changes: 1 addition & 1 deletion prover/crates/bin/witness_generator/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use zksync_core_leftovers::temp_config_store::{load_database_secrets, load_gener
use zksync_env_config::object_store::ProverObjectStoreConfig;
use zksync_object_store::ObjectStoreFactory;
use zksync_prover_dal::{ConnectionPool, Prover, ProverDal};
use zksync_prover_fri_types::PROVER_PROTOCOL_SEMANTIC_VERSION;
use zksync_queued_job_processor::JobProcessor;
use zksync_types::basic_fri_types::AggregationRound;
use zksync_utils::wait_for_tasks::ManagedTasks;
Expand All @@ -35,7 +36,6 @@ mod utils;

#[cfg(not(target_env = "msvc"))]
use jemallocator::Jemalloc;
use zksync_prover_fri_types::PROVER_PROTOCOL_SEMANTIC_VERSION;

#[cfg(not(target_env = "msvc"))]
#[global_allocator]
Expand Down
3 changes: 2 additions & 1 deletion prover/crates/bin/witness_generator/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
collections::HashMap,
io::{BufWriter, Write as _},
sync::Arc,
};

use circuit_definitions::circuit_definitions::{
Expand Down Expand Up @@ -130,7 +131,7 @@ pub async fn save_circuit(
block_number: L1BatchNumber,
circuit: ZkSyncBaseLayerCircuit,
sequence_number: usize,
object_store: &dyn ObjectStore,
object_store: Arc<dyn ObjectStore>,
) -> (u8, String) {
let circuit_id = circuit.numeric_circuit_type();
let circuit_key = FriCircuitKey {
Expand Down
Loading