diff --git a/extensions/native/compiler/src/constraints/halo2/compiler.rs b/extensions/native/compiler/src/constraints/halo2/compiler.rs index d6dc144889..b0b82a1d65 100644 --- a/extensions/native/compiler/src/constraints/halo2/compiler.rs +++ b/extensions/native/compiler/src/constraints/halo2/compiler.rs @@ -36,6 +36,8 @@ use crate::{ #[derive(Debug, Clone)] pub struct Halo2ConstraintCompiler { pub num_public_values: usize, + #[allow(unused_variables)] + pub collect_metrics: bool, pub phantom: PhantomData, } @@ -68,9 +70,14 @@ impl Halo2ConstraintCompiler { pub fn new(num_public_values: usize) -> Self { Self { num_public_values, + collect_metrics: false, phantom: PhantomData, } } + pub fn with_collect_metrics(mut self) -> Self { + self.collect_metrics = true; + self + } // Create halo2-lib constraints from a list of operations in the DSL. // Assume: C::N = C::F = C::EF is type Fr pub fn constrain_halo2(&self, halo2_state: &mut Halo2State, operations: TracedVec>) @@ -92,10 +99,13 @@ impl Halo2ConstraintCompiler { let mut vkey_hash = None; let mut committed_values_digest = None; - + #[cfg(feature = "bench-metrics")] + let mut old_stats = stats_snapshot(ctx, range.clone()); for (instruction, backtrace) in operations { #[cfg(feature = "bench-metrics")] - let old_stats = stats_snapshot(ctx, range.clone()); + if self.collect_metrics { + old_stats = stats_snapshot(ctx, range.clone()); + } let res = catch_unwind(AssertUnwindSafe(|| { match instruction { DslIr::ImmV(a, b) => { @@ -420,7 +430,7 @@ impl Halo2ConstraintCompiler { res.unwrap(); } #[cfg(feature = "bench-metrics")] - { + if self.collect_metrics { let mut new_stats = stats_snapshot(ctx, range.clone()); new_stats.diff(&old_stats); new_stats.increment(cell_tracker.get_full_name()); diff --git a/extensions/native/recursion/src/halo2/mod.rs b/extensions/native/recursion/src/halo2/mod.rs index 62c95e18e5..91b16c51cd 100644 --- a/extensions/native/recursion/src/halo2/mod.rs +++ b/extensions/native/recursion/src/halo2/mod.rs @@ -5,7 +5,7 @@ pub mod testing_utils; #[cfg(test)] mod tests; -use std::{fmt::Debug, fs::File}; +use std::fmt::Debug; use axvm_native_compiler::{ constraints::halo2::compiler::{Halo2ConstraintCompiler, Halo2State}, @@ -67,6 +67,7 @@ impl Halo2Prover { builder: BaseCircuitBuilder, dsl_operations: DslOperations, witness: Witness, + #[allow(unused_variables)] collect_metrics: bool, ) -> BaseCircuitBuilder { let mut state = Halo2State { builder, @@ -75,6 +76,12 @@ impl Halo2Prover { state.load_witness(witness); let backend = Halo2ConstraintCompiler::::new(dsl_operations.num_public_values); + #[cfg(feature = "bench-metrics")] + let backend = if collect_metrics { + backend.with_collect_metrics() + } else { + backend + }; backend.constrain_halo2(&mut state, dsl_operations.operations); state.builder @@ -91,7 +98,7 @@ impl Halo2Prover { witness: Witness, ) -> Vec> { let builder = Self::builder(CircuitBuilderStage::Mock, k); - let mut builder = Self::populate(builder, dsl_operations, witness); + let mut builder = Self::populate(builder, dsl_operations, witness, true); let public_instances = builder.instances(); println!("Public instances: {:?}", public_instances); @@ -113,7 +120,7 @@ impl Halo2Prover { witness: Witness, ) -> Halo2ProvingPinning { let builder = Self::builder(CircuitBuilderStage::Keygen, k); - let mut builder = Self::populate(builder, dsl_operations, witness); + let mut builder = Self::populate(builder, dsl_operations, witness, true); builder.calculate_params(Some(20)); let params = read_params(k as u32); @@ -141,8 +148,8 @@ impl Halo2Prover { .map(|x| x.len()) .collect_vec(); - let file = File::create("halo2_final.json").unwrap(); - serde_json::to_writer(file, &break_points).unwrap(); + // let file = File::create("halo2_final.json").unwrap(); + // serde_json::to_writer(file, &break_points).unwrap(); Halo2ProvingPinning { pk, config_params, @@ -161,10 +168,13 @@ impl Halo2Prover { witness: Witness, ) -> Snark { let k = config_params.k; + let params = read_params(k as u32); + #[cfg(feature = "bench-metrics")] + let start = std::time::Instant::now(); let builder = Self::builder(CircuitBuilderStage::Prover, k) .use_params(config_params) .use_break_points(break_points); - let builder = Self::populate(builder, dsl_operations, witness); + let builder = Self::populate(builder, dsl_operations, witness, false); #[cfg(feature = "bench-metrics")] { let stats = builder.statistics(); @@ -173,12 +183,6 @@ impl Halo2Prover { let total_cell = total_advices + total_lookups + stats.gate.total_fixed; metrics::gauge!("halo2_total_cells").set(total_cell as f64); } - - let params = read_params(k as u32); - - #[cfg(feature = "bench-metrics")] - let start = std::time::Instant::now(); - let snark = gen_snark_shplonk(¶ms, pk, builder, None::<&str>); #[cfg(feature = "bench-metrics")]