Skip to content

Commit

Permalink
feat: Remove legacy flatten/from_iter APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
zlangley committed Jan 7, 2025
1 parent 1885566 commit 1affc47
Showing 1 changed file with 27 additions and 220 deletions.
247 changes: 27 additions & 220 deletions crates/vm/src/system/memory/offline_checker/columns.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
//! Defines auxiliary columns for memory operations: `MemoryReadAuxCols`,
//! `MemoryReadWithImmediateAuxCols`, and `MemoryWriteAuxCols`.
use std::{array, borrow::Borrow, iter};

use openvm_circuit_primitives::is_less_than::LessThanAuxCols;
use openvm_circuit_primitives_derive::AlignedBorrow;
use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32};
Expand All @@ -21,29 +19,6 @@ pub struct MemoryBaseAuxCols<T> {
pub(super) clk_lt_aux: LessThanAuxCols<T, AUX_LEN>,
}

impl<T: Clone> MemoryBaseAuxCols<T> {
/// TODO[arayi]: Since we have AlignedBorrow, should remove all from_slice, from_iterator, and flatten in a future PR.
pub fn from_slice(slc: &[T]) -> Self {
let base_aux_cols: &MemoryBaseAuxCols<T> = slc.borrow();
base_aux_cols.clone()
}

pub fn from_iterator<I: Iterator<Item = T>>(iter: &mut I) -> Self {
let sm = iter.take(Self::width()).collect::<Vec<T>>();
let base_aux_cols: &MemoryBaseAuxCols<T> = sm[..].borrow();
base_aux_cols.clone()
}
}

impl<T> MemoryBaseAuxCols<T> {
pub fn flatten(self) -> Vec<T> {
iter::empty()
.chain(iter::once(self.prev_timestamp))
.chain(self.clk_lt_aux.lower_decomp)
.collect()
}
}

#[repr(C)]
#[derive(Clone, Copy, Debug, AlignedBorrow)]
pub struct MemoryWriteAuxCols<T, const N: usize> {
Expand All @@ -63,22 +38,7 @@ impl<const N: usize, T> MemoryWriteAuxCols<T, N> {
}
}

impl<const N: usize, T: Clone> MemoryWriteAuxCols<T, N> {
pub fn from_slice(slc: &[T]) -> Self {
let width = MemoryBaseAuxCols::<T>::width();
Self {
base: MemoryBaseAuxCols::from_slice(&slc[..width]),
prev_data: array::from_fn(|i| slc[width + i].clone()),
}
}

pub fn from_iterator<I: Iterator<Item = T>>(iter: &mut I) -> Self {
Self {
base: MemoryBaseAuxCols::from_iterator(iter),
prev_data: array::from_fn(|_| iter.next().unwrap()),
}
}

impl<const N: usize, T> MemoryWriteAuxCols<T, N> {
pub fn from_base(base: MemoryBaseAuxCols<T>, prev_data: [T; N]) -> Self {
Self { base, prev_data }
}
Expand All @@ -88,19 +48,17 @@ impl<const N: usize, T: Clone> MemoryWriteAuxCols<T, N> {
}
}

impl<const N: usize, T> MemoryWriteAuxCols<T, N> {
pub fn flatten(self) -> Vec<T> {
iter::empty()
.chain(self.base.flatten())
.chain(self.prev_data)
.collect()
}
}

impl<const N: usize, F: FieldAlgebra + Copy> MemoryWriteAuxCols<F, N> {
impl<const N: usize, F: FieldAlgebra> MemoryWriteAuxCols<F, N> {
pub fn disabled() -> Self {
let width = MemoryWriteAuxCols::<F, N>::width();
MemoryWriteAuxCols::from_slice(&F::zero_vec(width))
Self {
base: MemoryBaseAuxCols {
prev_timestamp: F::ZERO,
clk_lt_aux: LessThanAuxCols {
lower_decomp: [F::ZERO; AUX_LEN],
},
},
prev_data: [F::ZERO; N],
}
}
}

Expand All @@ -125,101 +83,17 @@ impl<const N: usize, F: PrimeField32> MemoryReadAuxCols<F, N> {
}
}

impl<const N: usize, T: Clone> MemoryReadAuxCols<T, N> {
pub fn from_slice(slc: &[T]) -> Self {
Self {
base: MemoryBaseAuxCols::from_slice(slc),
}
}

pub fn from_iterator<I: Iterator<Item = T>>(iter: &mut I) -> Self {
Self {
base: MemoryBaseAuxCols::from_iterator(iter),
}
}
}

impl<const N: usize, T> MemoryReadAuxCols<T, N> {
pub fn flatten(self) -> Vec<T> {
self.base.flatten()
}
}

impl<const N: usize, F: FieldAlgebra + Copy> MemoryReadAuxCols<F, N> {
pub fn disabled() -> Self {
let width = MemoryReadAuxCols::<F, N>::width();
MemoryReadAuxCols::from_slice(&F::zero_vec(width))
}
}

#[repr(C)]
#[derive(Clone, Debug, AlignedBorrow)]
pub struct MemoryHeapReadAuxCols<T, const N: usize> {
pub address: MemoryReadAuxCols<T, 1>,
pub data: MemoryReadAuxCols<T, N>,
}

impl<const N: usize, T: Clone> MemoryHeapReadAuxCols<T, N> {
pub fn from_iterator<I: Iterator<Item = T>>(iter: &mut I) -> Self {
Self {
address: MemoryReadAuxCols::from_iterator(iter),
data: MemoryReadAuxCols::from_iterator(iter),
}
}

pub fn flatten(self) -> Vec<T> {
iter::empty()
.chain(self.address.flatten())
.chain(self.data.flatten())
.collect()
}
}

impl<const N: usize, F: FieldAlgebra + Copy> MemoryHeapReadAuxCols<F, N> {
pub fn disabled() -> Self {
let width = MemoryReadAuxCols::<F, 1>::width();
let address = MemoryReadAuxCols::from_slice(&F::zero_vec(width));
let width = MemoryReadAuxCols::<F, N>::width();
let data = MemoryReadAuxCols::from_slice(&F::zero_vec(width));
MemoryHeapReadAuxCols { address, data }
}
}

#[repr(C)]
#[derive(Clone, Debug)]
pub struct MemoryHeapWriteAuxCols<T, const N: usize> {
pub address: MemoryReadAuxCols<T, 1>,
pub data: MemoryWriteAuxCols<T, N>,
}

impl<const N: usize, T: Clone> MemoryHeapWriteAuxCols<T, N> {
pub fn from_iterator<I: Iterator<Item = T>>(iter: &mut I) -> Self {
Self {
address: MemoryReadAuxCols::from_iterator(iter),
data: MemoryWriteAuxCols::from_iterator(iter),
base: MemoryBaseAuxCols {
prev_timestamp: F::ZERO,
clk_lt_aux: LessThanAuxCols {
lower_decomp: [F::ZERO; AUX_LEN],
},
},
}
}

pub fn flatten(self) -> Vec<T> {
iter::empty()
.chain(self.address.flatten())
.chain(self.data.flatten())
.collect()
}

pub const fn width() -> usize {
MemoryReadAuxCols::<T, 1>::width() + MemoryWriteAuxCols::<T, N>::width()
}
}

impl<const N: usize, F: FieldAlgebra + Copy> MemoryHeapWriteAuxCols<F, N> {
pub fn disabled() -> Self {
let width = MemoryReadAuxCols::<F, 1>::width();
let address = MemoryReadAuxCols::from_slice(&F::zero_vec(width));
let width = MemoryWriteAuxCols::<F, N>::width();
let data = MemoryWriteAuxCols::from_slice(&F::zero_vec(width));
MemoryHeapWriteAuxCols { address, data }
}
}

#[repr(C)]
Expand Down Expand Up @@ -248,84 +122,17 @@ impl<T> MemoryReadOrImmediateAuxCols<T> {
}
}

impl<T: Clone> MemoryReadOrImmediateAuxCols<T> {
pub fn from_slice(slc: &[T]) -> Self {
let width = MemoryBaseAuxCols::<T>::width();
Self {
base: MemoryBaseAuxCols::from_slice(&slc[..width]),
is_immediate: slc[width].clone(),
is_zero_aux: slc[width + 1].clone(),
}
}

pub fn from_iterator<I: Iterator<Item = T>>(iter: &mut I) -> Self {
Self {
base: MemoryBaseAuxCols::from_iterator(iter),
is_immediate: iter.next().unwrap(),
is_zero_aux: iter.next().unwrap(),
}
}
}

impl<T> MemoryReadOrImmediateAuxCols<T> {
pub fn flatten(self) -> Vec<T> {
iter::empty()
.chain(self.base.flatten())
.chain(iter::once(self.is_immediate))
.chain(iter::once(self.is_zero_aux))
.collect()
}
}

impl<F: FieldAlgebra + Copy> MemoryReadOrImmediateAuxCols<F> {
pub fn disabled() -> Self {
let width = MemoryReadOrImmediateAuxCols::<F>::width();
MemoryReadOrImmediateAuxCols::from_slice(&F::zero_vec(width))
}
}

#[cfg(test)]
mod tests {
use openvm_stark_sdk::p3_baby_bear::BabyBear;

use super::*;

#[test]
fn test_write_aux_cols_width() {
type F = BabyBear;

let disabled = MemoryWriteAuxCols::<F, 1>::disabled();
assert_eq!(
disabled.flatten().len(),
MemoryWriteAuxCols::<F, 1>::width()
);

let disabled = MemoryWriteAuxCols::<F, 4>::disabled();
assert_eq!(
disabled.flatten().len(),
MemoryWriteAuxCols::<F, 4>::width()
);
}

#[test]
fn test_read_aux_cols_width() {
type F = BabyBear;

let disabled = MemoryReadAuxCols::<F, 1>::disabled();
assert_eq!(disabled.flatten().len(), MemoryReadAuxCols::<F, 1>::width());

let disabled = MemoryReadAuxCols::<F, 4>::disabled();
assert_eq!(disabled.flatten().len(), MemoryReadAuxCols::<F, 4>::width());
}

#[test]
fn test_read_or_immediate_aux_cols_width() {
type F = BabyBear;

let disabled = MemoryReadOrImmediateAuxCols::<F>::disabled();
assert_eq!(
disabled.flatten().len(),
MemoryReadOrImmediateAuxCols::<F>::width()
);
MemoryReadOrImmediateAuxCols {
base: MemoryBaseAuxCols {
prev_timestamp: F::ZERO,
clk_lt_aux: LessThanAuxCols {
lower_decomp: [F::ZERO; AUX_LEN],
},
},
is_immediate: F::ZERO,
is_zero_aux: F::ZERO,
}
}
}

0 comments on commit 1affc47

Please sign in to comment.