Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Remove legacy flatten/from_iter APIs #1186

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 30 additions & 223 deletions crates/vm/src/system/memory/offline_checker/columns.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
//! Defines auxiliary columns for memory operations: `MemoryReadAuxCols`,
//! `MemoryReadWithImmediateAuxCols`, and `MemoryWriteAuxCols`.

use std::{array, borrow::Borrow, iter};

use openvm_circuit_primitives::is_less_than::LessThanAuxCols;
use openvm_circuit_primitives_derive::AlignedBorrow;
use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32};
Expand All @@ -21,29 +19,6 @@ pub struct MemoryBaseAuxCols<T> {
pub(super) clk_lt_aux: LessThanAuxCols<T, AUX_LEN>,
}

impl<T: Clone> MemoryBaseAuxCols<T> {
// TODO[arayi]: Since we have AlignedBorrow, should remove all from_slice, from_iterator, and flatten in a future PR.
pub fn from_slice(slc: &[T]) -> Self {
let base_aux_cols: &MemoryBaseAuxCols<T> = slc.borrow();
base_aux_cols.clone()
}

pub fn from_iterator<I: Iterator<Item = T>>(iter: &mut I) -> Self {
let sm = iter.take(Self::width()).collect::<Vec<T>>();
let base_aux_cols: &MemoryBaseAuxCols<T> = sm[..].borrow();
base_aux_cols.clone()
}
}

impl<T> MemoryBaseAuxCols<T> {
pub fn flatten(self) -> Vec<T> {
iter::empty()
.chain(iter::once(self.prev_timestamp))
.chain(self.clk_lt_aux.lower_decomp)
.collect()
}
}

#[repr(C)]
#[derive(Clone, Copy, Debug, AlignedBorrow)]
pub struct MemoryWriteAuxCols<T, const N: usize> {
Expand All @@ -63,22 +38,7 @@ impl<const N: usize, T> MemoryWriteAuxCols<T, N> {
}
}

impl<const N: usize, T: Clone> MemoryWriteAuxCols<T, N> {
pub fn from_slice(slc: &[T]) -> Self {
let width = MemoryBaseAuxCols::<T>::width();
Self {
base: MemoryBaseAuxCols::from_slice(&slc[..width]),
prev_data: array::from_fn(|i| slc[width + i].clone()),
}
}

pub fn from_iterator<I: Iterator<Item = T>>(iter: &mut I) -> Self {
Self {
base: MemoryBaseAuxCols::from_iterator(iter),
prev_data: array::from_fn(|_| iter.next().unwrap()),
}
}

impl<const N: usize, T> MemoryWriteAuxCols<T, N> {
pub fn from_base(base: MemoryBaseAuxCols<T>, prev_data: [T; N]) -> Self {
Self { base, prev_data }
}
Expand All @@ -88,19 +48,17 @@ impl<const N: usize, T: Clone> MemoryWriteAuxCols<T, N> {
}
}

impl<const N: usize, T> MemoryWriteAuxCols<T, N> {
pub fn flatten(self) -> Vec<T> {
iter::empty()
.chain(self.base.flatten())
.chain(self.prev_data)
.collect()
}
}

impl<const N: usize, F: FieldAlgebra + Copy> MemoryWriteAuxCols<F, N> {
pub fn disabled() -> Self {
let width = MemoryWriteAuxCols::<F, N>::width();
MemoryWriteAuxCols::from_slice(&F::zero_vec(width))
impl<const N: usize, F: FieldAlgebra> MemoryWriteAuxCols<F, N> {
pub const fn disabled() -> Self {
Self {
base: MemoryBaseAuxCols {
prev_timestamp: F::ZERO,
clk_lt_aux: LessThanAuxCols {
lower_decomp: [F::ZERO; AUX_LEN],
},
},
prev_data: [F::ZERO; N],
}
}
}

Expand All @@ -125,101 +83,17 @@ impl<const N: usize, F: PrimeField32> MemoryReadAuxCols<F, N> {
}
}

impl<const N: usize, T: Clone> MemoryReadAuxCols<T, N> {
pub fn from_slice(slc: &[T]) -> Self {
Self {
base: MemoryBaseAuxCols::from_slice(slc),
}
}

pub fn from_iterator<I: Iterator<Item = T>>(iter: &mut I) -> Self {
Self {
base: MemoryBaseAuxCols::from_iterator(iter),
}
}
}

impl<const N: usize, T> MemoryReadAuxCols<T, N> {
pub fn flatten(self) -> Vec<T> {
self.base.flatten()
}
}

impl<const N: usize, F: FieldAlgebra + Copy> MemoryReadAuxCols<F, N> {
pub fn disabled() -> Self {
let width = MemoryReadAuxCols::<F, N>::width();
MemoryReadAuxCols::from_slice(&F::zero_vec(width))
}
}

#[repr(C)]
#[derive(Clone, Debug, AlignedBorrow)]
pub struct MemoryHeapReadAuxCols<T, const N: usize> {
pub address: MemoryReadAuxCols<T, 1>,
pub data: MemoryReadAuxCols<T, N>,
}

impl<const N: usize, T: Clone> MemoryHeapReadAuxCols<T, N> {
pub fn from_iterator<I: Iterator<Item = T>>(iter: &mut I) -> Self {
pub const fn disabled() -> Self {
Self {
address: MemoryReadAuxCols::from_iterator(iter),
data: MemoryReadAuxCols::from_iterator(iter),
}
}

pub fn flatten(self) -> Vec<T> {
iter::empty()
.chain(self.address.flatten())
.chain(self.data.flatten())
.collect()
}
}

impl<const N: usize, F: FieldAlgebra + Copy> MemoryHeapReadAuxCols<F, N> {
pub fn disabled() -> Self {
let width = MemoryReadAuxCols::<F, 1>::width();
let address = MemoryReadAuxCols::from_slice(&F::zero_vec(width));
let width = MemoryReadAuxCols::<F, N>::width();
let data = MemoryReadAuxCols::from_slice(&F::zero_vec(width));
MemoryHeapReadAuxCols { address, data }
}
}

#[repr(C)]
#[derive(Clone, Debug)]
pub struct MemoryHeapWriteAuxCols<T, const N: usize> {
pub address: MemoryReadAuxCols<T, 1>,
pub data: MemoryWriteAuxCols<T, N>,
}

impl<const N: usize, T: Clone> MemoryHeapWriteAuxCols<T, N> {
pub fn from_iterator<I: Iterator<Item = T>>(iter: &mut I) -> Self {
Self {
address: MemoryReadAuxCols::from_iterator(iter),
data: MemoryWriteAuxCols::from_iterator(iter),
base: MemoryBaseAuxCols {
prev_timestamp: F::ZERO,
clk_lt_aux: LessThanAuxCols {
lower_decomp: [F::ZERO; AUX_LEN],
},
},
}
}

pub fn flatten(self) -> Vec<T> {
iter::empty()
.chain(self.address.flatten())
.chain(self.data.flatten())
.collect()
}

pub const fn width() -> usize {
MemoryReadAuxCols::<T, 1>::width() + MemoryWriteAuxCols::<T, N>::width()
}
}

impl<const N: usize, F: FieldAlgebra + Copy> MemoryHeapWriteAuxCols<F, N> {
pub fn disabled() -> Self {
let width = MemoryReadAuxCols::<F, 1>::width();
let address = MemoryReadAuxCols::from_slice(&F::zero_vec(width));
let width = MemoryWriteAuxCols::<F, N>::width();
let data = MemoryWriteAuxCols::from_slice(&F::zero_vec(width));
MemoryHeapWriteAuxCols { address, data }
}
}

#[repr(C)]
Expand Down Expand Up @@ -248,84 +122,17 @@ impl<T> MemoryReadOrImmediateAuxCols<T> {
}
}

impl<T: Clone> MemoryReadOrImmediateAuxCols<T> {
pub fn from_slice(slc: &[T]) -> Self {
let width = MemoryBaseAuxCols::<T>::width();
Self {
base: MemoryBaseAuxCols::from_slice(&slc[..width]),
is_immediate: slc[width].clone(),
is_zero_aux: slc[width + 1].clone(),
}
}

pub fn from_iterator<I: Iterator<Item = T>>(iter: &mut I) -> Self {
Self {
base: MemoryBaseAuxCols::from_iterator(iter),
is_immediate: iter.next().unwrap(),
is_zero_aux: iter.next().unwrap(),
}
}
}

impl<T> MemoryReadOrImmediateAuxCols<T> {
pub fn flatten(self) -> Vec<T> {
iter::empty()
.chain(self.base.flatten())
.chain(iter::once(self.is_immediate))
.chain(iter::once(self.is_zero_aux))
.collect()
}
}

impl<F: FieldAlgebra + Copy> MemoryReadOrImmediateAuxCols<F> {
pub fn disabled() -> Self {
let width = MemoryReadOrImmediateAuxCols::<F>::width();
MemoryReadOrImmediateAuxCols::from_slice(&F::zero_vec(width))
}
}

#[cfg(test)]
mod tests {
use openvm_stark_sdk::p3_baby_bear::BabyBear;

use super::*;

#[test]
fn test_write_aux_cols_width() {
type F = BabyBear;

let disabled = MemoryWriteAuxCols::<F, 1>::disabled();
assert_eq!(
disabled.flatten().len(),
MemoryWriteAuxCols::<F, 1>::width()
);

let disabled = MemoryWriteAuxCols::<F, 4>::disabled();
assert_eq!(
disabled.flatten().len(),
MemoryWriteAuxCols::<F, 4>::width()
);
}

#[test]
fn test_read_aux_cols_width() {
type F = BabyBear;

let disabled = MemoryReadAuxCols::<F, 1>::disabled();
assert_eq!(disabled.flatten().len(), MemoryReadAuxCols::<F, 1>::width());

let disabled = MemoryReadAuxCols::<F, 4>::disabled();
assert_eq!(disabled.flatten().len(), MemoryReadAuxCols::<F, 4>::width());
}

#[test]
fn test_read_or_immediate_aux_cols_width() {
type F = BabyBear;

let disabled = MemoryReadOrImmediateAuxCols::<F>::disabled();
assert_eq!(
disabled.flatten().len(),
MemoryReadOrImmediateAuxCols::<F>::width()
);
pub const fn disabled() -> Self {
MemoryReadOrImmediateAuxCols {
base: MemoryBaseAuxCols {
prev_timestamp: F::ZERO,
clk_lt_aux: LessThanAuxCols {
lower_decomp: [F::ZERO; AUX_LEN],
},
},
is_immediate: F::ZERO,
is_zero_aux: F::ZERO,
}
}
}