Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Only collect halo2 metrics with spans during keygen #935

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions extensions/native/compiler/src/constraints/halo2/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ use crate::{
#[derive(Debug, Clone)]
pub struct Halo2ConstraintCompiler<C: Config> {
pub num_public_values: usize,
#[allow(unused_variables)]
pub collect_metrics: bool,
pub phantom: PhantomData<C>,
}

Expand Down Expand Up @@ -71,9 +73,14 @@ impl<C: Config + Debug> Halo2ConstraintCompiler<C> {
pub fn new(num_public_values: usize) -> Self {
Self {
num_public_values,
collect_metrics: false,
phantom: PhantomData,
}
}
pub fn with_collect_metrics(mut self) -> Self {
self.collect_metrics = true;
self
}
// Create halo2-lib constraints from a list of operations in the DSL.
// Assume: C::N = C::F = C::EF is type Fr
pub fn constrain_halo2(&self, halo2_state: &mut Halo2State<C>, operations: TracedVec<DslIr<C>>)
Expand All @@ -95,10 +102,13 @@ impl<C: Config + Debug> Halo2ConstraintCompiler<C> {

let mut vkey_hash = None;
let mut committed_values_digest = None;

#[cfg(feature = "bench-metrics")]
let mut old_stats = stats_snapshot(ctx, range.clone());
for (instruction, backtrace) in operations {
#[cfg(feature = "bench-metrics")]
let old_stats = stats_snapshot(ctx, range.clone());
if self.collect_metrics {
old_stats = stats_snapshot(ctx, range.clone());
}
let res = catch_unwind(AssertUnwindSafe(|| {
match instruction {
DslIr::ImmV(a, b) => {
Expand Down Expand Up @@ -437,7 +447,7 @@ impl<C: Config + Debug> Halo2ConstraintCompiler<C> {
res.unwrap();
}
#[cfg(feature = "bench-metrics")]
{
if self.collect_metrics {
let mut new_stats = stats_snapshot(ctx, range.clone());
new_stats.diff(&old_stats);
new_stats.increment(cell_tracker.get_full_name());
Expand Down
28 changes: 16 additions & 12 deletions extensions/native/recursion/src/halo2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub mod testing_utils;
#[cfg(test)]
mod tests;

use std::{fmt::Debug, fs::File};
use std::fmt::Debug;

use axvm_native_compiler::{
constraints::halo2::compiler::{Halo2ConstraintCompiler, Halo2State},
Expand Down Expand Up @@ -67,6 +67,7 @@ impl Halo2Prover {
builder: BaseCircuitBuilder<Fr>,
dsl_operations: DslOperations<C>,
witness: Witness<C>,
#[allow(unused_variables)] collect_metrics: bool,
) -> BaseCircuitBuilder<Fr> {
let mut state = Halo2State {
builder,
Expand All @@ -75,6 +76,12 @@ impl Halo2Prover {
state.load_witness(witness);

let backend = Halo2ConstraintCompiler::<C>::new(dsl_operations.num_public_values);
#[cfg(feature = "bench-metrics")]
let backend = if collect_metrics {
backend.with_collect_metrics()
} else {
backend
};
backend.constrain_halo2(&mut state, dsl_operations.operations);

state.builder
Expand All @@ -91,7 +98,7 @@ impl Halo2Prover {
witness: Witness<C>,
) -> Vec<Vec<Fr>> {
let builder = Self::builder(CircuitBuilderStage::Mock, k);
let mut builder = Self::populate(builder, dsl_operations, witness);
let mut builder = Self::populate(builder, dsl_operations, witness, true);

let public_instances = builder.instances();
println!("Public instances: {:?}", public_instances);
Expand All @@ -113,7 +120,7 @@ impl Halo2Prover {
witness: Witness<C>,
) -> Halo2ProvingPinning {
let builder = Self::builder(CircuitBuilderStage::Keygen, k);
let mut builder = Self::populate(builder, dsl_operations, witness);
let mut builder = Self::populate(builder, dsl_operations, witness, true);
builder.calculate_params(Some(20));

let params = read_params(k as u32);
Expand Down Expand Up @@ -141,8 +148,8 @@ impl Halo2Prover {
.map(|x| x.len())
.collect_vec();

let file = File::create("halo2_final.json").unwrap();
serde_json::to_writer(file, &break_points).unwrap();
// let file = File::create("halo2_final.json").unwrap();
// serde_json::to_writer(file, &break_points).unwrap();
Halo2ProvingPinning {
pk,
config_params,
Expand All @@ -161,10 +168,13 @@ impl Halo2Prover {
witness: Witness<C>,
) -> Snark {
let k = config_params.k;
let params = read_params(k as u32);
#[cfg(feature = "bench-metrics")]
let start = std::time::Instant::now();
let builder = Self::builder(CircuitBuilderStage::Prover, k)
.use_params(config_params)
.use_break_points(break_points);
let builder = Self::populate(builder, dsl_operations, witness);
let builder = Self::populate(builder, dsl_operations, witness, false);
#[cfg(feature = "bench-metrics")]
{
let stats = builder.statistics();
Expand All @@ -173,12 +183,6 @@ impl Halo2Prover {
let total_cell = total_advices + total_lookups + stats.gate.total_fixed;
metrics::gauge!("halo2_total_cells").set(total_cell as f64);
}

let params = read_params(k as u32);

#[cfg(feature = "bench-metrics")]
let start = std::time::Instant::now();

let snark = gen_snark_shplonk(&params, pk, builder, None::<&str>);

#[cfg(feature = "bench-metrics")]
Expand Down