diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 456fb54b34..053fca1dac 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -56,7 +56,7 @@ pprof = { version = "0.13", features = [ [features] default = ["parallel", "mimalloc", "bench-metrics"] -bench-metrics = ["openvm-native-recursion/bench-metrics"] +bench-metrics = ["openvm-native-recursion/bench-metrics", "openvm-native-compiler/bench-metrics"] profiling = ["openvm-sdk/profiling"] aggregation = [] static-verifier = ["openvm-native-recursion/static-verifier"] diff --git a/benchmarks/src/utils.rs b/benchmarks/src/utils.rs index c8a1f9d96c..46d8449864 100644 --- a/benchmarks/src/utils.rs +++ b/benchmarks/src/utils.rs @@ -122,6 +122,7 @@ impl BenchmarkCli { halo2_config: Halo2Config { verifier_k: self.halo2_outer_k.unwrap_or(24), wrapper_k: self.halo2_wrapper_k, + profiling: self.profiling, }, } } diff --git a/ci/scripts/metric_unify/flamegraph.py b/ci/scripts/metric_unify/flamegraph.py index 43e2cfe2f6..f98a2e1e8f 100644 --- a/ci/scripts/metric_unify/flamegraph.py +++ b/ci/scripts/metric_unify/flamegraph.py @@ -6,20 +6,25 @@ from utils import FLAMEGRAPHS_DIR, get_git_root -def get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name): +def get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics=None): """ Filters a metrics_dict obtained from json for entries that look like: [ { labels: [["key1", "span1;span2"], ["key2", "span3"]], "metric": metric_name, "value": 2 } ] It will find entries that have all of stack_keys as present in the labels and then concatenate the corresponding values into a single flat stack entry and then add the value at the end. It will write a file with one line each for flamegraph.pl or inferno-flamegraph to consume. + If sum_metrics is not None, instead of searching for metric_name, it will sum the values of the metrics in sum_metrics. """ lines = [] + stack_sums = {} + non_zero = False # Process counters for counter in metrics_dict.get('counter', []): - if counter['metric'] != metric_name: + if (sum_metrics is not None and counter['metric'] not in sum_metrics) or \ + (sum_metrics is None and counter['metric'] != metric_name): continue + # list of pairs -> dict labels = dict(counter['labels']) filter = False @@ -41,15 +46,21 @@ def get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name): stack = ';'.join(stack_values) value = int(counter['value']) + stack_sums[stack] = stack_sums.get(stack, 0) + value + + if value != 0: + non_zero = True - lines.append(f"{stack} {value}") + lines = [f"{stack} {value}" for stack, value in stack_sums.items() if value != 0] # Currently cycle tracker does not use gauge - return lines + return lines if non_zero else [] -def create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, reverse=False): - lines = get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name) +def create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics=None, reverse=False): + lines = get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics) + if not lines: + return suffixes = [key for key in stack_keys if key != "cycle_tracker_span"] @@ -74,7 +85,7 @@ def create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name print(f"Created flamegraph at {flamegraph_path}") -def create_flamegraphs(metrics_file, group_by, stack_keys, metric_name, reverse=False): +def create_flamegraphs(metrics_file, group_by, stack_keys, metric_name, sum_metrics=None, reverse=False): fname_prefix = os.path.splitext(os.path.basename(metrics_file))[0] with open(metrics_file, 'r') as f: @@ -92,7 +103,7 @@ def create_flamegraphs(metrics_file, group_by, stack_keys, metric_name, reverse= for group_by_values in group_by_values_list: group_by_kvs = list(zip(group_by, group_by_values)) fname = fname_prefix + '-' + '-'.join(group_by_values) - create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, reverse=reverse) + create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics, reverse=reverse) def create_custom_flamegraphs(metrics_file, group_by=["group"]): @@ -101,6 +112,9 @@ def create_custom_flamegraphs(metrics_file, group_by=["group"]): reverse=reverse) create_flamegraphs(metrics_file, group_by, ["cycle_tracker_span", "dsl_ir", "opcode", "air_name"], "cells_used", reverse=reverse) + create_flamegraphs(metrics_file, group_by, ["cell_tracker_span"], "cells_used", + sum_metrics=["simple_advice_cells", "fixed_cells", "lookup_advice_cells"], + reverse=reverse) def main(): diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index 9ba8045373..3cd82adaf9 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -50,6 +50,7 @@ default = ["parallel"] bench-metrics = [ "openvm-circuit/bench-metrics", "openvm-native-recursion/bench-metrics", + "openvm-native-compiler/bench-metrics", ] profiling = ["openvm-circuit/function-span", "openvm-transpiler/function-span"] parallel = ["openvm-circuit/parallel"] diff --git a/crates/sdk/src/config/mod.rs b/crates/sdk/src/config/mod.rs index ae7fd059f9..b99adeff41 100644 --- a/crates/sdk/src/config/mod.rs +++ b/crates/sdk/src/config/mod.rs @@ -49,6 +49,8 @@ pub struct Halo2Config { pub verifier_k: usize, /// If not specified, keygen will tune wrapper_k automatically. pub wrapper_k: Option, + /// Sets the profiling mode of halo2 VM + pub profiling: bool, } impl AppConfig { @@ -101,6 +103,7 @@ impl Default for AggConfig { halo2_config: Halo2Config { verifier_k: 24, wrapper_k: None, + profiling: false, }, } } diff --git a/crates/sdk/src/keygen/mod.rs b/crates/sdk/src/keygen/mod.rs index e92b3286f2..2ce2eb915f 100644 --- a/crates/sdk/src/keygen/mod.rs +++ b/crates/sdk/src/keygen/mod.rs @@ -76,6 +76,8 @@ pub struct Halo2ProvingKey { pub verifier: Halo2VerifierProvingKey, /// Wrapper circuit to verify static verifier and reduce the verification costs in the final proof. pub wrapper: Halo2WrapperProvingKey, + /// Whether to collect detailed profiling metrics + pub profiling: bool, } impl> AppProvingKey @@ -315,7 +317,11 @@ impl AggProvingKey { } else { Halo2WrapperProvingKey::keygen_auto_tune(reader, dummy_snark) }; - let halo2_pk = Halo2ProvingKey { verifier, wrapper }; + let halo2_pk = Halo2ProvingKey { + verifier, + wrapper, + profiling: halo2_config.profiling, + }; Self { agg_stark_pk, halo2_pk, diff --git a/crates/sdk/src/prover/halo2.rs b/crates/sdk/src/prover/halo2.rs index e6d44640aa..d1f7a4f823 100644 --- a/crates/sdk/src/prover/halo2.rs +++ b/crates/sdk/src/prover/halo2.rs @@ -30,8 +30,11 @@ impl Halo2Prover { pub fn prove_for_evm(&self, root_proof: &Proof) -> EvmProof { let mut witness = Witness::default(); root_proof.write(&mut witness); - let snark = info_span!("prove", group = "halo2_outer") - .in_scope(|| self.halo2_pk.verifier.prove(&self.verifier_srs, witness)); + let snark = info_span!("prove", group = "halo2_outer").in_scope(|| { + self.halo2_pk + .verifier + .prove(&self.verifier_srs, witness, self.halo2_pk.profiling) + }); info_span!("prove_for_evm", group = "halo2_wrapper").in_scope(|| { self.halo2_pk .wrapper diff --git a/crates/sdk/tests/integration_test.rs b/crates/sdk/tests/integration_test.rs index 4f17009010..9c0f7853c1 100644 --- a/crates/sdk/tests/integration_test.rs +++ b/crates/sdk/tests/integration_test.rs @@ -90,6 +90,7 @@ fn agg_config_for_test() -> AggConfig { halo2_config: Halo2Config { verifier_k: 24, wrapper_k: None, + profiling: false, }, } } diff --git a/extensions/native/recursion/Cargo.toml b/extensions/native/recursion/Cargo.toml index 1e3439f38a..81d7f7b556 100644 --- a/extensions/native/recursion/Cargo.toml +++ b/extensions/native/recursion/Cargo.toml @@ -46,7 +46,7 @@ static-verifier = [ "dep:once_cell", ] test-utils = ["openvm-circuit/test-utils"] -bench-metrics = ["dep:metrics", "openvm-circuit/bench-metrics"] +bench-metrics = ["dep:metrics", "openvm-circuit/bench-metrics", "openvm-native-compiler/bench-metrics"] mimalloc = ["openvm-stark-backend/mimalloc"] jemalloc = ["openvm-stark-backend/jemalloc"] nightly-features = ["openvm-circuit/nightly-features"] diff --git a/extensions/native/recursion/src/halo2/mod.rs b/extensions/native/recursion/src/halo2/mod.rs index 33d6b45491..a3e402db7f 100644 --- a/extensions/native/recursion/src/halo2/mod.rs +++ b/extensions/native/recursion/src/halo2/mod.rs @@ -199,6 +199,7 @@ impl Halo2Prover { pk: &ProvingKey, dsl_operations: DslOperations, witness: Witness, + profiling: bool, ) -> Snark { let k = config_params.k; #[cfg(feature = "bench-metrics")] @@ -206,7 +207,7 @@ impl Halo2Prover { let builder = Self::builder(CircuitBuilderStage::Prover, k) .use_params(config_params) .use_break_points(break_points); - let builder = Self::populate(builder, dsl_operations, witness, false); + let builder = Self::populate(builder, dsl_operations, witness, profiling); #[cfg(feature = "bench-metrics")] { let stats = builder.statistics(); diff --git a/extensions/native/recursion/src/halo2/testing_utils.rs b/extensions/native/recursion/src/halo2/testing_utils.rs index baf6cd42d7..41e87cca4d 100644 --- a/extensions/native/recursion/src/halo2/testing_utils.rs +++ b/extensions/native/recursion/src/halo2/testing_utils.rs @@ -57,7 +57,7 @@ pub fn run_static_verifier_test( .entered(); let mut witness = Witness::default(); vparams.data.proof.write(&mut witness); - let static_verifier_snark = stark_verifier_circuit.prove(params, witness); + let static_verifier_snark = stark_verifier_circuit.prove(params, witness, false); info_span.exit(); (stark_verifier_circuit, static_verifier_snark) } diff --git a/extensions/native/recursion/src/halo2/verifier.rs b/extensions/native/recursion/src/halo2/verifier.rs index 6ad47581b7..b7e0c4c224 100644 --- a/extensions/native/recursion/src/halo2/verifier.rs +++ b/extensions/native/recursion/src/halo2/verifier.rs @@ -39,7 +39,12 @@ pub fn generate_halo2_verifier_proving_key( } impl Halo2VerifierProvingKey { - pub fn prove(&self, params: &Halo2Params, witness: Witness) -> Snark { + pub fn prove( + &self, + params: &Halo2Params, + witness: Witness, + profiling: bool, + ) -> Snark { Halo2Prover::prove( params, self.pinning.metadata.config_params.clone(), @@ -47,6 +52,7 @@ impl Halo2VerifierProvingKey { &self.pinning.pk, self.dsl_ops.clone(), witness, + profiling, ) } // TODO: Add verify method