Skip to content

Commit

Permalink
feat: E2E benchmark halo2 generate flamegraphs (#1203)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenh-axiom-xyz authored Jan 13, 2025
1 parent 2b28364 commit 0ea118f
Show file tree
Hide file tree
Showing 12 changed files with 52 additions and 16 deletions.
2 changes: 1 addition & 1 deletion benchmarks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pprof = { version = "0.13", features = [

[features]
default = ["parallel", "mimalloc", "bench-metrics"]
bench-metrics = ["openvm-native-recursion/bench-metrics"]
bench-metrics = ["openvm-native-recursion/bench-metrics", "openvm-native-compiler/bench-metrics"]
profiling = ["openvm-sdk/profiling"]
aggregation = []
static-verifier = ["openvm-native-recursion/static-verifier"]
Expand Down
1 change: 1 addition & 0 deletions benchmarks/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ impl BenchmarkCli {
halo2_config: Halo2Config {
verifier_k: self.halo2_outer_k.unwrap_or(24),
wrapper_k: self.halo2_wrapper_k,
profiling: self.profiling,
},
}
}
Expand Down
30 changes: 22 additions & 8 deletions ci/scripts/metric_unify/flamegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,25 @@

from utils import FLAMEGRAPHS_DIR, get_git_root

def get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name):
def get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics=None):
"""
Filters a metrics_dict obtained from json for entries that look like:
[ { labels: [["key1", "span1;span2"], ["key2", "span3"]], "metric": metric_name, "value": 2 } ]
It will find entries that have all of stack_keys as present in the labels and then concatenate the corresponding values into a single flat stack entry and then add the value at the end.
It will write a file with one line each for flamegraph.pl or inferno-flamegraph to consume.
If sum_metrics is not None, instead of searching for metric_name, it will sum the values of the metrics in sum_metrics.
"""
lines = []
stack_sums = {}
non_zero = False

# Process counters
for counter in metrics_dict.get('counter', []):
if counter['metric'] != metric_name:
if (sum_metrics is not None and counter['metric'] not in sum_metrics) or \
(sum_metrics is None and counter['metric'] != metric_name):
continue

# list of pairs -> dict
labels = dict(counter['labels'])
filter = False
Expand All @@ -41,15 +46,21 @@ def get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name):

stack = ';'.join(stack_values)
value = int(counter['value'])
stack_sums[stack] = stack_sums.get(stack, 0) + value

if value != 0:
non_zero = True

lines.append(f"{stack} {value}")
lines = [f"{stack} {value}" for stack, value in stack_sums.items() if value != 0]

# Currently cycle tracker does not use gauge
return lines
return lines if non_zero else []


def create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, reverse=False):
lines = get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name)
def create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics=None, reverse=False):
lines = get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics)
if not lines:
return

suffixes = [key for key in stack_keys if key != "cycle_tracker_span"]

Expand All @@ -74,7 +85,7 @@ def create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name
print(f"Created flamegraph at {flamegraph_path}")


def create_flamegraphs(metrics_file, group_by, stack_keys, metric_name, reverse=False):
def create_flamegraphs(metrics_file, group_by, stack_keys, metric_name, sum_metrics=None, reverse=False):
fname_prefix = os.path.splitext(os.path.basename(metrics_file))[0]

with open(metrics_file, 'r') as f:
Expand All @@ -92,7 +103,7 @@ def create_flamegraphs(metrics_file, group_by, stack_keys, metric_name, reverse=
for group_by_values in group_by_values_list:
group_by_kvs = list(zip(group_by, group_by_values))
fname = fname_prefix + '-' + '-'.join(group_by_values)
create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, reverse=reverse)
create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics, reverse=reverse)


def create_custom_flamegraphs(metrics_file, group_by=["group"]):
Expand All @@ -101,6 +112,9 @@ def create_custom_flamegraphs(metrics_file, group_by=["group"]):
reverse=reverse)
create_flamegraphs(metrics_file, group_by, ["cycle_tracker_span", "dsl_ir", "opcode", "air_name"], "cells_used",
reverse=reverse)
create_flamegraphs(metrics_file, group_by, ["cell_tracker_span"], "cells_used",
sum_metrics=["simple_advice_cells", "fixed_cells", "lookup_advice_cells"],
reverse=reverse)


def main():
Expand Down
1 change: 1 addition & 0 deletions crates/sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ default = ["parallel"]
bench-metrics = [
"openvm-circuit/bench-metrics",
"openvm-native-recursion/bench-metrics",
"openvm-native-compiler/bench-metrics",
]
profiling = ["openvm-circuit/function-span", "openvm-transpiler/function-span"]
parallel = ["openvm-circuit/parallel"]
Expand Down
3 changes: 3 additions & 0 deletions crates/sdk/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ pub struct Halo2Config {
pub verifier_k: usize,
/// If not specified, keygen will tune wrapper_k automatically.
pub wrapper_k: Option<usize>,
/// Sets the profiling mode of halo2 VM
pub profiling: bool,
}

impl<VC> AppConfig<VC> {
Expand Down Expand Up @@ -101,6 +103,7 @@ impl Default for AggConfig {
halo2_config: Halo2Config {
verifier_k: 24,
wrapper_k: None,
profiling: false,
},
}
}
Expand Down
8 changes: 7 additions & 1 deletion crates/sdk/src/keygen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ pub struct Halo2ProvingKey {
pub verifier: Halo2VerifierProvingKey,
/// Wrapper circuit to verify static verifier and reduce the verification costs in the final proof.
pub wrapper: Halo2WrapperProvingKey,
/// Whether to collect detailed profiling metrics
pub profiling: bool,
}

impl<VC: VmConfig<F>> AppProvingKey<VC>
Expand Down Expand Up @@ -315,7 +317,11 @@ impl AggProvingKey {
} else {
Halo2WrapperProvingKey::keygen_auto_tune(reader, dummy_snark)
};
let halo2_pk = Halo2ProvingKey { verifier, wrapper };
let halo2_pk = Halo2ProvingKey {
verifier,
wrapper,
profiling: halo2_config.profiling,
};
Self {
agg_stark_pk,
halo2_pk,
Expand Down
7 changes: 5 additions & 2 deletions crates/sdk/src/prover/halo2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ impl Halo2Prover {
pub fn prove_for_evm(&self, root_proof: &Proof<RootSC>) -> EvmProof {
let mut witness = Witness::default();
root_proof.write(&mut witness);
let snark = info_span!("prove", group = "halo2_outer")
.in_scope(|| self.halo2_pk.verifier.prove(&self.verifier_srs, witness));
let snark = info_span!("prove", group = "halo2_outer").in_scope(|| {
self.halo2_pk
.verifier
.prove(&self.verifier_srs, witness, self.halo2_pk.profiling)
});
info_span!("prove_for_evm", group = "halo2_wrapper").in_scope(|| {
self.halo2_pk
.wrapper
Expand Down
1 change: 1 addition & 0 deletions crates/sdk/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ fn agg_config_for_test() -> AggConfig {
halo2_config: Halo2Config {
verifier_k: 24,
wrapper_k: None,
profiling: false,
},
}
}
Expand Down
2 changes: 1 addition & 1 deletion extensions/native/recursion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ static-verifier = [
"dep:once_cell",
]
test-utils = ["openvm-circuit/test-utils"]
bench-metrics = ["dep:metrics", "openvm-circuit/bench-metrics"]
bench-metrics = ["dep:metrics", "openvm-circuit/bench-metrics", "openvm-native-compiler/bench-metrics"]
mimalloc = ["openvm-stark-backend/mimalloc"]
jemalloc = ["openvm-stark-backend/jemalloc"]
nightly-features = ["openvm-circuit/nightly-features"]
3 changes: 2 additions & 1 deletion extensions/native/recursion/src/halo2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,15 @@ impl Halo2Prover {
pk: &ProvingKey<G1Affine>,
dsl_operations: DslOperations<C>,
witness: Witness<C>,
profiling: bool,
) -> Snark {
let k = config_params.k;
#[cfg(feature = "bench-metrics")]
let start = std::time::Instant::now();
let builder = Self::builder(CircuitBuilderStage::Prover, k)
.use_params(config_params)
.use_break_points(break_points);
let builder = Self::populate(builder, dsl_operations, witness, false);
let builder = Self::populate(builder, dsl_operations, witness, profiling);
#[cfg(feature = "bench-metrics")]
{
let stats = builder.statistics();
Expand Down
2 changes: 1 addition & 1 deletion extensions/native/recursion/src/halo2/testing_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub fn run_static_verifier_test(
.entered();
let mut witness = Witness::default();
vparams.data.proof.write(&mut witness);
let static_verifier_snark = stark_verifier_circuit.prove(params, witness);
let static_verifier_snark = stark_verifier_circuit.prove(params, witness, false);
info_span.exit();
(stark_verifier_circuit, static_verifier_snark)
}
8 changes: 7 additions & 1 deletion extensions/native/recursion/src/halo2/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,20 @@ pub fn generate_halo2_verifier_proving_key(
}

impl Halo2VerifierProvingKey {
pub fn prove(&self, params: &Halo2Params, witness: Witness<OuterConfig>) -> Snark {
pub fn prove(
&self,
params: &Halo2Params,
witness: Witness<OuterConfig>,
profiling: bool,
) -> Snark {
Halo2Prover::prove(
params,
self.pinning.metadata.config_params.clone(),
self.pinning.metadata.break_points.clone(),
&self.pinning.pk,
self.dsl_ops.clone(),
witness,
profiling,
)
}
// TODO: Add verify method
Expand Down

0 comments on commit 0ea118f

Please sign in to comment.