Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(prover): Refactor WitnessGenerator #2845

Merged
merged 8 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions prover/crates/bin/proof_fri_compressor/src/compressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ impl ProofCompressor {

#[tracing::instrument(skip(proof, _compression_mode))]
pub fn compress_proof(
l1_batch: L1BatchNumber,
proof: ZkSyncRecursionLayerProof,
_compression_mode: u8,
keystore: Keystore,
Expand Down Expand Up @@ -171,16 +170,13 @@ impl JobProcessor for ProofCompressor {

async fn process_job(
&self,
job_id: &L1BatchNumber,
_job_id: &L1BatchNumber,
job: ZkSyncRecursionLayerProof,
_started_at: Instant,
) -> JoinHandle<anyhow::Result<Self::JobArtifacts>> {
let compression_mode = self.compression_mode;
let block_number = *job_id;
let keystore = self.keystore.clone();
tokio::task::spawn_blocking(move || {
Self::compress_proof(block_number, job, compression_mode, keystore)
})
tokio::task::spawn_blocking(move || Self::compress_proof(job, compression_mode, keystore))
}

async fn save_result(
Expand Down
50 changes: 50 additions & 0 deletions prover/crates/bin/witness_generator/src/artifacts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use std::time::Instant;

use async_trait::async_trait;
use zksync_object_store::ObjectStore;
use zksync_prover_dal::{ConnectionPool, Prover};

#[derive(Debug)]
pub(crate) struct AggregationBlobUrls {
pub aggregations_urls: String,
pub circuit_ids_and_urls: Vec<(u8, String)>,
}

#[derive(Debug)]
pub(crate) struct SchedulerBlobUrls {
pub circuit_ids_and_urls: Vec<(u8, String)>,
pub closed_form_inputs_and_urls: Vec<(u8, String, usize)>,
pub scheduler_witness_url: String,
}

pub(crate) enum BlobUrls {
Url(String),
Aggregation(AggregationBlobUrls),
Scheduler(SchedulerBlobUrls),
}

#[async_trait]
pub(crate) trait ArtifactsManager {
type InputMetadata;
type InputArtifacts;
type OutputArtifacts;

async fn get_artifacts(
metadata: &Self::InputMetadata,
object_store: &dyn ObjectStore,
) -> anyhow::Result<Self::InputArtifacts>;

async fn save_artifacts(
job_id: u32,
artifacts: Self::OutputArtifacts,
object_store: &dyn ObjectStore,
) -> BlobUrls;

async fn update_database(
connection_pool: &ConnectionPool<Prover>,
job_id: u32,
started_at: Instant,
blob_urls: BlobUrls,
artifacts: Self::OutputArtifacts,
) -> anyhow::Result<()>;
}
108 changes: 108 additions & 0 deletions prover/crates/bin/witness_generator/src/basic_circuits/artifacts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use std::time::Instant;

use async_trait::async_trait;
use zksync_object_store::ObjectStore;
use zksync_prover_dal::{ConnectionPool, Prover, ProverDal};
use zksync_prover_fri_types::AuxOutputWitnessWrapper;
use zksync_prover_fri_utils::get_recursive_layer_circuit_id_for_base_layer;
use zksync_types::{basic_fri_types::AggregationRound, L1BatchNumber};

use crate::{
artifacts::{ArtifactsManager, BlobUrls},
basic_circuits::{BasicCircuitArtifacts, BasicWitnessGenerator, BasicWitnessGeneratorJob},
utils::SchedulerPartialInputWrapper,
};

#[async_trait]
impl ArtifactsManager for BasicWitnessGenerator {
type InputMetadata = L1BatchNumber;
type InputArtifacts = BasicWitnessGeneratorJob;
type OutputArtifacts = BasicCircuitArtifacts;

async fn get_artifacts(
metadata: &Self::InputMetadata,
object_store: &dyn ObjectStore,
) -> anyhow::Result<Self::InputArtifacts> {
let l1_batch_number = *metadata;
let data = object_store.get(l1_batch_number).await.unwrap();
Ok(BasicWitnessGeneratorJob {
block_number: l1_batch_number,
data,
})
}

async fn save_artifacts(
job_id: u32,
artifacts: Self::OutputArtifacts,
object_store: &dyn ObjectStore,
) -> BlobUrls {
let aux_output_witness_wrapper = AuxOutputWitnessWrapper(artifacts.aux_output_witness);
object_store
.put(L1BatchNumber(job_id), &aux_output_witness_wrapper)
.await
.unwrap();
let wrapper = SchedulerPartialInputWrapper(artifacts.scheduler_witness);
let url = object_store
.put(L1BatchNumber(job_id), &wrapper)
.await
.unwrap();

BlobUrls::Url(url)
}

#[tracing::instrument(skip_all, fields(l1_batch = %job_id))]
async fn update_database(
connection_pool: &ConnectionPool<Prover>,
job_id: u32,
started_at: Instant,
blob_urls: BlobUrls,
_artifacts: Self::OutputArtifacts,
) -> anyhow::Result<()> {
let blob_urls = match blob_urls {
BlobUrls::Scheduler(blobs) => blobs,
_ => unreachable!(),
};

let mut connection = connection_pool
.connection()
.await
.expect("failed to get database connection");
let mut transaction = connection
.start_transaction()
.await
.expect("failed to get database transaction");
let protocol_version_id = transaction
.fri_witness_generator_dal()
.protocol_version_for_l1_batch(L1BatchNumber(job_id))
.await;
transaction
.fri_prover_jobs_dal()
.insert_prover_jobs(
L1BatchNumber(job_id),
blob_urls.circuit_ids_and_urls,
AggregationRound::BasicCircuits,
0,
protocol_version_id,
)
.await;
transaction
.fri_witness_generator_dal()
.create_aggregation_jobs(
L1BatchNumber(job_id),
&blob_urls.closed_form_inputs_and_urls,
&blob_urls.scheduler_witness_url,
get_recursive_layer_circuit_id_for_base_layer,
protocol_version_id,
)
.await;
transaction
.fri_witness_generator_dal()
.mark_witness_job_as_successful(L1BatchNumber(job_id), started_at.elapsed())
.await;
transaction
.commit()
.await
.expect("failed to commit database transaction");
Ok(())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
use std::{sync::Arc, time::Instant};

use anyhow::Context as _;
use tracing::Instrument;
use zksync_prover_dal::ProverDal;
use zksync_prover_fri_types::{get_current_pod_name, AuxOutputWitnessWrapper};
use zksync_queued_job_processor::{async_trait, JobProcessor};
use zksync_types::{basic_fri_types::AggregationRound, L1BatchNumber};

use crate::{
artifacts::{ArtifactsManager, BlobUrls, SchedulerBlobUrls},
basic_circuits::{BasicCircuitArtifacts, BasicWitnessGenerator, BasicWitnessGeneratorJob},
metrics::WITNESS_GENERATOR_METRICS,
};

#[async_trait]
impl JobProcessor for BasicWitnessGenerator {
type Job = BasicWitnessGeneratorJob;
type JobId = L1BatchNumber;
// The artifact is optional to support skipping blocks when sampling is enabled.
type JobArtifacts = Option<BasicCircuitArtifacts>;

const SERVICE_NAME: &'static str = "fri_basic_circuit_witness_generator";

async fn get_next_job(&self) -> anyhow::Result<Option<(Self::JobId, Self::Job)>> {
let mut prover_connection = self.prover_connection_pool.connection().await?;
let last_l1_batch_to_process = self.config.last_l1_batch_to_process();
let pod_name = get_current_pod_name();
match prover_connection
.fri_witness_generator_dal()
.get_next_basic_circuit_witness_job(
last_l1_batch_to_process,
self.protocol_version,
&pod_name,
)
.await
{
Some(block_number) => {
tracing::info!(
"Processing FRI basic witness-gen for block {}",
block_number
);
let started_at = Instant::now();
let job = Self::get_artifacts(&block_number, &*self.object_store).await?;

WITNESS_GENERATOR_METRICS.blob_fetch_time[&AggregationRound::BasicCircuits.into()]
.observe(started_at.elapsed());

Ok(Some((block_number, job)))
}
None => Ok(None),
}
}

async fn save_failure(&self, job_id: L1BatchNumber, _started_at: Instant, error: String) -> () {
self.prover_connection_pool
.connection()
.await
.unwrap()
.fri_witness_generator_dal()
.mark_witness_job_failed(&error, job_id)
.await;
}

#[allow(clippy::async_yields_async)]
async fn process_job(
&self,
_job_id: &Self::JobId,
job: BasicWitnessGeneratorJob,
started_at: Instant,
) -> tokio::task::JoinHandle<anyhow::Result<Option<BasicCircuitArtifacts>>> {
let object_store = Arc::clone(&self.object_store);
let max_circuits_in_flight = self.config.max_circuits_in_flight;
tokio::spawn(async move {
let block_number = job.block_number;
Ok(
Self::process_job_impl(object_store, job, started_at, max_circuits_in_flight)
.instrument(tracing::info_span!("basic_circuit", %block_number))
.await,
)
})
}

#[tracing::instrument(skip_all, fields(l1_batch = %job_id))]
async fn save_result(
&self,
job_id: L1BatchNumber,
started_at: Instant,
optional_artifacts: Option<BasicCircuitArtifacts>,
) -> anyhow::Result<()> {
match optional_artifacts {
None => Ok(()),
Some(artifacts) => {
let blob_started_at = Instant::now();
let circuit_urls = artifacts.circuit_urls.clone();
let queue_urls = artifacts.queue_urls.clone();

let aux_output_witness_wrapper =
AuxOutputWitnessWrapper(artifacts.aux_output_witness.clone());
if self.config.shall_save_to_public_bucket {
self.public_blob_store.as_deref()
.expect("public_object_store shall not be empty while running with shall_save_to_public_bucket config")
.put(job_id, &aux_output_witness_wrapper)
.await
.unwrap();
}

let scheduler_witness_url =
match Self::save_artifacts(job_id.0, artifacts.clone(), &*self.object_store)
.await
{
BlobUrls::Url(url) => url,
_ => unreachable!(),
};

WITNESS_GENERATOR_METRICS.blob_save_time[&AggregationRound::BasicCircuits.into()]
.observe(blob_started_at.elapsed());

Self::update_database(
&self.prover_connection_pool,
job_id.0,
started_at,
BlobUrls::Scheduler(SchedulerBlobUrls {
circuit_ids_and_urls: circuit_urls,
closed_form_inputs_and_urls: queue_urls,
scheduler_witness_url,
}),
artifacts,
)
.await?;
Ok(())
}
}
}

fn max_attempts(&self) -> u32 {
self.config.max_attempts
}

async fn get_job_attempts(&self, job_id: &L1BatchNumber) -> anyhow::Result<u32> {
let mut prover_storage = self
.prover_connection_pool
.connection()
.await
.context("failed to acquire DB connection for BasicWitnessGenerator")?;
prover_storage
.fri_witness_generator_dal()
.get_basic_circuit_witness_job_attempts(*job_id)
.await
.map(|attempts| attempts.unwrap_or(0))
.context("failed to get job attempts for BasicWitnessGenerator")
}
}
Loading
Loading