Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Customizable static verifier public values handler #1226

Merged
merged 15 commits into from
Jan 21, 2025
10 changes: 8 additions & 2 deletions benchmarks/src/bin/fib_e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use openvm_rv32im_circuit::Rv32ImConfig;
use openvm_rv32im_transpiler::{
Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension,
};
use openvm_sdk::{commit::commit_app_exe, prover::ContinuationProver, Sdk, StdIn};
use openvm_sdk::{
commit::commit_app_exe, keygen::RootVerifierProvingKey, prover::ContinuationProver, Sdk, StdIn,
};
use openvm_stark_sdk::bench::run_with_metric_collection;
use openvm_transpiler::{transpiler::Transpiler, FromElf};

Expand All @@ -35,7 +37,11 @@ async fn main() -> Result<()> {
.unwrap_or(PathBuf::from(DEFAULT_PARAMS_DIR)),
);
let app_pk = Arc::new(sdk.app_keygen(app_config)?);
let full_agg_pk = sdk.agg_keygen(agg_config, &halo2_params_reader)?;
let full_agg_pk = sdk.agg_keygen(
agg_config,
&halo2_params_reader,
None::<&RootVerifierProvingKey>,
)?;
let elf = args.build_bench_program("fibonacci")?;
let exe = VmExe::from_elf(
elf,
Expand Down
3 changes: 2 additions & 1 deletion crates/cli/src/commands/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use openvm_native_recursion::halo2::utils::CacheHalo2ParamsReader;
use openvm_sdk::{
config::AggConfig,
fs::{write_agg_pk_to_file, write_evm_verifier_to_file},
keygen::RootVerifierProvingKey,
Sdk,
};

Expand Down Expand Up @@ -41,7 +42,7 @@ impl EvmProvingSetupCmd {
let agg_config = AggConfig::default();

println!("Generating proving key...");
let agg_pk = Sdk.agg_keygen(agg_config, &params_reader)?;
let agg_pk = Sdk.agg_keygen(agg_config, &params_reader, None::<&RootVerifierProvingKey>)?;

println!("Generating verifier contract...");
let verifier = Sdk.generate_snark_verifier_contract(&params_reader, &agg_pk)?;
Expand Down
7 changes: 6 additions & 1 deletion crates/sdk/examples/sdk_evm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use openvm_build::GuestOptions;
use openvm_native_recursion::halo2::utils::CacheHalo2ParamsReader;
use openvm_sdk::{
config::{AggConfig, AppConfig, SdkVmConfig},
keygen::RootVerifierProvingKey,
Sdk, StdIn,
};
use openvm_stark_sdk::config::FriParameters;
Expand Down Expand Up @@ -89,7 +90,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
const DEFAULT_PARAMS_DIR: &str = concat!(env!("HOME"), "/.openvm/params/");
let halo2_params_reader = CacheHalo2ParamsReader::new(DEFAULT_PARAMS_DIR);
let agg_config = AggConfig::default();
let agg_pk = sdk.agg_keygen(agg_config, &halo2_params_reader)?;
let agg_pk = sdk.agg_keygen(
agg_config,
&halo2_params_reader,
None::<&RootVerifierProvingKey>,
)?;

// 9. Generate the SNARK verifier smart contract
let verifier = sdk.generate_snark_verifier_contract(&halo2_params_reader, &agg_pk)?;
Expand Down
8 changes: 7 additions & 1 deletion crates/sdk/src/keygen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use crate::{
config::{AggConfig, AggStarkConfig, AppConfig},
keygen::perm::AirIdPermutation,
prover::vm::types::VmProvingKey,
static_verifier::StaticVerifierPvHandler,
verifier::{
internal::InternalVmVerifierConfig, leaf::LeafVmVerifierConfig, root::RootVmVerifierConfig,
},
Expand Down Expand Up @@ -296,7 +297,11 @@ impl AggProvingKey {
/// - This function is very expensive. Usually it requires >64GB memory and takes >10 minutes.
/// - Please make sure SRS(KZG parameters) is already downloaded.
#[tracing::instrument(level = "info", fields(group = "agg_keygen"), skip_all)]
pub fn keygen(config: AggConfig, reader: &impl Halo2ParamsReader) -> Self {
pub fn keygen(
config: AggConfig,
reader: &impl Halo2ParamsReader,
pv_handler: Option<&impl StaticVerifierPvHandler>,
) -> Self {
let AggConfig {
agg_stark_config,
halo2_config,
Expand All @@ -310,6 +315,7 @@ impl AggProvingKey {
let verifier = agg_stark_pk.root_verifier_pk.keygen_static_verifier(
&reader.read_params(halo2_config.verifier_k),
dummy_root_proof,
pv_handler,
);
let dummy_snark = verifier.generate_dummy_snark(reader);
let wrapper = if let Some(wrapper_k) = halo2_config.wrapper_k {
Expand Down
7 changes: 4 additions & 3 deletions crates/sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ use prover::vm::ContinuationVmProof;

pub mod commit;
pub mod config;
pub mod keygen;
pub mod prover;
pub mod static_verifier;

pub mod keygen;
pub mod verifier;

mod stdin;
use static_verifier::StaticVerifierPvHandler;
pub use stdin::*;
pub mod fs;

Expand Down Expand Up @@ -178,8 +178,9 @@ impl Sdk {
&self,
config: AggConfig,
reader: &impl Halo2ParamsReader,
pv_handler: Option<&impl StaticVerifierPvHandler>,
) -> Result<AggProvingKey> {
let agg_pk = AggProvingKey::keygen(config, reader);
let agg_pk = AggProvingKey::keygen(config, reader, pv_handler);
Ok(agg_pk)
}

Expand Down
153 changes: 91 additions & 62 deletions crates/sdk/src/static_verifier/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,23 @@ use openvm_native_recursion::{
hints::Hintable,
stark::StarkVerifier,
utils::const_fri_config,
vars::StarkProofVariable,
witness::Witnessable,
};
use openvm_stark_sdk::{
openvm_stark_backend::{
p3_field::{FieldAlgebra, PrimeField32},
prover::types::Proof,
},
p3_baby_bear::BabyBear,
openvm_stark_backend::{p3_field::FieldAlgebra, prover::types::Proof},
p3_bn254_fr::Bn254Fr,
};

use crate::{
keygen::RootVerifierProvingKey,
prover::{vm::SingleSegmentVmProver, RootVerifierLocalProver},
verifier::{
common::assert_single_segment_vm_exit_successfully_with_connector_air_id,
common::{
assert_single_segment_vm_exit_successfully_with_connector_air_id, types::SpecialAirIds,
},
root::types::{RootVmVerifierInput, RootVmVerifierPvs},
utils::compress_babybear_var_to_bn254,
},
RootSC, F, SC,
};
Expand All @@ -36,10 +36,11 @@ impl RootVerifierProvingKey {
&self,
params: &Halo2Params,
root_proof: Proof<RootSC>,
pv_handler: Option<&impl StaticVerifierPvHandler>,
) -> Halo2VerifierProvingKey {
let mut witness = Witness::default();
root_proof.write(&mut witness);
let dsl_operations = build_static_verifier_operations(self, &root_proof);
let dsl_operations = build_static_verifier_operations(self, &root_proof, pv_handler);
Halo2VerifierProvingKey {
pinning: Halo2Prover::keygen(params, dsl_operations.clone(), witness),
dsl_ops: dsl_operations,
Expand Down Expand Up @@ -67,49 +68,26 @@ impl RootVerifierProvingKey {
}
}

fn build_static_verifier_operations(
root_verifier_pk: &RootVerifierProvingKey,
proof: &Proof<RootSC>,
) -> DslOperations<OuterConfig> {
let advice = new_from_outer_multi_vk(&root_verifier_pk.vm_pk.vm_pk.get_vk());
let special_air_ids = root_verifier_pk.air_id_permutation().get_special_air_ids();
let mut builder = Builder::<OuterConfig>::default();
builder.flags.static_only = true;
let num_public_values = {
builder.cycle_tracker_start("VerifierProgram");
let input = proof.read(&mut builder);

let pcs = TwoAdicFriPcsVariable {
config: const_fri_config(&mut builder, &root_verifier_pk.vm_pk.fri_params),
};
StarkVerifier::verify::<MultiField32ChallengerVariable<_>>(
&mut builder,
&pcs,
&advice,
&input,
);
{
// Program AIR is the only AIR with a cached trace. The cached trace index doesn't
// change after reordering.
let t_id = RVar::from(PROGRAM_CACHED_TRACE_INDEX);
let commit = builder.get(&input.commitments.main_trace, t_id);
let commit = if let DigestVariable::Var(commit_arr) = commit {
builder.get(&commit_arr, 0)
} else {
unreachable!()
};
let expected_program_commit: [Bn254Fr; 1] = root_verifier_pk
.root_committed_exe
.get_program_commit()
.into();
builder.assert_var_eq(commit, expected_program_commit[0]);
}
assert_single_segment_vm_exit_successfully_with_connector_air_id(
&mut builder,
&input,
special_air_ids.connector_air_id,
);
/// Custom public values handler for static verifier. Implement this trait on a struct and pass it in to `RootVerifierProvingKey::keygen_static_verifier`.
/// If this trait is not implemented, `None` should be passed in for pv_handler to use the default handler.
pub trait StaticVerifierPvHandler {
fn handle_public_values(
&self,
builder: &mut Builder<OuterConfig>,
input: &StarkProofVariable<OuterConfig>,
root_verifier_pk: &RootVerifierProvingKey,
special_air_ids: &SpecialAirIds,
) -> usize;
}

impl StaticVerifierPvHandler for RootVerifierProvingKey {
fn handle_public_values(
&self,
builder: &mut Builder<OuterConfig>,
input: &StarkProofVariable<OuterConfig>,
_root_verifier_pk: &RootVerifierProvingKey,
special_air_ids: &SpecialAirIds,
) -> usize {
let pv_air = builder.get(&input.per_air, special_air_ids.public_values_air_id);
let public_values: Vec<_> = pv_air
.public_values
Expand All @@ -118,14 +96,45 @@ fn build_static_verifier_operations(
.map(|x| builder.cast_felt_to_var(x))
.collect();
let pvs = RootVmVerifierPvs::from_flatten(public_values);
let exe_commit = compress_babybear_var_to_bn254(&mut builder, pvs.exe_commit);
let leaf_commit = compress_babybear_var_to_bn254(&mut builder, pvs.leaf_verifier_commit);
let exe_commit = compress_babybear_var_to_bn254(builder, pvs.exe_commit);
let leaf_commit = compress_babybear_var_to_bn254(builder, pvs.leaf_verifier_commit);
let num_public_values = 2 + pvs.public_values.len();
builder.static_commit_public_value(0, exe_commit);
builder.static_commit_public_value(1, leaf_commit);
for (i, x) in pvs.public_values.into_iter().enumerate() {
builder.static_commit_public_value(i + 2, x);
}
num_public_values
}
}

fn build_static_verifier_operations(
root_verifier_pk: &RootVerifierProvingKey,
proof: &Proof<RootSC>,
pv_handler: Option<&impl StaticVerifierPvHandler>,
) -> DslOperations<OuterConfig> {
let special_air_ids = root_verifier_pk.air_id_permutation().get_special_air_ids();
let mut builder = Builder::<OuterConfig>::default();
builder.flags.static_only = true;
let num_public_values = {
builder.cycle_tracker_start("VerifierProgram");
let input = proof.read(&mut builder);
verify_root_proof(&mut builder, &input, root_verifier_pk, &special_air_ids);

let num_public_values = match &pv_handler {
Some(handler) => handler.handle_public_values(
&mut builder,
&input,
root_verifier_pk,
&special_air_ids,
),
None => root_verifier_pk.handle_public_values(
&mut builder,
&input,
root_verifier_pk,
&special_air_ids,
),
};
builder.cycle_tracker_end("VerifierProgram");
num_public_values
};
Expand All @@ -135,16 +144,36 @@ fn build_static_verifier_operations(
}
}

fn compress_babybear_var_to_bn254(
fn verify_root_proof(
builder: &mut Builder<OuterConfig>,
var: [Var<Bn254Fr>; DIGEST_SIZE],
) -> Var<Bn254Fr> {
let mut ret = SymbolicVar::ZERO;
let order = Bn254Fr::from_canonical_u32(BabyBear::ORDER_U32);
let mut base = Bn254Fr::ONE;
var.iter().for_each(|&x| {
ret += x * base;
base *= order;
});
builder.eval(ret)
input: &StarkProofVariable<OuterConfig>,
root_verifier_pk: &RootVerifierProvingKey,
special_air_ids: &SpecialAirIds,
) {
let advice = new_from_outer_multi_vk(&root_verifier_pk.vm_pk.vm_pk.get_vk());
let pcs = TwoAdicFriPcsVariable {
config: const_fri_config(builder, &root_verifier_pk.vm_pk.fri_params),
};
StarkVerifier::verify::<MultiField32ChallengerVariable<_>>(builder, &pcs, &advice, input);
{
// Program AIR is the only AIR with a cached trace. The cached trace index doesn't
// change after reordering.
let t_id = RVar::from(PROGRAM_CACHED_TRACE_INDEX);
let commit = builder.get(&input.commitments.main_trace, t_id);
let commit = if let DigestVariable::Var(commit_arr) = commit {
builder.get(&commit_arr, 0)
} else {
unreachable!()
};
let expected_program_commit: [Bn254Fr; 1] = root_verifier_pk
.root_committed_exe
.get_program_commit()
.into();
builder.assert_var_eq(commit, expected_program_commit[0]);
}
nyunyunyunyu marked this conversation as resolved.
Show resolved Hide resolved
assert_single_segment_vm_exit_successfully_with_connector_air_id(
builder,
input,
special_air_ids.connector_air_id,
);
}
2 changes: 1 addition & 1 deletion crates/sdk/src/verifier/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub mod common;
pub mod internal;
pub mod leaf;
pub mod root;
pub(crate) mod utils;
pub mod utils;

const SBOX_SIZE: usize = 7;

Expand Down
21 changes: 19 additions & 2 deletions crates/sdk/src/verifier/utils.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
use std::array;

use openvm_native_compiler::prelude::*;
use openvm_native_recursion::{hints::Hintable, types::InnerConfig};
use openvm_stark_sdk::{openvm_stark_backend::p3_field::FieldAlgebra, p3_baby_bear::BabyBear};
use openvm_native_recursion::{config::outer::OuterConfig, hints::Hintable, types::InnerConfig};
use openvm_stark_backend::p3_field::PrimeField32;
use openvm_stark_sdk::{
openvm_stark_backend::p3_field::FieldAlgebra, p3_baby_bear::BabyBear, p3_bn254_fr::Bn254Fr,
};

pub fn compress_babybear_var_to_bn254(
builder: &mut Builder<OuterConfig>,
var: [Var<Bn254Fr>; DIGEST_SIZE],
) -> Var<Bn254Fr> {
let mut ret = SymbolicVar::ZERO;
let order = Bn254Fr::from_canonical_u32(BabyBear::ORDER_U32);
let mut base = Bn254Fr::ONE;
var.iter().for_each(|&x| {
ret += x * base;
base *= order;
});
builder.eval(ret)
}

pub(crate) fn assign_array_to_slice<C: Config>(
builder: &mut Builder<C>,
Expand Down
Loading