From 8eb4543608743d1398182e7461bdec76b8ab1bbc Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Sun, 15 Dec 2024 12:26:12 -0800 Subject: [PATCH] [refactor] Implement Chip for All Memory Chips (#1038) * Implement Chip for all memory chips * Parallize Map transformation --- crates/circuits/primitives/derive/src/lib.rs | 36 ++- crates/vm/src/arch/testing/mod.rs | 26 +- crates/vm/src/system/memory/adapter/mod.rs | 88 +++---- .../vm/src/system/memory/manager/interface.rs | 1 + crates/vm/src/system/memory/manager/mod.rs | 208 +++++++--------- crates/vm/src/system/memory/merkle/mod.rs | 21 +- .../vm/src/system/memory/merkle/tests/mod.rs | 84 +++---- .../vm/src/system/memory/merkle/tests/util.rs | 17 +- crates/vm/src/system/memory/merkle/trace.rs | 67 ++++- crates/vm/src/system/memory/persistent.rs | 230 +++++++++++++----- crates/vm/src/system/memory/volatile/mod.rs | 110 +++++---- crates/vm/src/system/memory/volatile/tests.rs | 52 ++-- 12 files changed, 548 insertions(+), 392 deletions(-) diff --git a/crates/circuits/primitives/derive/src/lib.rs b/crates/circuits/primitives/derive/src/lib.rs index 1e1ba1aa52..18db7d2970 100644 --- a/crates/circuits/primitives/derive/src/lib.rs +++ b/crates/circuits/primitives/derive/src/lib.rs @@ -5,7 +5,7 @@ extern crate proc_macro; use itertools::multiunzip; use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam}; +use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam, LitStr, Meta}; #[proc_macro_derive(AlignedBorrow)] pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream { @@ -72,8 +72,9 @@ pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream { TokenStream::from(methods) } -#[proc_macro_derive(Chip)] +#[proc_macro_derive(Chip, attributes(chip))] pub fn chip_derive(input: TokenStream) -> TokenStream { + // Parse the attributes from the struct or enum let ast: syn::DeriveInput = syn::parse(input).unwrap(); let name = &ast.ident; @@ -160,6 +161,37 @@ pub fn chip_derive(input: TokenStream) -> TokenStream { let where_clause = new_generics.make_where_clause(); where_clause.predicates.push(syn::parse_quote! { openvm_stark_backend::config::Domain: openvm_stark_backend::p3_commit::PolynomialSpace }); + let attributes = ast.attrs.iter().find(|&attr| attr.path().is_ident("chip")); + if let Some(attr) = attributes { + let mut fail_flag = false; + + match &attr.meta { + Meta::List(meta_list) => { + meta_list + .parse_nested_meta(|meta| { + if meta.path.is_ident("where") { + let value = meta.value()?; // this parses the `=` + let s: LitStr = value.parse()?; + let where_value = s.value(); + where_clause.predicates.push(syn::parse_str(&where_value)?); + } else { + fail_flag = true; + } + Ok(()) + }) + .unwrap(); + } + _ => fail_flag = true, + } + if fail_flag { + return syn::Error::new( + name.span(), + "Only `#[chip(where = ...)]` format is supported", + ) + .to_compile_error() + .into(); + } + } quote! { impl #impl_generics openvm_stark_backend::Chip for #name #ty_generics #where_clause { diff --git a/crates/vm/src/arch/testing/mod.rs b/crates/vm/src/arch/testing/mod.rs index c960d5ab18..4ad33640ba 100644 --- a/crates/vm/src/arch/testing/mod.rs +++ b/crates/vm/src/arch/testing/mod.rs @@ -1,16 +1,12 @@ use std::{cell::RefCell, rc::Rc, sync::Arc}; -use itertools::izip; use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; use openvm_instructions::instruction::Instruction; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, engine::VerificationData, p3_field::PrimeField32, - p3_matrix::{ - dense::{DenseMatrix, RowMajorMatrix}, - Matrix, - }, + p3_matrix::dense::{DenseMatrix, RowMajorMatrix}, prover::types::AirProofInput, verifier::VerificationError, Chip, @@ -267,21 +263,15 @@ where let range_checker = memory_controller.borrow().range_checker.clone(); self = self.load(memory_tester); // dummy memory interactions { - let memory = memory_controller.borrow(); - let public_values = memory.generate_public_values_per_air(); - let airs = memory.airs(); - drop(memory); - let traces = Rc::try_unwrap(memory_controller) + let air_proof_inputs = Rc::try_unwrap(memory_controller) .unwrap() .into_inner() - .generate_traces(); - - for (pvs, air, trace) in izip!(public_values, airs, traces) { - if trace.height() > 0 { - self.air_proof_inputs - .push(AirProofInput::simple(air, trace, pvs)); - } - } + .generate_air_proof_inputs(); + self.air_proof_inputs.extend( + air_proof_inputs + .into_iter() + .filter(|api| api.main_trace_height() > 0), + ); } self = self.load(range_checker); // this must be last because other trace generation mutates its state } diff --git a/crates/vm/src/system/memory/adapter/mod.rs b/crates/vm/src/system/memory/adapter/mod.rs index 652d918cbc..ff015a19c8 100644 --- a/crates/vm/src/system/memory/adapter/mod.rs +++ b/crates/vm/src/system/memory/adapter/mod.rs @@ -7,7 +7,7 @@ use openvm_circuit_primitives::{ is_less_than::IsLtSubAir, utils::next_power_of_two_or_zero, var_range::VariableRangeCheckerChip, TraceSubRowGenerator, }; -use openvm_circuit_primitives_derive::ChipUsageGetter; +use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig, Val}, p3_air::BaseAir, @@ -31,6 +31,7 @@ mod tests; #[derive(Debug, Clone)] pub struct AccessAdapterInventory { chips: Vec>, + air_names: Vec, } impl AccessAdapterInventory { @@ -44,19 +45,19 @@ impl AccessAdapterInventory { let mb = memory_bus; let cmb = clk_max_bits; let maan = max_access_adapter_n; - Self { - chips: [ - Self::create_access_adapter_chip::<2>(rc.clone(), mb, cmb, maan), - Self::create_access_adapter_chip::<4>(rc.clone(), mb, cmb, maan), - Self::create_access_adapter_chip::<8>(rc.clone(), mb, cmb, maan), - Self::create_access_adapter_chip::<16>(rc.clone(), mb, cmb, maan), - Self::create_access_adapter_chip::<32>(rc.clone(), mb, cmb, maan), - Self::create_access_adapter_chip::<64>(rc.clone(), mb, cmb, maan), - ] - .into_iter() - .flatten() - .collect(), - } + let chips: Vec<_> = [ + Self::create_access_adapter_chip::<2>(rc.clone(), mb, cmb, maan), + Self::create_access_adapter_chip::<4>(rc.clone(), mb, cmb, maan), + Self::create_access_adapter_chip::<8>(rc.clone(), mb, cmb, maan), + Self::create_access_adapter_chip::<16>(rc.clone(), mb, cmb, maan), + Self::create_access_adapter_chip::<32>(rc.clone(), mb, cmb, maan), + Self::create_access_adapter_chip::<64>(rc.clone(), mb, cmb, maan), + ] + .into_iter() + .flatten() + .collect(); + let air_names = (0..chips.len()).map(|i| air_name(1 << (i + 1))).collect(); + Self { chips, air_names } } pub fn num_access_adapters(&self) -> usize { self.chips.len() @@ -80,9 +81,16 @@ impl AccessAdapterInventory { .map(|chip| chip.current_trace_height()) .collect() } + #[allow(dead_code)] pub fn get_widths(&self) -> Vec { self.chips.iter().map(|chip| chip.trace_width()).collect() } + pub fn get_cells(&self) -> Vec { + self.chips + .iter() + .map(|chip| chip.current_trace_cells()) + .collect() + } pub fn airs(&self) -> Vec>> where F: PrimeField32, @@ -90,23 +98,16 @@ impl AccessAdapterInventory { { self.chips.iter().map(|chip| chip.air()).collect() } - pub fn generate_traces(self) -> Vec> - where - F: PrimeField32, - { - self.chips - .into_par_iter() - .map(|chip| chip.generate_trace()) - .collect() + pub fn air_names(&self) -> Vec { + self.air_names.clone() } - #[allow(dead_code)] - pub fn generate_air_proof_input(self) -> Vec> + pub fn generate_air_proof_inputs(self) -> Vec> where F: PrimeField32, Domain: PolynomialSpace, { self.chips - .into_par_iter() + .into_iter() .map(|chip| chip.generate_air_proof_input()) .collect() } @@ -157,8 +158,9 @@ pub trait GenericAccessAdapterChipTrait { F: PrimeField32; } -#[derive(Debug, Clone, ChipUsageGetter)] +#[derive(Debug, Clone, Chip, ChipUsageGetter)] #[enum_dispatch(GenericAccessAdapterChipTrait)] +#[chip(where = "F: PrimeField32")] enum GenericAccessAdapterChip { N2(AccessAdapterChip), N4(AccessAdapterChip), @@ -168,33 +170,6 @@ enum GenericAccessAdapterChip { N64(AccessAdapterChip), } -impl Chip for GenericAccessAdapterChip> -where - Val: PrimeField32, -{ - fn air(&self) -> Arc> { - match self { - GenericAccessAdapterChip::N2(chip) => chip.air(), - GenericAccessAdapterChip::N4(chip) => chip.air(), - GenericAccessAdapterChip::N8(chip) => chip.air(), - GenericAccessAdapterChip::N16(chip) => chip.air(), - GenericAccessAdapterChip::N32(chip) => chip.air(), - GenericAccessAdapterChip::N64(chip) => chip.air(), - } - } - - fn generate_air_proof_input(self) -> AirProofInput { - match self { - GenericAccessAdapterChip::N2(chip) => chip.generate_air_proof_input(), - GenericAccessAdapterChip::N4(chip) => chip.generate_air_proof_input(), - GenericAccessAdapterChip::N8(chip) => chip.generate_air_proof_input(), - GenericAccessAdapterChip::N16(chip) => chip.generate_air_proof_input(), - GenericAccessAdapterChip::N32(chip) => chip.generate_air_proof_input(), - GenericAccessAdapterChip::N64(chip) => chip.generate_air_proof_input(), - } - } -} - impl GenericAccessAdapterChip { fn new( range_checker: Arc, @@ -313,7 +288,7 @@ where impl ChipUsageGetter for AccessAdapterChip { fn air_name(&self) -> String { - format!("AccessAdapter<{}>", N) + air_name(N) } fn current_trace_height(&self) -> usize { @@ -324,3 +299,8 @@ impl ChipUsageGetter for AccessAdapterChip { BaseAir::::width(&self.air) } } + +#[inline] +fn air_name(n: usize) -> String { + format!("AccessAdapter<{}>", n) +} diff --git a/crates/vm/src/system/memory/manager/interface.rs b/crates/vm/src/system/memory/manager/interface.rs index 1ef03726bb..a3a69d8b1a 100644 --- a/crates/vm/src/system/memory/manager/interface.rs +++ b/crates/vm/src/system/memory/manager/interface.rs @@ -7,6 +7,7 @@ use crate::system::memory::{ Equipartition, CHUNK, }; +#[allow(clippy::large_enum_variant)] #[derive(Debug)] pub enum MemoryInterface { Volatile { diff --git a/crates/vm/src/system/memory/manager/mod.rs b/crates/vm/src/system/memory/manager/mod.rs index 33cb989b49..595257a3c4 100644 --- a/crates/vm/src/system/memory/manager/mod.rs +++ b/crates/vm/src/system/memory/manager/mod.rs @@ -9,7 +9,6 @@ use std::{ }; use getset::Getters; -use itertools::{izip, zip_eq}; pub use memory::{MemoryReadRecord, MemoryWriteRecord}; use openvm_circuit_primitives::{ assert_less_than::{AssertLtSubAir, LessThanAuxCols}, @@ -21,13 +20,13 @@ use openvm_circuit_primitives::{ use openvm_instructions::exe::MemoryImage; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig}, - p3_air::BaseAir, p3_commit::PolynomialSpace, p3_field::PrimeField32, - p3_matrix::dense::RowMajorMatrix, + p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator}, p3_util::log2_strict_usize, prover::types::AirProofInput, rap::AnyRap, + Chip, ChipUsageGetter, }; use serde::{Deserialize, Serialize}; @@ -66,12 +65,6 @@ pub struct TimestampedValues { pub values: [T; N], } -#[derive(Clone, Debug)] -pub struct MemoryControllerResult { - traces: Vec>, - public_values: Vec>, -} - pub type MemoryControllerRef = Rc>>; /// A equipartition of memory, with timestamps and values. @@ -106,11 +99,26 @@ pub struct MemoryController { memory: Memory, access_adapters: AccessAdapterInventory, - /// If set, the height of the traces will be overridden. - overridden_heights: Option, // Filled during finalization. - result: Option>, + final_state: Option>, +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +enum FinalState { + Volatile(VolatileFinalState), + #[allow(dead_code)] + Persistent(PersistentFinalState), +} +#[derive(Debug, Default)] +struct VolatileFinalState { + _marker: PhantomData, +} +#[allow(dead_code)] +#[derive(Debug)] +struct PersistentFinalState { + final_memory: Equipartition, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -236,8 +244,7 @@ impl MemoryController { ), range_checker, range_checker_bus, - result: None, - overridden_heights: None, + final_state: None, } } @@ -279,29 +286,34 @@ impl MemoryController { ), range_checker, range_checker_bus, - result: None, - overridden_heights: None, + final_state: None, } } pub fn set_override_trace_heights(&mut self, overridden_heights: MemoryTraceHeights) { - match &self.interface_chip { - MemoryInterface::Volatile { .. } => match &overridden_heights { + match &mut self.interface_chip { + MemoryInterface::Volatile { boundary_chip } => match overridden_heights { MemoryTraceHeights::Volatile(oh) => { + boundary_chip.set_overridden_height(oh.boundary); self.access_adapters - .set_override_trace_heights(oh.access_adapters.clone()); + .set_override_trace_heights(oh.access_adapters); } _ => panic!("Expect overridden_heights to be MemoryTraceHeights::Volatile"), }, - MemoryInterface::Persistent { .. } => match &overridden_heights { + MemoryInterface::Persistent { + boundary_chip, + merkle_chip, + .. + } => match overridden_heights { MemoryTraceHeights::Persistent(oh) => { + boundary_chip.set_overridden_height(oh.boundary); + merkle_chip.set_overridden_height(oh.merkle); self.access_adapters - .set_override_trace_heights(oh.access_adapters.clone()); + .set_override_trace_heights(oh.access_adapters); } _ => panic!("Expect overridden_heights to be MemoryTraceHeights::Persistent"), }, } - self.overridden_heights = Some(overridden_heights); } pub fn set_initial_memory(&mut self, memory: Equipartition) { @@ -451,27 +463,15 @@ impl MemoryController { &mut self, hasher: Option<&mut impl HasherChip>, ) -> Option> { - if self.result.is_some() { + if self.final_state.is_some() { panic!("Cannot finalize more than once"); } - let mut traces = vec![]; - let mut pvs = vec![]; let (records, final_memory) = match &mut self.interface_chip { MemoryInterface::Volatile { boundary_chip } => { - let overridden_heights = self.overridden_heights.as_ref().map(|oh| match oh { - MemoryTraceHeights::Volatile(oh) => oh, - _ => unreachable!(), - }); let (final_memory, records) = self.memory.finalize::<1>(); - debug_assert_eq!(traces.len(), BOUNDARY_AIR_OFFSET); - traces.push( - boundary_chip - .generate_trace(&final_memory, overridden_heights.map(|oh| oh.boundary)), - ); - debug_assert_eq!(pvs.len(), BOUNDARY_AIR_OFFSET); - pvs.push(vec![]); - + boundary_chip.finalize(final_memory); + self.final_state = Some(FinalState::Volatile(VolatileFinalState::default())); (records, None) } MemoryInterface::Persistent { @@ -479,45 +479,24 @@ impl MemoryController { boundary_chip, initial_memory, } => { - let overridden_heights = self.overridden_heights.as_ref().map(|oh| match oh { - MemoryTraceHeights::Persistent(oh) => oh, - _ => unreachable!(), - }); let hasher = hasher.unwrap(); - let (final_partition, records) = self.memory.finalize::<8>(); - traces.push(boundary_chip.generate_trace( - initial_memory, - &final_partition, - hasher, - overridden_heights.map(|oh| oh.boundary), - )); - pvs.push(vec![]); - + let (final_partition, records) = self.memory.finalize::(); + boundary_chip.finalize(initial_memory, &final_partition, hasher); let final_memory_values = final_partition - .iter() - .map(|(key, value)| (*key, value.values)) + .into_par_iter() + .map(|(key, value)| (key, value.values)) .collect(); - let initial_node = MemoryNode::tree_from_memory( merkle_chip.air.memory_dimensions, initial_memory, hasher, ); - let (expand_trace, final_node) = merkle_chip.generate_trace_and_final_tree( - &initial_node, - &final_memory_values, - hasher, - overridden_heights.map(|oh| oh.merkle), - ); - - debug_assert_eq!(traces.len(), MERKLE_AIR_OFFSET); - traces.push(expand_trace); - let mut expand_pvs = vec![]; - expand_pvs.extend(initial_node.hash()); - expand_pvs.extend(final_node.hash()); - debug_assert_eq!(pvs.len(), MERKLE_AIR_OFFSET); - pvs.push(expand_pvs); + merkle_chip.finalize(&initial_node, &final_memory_values, hasher); + self.final_state = Some(FinalState::Persistent(PersistentFinalState { + final_memory: final_memory_values.clone(), + })); + // FIXME: avoid clone here. (records, Some(final_memory_values)) } }; @@ -525,17 +504,6 @@ impl MemoryController { self.access_adapters.add_record(record); } - // FIXME: avoid clone. - let aa_traces = self.access_adapters.clone().generate_traces(); - let aa_pvs = vec![vec![]; aa_traces.len()]; - traces.extend(aa_traces); - pvs.extend(aa_pvs); - - self.result = Some(MemoryControllerResult { - traces, - public_values: pvs, - }); - final_memory } @@ -543,18 +511,30 @@ impl MemoryController { where Domain: PolynomialSpace, { - let airs = self.airs(); - let MemoryControllerResult { - traces, - public_values, - } = self.result.unwrap(); - izip!(airs, traces, public_values) - .map(|(air, trace, pvs)| AirProofInput::simple(air, trace, pvs)) - .collect() - } + let mut ret = Vec::new(); - pub fn generate_traces(self) -> Vec> { - self.result.unwrap().traces + let Self { + interface_chip, + access_adapters, + .. + } = self; + match interface_chip { + MemoryInterface::Volatile { boundary_chip } => { + ret.push(boundary_chip.generate_air_proof_input()); + } + MemoryInterface::Persistent { + merkle_chip, + boundary_chip, + .. + } => { + debug_assert_eq!(ret.len(), BOUNDARY_AIR_OFFSET); + ret.push(boundary_chip.generate_air_proof_input()); + debug_assert_eq!(ret.len(), MERKLE_AIR_OFFSET); + ret.push(merkle_chip.generate_air_proof_input()); + } + } + ret.extend(access_adapters.generate_air_proof_inputs()); + ret } pub fn airs(&self) -> Vec>> @@ -566,7 +546,7 @@ impl MemoryController { match &self.interface_chip { MemoryInterface::Volatile { boundary_chip } => { debug_assert_eq!(airs.len(), BOUNDARY_AIR_OFFSET); - airs.push(Arc::new(boundary_chip.air.clone())) + airs.push(boundary_chip.air()) } MemoryInterface::Persistent { boundary_chip, @@ -574,9 +554,9 @@ impl MemoryController { .. } => { debug_assert_eq!(airs.len(), BOUNDARY_AIR_OFFSET); - airs.push(Arc::new(boundary_chip.air.clone())); + airs.push(boundary_chip.air()); debug_assert_eq!(airs.len(), MERKLE_AIR_OFFSET); - airs.push(Arc::new(merkle_chip.air.clone())); + airs.push(merkle_chip.air()); } } airs.extend(self.access_adapters.airs()); @@ -590,11 +570,7 @@ impl MemoryController { if self.continuation_enabled() { num_airs += 1; } - for n in [2, 4, 8, 16, 32, 64] { - if self.mem_config.max_access_adapter_n >= n { - num_airs += 1; - } - } + num_airs += self.access_adapters.num_access_adapters(); num_airs } @@ -603,11 +579,7 @@ impl MemoryController { if self.continuation_enabled() { air_names.push("Merkle".to_string()); } - for n in [2, 4, 8, 16, 32, 64] { - if self.mem_config.max_access_adapter_n >= n { - air_names.push(format!("AccessAdapter<{}>", n)); - } - } + air_names.extend(self.access_adapters.air_names()); air_names } @@ -620,7 +592,7 @@ impl MemoryController { match &self.interface_chip { MemoryInterface::Volatile { boundary_chip } => { MemoryTraceHeights::Volatile(VolatileMemoryTraceHeights { - boundary: boundary_chip.current_height(), + boundary: boundary_chip.current_trace_height(), access_adapters, }) } @@ -629,8 +601,8 @@ impl MemoryController { merkle_chip, .. } => MemoryTraceHeights::Persistent(PersistentMemoryTraceHeights { - boundary: boundary_chip.current_height(), - merkle: merkle_chip.current_height(), + boundary: boundary_chip.current_trace_height(), + merkle: merkle_chip.current_trace_height(), access_adapters, }), } @@ -654,33 +626,23 @@ impl MemoryController { } } - fn trace_widths(&self) -> Vec { - let mut widths = vec![]; + pub fn current_trace_cells(&self) -> Vec { + let mut ret = Vec::new(); match &self.interface_chip { MemoryInterface::Volatile { boundary_chip } => { - widths.push(BaseAir::::width(&boundary_chip.air)); + ret.push(boundary_chip.current_trace_cells()) } MemoryInterface::Persistent { boundary_chip, merkle_chip, .. } => { - widths.push(BaseAir::::width(&boundary_chip.air)); - widths.push(BaseAir::::width(&merkle_chip.air)); + ret.push(boundary_chip.current_trace_cells()); + ret.push(merkle_chip.current_trace_cells()); } - }; - widths.extend(self.access_adapters.get_widths()); - widths - } - - pub fn current_trace_cells(&self) -> Vec { - zip_eq(self.current_trace_heights(), self.trace_widths()) - .map(|(h, w)| h * w) - .collect() - } - - pub fn generate_public_values_per_air(&self) -> Vec> { - self.result.as_ref().unwrap().public_values.clone() + } + ret.extend(self.access_adapters.get_cells()); + ret } } diff --git a/crates/vm/src/system/memory/merkle/mod.rs b/crates/vm/src/system/memory/merkle/mod.rs index 654c86b03b..ad01e25c8b 100644 --- a/crates/vm/src/system/memory/merkle/mod.rs +++ b/crates/vm/src/system/memory/merkle/mod.rs @@ -1,5 +1,3 @@ -use std::marker::PhantomData; - use openvm_stark_backend::p3_field::PrimeField32; use rustc_hash::FxHashSet; @@ -21,7 +19,14 @@ pub struct MemoryMerkleChip { pub air: MemoryMerkleAir, touched_nodes: FxHashSet<(usize, usize, usize)>, num_touched_nonleaves: usize, - _marker: PhantomData, + final_state: Option>, + overridden_height: Option, +} +#[derive(Debug)] +struct FinalState { + rows: Vec>, + init_root: [F; CHUNK], + final_root: [F; CHUNK], } impl MemoryMerkleChip { @@ -43,9 +48,13 @@ impl MemoryMerkleChip { }, touched_nodes, num_touched_nonleaves: 1, - _marker: PhantomData, + final_state: None, + overridden_height: None, } } + pub fn set_overridden_height(&mut self, override_height: usize) { + self.overridden_height = Some(override_height); + } fn touch_node(&mut self, height: usize, as_label: usize, address_label: usize) { if self.touched_nodes.insert((height, as_label, address_label)) { @@ -68,8 +77,4 @@ impl MemoryMerkleChip { (address.as_canonical_u32() as usize) / CHUNK, ); } - - pub fn current_height(&self) -> usize { - 2 * self.num_touched_nonleaves - } } diff --git a/crates/vm/src/system/memory/merkle/tests/mod.rs b/crates/vm/src/system/memory/merkle/tests/mod.rs index fc63937a3f..fb800c2d13 100644 --- a/crates/vm/src/system/memory/merkle/tests/mod.rs +++ b/crates/vm/src/system/memory/merkle/tests/mod.rs @@ -2,15 +2,18 @@ use std::{ array, borrow::BorrowMut, collections::{BTreeMap, BTreeSet, HashSet}, + sync::Arc, }; use openvm_stark_backend::{ interaction::InteractionType, p3_field::{AbstractField, PrimeField32}, p3_matrix::dense::RowMajorMatrix, + prover::types::AirProofInput, + Chip, ChipUsageGetter, }; use openvm_stark_sdk::{ - any_rap_arc_vec, config::baby_bear_poseidon2::BabyBearPoseidon2Engine, + config::baby_bear_poseidon2::BabyBearPoseidon2Engine, dummy_airs::interaction::dummy_interaction_air::DummyInteractionAir, engine::StarkFriEngine, p3_baby_bear::BabyBear, utils::create_seeded_rng, }; @@ -80,11 +83,13 @@ fn test( } } - println!("trace height = {}", chip.current_height()); - let (trace, final_tree) = - chip.generate_trace_and_final_tree(&initial_tree, final_memory, &mut hash_test_chip, None); - - assert_eq!(final_tree, final_tree_check); + println!("trace height = {}", chip.current_trace_height()); + chip.finalize(&initial_tree, final_memory, &mut hash_test_chip); + assert_eq!( + chip.final_state.as_ref().unwrap().final_root, + final_tree_check.hash() + ); + let chip_api = chip.generate_air_proof_input(); let dummy_interaction_air = DummyInteractionAir::new(4 + CHUNK, true, merkle_bus.0); let mut dummy_interaction_trace_rows = vec![]; @@ -145,17 +150,14 @@ fn test( dummy_interaction_trace_rows, dummy_interaction_air.field_width() + 1, ); - - let mut public_values = vec![vec![]; 3]; - public_values[0].extend(initial_tree.hash()); - public_values[0].extend(final_tree_check.hash()); - - let hash_test_chip_air = hash_test_chip.air(); - BabyBearPoseidon2Engine::run_simple_test_fast( - any_rap_arc_vec![chip.air, dummy_interaction_air, hash_test_chip_air], - vec![trace, dummy_interaction_trace, hash_test_chip.trace()], - public_values, - ) + let dummy_interaction_api = + AirProofInput::simple_no_pis(Arc::new(dummy_interaction_air), dummy_interaction_trace); + + BabyBearPoseidon2Engine::run_test_fast(vec![ + chip_api, + dummy_interaction_api, + hash_test_chip.generate_air_proof_input(), + ]) .expect("Verification failed"); } @@ -251,18 +253,11 @@ fn expand_test_no_accesses() { COMPRESSION_BUS, ); - let (trace, _) = chip.generate_trace_and_final_tree(&tree, &memory, &mut hash_test_chip, None); - - let mut public_values = vec![vec![]; 2]; - public_values[0].extend(tree.hash()); - public_values[0].extend(tree.hash()); - - let hash_test_chip_air = hash_test_chip.air(); - BabyBearPoseidon2Engine::run_simple_test_fast( - any_rap_arc_vec![chip.air, hash_test_chip_air], - vec![trace, hash_test_chip.trace()], - public_values, - ) + chip.finalize(&tree, &memory, &mut hash_test_chip); + BabyBearPoseidon2Engine::run_test_fast(vec![ + chip.generate_air_proof_input(), + hash_test_chip.generate_air_proof_input(), + ]) .expect("This should occur"); } @@ -290,25 +285,22 @@ fn expand_test_negative() { COMPRESSION_BUS, ); - let (mut trace, _) = - chip.generate_trace_and_final_tree(&tree, &memory, &mut hash_test_chip, None); - for row in trace.rows_mut() { - let row: &mut MemoryMerkleCols<_, DEFAULT_CHUNK> = row.borrow_mut(); - if row.expand_direction == BabyBear::NEG_ONE { - row.left_direction_different = BabyBear::ZERO; - row.right_direction_different = BabyBear::ZERO; + chip.finalize(&tree, &memory, &mut hash_test_chip); + let mut chip_api = chip.generate_air_proof_input(); + { + let trace = chip_api.raw.common_main.as_mut().unwrap(); + for row in trace.rows_mut() { + let row: &mut MemoryMerkleCols<_, DEFAULT_CHUNK> = row.borrow_mut(); + if row.expand_direction == BabyBear::NEG_ONE { + row.left_direction_different = BabyBear::ZERO; + row.right_direction_different = BabyBear::ZERO; + } } } - let mut public_values = vec![vec![]; 2]; - public_values[0].extend(tree.hash()); - public_values[0].extend(tree.hash()); - - let hash_test_chip_air = hash_test_chip.air(); - BabyBearPoseidon2Engine::run_simple_test_fast( - any_rap_arc_vec![chip.air, hash_test_chip_air], - vec![trace, hash_test_chip.trace()], - public_values, - ) + BabyBearPoseidon2Engine::run_test_fast(vec![ + chip_api, + hash_test_chip.generate_air_proof_input(), + ]) .expect("This should occur"); } diff --git a/crates/vm/src/system/memory/merkle/tests/util.rs b/crates/vm/src/system/memory/merkle/tests/util.rs index f5104fda98..3bd7a500e3 100644 --- a/crates/vm/src/system/memory/merkle/tests/util.rs +++ b/crates/vm/src/system/memory/merkle/tests/util.rs @@ -1,6 +1,13 @@ -use std::array::from_fn; +use std::{array::from_fn, sync::Arc}; -use openvm_stark_backend::{p3_air::BaseAir, p3_field::Field, p3_matrix::dense::RowMajorMatrix}; +use openvm_stark_backend::{ + config::{Domain, StarkGenericConfig}, + p3_air::BaseAir, + p3_commit::PolynomialSpace, + p3_field::Field, + p3_matrix::dense::RowMajorMatrix, + prover::types::AirProofInput, +}; use openvm_stark_sdk::dummy_airs::interaction::dummy_interaction_air::DummyInteractionAir; use crate::arch::{ @@ -40,6 +47,12 @@ impl HashTestChip { } RowMajorMatrix::new(rows, width) } + pub fn generate_air_proof_input(&self) -> AirProofInput + where + Domain: PolynomialSpace, + { + AirProofInput::simple_no_pis(Arc::new(self.air()), self.trace()) + } } impl Hasher for HashTestChip { diff --git a/crates/vm/src/system/memory/merkle/trace.rs b/crates/vm/src/system/memory/merkle/trace.rs index 6c9f21fa8d..8045e7fc92 100644 --- a/crates/vm/src/system/memory/merkle/trace.rs +++ b/crates/vm/src/system/memory/merkle/trace.rs @@ -1,26 +1,33 @@ use std::{borrow::BorrowMut, cmp::Reverse, sync::Arc}; -use openvm_stark_backend::{p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix}; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_field::{AbstractField, PrimeField32}, + p3_matrix::dense::RowMajorMatrix, + prover::types::AirProofInput, + rap::AnyRap, + Chip, ChipUsageGetter, +}; use rustc_hash::FxHashSet; use crate::{ arch::hasher::HasherChip, system::memory::{ manager::dimensions::MemoryDimensions, - merkle::{MemoryMerkleChip, MemoryMerkleCols}, + merkle::{FinalState, MemoryMerkleChip, MemoryMerkleCols}, tree::MemoryNode::{self, NonLeaf}, Equipartition, }, }; impl MemoryMerkleChip { - pub fn generate_trace_and_final_tree( + pub fn finalize( &mut self, initial_tree: &MemoryNode, final_memory: &Equipartition, hasher: &mut impl HasherChip, - overridden_height: Option, - ) -> (RowMajorMatrix, MemoryNode) { + ) { + assert!(self.final_state.is_none(), "Merkle chip already finalized"); // there needs to be a touched node with `height_section` = 0 // shouldn't be a leaf because // trace generation will expect an interaction from MemoryInterfaceChip in that case @@ -42,13 +49,41 @@ impl MemoryMerkleChip { 0, hasher, ); + self.final_state = Some(FinalState { + rows, + init_root: initial_tree.hash(), + final_root: final_tree.hash(), + }); + } +} + +impl Chip for MemoryMerkleChip> +where + Val: PrimeField32, +{ + fn air(&self) -> Arc> { + Arc::new(self.air.clone()) + } + + fn generate_air_proof_input(self) -> AirProofInput { + let air = Arc::new(self.air); + assert!( + self.final_state.is_some(), + "Merkle chip must finalize before trace generation" + ); + let FinalState { + mut rows, + init_root, + final_root, + } = self.final_state.unwrap(); // important that this sort be stable, // because we need the initial root to be first and the final root to be second + // TODO: do we only need find all height == 0 instead of sorting? rows.sort_by_key(|row| Reverse(row.parent_height)); - let width = MemoryMerkleCols::::width(); + let width = MemoryMerkleCols::, CHUNK>::width(); let mut height = rows.len().next_power_of_two(); - if let Some(mut oh) = overridden_height { + if let Some(mut oh) = self.overridden_height { oh = oh.next_power_of_two(); assert!( oh >= height, @@ -56,14 +91,28 @@ impl MemoryMerkleChip { ); height = oh; } - let mut trace = F::zero_vec(width * height); + let mut trace = Val::::zero_vec(width * height); for (trace_row, row) in trace.chunks_exact_mut(width).zip(rows) { *trace_row.borrow_mut() = row; } let trace = RowMajorMatrix::new(trace, width); - (trace, final_tree) + let pvs = init_root.into_iter().chain(final_root).collect(); + AirProofInput::simple(air, trace, pvs) + } +} +impl ChipUsageGetter for MemoryMerkleChip { + fn air_name(&self) -> String { + "Merkle".to_string() + } + + fn current_trace_height(&self) -> usize { + 2 * self.num_touched_nonleaves + } + + fn trace_width(&self) -> usize { + MemoryMerkleCols::::width() } } diff --git a/crates/vm/src/system/memory/persistent.rs b/crates/vm/src/system/memory/persistent.rs index e854e591af..57eb71d633 100644 --- a/crates/vm/src/system/memory/persistent.rs +++ b/crates/vm/src/system/memory/persistent.rs @@ -1,15 +1,22 @@ use std::{ borrow::{Borrow, BorrowMut}, iter, + sync::Arc, }; use openvm_circuit_primitives_derive::AlignedBorrow; +#[allow(unused_imports)] +use openvm_stark_backend::p3_maybe_rayon::prelude::IndexedParallelIterator; use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, BaseAir}, p3_field::{AbstractField, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, + p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator, ParallelSliceMut}, + prover::types::AirProofInput, + rap::{AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, + Chip, ChipUsageGetter, }; use rustc_hash::FxHashSet; @@ -115,7 +122,49 @@ impl Air for PersistentBoundaryA #[derive(Debug)] pub struct PersistentBoundaryChip { pub air: PersistentBoundaryAir, - touched_labels: FxHashSet<(F, usize)>, + touched_labels: TouchedLabels, + overridden_height: Option, +} + +#[derive(Debug)] +enum TouchedLabels { + Running(FxHashSet<(F, usize)>), + Final(Vec>), +} + +#[derive(Debug)] +struct FinalTouchedLabel { + address_space: F, + label: usize, + init_values: [F; CHUNK], + final_values: [F; CHUNK], + init_exists: bool, + init_hash: [F; CHUNK], + final_hash: [F; CHUNK], + final_timestamp: u32, +} + +impl Default for TouchedLabels { + fn default() -> Self { + Self::Running(FxHashSet::default()) + } +} + +impl TouchedLabels { + fn touch(&mut self, address_space: F, label: usize) { + match self { + TouchedLabels::Running(touched_labels) => { + touched_labels.insert((address_space, label)); + } + _ => panic!("Cannot touch after finalization"), + } + } + fn len(&self) -> usize { + match self { + TouchedLabels::Running(touched_labels) => touched_labels.len(), + TouchedLabels::Final(touched_labels) => touched_labels.len(), + } + } } impl PersistentBoundaryChip { @@ -132,78 +181,133 @@ impl PersistentBoundaryChip { merkle_bus, compression_bus, }, - touched_labels: FxHashSet::default(), + touched_labels: Default::default(), + overridden_height: None, } } - pub fn touch_address(&mut self, address_space: F, pointer: F) { - let label = pointer.as_canonical_u32() as usize / CHUNK; - self.touched_labels.insert((address_space, label)); + pub fn set_overridden_height(&mut self, overridden_height: usize) { + self.overridden_height = Some(overridden_height); } - pub fn current_height(&self) -> usize { - 2 * self.touched_labels.len() + pub fn touch_address(&mut self, address_space: F, pointer: F) { + let label = pointer.as_canonical_u32() as usize / CHUNK; + self.touched_labels.touch(address_space, label); } - pub fn generate_trace( - &self, + pub fn finalize( + &mut self, initial_memory: &Equipartition, final_memory: &TimestampedEquipartition, hasher: &mut impl HasherChip, - overridden_height: Option, - ) -> RowMajorMatrix { - let width = PersistentBoundaryCols::::width(); - // Boundary AIR should always present in order to fix the AIR ID of merkle AIR. - let mut height = (2 * self.touched_labels.len()).next_power_of_two(); - if let Some(mut oh) = overridden_height { - oh = oh.next_power_of_two(); - assert!( - oh >= height, - "Overridden height is less than the required height" - ); - height = oh; + ) { + match &mut self.touched_labels { + TouchedLabels::Running(touched_labels) => { + // TODO: parallelize this. + let final_touched_labels = touched_labels + .iter() + .map(|touched_label| { + let (init_exists, initial_hash, init_values) = + match initial_memory.get(touched_label) { + Some(values) => (true, hasher.hash_and_record(values), *values), + None => ( + true, + hasher.hash_and_record(&[F::ZERO; CHUNK]), + [F::ZERO; CHUNK], + ), + }; + let timestamped_values = final_memory.get(touched_label).unwrap(); + let final_hash = hasher.hash_and_record(×tamped_values.values); + FinalTouchedLabel { + address_space: touched_label.0, + label: touched_label.1, + init_values, + final_values: timestamped_values.values, + init_exists, + init_hash: initial_hash, + final_hash, + final_timestamp: timestamped_values.timestamp, + } + }) + .collect(); + self.touched_labels = TouchedLabels::Final(final_touched_labels); + } + _ => panic!("Cannot finalize after finalization"), } - let mut rows = F::zero_vec(height * width); - - for (row, &(address_space, label)) in - rows.chunks_mut(2 * width).zip(self.touched_labels.iter()) - { - let (initial_row, final_row) = row.split_at_mut(width); - *initial_row.borrow_mut() = match initial_memory.get(&(address_space, label)) { - Some(values) => { - let initial_hash = hasher.hash_and_record(values); - PersistentBoundaryCols { - expand_direction: F::ONE, - address_space, - leaf_label: F::from_canonical_usize(label), - values: *values, - hash: initial_hash, - timestamp: F::from_canonical_u32(INITIAL_TIMESTAMP), - } - } - None => { - let initial_hash = hasher.hash_and_record(&[F::ZERO; CHUNK]); - PersistentBoundaryCols { - expand_direction: F::ONE, - address_space, - leaf_label: F::from_canonical_usize(label), - values: [F::ZERO; CHUNK], - hash: initial_hash, - timestamp: F::ZERO, - } - } - }; - let timestamped_values = final_memory.get(&(address_space, label)).unwrap(); - let final_hash = hasher.hash_and_record(×tamped_values.values); - *final_row.borrow_mut() = PersistentBoundaryCols { - expand_direction: F::NEG_ONE, - address_space, - leaf_label: F::from_canonical_usize(label), - values: timestamped_values.values, - hash: final_hash, - timestamp: F::from_canonical_u32(timestamped_values.timestamp), + } +} + +impl Chip for PersistentBoundaryChip, CHUNK> +where + Val: PrimeField32, +{ + fn air(&self) -> Arc> { + Arc::new(self.air.clone()) + } + + fn generate_air_proof_input(self) -> AirProofInput { + let air = Arc::new(self.air); + let trace = { + let width = PersistentBoundaryCols::, CHUNK>::width(); + // Boundary AIR should always present in order to fix the AIR ID of merkle AIR. + let mut height = (2 * self.touched_labels.len()).next_power_of_two(); + if let Some(mut oh) = self.overridden_height { + oh = oh.next_power_of_two(); + assert!( + oh >= height, + "Overridden height is less than the required height" + ); + height = oh; + } + let mut rows = Val::::zero_vec(height * width); + + let touched_labels = match self.touched_labels { + TouchedLabels::Final(touched_labels) => touched_labels, + _ => panic!("Cannot generate trace before finalization"), }; - } - RowMajorMatrix::new(rows, width) + + rows.par_chunks_mut(2 * width) + .zip(touched_labels.into_par_iter()) + .for_each(|(row, touched_label)| { + let (initial_row, final_row) = row.split_at_mut(width); + *initial_row.borrow_mut() = PersistentBoundaryCols { + expand_direction: Val::::ONE, + address_space: touched_label.address_space, + leaf_label: Val::::from_canonical_usize(touched_label.label), + values: touched_label.init_values, + hash: touched_label.init_hash, + timestamp: if touched_label.init_exists { + Val::::from_canonical_u32(INITIAL_TIMESTAMP) + } else { + Val::::ZERO + }, + }; + + *final_row.borrow_mut() = PersistentBoundaryCols { + expand_direction: Val::::NEG_ONE, + address_space: touched_label.address_space, + leaf_label: Val::::from_canonical_usize(touched_label.label), + values: touched_label.final_values, + hash: touched_label.final_hash, + timestamp: Val::::from_canonical_u32(touched_label.final_timestamp), + }; + }); + RowMajorMatrix::new(rows, width) + }; + AirProofInput::simple_no_pis(air, trace) + } +} + +impl ChipUsageGetter for PersistentBoundaryChip { + fn air_name(&self) -> String { + "Boundary".to_string() + } + + fn current_trace_height(&self) -> usize { + 2 * self.touched_labels.len() + } + + fn trace_width(&self) -> usize { + PersistentBoundaryCols::::width() } } diff --git a/crates/vm/src/system/memory/volatile/mod.rs b/crates/vm/src/system/memory/volatile/mod.rs index b29532d975..04e0c68579 100644 --- a/crates/vm/src/system/memory/volatile/mod.rs +++ b/crates/vm/src/system/memory/volatile/mod.rs @@ -14,12 +14,15 @@ use openvm_circuit_primitives::{ }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{AbstractField, Field, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, p3_maybe_rayon::prelude::*, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, + prover::types::AirProofInput, + rap::{AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, + Chip, ChipUsageGetter, }; use super::TimestampedEquipartition; @@ -132,6 +135,8 @@ pub struct VolatileBoundaryChip { pub air: VolatileBoundaryAir, touched_addresses: HashSet<(F, F)>, range_checker: Arc, + overridden_height: Option, + final_memory: Option>, } impl VolatileBoundaryChip { @@ -151,6 +156,8 @@ impl VolatileBoundaryChip { ), touched_addresses: HashSet::new(), range_checker, + overridden_height: None, + final_memory: None, } } @@ -161,21 +168,36 @@ impl VolatileBoundaryChip { pub fn all_addresses(&self) -> Vec<(F, F)> { self.touched_addresses.iter().cloned().collect() } - - pub fn current_height(&self) -> usize { - self.touched_addresses.len() - } } impl VolatileBoundaryChip { + pub fn set_overridden_height(&mut self, overridden_height: usize) { + self.overridden_height = Some(overridden_height); + } /// Volatile memory requires the starting and final memory to be in equipartition with block size `1`. /// When block size is `1`, then the `label` is the same as the address pointer. - pub fn generate_trace( - &self, - final_memory: &TimestampedEquipartition, - overridden_height: Option, - ) -> RowMajorMatrix { - let trace_height = if let Some(height) = overridden_height { + pub fn finalize(&mut self, final_memory: TimestampedEquipartition) { + self.final_memory = Some(final_memory); + } +} + +impl Chip for VolatileBoundaryChip> +where + Val: PrimeField32, +{ + fn air(&self) -> Arc> { + Arc::new(self.air.clone()) + } + + fn generate_air_proof_input(self) -> AirProofInput { + // Volatile memory requires the starting and final memory to be in equipartition with block size `1`. + // When block size is `1`, then the `label` is the same as the address pointer. + let width = self.trace_width(); + let air = Arc::new(self.air); + let final_memory = self + .final_memory + .expect("Trace generation should be called after finalize"); + let trace_height = if let Some(height) = self.overridden_height { assert!( height >= final_memory.len(), "Overridden height is less than the required height" @@ -184,65 +206,71 @@ impl VolatileBoundaryChip { } else { final_memory.len() }; - self.generate_trace_with_height(final_memory, trace_height.next_power_of_two()) - } - - fn generate_trace_with_height( - &self, - final_memory: &TimestampedEquipartition, - trace_height: usize, - ) -> RowMajorMatrix { - assert!(trace_height.is_power_of_two()); - let width = BaseAir::::width(&self.air); + let trace_height = trace_height.next_power_of_two(); // Collect into Vec to sort from BTreeMap and also so we can look at adjacent entries - let sorted_final_memory: Vec<_> = final_memory.iter().collect(); - assert!(sorted_final_memory.len() <= trace_height); + let sorted_final_memory: Vec<_> = final_memory.into_par_iter().collect(); + let memory_len = sorted_final_memory.len(); - let mut rows = F::zero_vec(trace_height * width); + let mut rows = Val::::zero_vec(trace_height * width); rows.par_chunks_mut(width) - .zip(&sorted_final_memory) + .zip(sorted_final_memory.par_iter()) .enumerate() .for_each(|(i, (row, ((addr_space, ptr), timestamped_values)))| { // `pointer` is the same as `label` since the equipartition has block size 1 let [data] = timestamped_values.values; let row: &mut VolatileBoundaryCols<_> = row.borrow_mut(); row.addr_space = *addr_space; - row.pointer = F::from_canonical_usize(*ptr); - row.initial_data = F::ZERO; + row.pointer = Val::::from_canonical_usize(*ptr); + row.initial_data = Val::::ZERO; row.final_data = data; - row.final_timestamp = F::from_canonical_u32(timestamped_values.timestamp); - row.is_valid = F::ONE; + row.final_timestamp = Val::::from_canonical_u32(timestamped_values.timestamp); + row.is_valid = Val::::ONE; // If next.is_valid == 1: - if i != sorted_final_memory.len() - 1 { - let (next_addr_space, next_ptr) = *sorted_final_memory[i + 1].0; - let mut out = F::ZERO; - self.air.addr_lt_air.0.generate_subrow( + if i != memory_len - 1 { + let (next_addr_space, next_ptr) = sorted_final_memory[i + 1].0; + let mut out = Val::::ZERO; + air.addr_lt_air.0.generate_subrow( ( &self.range_checker, &[row.addr_space, row.pointer], - &[next_addr_space, F::from_canonical_usize(next_ptr)], + &[next_addr_space, Val::::from_canonical_usize(next_ptr)], ), ((&mut row.addr_lt_aux).into(), &mut out), ); - debug_assert_eq!(out, F::ONE, "Addresses are not sorted"); + debug_assert_eq!(out, Val::::ONE, "Addresses are not sorted"); } }); // Always do a dummy range check on the last row due to wraparound - if !sorted_final_memory.is_empty() { - let mut out = F::ZERO; + if memory_len > 0 { + let mut out = Val::::ZERO; let row: &mut VolatileBoundaryCols<_> = rows[width * (trace_height - 1)..].borrow_mut(); - self.air.addr_lt_air.0.generate_subrow( + air.addr_lt_air.0.generate_subrow( ( &self.range_checker, - &[F::ZERO, F::ZERO], - &[F::ZERO, F::ZERO], + &[Val::::ZERO, Val::::ZERO], + &[Val::::ZERO, Val::::ZERO], ), ((&mut row.addr_lt_aux).into(), &mut out), ); } - RowMajorMatrix::new(rows, width) + let trace = RowMajorMatrix::new(rows, width); + AirProofInput::simple_no_pis(air, trace) + } +} + +impl ChipUsageGetter for VolatileBoundaryChip { + fn air_name(&self) -> String { + "Boundary".to_string() + } + + fn current_trace_height(&self) -> usize { + self.touched_addresses.len() + } + + fn trace_width(&self) -> usize { + VolatileBoundaryCols::::width() } } diff --git a/crates/vm/src/system/memory/volatile/tests.rs b/crates/vm/src/system/memory/volatile/tests.rs index 6b00e01a9b..eaaf3bb674 100644 --- a/crates/vm/src/system/memory/volatile/tests.rs +++ b/crates/vm/src/system/memory/volatile/tests.rs @@ -3,12 +3,16 @@ use std::{collections::HashSet, iter, sync::Arc}; use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; use openvm_stark_backend::{ p3_field::{AbstractField, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_matrix::dense::RowMajorMatrix, + prover::types::AirProofInput, + Chip, }; use openvm_stark_sdk::{ - any_rap_arc_vec, config::baby_bear_poseidon2::BabyBearPoseidon2Engine, - dummy_airs::interaction::dummy_interaction_air::DummyInteractionAir, engine::StarkFriEngine, - p3_baby_bear::BabyBear, utils::create_seeded_rng, + config::baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, + dummy_airs::interaction::dummy_interaction_air::DummyInteractionAir, + engine::StarkFriEngine, + p3_baby_bear::BabyBear, + utils::create_seeded_rng, }; use rand::Rng; use test_log::test; @@ -42,7 +46,8 @@ fn boundary_air_test() { let range_bus = VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, DECOMP); let range_checker = Arc::new(VariableRangeCheckerChip::new(range_bus)); - let boundary_chip = VolatileBoundaryChip::new(memory_bus, 2, LIMB_BITS, range_checker.clone()); + let mut boundary_chip = + VolatileBoundaryChip::new(memory_bus, 2, LIMB_BITS, range_checker.clone()); let mut final_memory = TimestampedEquipartition::new(); @@ -104,35 +109,30 @@ fn boundary_air_test() { 6, ); - let boundary_trace = boundary_chip.generate_trace(&final_memory, None); + boundary_chip.finalize(final_memory.clone()); + let boundary_api: AirProofInput = + boundary_chip.generate_air_proof_input(); // test trace height override { - let overridden_height = boundary_trace.height() * 2; + let overridden_height = boundary_api.main_trace_height() * 2; let range_checker = Arc::new(VariableRangeCheckerChip::new(range_bus)); - let boundary_chip = + let mut boundary_chip = VolatileBoundaryChip::new(memory_bus, 2, LIMB_BITS, range_checker.clone()); - let boundary_trace = boundary_chip.generate_trace(&final_memory, Some(overridden_height)); + boundary_chip.set_overridden_height(overridden_height); + boundary_chip.finalize(final_memory.clone()); + let boundary_api: AirProofInput = + boundary_chip.generate_air_proof_input(); assert_eq!( - boundary_trace.height(), + boundary_api.main_trace_height(), overridden_height.next_power_of_two() ); } - let range_checker_trace = range_checker.generate_trace(); - - BabyBearPoseidon2Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![ - boundary_chip.air, - range_checker.air, - init_memory_dummy_air, - final_memory_dummy_air - ], - vec![ - boundary_trace, - range_checker_trace, - init_memory_trace, - final_memory_trace, - ], - ) + BabyBearPoseidon2Engine::run_test_fast(vec![ + boundary_api, + range_checker.generate_air_proof_input(), + AirProofInput::simple_no_pis(Arc::new(init_memory_dummy_air), init_memory_trace), + AirProofInput::simple_no_pis(Arc::new(final_memory_dummy_air), final_memory_trace), + ]) .expect("Verification failed"); }