Skip to content

Commit

Permalink
[refactor] Implement Chip for All Memory Chips (#1038)
Browse files Browse the repository at this point in the history
* Implement Chip for all memory chips

* Parallize Map transformation
  • Loading branch information
nyunyunyunyu authored Dec 15, 2024
1 parent 65494db commit 8eb4543
Show file tree
Hide file tree
Showing 12 changed files with 548 additions and 392 deletions.
36 changes: 34 additions & 2 deletions crates/circuits/primitives/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ extern crate proc_macro;
use itertools::multiunzip;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam};
use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam, LitStr, Meta};

#[proc_macro_derive(AlignedBorrow)]
pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -72,8 +72,9 @@ pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream {
TokenStream::from(methods)
}

#[proc_macro_derive(Chip)]
#[proc_macro_derive(Chip, attributes(chip))]
pub fn chip_derive(input: TokenStream) -> TokenStream {
// Parse the attributes from the struct or enum
let ast: syn::DeriveInput = syn::parse(input).unwrap();

let name = &ast.ident;
Expand Down Expand Up @@ -160,6 +161,37 @@ pub fn chip_derive(input: TokenStream) -> TokenStream {
let where_clause = new_generics.make_where_clause();
where_clause.predicates.push(syn::parse_quote! { openvm_stark_backend::config::Domain<SC>: openvm_stark_backend::p3_commit::PolynomialSpace<Val = F>
});
let attributes = ast.attrs.iter().find(|&attr| attr.path().is_ident("chip"));
if let Some(attr) = attributes {
let mut fail_flag = false;

match &attr.meta {
Meta::List(meta_list) => {
meta_list
.parse_nested_meta(|meta| {
if meta.path.is_ident("where") {
let value = meta.value()?; // this parses the `=`
let s: LitStr = value.parse()?;
let where_value = s.value();
where_clause.predicates.push(syn::parse_str(&where_value)?);
} else {
fail_flag = true;
}
Ok(())
})
.unwrap();
}
_ => fail_flag = true,
}
if fail_flag {
return syn::Error::new(
name.span(),
"Only `#[chip(where = ...)]` format is supported",
)
.to_compile_error()
.into();
}
}

quote! {
impl #impl_generics openvm_stark_backend::Chip<SC> for #name #ty_generics #where_clause {
Expand Down
26 changes: 8 additions & 18 deletions crates/vm/src/arch/testing/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
use std::{cell::RefCell, rc::Rc, sync::Arc};

use itertools::izip;
use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip};
use openvm_instructions::instruction::Instruction;
use openvm_stark_backend::{
config::{StarkGenericConfig, Val},
engine::VerificationData,
p3_field::PrimeField32,
p3_matrix::{
dense::{DenseMatrix, RowMajorMatrix},
Matrix,
},
p3_matrix::dense::{DenseMatrix, RowMajorMatrix},
prover::types::AirProofInput,
verifier::VerificationError,
Chip,
Expand Down Expand Up @@ -267,21 +263,15 @@ where
let range_checker = memory_controller.borrow().range_checker.clone();
self = self.load(memory_tester); // dummy memory interactions
{
let memory = memory_controller.borrow();
let public_values = memory.generate_public_values_per_air();
let airs = memory.airs();
drop(memory);
let traces = Rc::try_unwrap(memory_controller)
let air_proof_inputs = Rc::try_unwrap(memory_controller)
.unwrap()
.into_inner()
.generate_traces();

for (pvs, air, trace) in izip!(public_values, airs, traces) {
if trace.height() > 0 {
self.air_proof_inputs
.push(AirProofInput::simple(air, trace, pvs));
}
}
.generate_air_proof_inputs();
self.air_proof_inputs.extend(
air_proof_inputs
.into_iter()
.filter(|api| api.main_trace_height() > 0),
);
}
self = self.load(range_checker); // this must be last because other trace generation mutates its state
}
Expand Down
88 changes: 34 additions & 54 deletions crates/vm/src/system/memory/adapter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use openvm_circuit_primitives::{
is_less_than::IsLtSubAir, utils::next_power_of_two_or_zero,
var_range::VariableRangeCheckerChip, TraceSubRowGenerator,
};
use openvm_circuit_primitives_derive::ChipUsageGetter;
use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
use openvm_stark_backend::{
config::{Domain, StarkGenericConfig, Val},
p3_air::BaseAir,
Expand All @@ -31,6 +31,7 @@ mod tests;
#[derive(Debug, Clone)]
pub struct AccessAdapterInventory<F> {
chips: Vec<GenericAccessAdapterChip<F>>,
air_names: Vec<String>,
}

impl<F> AccessAdapterInventory<F> {
Expand All @@ -44,19 +45,19 @@ impl<F> AccessAdapterInventory<F> {
let mb = memory_bus;
let cmb = clk_max_bits;
let maan = max_access_adapter_n;
Self {
chips: [
Self::create_access_adapter_chip::<2>(rc.clone(), mb, cmb, maan),
Self::create_access_adapter_chip::<4>(rc.clone(), mb, cmb, maan),
Self::create_access_adapter_chip::<8>(rc.clone(), mb, cmb, maan),
Self::create_access_adapter_chip::<16>(rc.clone(), mb, cmb, maan),
Self::create_access_adapter_chip::<32>(rc.clone(), mb, cmb, maan),
Self::create_access_adapter_chip::<64>(rc.clone(), mb, cmb, maan),
]
.into_iter()
.flatten()
.collect(),
}
let chips: Vec<_> = [
Self::create_access_adapter_chip::<2>(rc.clone(), mb, cmb, maan),
Self::create_access_adapter_chip::<4>(rc.clone(), mb, cmb, maan),
Self::create_access_adapter_chip::<8>(rc.clone(), mb, cmb, maan),
Self::create_access_adapter_chip::<16>(rc.clone(), mb, cmb, maan),
Self::create_access_adapter_chip::<32>(rc.clone(), mb, cmb, maan),
Self::create_access_adapter_chip::<64>(rc.clone(), mb, cmb, maan),
]
.into_iter()
.flatten()
.collect();
let air_names = (0..chips.len()).map(|i| air_name(1 << (i + 1))).collect();
Self { chips, air_names }
}
pub fn num_access_adapters(&self) -> usize {
self.chips.len()
Expand All @@ -80,33 +81,33 @@ impl<F> AccessAdapterInventory<F> {
.map(|chip| chip.current_trace_height())
.collect()
}
#[allow(dead_code)]
pub fn get_widths(&self) -> Vec<usize> {
self.chips.iter().map(|chip| chip.trace_width()).collect()
}
pub fn get_cells(&self) -> Vec<usize> {
self.chips
.iter()
.map(|chip| chip.current_trace_cells())
.collect()
}
pub fn airs<SC: StarkGenericConfig>(&self) -> Vec<Arc<dyn AnyRap<SC>>>
where
F: PrimeField32,
Domain<SC>: PolynomialSpace<Val = F>,
{
self.chips.iter().map(|chip| chip.air()).collect()
}
pub fn generate_traces(self) -> Vec<RowMajorMatrix<F>>
where
F: PrimeField32,
{
self.chips
.into_par_iter()
.map(|chip| chip.generate_trace())
.collect()
pub fn air_names(&self) -> Vec<String> {
self.air_names.clone()
}
#[allow(dead_code)]
pub fn generate_air_proof_input<SC: StarkGenericConfig>(self) -> Vec<AirProofInput<SC>>
pub fn generate_air_proof_inputs<SC: StarkGenericConfig>(self) -> Vec<AirProofInput<SC>>
where
F: PrimeField32,
Domain<SC>: PolynomialSpace<Val = F>,
{
self.chips
.into_par_iter()
.into_iter()
.map(|chip| chip.generate_air_proof_input())
.collect()
}
Expand Down Expand Up @@ -157,8 +158,9 @@ pub trait GenericAccessAdapterChipTrait<F> {
F: PrimeField32;
}

#[derive(Debug, Clone, ChipUsageGetter)]
#[derive(Debug, Clone, Chip, ChipUsageGetter)]
#[enum_dispatch(GenericAccessAdapterChipTrait<F>)]
#[chip(where = "F: PrimeField32")]
enum GenericAccessAdapterChip<F> {
N2(AccessAdapterChip<F, 2>),
N4(AccessAdapterChip<F, 4>),
Expand All @@ -168,33 +170,6 @@ enum GenericAccessAdapterChip<F> {
N64(AccessAdapterChip<F, 64>),
}

impl<SC: StarkGenericConfig> Chip<SC> for GenericAccessAdapterChip<Val<SC>>
where
Val<SC>: PrimeField32,
{
fn air(&self) -> Arc<dyn AnyRap<SC>> {
match self {
GenericAccessAdapterChip::N2(chip) => chip.air(),
GenericAccessAdapterChip::N4(chip) => chip.air(),
GenericAccessAdapterChip::N8(chip) => chip.air(),
GenericAccessAdapterChip::N16(chip) => chip.air(),
GenericAccessAdapterChip::N32(chip) => chip.air(),
GenericAccessAdapterChip::N64(chip) => chip.air(),
}
}

fn generate_air_proof_input(self) -> AirProofInput<SC> {
match self {
GenericAccessAdapterChip::N2(chip) => chip.generate_air_proof_input(),
GenericAccessAdapterChip::N4(chip) => chip.generate_air_proof_input(),
GenericAccessAdapterChip::N8(chip) => chip.generate_air_proof_input(),
GenericAccessAdapterChip::N16(chip) => chip.generate_air_proof_input(),
GenericAccessAdapterChip::N32(chip) => chip.generate_air_proof_input(),
GenericAccessAdapterChip::N64(chip) => chip.generate_air_proof_input(),
}
}
}

impl<F> GenericAccessAdapterChip<F> {
fn new<const N: usize>(
range_checker: Arc<VariableRangeCheckerChip>,
Expand Down Expand Up @@ -313,7 +288,7 @@ where

impl<F, const N: usize> ChipUsageGetter for AccessAdapterChip<F, N> {
fn air_name(&self) -> String {
format!("AccessAdapter<{}>", N)
air_name(N)
}

fn current_trace_height(&self) -> usize {
Expand All @@ -324,3 +299,8 @@ impl<F, const N: usize> ChipUsageGetter for AccessAdapterChip<F, N> {
BaseAir::<F>::width(&self.air)
}
}

#[inline]
fn air_name(n: usize) -> String {
format!("AccessAdapter<{}>", n)
}
1 change: 1 addition & 0 deletions crates/vm/src/system/memory/manager/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::system::memory::{
Equipartition, CHUNK,
};

#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub enum MemoryInterface<F> {
Volatile {
Expand Down
Loading

0 comments on commit 8eb4543

Please sign in to comment.