Skip to content

Commit

Permalink
Add internal wrapper logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nyunyunyunyu committed Dec 3, 2024
1 parent 0289239 commit b438faa
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 18 deletions.
69 changes: 53 additions & 16 deletions crates/axvm-sdk/src/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub use root::*;

const DEFAULT_NUM_CHILDREN_LEAF: usize = 2;
const DEFAULT_NUM_CHILDREN_INTERNAL: usize = 2;
const DEFAULT_MAX_INTERNAL_WRAPPER_LAYERS: usize = 3;

pub struct StarkProver<VC> {
pub app_pk: AppProvingKey<VC>,
Expand All @@ -50,6 +51,7 @@ pub struct StarkProver<VC> {

pub num_children_leaf: usize,
pub num_children_internal: usize,
pub max_internal_wrapper_layers: usize,
}

impl<VC: VmConfig<F>> StarkProver<VC>
Expand All @@ -69,6 +71,7 @@ where
leaf_committed_exe: None,
num_children_leaf: DEFAULT_NUM_CHILDREN_LEAF,
num_children_internal: DEFAULT_NUM_CHILDREN_INTERNAL,
max_internal_wrapper_layers: DEFAULT_MAX_INTERNAL_WRAPPER_LAYERS,
app_prover,
leaf_prover: None,
internal_prover: None,
Expand Down Expand Up @@ -117,6 +120,11 @@ where
self
}

pub fn with_max_internal_wrapper_layers(mut self, max_internal_wrapper_layers: usize) -> Self {
self.max_internal_wrapper_layers = max_internal_wrapper_layers;
self
}

pub fn agg_pk(&self) -> &AggProvingKey {
assert!(self.agg_pk.is_some(), "Aggregation has not been configured");
self.agg_pk.as_ref().unwrap()
Expand All @@ -134,8 +142,12 @@ where
assert!(self.agg_pk.is_some(), "Aggregation has not been configured");
let app_proofs = self.generate_app_proof(input);
let leaf_proofs = self.generate_leaf_proof_impl(&app_proofs);
let internal_proof = self.generate_internal_proof_impl(leaf_proofs);
self.generate_root_proof_impl(app_proofs, internal_proof)
let public_values = app_proofs.user_public_values.public_values;
let internal_proof = self.generate_internal_proof_impl(leaf_proofs, &public_values);
self.generate_root_proof_impl(RootVmVerifierInput {
proofs: vec![internal_proof],
public_values,
})
}

pub fn generate_app_proof(&self, input: Vec<Vec<F>>) -> ContinuationVmProof<SC> {
Expand All @@ -162,11 +174,15 @@ where
.absolute(self.agg_pk().leaf_vm_pk.fri_params.log_blowup as u64);
self.generate_leaf_proof_impl(&app_proofs)
});
let internal_proof = self.generate_internal_proof_impl(leaf_proofs);
let public_values = app_proofs.user_public_values.public_values;
let internal_proof = self.generate_internal_proof_impl(leaf_proofs, &public_values);
info_span!("root verifier", group = "root_verifier").in_scope(|| {
counter!("fri.log_blowup")
.absolute(self.agg_pk().root_verifier_pk.vm_pk.fri_params.log_blowup as u64);
self.generate_root_proof_impl(app_proofs, internal_proof)
self.generate_root_proof_impl(RootVmVerifierInput {
proofs: vec![internal_proof],
public_values,
})
})
}

Expand Down Expand Up @@ -199,11 +215,36 @@ where
.collect::<Vec<_>>()
}

fn generate_internal_proof_impl(&self, leaf_proofs: Vec<Proof<SC>>) -> Proof<SC> {
fn generate_internal_proof_impl(
&self,
leaf_proofs: Vec<Proof<SC>>,
public_values: &[F],
) -> Proof<SC> {
let mut internal_node_idx = -1;
let mut internal_node_height = 0;
let mut proofs = leaf_proofs;
while proofs.len() > 1 {
let mut wrapper_layers = 0;
loop {
// TODO: what's a good test case for the wrapping logic?
if proofs.len() == 1 {
let root_prover = self.root_prover.as_ref().unwrap();
// TODO: record execution time as a part of root verifier execution time.
let actual_air_heights = root_prover.execute_for_air_heights(RootVmVerifierInput {
proofs: vec![proofs[0].clone()],
public_values: public_values.to_vec(),
});
// Root verifier can handle the internal proof. We can stop here.
if heights_le(
&actual_air_heights,
&root_prover.root_verifier_pk.air_heights,
) {
break;
}
if wrapper_layers >= self.max_internal_wrapper_layers {
panic!("The heights of the root verifier still exceed the required heights after {} wrapper layers", self.max_internal_wrapper_layers);
}
wrapper_layers += 1;
}
let internal_inputs = InternalVmVerifierInput::chunk_leaf_or_internal_proofs(
self.agg_pk()
.internal_committed_exe
Expand Down Expand Up @@ -239,16 +280,7 @@ where
proofs.pop().unwrap()
}

fn generate_root_proof_impl(
&self,
app_proofs: ContinuationVmProof<SC>,
internal_proof: Proof<SC>,
) -> Proof<OuterSC> {
// TODO: wrap internal verifier if heights exceed
let root_input = RootVmVerifierInput {
proofs: vec![internal_proof],
public_values: app_proofs.user_public_values.public_values,
};
fn generate_root_proof_impl(&self, root_input: RootVmVerifierInput<SC>) -> Proof<OuterSC> {
let input = root_input.write();
let root_prover = self.root_prover.as_ref().unwrap();
#[cfg(feature = "bench-metrics")]
Expand Down Expand Up @@ -293,3 +325,8 @@ fn execute_app_exe_for_metrics_collection<VC: VmConfig<F>>(
vm.execute_segments(app_committed_exe.exe.clone(), input)
.unwrap();
}

fn heights_le(a: &[usize], b: &[usize]) -> bool {
assert_eq!(a.len(), b.len());
a.iter().zip(b.iter()).all(|(a, b)| a <= b)
}
24 changes: 22 additions & 2 deletions crates/axvm-sdk/src/prover/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,37 @@ use axvm_circuit::{
arch::SingleSegmentVmExecutor,
prover::{AsyncSingleSegmentVmProver, SingleSegmentVmProver},
};
use axvm_native_circuit::NativeConfig;
use axvm_native_recursion::hints::Hintable;

use crate::{keygen::RootVerifierProvingKey, OuterSC, F};
use crate::{
keygen::RootVerifierProvingKey, verifier::root::types::RootVmVerifierInput, OuterSC, F, SC,
};

/// Local prover for a root verifier.
pub struct RootVerifierLocalProver {
pub root_verifier_pk: RootVerifierProvingKey,
executor_for_heights: SingleSegmentVmExecutor<F, NativeConfig>,
}

impl RootVerifierLocalProver {
pub fn new(root_verifier_pk: RootVerifierProvingKey) -> Self {
Self { root_verifier_pk }
let executor_for_heights =
SingleSegmentVmExecutor::<F, _>::new(root_verifier_pk.vm_pk.vm_config.clone());
Self {
root_verifier_pk,
executor_for_heights,
}
}
pub fn execute_for_air_heights(&self, input: RootVmVerifierInput<SC>) -> Vec<usize> {
let result = self
.executor_for_heights
.execute(
self.root_verifier_pk.root_committed_exe.exe.clone(),
input.write(),
)
.unwrap();
result.air_heights
}
}

Expand Down

0 comments on commit b438faa

Please sign in to comment.