diff --git a/benchmarks/src/bin/fib_e2e.rs b/benchmarks/src/bin/fib_e2e.rs index 169e6ca08b..32cd9442a5 100644 --- a/benchmarks/src/bin/fib_e2e.rs +++ b/benchmarks/src/bin/fib_e2e.rs @@ -9,7 +9,9 @@ use openvm_rv32im_circuit::Rv32ImConfig; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; -use openvm_sdk::{commit::commit_app_exe, prover::ContinuationProver, Sdk, StdIn}; +use openvm_sdk::{ + commit::commit_app_exe, keygen::RootVerifierProvingKey, prover::ContinuationProver, Sdk, StdIn, +}; use openvm_stark_sdk::bench::run_with_metric_collection; use openvm_transpiler::{transpiler::Transpiler, FromElf}; @@ -35,7 +37,11 @@ async fn main() -> Result<()> { .unwrap_or(PathBuf::from(DEFAULT_PARAMS_DIR)), ); let app_pk = Arc::new(sdk.app_keygen(app_config)?); - let full_agg_pk = sdk.agg_keygen(agg_config, &halo2_params_reader)?; + let full_agg_pk = sdk.agg_keygen( + agg_config, + &halo2_params_reader, + None::<&RootVerifierProvingKey>, + )?; let elf = args.build_bench_program("fibonacci")?; let exe = VmExe::from_elf( elf, diff --git a/crates/cli/src/commands/setup.rs b/crates/cli/src/commands/setup.rs index d5fd7fd2b2..00bc625bb3 100644 --- a/crates/cli/src/commands/setup.rs +++ b/crates/cli/src/commands/setup.rs @@ -11,6 +11,7 @@ use openvm_native_recursion::halo2::utils::CacheHalo2ParamsReader; use openvm_sdk::{ config::AggConfig, fs::{write_agg_pk_to_file, write_evm_verifier_to_file}, + keygen::RootVerifierProvingKey, Sdk, }; @@ -41,7 +42,7 @@ impl EvmProvingSetupCmd { let agg_config = AggConfig::default(); println!("Generating proving key..."); - let agg_pk = Sdk.agg_keygen(agg_config, ¶ms_reader)?; + let agg_pk = Sdk.agg_keygen(agg_config, ¶ms_reader, None::<&RootVerifierProvingKey>)?; println!("Generating verifier contract..."); let verifier = Sdk.generate_snark_verifier_contract(¶ms_reader, &agg_pk)?; diff --git a/crates/sdk/examples/sdk_evm.rs b/crates/sdk/examples/sdk_evm.rs index 24bd6add55..41b91c71cf 100644 --- a/crates/sdk/examples/sdk_evm.rs +++ b/crates/sdk/examples/sdk_evm.rs @@ -7,6 +7,7 @@ use openvm_build::GuestOptions; use openvm_native_recursion::halo2::utils::CacheHalo2ParamsReader; use openvm_sdk::{ config::{AggConfig, AppConfig, SdkVmConfig}, + keygen::RootVerifierProvingKey, Sdk, StdIn, }; use openvm_stark_sdk::config::FriParameters; @@ -89,7 +90,11 @@ fn main() -> Result<(), Box> { const DEFAULT_PARAMS_DIR: &str = concat!(env!("HOME"), "/.openvm/params/"); let halo2_params_reader = CacheHalo2ParamsReader::new(DEFAULT_PARAMS_DIR); let agg_config = AggConfig::default(); - let agg_pk = sdk.agg_keygen(agg_config, &halo2_params_reader)?; + let agg_pk = sdk.agg_keygen( + agg_config, + &halo2_params_reader, + None::<&RootVerifierProvingKey>, + )?; // 9. Generate the SNARK verifier smart contract let verifier = sdk.generate_snark_verifier_contract(&halo2_params_reader, &agg_pk)?; diff --git a/crates/sdk/src/keygen/mod.rs b/crates/sdk/src/keygen/mod.rs index 2ce2eb915f..49cbfae48d 100644 --- a/crates/sdk/src/keygen/mod.rs +++ b/crates/sdk/src/keygen/mod.rs @@ -33,6 +33,7 @@ use crate::{ config::{AggConfig, AggStarkConfig, AppConfig}, keygen::perm::AirIdPermutation, prover::vm::types::VmProvingKey, + static_verifier::StaticVerifierPvHandler, verifier::{ internal::InternalVmVerifierConfig, leaf::LeafVmVerifierConfig, root::RootVmVerifierConfig, }, @@ -296,7 +297,11 @@ impl AggProvingKey { /// - This function is very expensive. Usually it requires >64GB memory and takes >10 minutes. /// - Please make sure SRS(KZG parameters) is already downloaded. #[tracing::instrument(level = "info", fields(group = "agg_keygen"), skip_all)] - pub fn keygen(config: AggConfig, reader: &impl Halo2ParamsReader) -> Self { + pub fn keygen( + config: AggConfig, + reader: &impl Halo2ParamsReader, + pv_handler: Option<&impl StaticVerifierPvHandler>, + ) -> Self { let AggConfig { agg_stark_config, halo2_config, @@ -310,6 +315,7 @@ impl AggProvingKey { let verifier = agg_stark_pk.root_verifier_pk.keygen_static_verifier( &reader.read_params(halo2_config.verifier_k), dummy_root_proof, + pv_handler, ); let dummy_snark = verifier.generate_dummy_snark(reader); let wrapper = if let Some(wrapper_k) = halo2_config.wrapper_k { diff --git a/crates/sdk/src/lib.rs b/crates/sdk/src/lib.rs index a4038e194a..417426487c 100644 --- a/crates/sdk/src/lib.rs +++ b/crates/sdk/src/lib.rs @@ -42,13 +42,13 @@ use prover::vm::ContinuationVmProof; pub mod commit; pub mod config; +pub mod keygen; pub mod prover; pub mod static_verifier; - -pub mod keygen; pub mod verifier; mod stdin; +use static_verifier::StaticVerifierPvHandler; pub use stdin::*; pub mod fs; @@ -178,8 +178,9 @@ impl Sdk { &self, config: AggConfig, reader: &impl Halo2ParamsReader, + pv_handler: Option<&impl StaticVerifierPvHandler>, ) -> Result { - let agg_pk = AggProvingKey::keygen(config, reader); + let agg_pk = AggProvingKey::keygen(config, reader, pv_handler); Ok(agg_pk) } diff --git a/crates/sdk/src/static_verifier/mod.rs b/crates/sdk/src/static_verifier/mod.rs index 82d989f082..bc0e1a9148 100644 --- a/crates/sdk/src/static_verifier/mod.rs +++ b/crates/sdk/src/static_verifier/mod.rs @@ -9,14 +9,11 @@ use openvm_native_recursion::{ hints::Hintable, stark::StarkVerifier, utils::const_fri_config, + vars::StarkProofVariable, witness::Witnessable, }; use openvm_stark_sdk::{ - openvm_stark_backend::{ - p3_field::{FieldAlgebra, PrimeField32}, - prover::types::Proof, - }, - p3_baby_bear::BabyBear, + openvm_stark_backend::{p3_field::FieldAlgebra, prover::types::Proof}, p3_bn254_fr::Bn254Fr, }; @@ -24,8 +21,11 @@ use crate::{ keygen::RootVerifierProvingKey, prover::{vm::SingleSegmentVmProver, RootVerifierLocalProver}, verifier::{ - common::assert_single_segment_vm_exit_successfully_with_connector_air_id, + common::{ + assert_single_segment_vm_exit_successfully_with_connector_air_id, types::SpecialAirIds, + }, root::types::{RootVmVerifierInput, RootVmVerifierPvs}, + utils::compress_babybear_var_to_bn254, }, RootSC, F, SC, }; @@ -36,10 +36,11 @@ impl RootVerifierProvingKey { &self, params: &Halo2Params, root_proof: Proof, + pv_handler: Option<&impl StaticVerifierPvHandler>, ) -> Halo2VerifierProvingKey { let mut witness = Witness::default(); root_proof.write(&mut witness); - let dsl_operations = build_static_verifier_operations(self, &root_proof); + let dsl_operations = build_static_verifier_operations(self, &root_proof, pv_handler); Halo2VerifierProvingKey { pinning: Halo2Prover::keygen(params, dsl_operations.clone(), witness), dsl_ops: dsl_operations, @@ -67,49 +68,26 @@ impl RootVerifierProvingKey { } } -fn build_static_verifier_operations( - root_verifier_pk: &RootVerifierProvingKey, - proof: &Proof, -) -> DslOperations { - let advice = new_from_outer_multi_vk(&root_verifier_pk.vm_pk.vm_pk.get_vk()); - let special_air_ids = root_verifier_pk.air_id_permutation().get_special_air_ids(); - let mut builder = Builder::::default(); - builder.flags.static_only = true; - let num_public_values = { - builder.cycle_tracker_start("VerifierProgram"); - let input = proof.read(&mut builder); - - let pcs = TwoAdicFriPcsVariable { - config: const_fri_config(&mut builder, &root_verifier_pk.vm_pk.fri_params), - }; - StarkVerifier::verify::>( - &mut builder, - &pcs, - &advice, - &input, - ); - { - // Program AIR is the only AIR with a cached trace. The cached trace index doesn't - // change after reordering. - let t_id = RVar::from(PROGRAM_CACHED_TRACE_INDEX); - let commit = builder.get(&input.commitments.main_trace, t_id); - let commit = if let DigestVariable::Var(commit_arr) = commit { - builder.get(&commit_arr, 0) - } else { - unreachable!() - }; - let expected_program_commit: [Bn254Fr; 1] = root_verifier_pk - .root_committed_exe - .get_program_commit() - .into(); - builder.assert_var_eq(commit, expected_program_commit[0]); - } - assert_single_segment_vm_exit_successfully_with_connector_air_id( - &mut builder, - &input, - special_air_ids.connector_air_id, - ); +/// Custom public values handler for static verifier. Implement this trait on a struct and pass it in to `RootVerifierProvingKey::keygen_static_verifier`. +/// If this trait is not implemented, `None` should be passed in for pv_handler to use the default handler. +pub trait StaticVerifierPvHandler { + fn handle_public_values( + &self, + builder: &mut Builder, + input: &StarkProofVariable, + root_verifier_pk: &RootVerifierProvingKey, + special_air_ids: &SpecialAirIds, + ) -> usize; +} +impl StaticVerifierPvHandler for RootVerifierProvingKey { + fn handle_public_values( + &self, + builder: &mut Builder, + input: &StarkProofVariable, + _root_verifier_pk: &RootVerifierProvingKey, + special_air_ids: &SpecialAirIds, + ) -> usize { let pv_air = builder.get(&input.per_air, special_air_ids.public_values_air_id); let public_values: Vec<_> = pv_air .public_values @@ -118,14 +96,45 @@ fn build_static_verifier_operations( .map(|x| builder.cast_felt_to_var(x)) .collect(); let pvs = RootVmVerifierPvs::from_flatten(public_values); - let exe_commit = compress_babybear_var_to_bn254(&mut builder, pvs.exe_commit); - let leaf_commit = compress_babybear_var_to_bn254(&mut builder, pvs.leaf_verifier_commit); + let exe_commit = compress_babybear_var_to_bn254(builder, pvs.exe_commit); + let leaf_commit = compress_babybear_var_to_bn254(builder, pvs.leaf_verifier_commit); let num_public_values = 2 + pvs.public_values.len(); builder.static_commit_public_value(0, exe_commit); builder.static_commit_public_value(1, leaf_commit); for (i, x) in pvs.public_values.into_iter().enumerate() { builder.static_commit_public_value(i + 2, x); } + num_public_values + } +} + +fn build_static_verifier_operations( + root_verifier_pk: &RootVerifierProvingKey, + proof: &Proof, + pv_handler: Option<&impl StaticVerifierPvHandler>, +) -> DslOperations { + let special_air_ids = root_verifier_pk.air_id_permutation().get_special_air_ids(); + let mut builder = Builder::::default(); + builder.flags.static_only = true; + let num_public_values = { + builder.cycle_tracker_start("VerifierProgram"); + let input = proof.read(&mut builder); + verify_root_proof(&mut builder, &input, root_verifier_pk, &special_air_ids); + + let num_public_values = match &pv_handler { + Some(handler) => handler.handle_public_values( + &mut builder, + &input, + root_verifier_pk, + &special_air_ids, + ), + None => root_verifier_pk.handle_public_values( + &mut builder, + &input, + root_verifier_pk, + &special_air_ids, + ), + }; builder.cycle_tracker_end("VerifierProgram"); num_public_values }; @@ -135,16 +144,36 @@ fn build_static_verifier_operations( } } -fn compress_babybear_var_to_bn254( +fn verify_root_proof( builder: &mut Builder, - var: [Var; DIGEST_SIZE], -) -> Var { - let mut ret = SymbolicVar::ZERO; - let order = Bn254Fr::from_canonical_u32(BabyBear::ORDER_U32); - let mut base = Bn254Fr::ONE; - var.iter().for_each(|&x| { - ret += x * base; - base *= order; - }); - builder.eval(ret) + input: &StarkProofVariable, + root_verifier_pk: &RootVerifierProvingKey, + special_air_ids: &SpecialAirIds, +) { + let advice = new_from_outer_multi_vk(&root_verifier_pk.vm_pk.vm_pk.get_vk()); + let pcs = TwoAdicFriPcsVariable { + config: const_fri_config(builder, &root_verifier_pk.vm_pk.fri_params), + }; + StarkVerifier::verify::>(builder, &pcs, &advice, input); + { + // Program AIR is the only AIR with a cached trace. The cached trace index doesn't + // change after reordering. + let t_id = RVar::from(PROGRAM_CACHED_TRACE_INDEX); + let commit = builder.get(&input.commitments.main_trace, t_id); + let commit = if let DigestVariable::Var(commit_arr) = commit { + builder.get(&commit_arr, 0) + } else { + unreachable!() + }; + let expected_program_commit: [Bn254Fr; 1] = root_verifier_pk + .root_committed_exe + .get_program_commit() + .into(); + builder.assert_var_eq(commit, expected_program_commit[0]); + } + assert_single_segment_vm_exit_successfully_with_connector_air_id( + builder, + input, + special_air_ids.connector_air_id, + ); } diff --git a/crates/sdk/src/verifier/mod.rs b/crates/sdk/src/verifier/mod.rs index d9b1bd664e..f037205353 100644 --- a/crates/sdk/src/verifier/mod.rs +++ b/crates/sdk/src/verifier/mod.rs @@ -8,7 +8,7 @@ pub mod common; pub mod internal; pub mod leaf; pub mod root; -pub(crate) mod utils; +pub mod utils; const SBOX_SIZE: usize = 7; diff --git a/crates/sdk/src/verifier/utils.rs b/crates/sdk/src/verifier/utils.rs index cc1e69c291..2420e0ea03 100644 --- a/crates/sdk/src/verifier/utils.rs +++ b/crates/sdk/src/verifier/utils.rs @@ -1,8 +1,25 @@ use std::array; use openvm_native_compiler::prelude::*; -use openvm_native_recursion::{hints::Hintable, types::InnerConfig}; -use openvm_stark_sdk::{openvm_stark_backend::p3_field::FieldAlgebra, p3_baby_bear::BabyBear}; +use openvm_native_recursion::{config::outer::OuterConfig, hints::Hintable, types::InnerConfig}; +use openvm_stark_backend::p3_field::PrimeField32; +use openvm_stark_sdk::{ + openvm_stark_backend::p3_field::FieldAlgebra, p3_baby_bear::BabyBear, p3_bn254_fr::Bn254Fr, +}; + +pub fn compress_babybear_var_to_bn254( + builder: &mut Builder, + var: [Var; DIGEST_SIZE], +) -> Var { + let mut ret = SymbolicVar::ZERO; + let order = Bn254Fr::from_canonical_u32(BabyBear::ORDER_U32); + let mut base = Bn254Fr::ONE; + var.iter().for_each(|&x| { + ret += x * base; + base *= order; + }); + builder.eval(ret) +} pub(crate) fn assign_array_to_slice( builder: &mut Builder, diff --git a/crates/sdk/tests/integration_test.rs b/crates/sdk/tests/integration_test.rs index 9c0f7853c1..7c2e6ba4ec 100644 --- a/crates/sdk/tests/integration_test.rs +++ b/crates/sdk/tests/integration_test.rs @@ -10,14 +10,21 @@ use openvm_circuit::{ }; use openvm_native_circuit::{Native, NativeConfig}; use openvm_native_compiler::{conversion::CompilerOptions, prelude::*}; -use openvm_native_recursion::{halo2::utils::CacheHalo2ParamsReader, types::InnerConfig}; +use openvm_native_recursion::{ + config::outer::OuterConfig, halo2::utils::CacheHalo2ParamsReader, types::InnerConfig, + vars::StarkProofVariable, +}; use openvm_rv32im_transpiler::{Rv32ITranspilerExtension, Rv32MTranspilerExtension}; use openvm_sdk::{ + commit::AppExecutionCommit, config::{AggConfig, AggStarkConfig, AppConfig, Halo2Config}, - keygen::AppProvingKey, + keygen::{AppProvingKey, RootVerifierProvingKey}, + static_verifier::StaticVerifierPvHandler, verifier::{ - common::types::VmVerifierPvs, + common::types::{SpecialAirIds, VmVerifierPvs}, leaf::types::{LeafVmVerifierInput, UserPublicValuesRootProof}, + root::types::RootVmVerifierPvs, + utils::compress_babybear_var_to_bn254, }, Sdk, StdIn, }; @@ -29,6 +36,7 @@ use openvm_stark_sdk::{ engine::{StarkEngine, StarkFriEngine}, openvm_stark_backend::{p3_field::FieldAlgebra, Chip}, p3_baby_bear::BabyBear, + p3_bn254_fr::Bn254Fr, }; use openvm_transpiler::transpiler::Transpiler; @@ -75,7 +83,9 @@ fn app_committed_exe_for_test(app_log_blowup: usize) -> Arc> builder.assign(&b, c); }); builder.halt(); - builder.compile_isa() + let mut program = builder.compile_isa(); + program.max_num_public_values = NUM_PUB_VALUES; + program }; Sdk.commit_app_exe( standard_fri_params_with_100_bits_conjectured_security(app_log_blowup), @@ -120,7 +130,7 @@ fn small_test_app_config(app_log_blowup: usize) -> AppConfig { SystemConfig::default() .with_max_segment_len(200) .with_continuations() - .with_public_values(16), + .with_public_values(NUM_PUB_VALUES), Native, ), leaf_fri_params: standard_fri_params_with_100_bits_conjectured_security(LEAF_LOG_BLOWUP) @@ -254,6 +264,99 @@ fn test_public_values_and_leaf_verification() { } } +#[test] +fn test_static_verifier_custom_pv_handler() { + // Define custom public values handler and implement StaticVerifierPvHandler trait on it + pub struct CustomPvHandler { + pub exe_commit: Bn254Fr, + pub leaf_verifier_commit: Bn254Fr, + } + + impl StaticVerifierPvHandler for CustomPvHandler { + fn handle_public_values( + &self, + builder: &mut Builder, + input: &StarkProofVariable, + _root_verifier_pk: &RootVerifierProvingKey, + special_air_ids: &SpecialAirIds, + ) -> usize { + let pv_air = builder.get(&input.per_air, special_air_ids.public_values_air_id); + let public_values: Vec<_> = pv_air + .public_values + .vec() + .into_iter() + .map(|x| builder.cast_felt_to_var(x)) + .collect(); + let pvs = RootVmVerifierPvs::from_flatten(public_values); + let exe_commit = compress_babybear_var_to_bn254(builder, pvs.exe_commit); + let leaf_commit = compress_babybear_var_to_bn254(builder, pvs.leaf_verifier_commit); + let num_public_values = pvs.public_values.len(); + + println!("num_public_values: {}", num_public_values); + println!("self.exe_commit: {:?}", self.exe_commit); + println!("self.leaf_verifier_commit: {:?}", self.leaf_verifier_commit); + + let expected_exe_commit: Var = builder.constant(self.exe_commit); + let expected_leaf_commit: Var = builder.constant(self.leaf_verifier_commit); + + builder.assert_var_eq(exe_commit, expected_exe_commit); + builder.assert_var_eq(leaf_commit, expected_leaf_commit); + + num_public_values + } + } + + // Test setup + println!("test setup"); + let app_log_blowup = 1; + let app_config = small_test_app_config(app_log_blowup); + let app_pk = Sdk.app_keygen(app_config.clone()).unwrap(); + let app_committed_exe = app_committed_exe_for_test(app_log_blowup); + println!("app_config: {:?}", app_config.app_vm_config); + println!( + "app_committed_exe max_num_public_values: {:?}", + app_committed_exe.exe.program.max_num_public_values + ); + let params_reader = CacheHalo2ParamsReader::new_with_default_params_dir(); + + // Generate PK using custom PV handler + println!("generate PK using custom PV handler"); + let commits = AppExecutionCommit::compute( + &app_config.app_vm_config, + &app_committed_exe, + &app_pk.leaf_committed_exe, + ); + let exe_commit = commits.exe_commit_to_bn254(); + let leaf_verifier_commit = commits.app_config_commit_to_bn254(); + + let pv_handler = CustomPvHandler { + exe_commit, + leaf_verifier_commit, + }; + let agg_pk = Sdk + .agg_keygen(agg_config_for_test(), ¶ms_reader, Some(&pv_handler)) + .unwrap(); + + // Generate verifier contract + println!("generate verifier contract"); + let evm_verifier = Sdk + .generate_snark_verifier_contract(¶ms_reader, &agg_pk) + .unwrap(); + + // Generate and verify proof + println!("generate and verify proof"); + let evm_proof = Sdk + .generate_evm_proof( + ¶ms_reader, + Arc::new(app_pk), + app_committed_exe, + agg_pk, + StdIn::default(), + ) + .unwrap(); + assert!(Sdk.verify_evm_proof(&evm_verifier, &evm_proof)); +} + #[test] fn test_e2e_proof_generation_and_verification() { let app_log_blowup = 1; @@ -261,7 +364,11 @@ fn test_e2e_proof_generation_and_verification() { let app_pk = Sdk.app_keygen(app_config).unwrap(); let params_reader = CacheHalo2ParamsReader::new_with_default_params_dir(); let agg_pk = Sdk - .agg_keygen(agg_config_for_test(), ¶ms_reader) + .agg_keygen( + agg_config_for_test(), + ¶ms_reader, + None::<&RootVerifierProvingKey>, + ) .unwrap(); let evm_verifier = Sdk .generate_snark_verifier_contract(¶ms_reader, &agg_pk) diff --git a/extensions/native/recursion/src/halo2/wrapper.rs b/extensions/native/recursion/src/halo2/wrapper.rs index 9ca9a3187e..8ffa53f497 100644 --- a/extensions/native/recursion/src/halo2/wrapper.rs +++ b/extensions/native/recursion/src/halo2/wrapper.rs @@ -127,8 +127,7 @@ impl Halo2WrapperProvingKey { ); assert_eq!( self.pinning.metadata.num_pvs[0], - // 12 is the number of public values for the accumulator - snark_to_verify.instances[0].len() + 12 + snark_to_verify.instances[0].len() + 12, ); generate_wrapper_circuit_object(Prover, k, snark_to_verify) .use_params(