Skip to content

Commit

Permalink
Only collect halo2 metrics with spans during keygen (#935)
Browse files Browse the repository at this point in the history
  • Loading branch information
nyunyunyunyu authored Dec 4, 2024
1 parent 138157b commit d323cb8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
16 changes: 13 additions & 3 deletions extensions/native/compiler/src/constraints/halo2/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ use crate::{
#[derive(Debug, Clone)]
pub struct Halo2ConstraintCompiler<C: Config> {
pub num_public_values: usize,
#[allow(unused_variables)]
pub collect_metrics: bool,
pub phantom: PhantomData<C>,
}

Expand Down Expand Up @@ -68,9 +70,14 @@ impl<C: Config + Debug> Halo2ConstraintCompiler<C> {
pub fn new(num_public_values: usize) -> Self {
Self {
num_public_values,
collect_metrics: false,
phantom: PhantomData,
}
}
pub fn with_collect_metrics(mut self) -> Self {
self.collect_metrics = true;
self
}
// Create halo2-lib constraints from a list of operations in the DSL.
// Assume: C::N = C::F = C::EF is type Fr
pub fn constrain_halo2(&self, halo2_state: &mut Halo2State<C>, operations: TracedVec<DslIr<C>>)
Expand All @@ -92,10 +99,13 @@ impl<C: Config + Debug> Halo2ConstraintCompiler<C> {

let mut vkey_hash = None;
let mut committed_values_digest = None;

#[cfg(feature = "bench-metrics")]
let mut old_stats = stats_snapshot(ctx, range.clone());
for (instruction, backtrace) in operations {
#[cfg(feature = "bench-metrics")]
let old_stats = stats_snapshot(ctx, range.clone());
if self.collect_metrics {
old_stats = stats_snapshot(ctx, range.clone());
}
let res = catch_unwind(AssertUnwindSafe(|| {
match instruction {
DslIr::ImmV(a, b) => {
Expand Down Expand Up @@ -420,7 +430,7 @@ impl<C: Config + Debug> Halo2ConstraintCompiler<C> {
res.unwrap();
}
#[cfg(feature = "bench-metrics")]
{
if self.collect_metrics {
let mut new_stats = stats_snapshot(ctx, range.clone());
new_stats.diff(&old_stats);
new_stats.increment(cell_tracker.get_full_name());
Expand Down
28 changes: 16 additions & 12 deletions extensions/native/recursion/src/halo2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub mod testing_utils;
#[cfg(test)]
mod tests;

use std::{fmt::Debug, fs::File};
use std::fmt::Debug;

use axvm_native_compiler::{
constraints::halo2::compiler::{Halo2ConstraintCompiler, Halo2State},
Expand Down Expand Up @@ -67,6 +67,7 @@ impl Halo2Prover {
builder: BaseCircuitBuilder<Fr>,
dsl_operations: DslOperations<C>,
witness: Witness<C>,
#[allow(unused_variables)] collect_metrics: bool,
) -> BaseCircuitBuilder<Fr> {
let mut state = Halo2State {
builder,
Expand All @@ -75,6 +76,12 @@ impl Halo2Prover {
state.load_witness(witness);

let backend = Halo2ConstraintCompiler::<C>::new(dsl_operations.num_public_values);
#[cfg(feature = "bench-metrics")]
let backend = if collect_metrics {
backend.with_collect_metrics()
} else {
backend
};
backend.constrain_halo2(&mut state, dsl_operations.operations);

state.builder
Expand All @@ -91,7 +98,7 @@ impl Halo2Prover {
witness: Witness<C>,
) -> Vec<Vec<Fr>> {
let builder = Self::builder(CircuitBuilderStage::Mock, k);
let mut builder = Self::populate(builder, dsl_operations, witness);
let mut builder = Self::populate(builder, dsl_operations, witness, true);

let public_instances = builder.instances();
println!("Public instances: {:?}", public_instances);
Expand All @@ -113,7 +120,7 @@ impl Halo2Prover {
witness: Witness<C>,
) -> Halo2ProvingPinning {
let builder = Self::builder(CircuitBuilderStage::Keygen, k);
let mut builder = Self::populate(builder, dsl_operations, witness);
let mut builder = Self::populate(builder, dsl_operations, witness, true);
builder.calculate_params(Some(20));

let params = read_params(k as u32);
Expand Down Expand Up @@ -141,8 +148,8 @@ impl Halo2Prover {
.map(|x| x.len())
.collect_vec();

let file = File::create("halo2_final.json").unwrap();
serde_json::to_writer(file, &break_points).unwrap();
// let file = File::create("halo2_final.json").unwrap();
// serde_json::to_writer(file, &break_points).unwrap();
Halo2ProvingPinning {
pk,
config_params,
Expand All @@ -161,10 +168,13 @@ impl Halo2Prover {
witness: Witness<C>,
) -> Snark {
let k = config_params.k;
let params = read_params(k as u32);
#[cfg(feature = "bench-metrics")]
let start = std::time::Instant::now();
let builder = Self::builder(CircuitBuilderStage::Prover, k)
.use_params(config_params)
.use_break_points(break_points);
let builder = Self::populate(builder, dsl_operations, witness);
let builder = Self::populate(builder, dsl_operations, witness, false);
#[cfg(feature = "bench-metrics")]
{
let stats = builder.statistics();
Expand All @@ -173,12 +183,6 @@ impl Halo2Prover {
let total_cell = total_advices + total_lookups + stats.gate.total_fixed;
metrics::gauge!("halo2_total_cells").set(total_cell as f64);
}

let params = read_params(k as u32);

#[cfg(feature = "bench-metrics")]
let start = std::time::Instant::now();

let snark = gen_snark_shplonk(&params, pk, builder, None::<&str>);

#[cfg(feature = "bench-metrics")]
Expand Down

0 comments on commit d323cb8

Please sign in to comment.