Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: change MAX generic to instance field #40

Merged
merged 3 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions chips/src/range/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ use p3_matrix::dense::RowMajorMatrix;
use super::columns::NUM_RANGE_COLS;
use super::RangeCheckerChip;

impl<F: Field, const MAX: u32> BaseAir<F> for RangeCheckerChip<MAX> {
impl<F: Field> BaseAir<F> for RangeCheckerChip {
fn width(&self) -> usize {
NUM_RANGE_COLS
}

fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
let column = (0..MAX).map(F::from_canonical_u32).collect();
let column = (0..self.range_max).map(F::from_canonical_u32).collect();
Some(RowMajorMatrix::new_col(column))
}
}

impl<AB, const MAX: u32> Air<AB> for RangeCheckerChip<MAX>
impl<AB> Air<AB> for RangeCheckerChip
where
AB: AirBuilder,
{
Expand Down
2 changes: 1 addition & 1 deletion chips/src/range/chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::{
RangeCheckerChip,
};

impl<F: PrimeField32, const MAX: u32> Chip<F> for RangeCheckerChip<MAX> {
impl<F: PrimeField32> Chip<F> for RangeCheckerChip {
fn receives(&self) -> Vec<Interaction<F>> {
vec![Interaction {
fields: vec![VirtualPairCol::single_preprocessed(
Expand Down
15 changes: 10 additions & 5 deletions chips/src/range/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,24 @@ pub mod columns;
pub mod trace;

#[derive(Default)]
pub struct RangeCheckerChip<const MAX: u32> {
pub struct RangeCheckerChip {
/// The index for the Range Checker bus.
bus_index: usize,
range_max: u32,
pub count: Vec<Arc<AtomicU32>>,
}

impl<const MAX: u32> RangeCheckerChip<MAX> {
pub fn new(bus_index: usize) -> Self {
impl RangeCheckerChip {
pub fn new(bus_index: usize, range_max: u32) -> Self {
let mut count = vec![];
for _ in 0..MAX {
for _ in 0..range_max {
count.push(Arc::new(AtomicU32::new(0)));
}
Self { bus_index, count }
Self {
bus_index,
range_max,
count,
}
}

pub fn bus_index(&self) -> usize {
Expand Down
4 changes: 2 additions & 2 deletions chips/src/range/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use super::{
RangeCheckerChip,
};

impl<const MAX: u32> RangeCheckerChip<MAX> {
impl RangeCheckerChip {
pub fn generate_trace<F: PrimeField32>(&self) -> RowMajorMatrix<F> {
let mut rows = vec![[F::zero(); NUM_RANGE_COLS]; MAX as usize];
let mut rows = vec![[F::zero(); NUM_RANGE_COLS]; self.range_max as usize];
for (n, row) in rows.iter_mut().enumerate() {
let cols: &mut RangeCols<F> = unsafe { transmute(row) };

Expand Down
4 changes: 2 additions & 2 deletions chips/src/range_gate/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ use super::columns::RangeGateCols;
use super::columns::NUM_RANGE_GATE_COLS;
use super::RangeCheckerGateChip;

impl<F: Field, const MAX: u32> BaseAir<F> for RangeCheckerGateChip<MAX> {
impl<F: Field> BaseAir<F> for RangeCheckerGateChip {
fn width(&self) -> usize {
NUM_RANGE_GATE_COLS
}
}

impl<AB, const MAX: u32> Air<AB> for RangeCheckerGateChip<MAX>
impl<AB> Air<AB> for RangeCheckerGateChip
where
AB: AirBuilder,
{
Expand Down
2 changes: 1 addition & 1 deletion chips/src/range_gate/chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use p3_field::PrimeField64;

use super::{columns::RANGE_GATE_COL_MAP, RangeCheckerGateChip};

impl<F: PrimeField64, const MAX: u32> Chip<F> for RangeCheckerGateChip<MAX> {
impl<F: PrimeField64> Chip<F> for RangeCheckerGateChip {
fn receives(&self) -> Vec<Interaction<F>> {
vec![Interaction {
fields: vec![VirtualPairCol::single_main(RANGE_GATE_COL_MAP.counter)],
Expand Down
15 changes: 10 additions & 5 deletions chips/src/range_gate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,24 @@ pub mod trace;
/// column. The counter column is generated using a gate, as opposed to
/// the other RangeCheckerChip.
#[derive(Default)]
pub struct RangeCheckerGateChip<const MAX: u32> {
pub struct RangeCheckerGateChip {
/// The index for the Range Checker bus.
bus_index: usize,
_range_max: u32,
pub count: Vec<Arc<AtomicU32>>,
}

impl<const MAX: u32> RangeCheckerGateChip<MAX> {
pub fn new(bus_index: usize) -> Self {
impl RangeCheckerGateChip {
pub fn new(bus_index: usize, range_max: u32) -> Self {
let mut count = vec![];
for _ in 0..MAX {
for _ in 0..range_max {
count.push(Arc::new(AtomicU32::new(0)));
}
Self { bus_index, count }
Self {
bus_index,
_range_max: range_max,
count,
}
}

pub fn bus_index(&self) -> usize {
Expand Down
2 changes: 1 addition & 1 deletion chips/src/range_gate/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use p3_matrix::dense::RowMajorMatrix;

use super::{columns::NUM_RANGE_GATE_COLS, RangeCheckerGateChip};

impl<const MAX: u32> RangeCheckerGateChip<MAX> {
impl RangeCheckerGateChip {
pub fn generate_trace<F: PrimeField64>(&self) -> RowMajorMatrix<F> {
let rows = self
.count
Expand Down
10 changes: 5 additions & 5 deletions chips/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn test_list_range_checker() {
const LIST_LEN: usize = 1 << LOG_TRACE_DEGREE_LIST;

// Creating a RangeCheckerChip
let range_checker = Arc::new(RangeCheckerChip::<MAX>::new(bus_index));
let range_checker = Arc::new(RangeCheckerChip::new(bus_index, MAX));

// Generating random lists
let num_lists = 10;
Expand All @@ -50,8 +50,8 @@ fn test_list_range_checker() {
// define a bunch of ListChips
let lists = lists_vals
.iter()
.map(|vals| ListChip::new(bus_index, vals.to_vec(), Arc::clone(&range_checker)))
.collect::<Vec<ListChip<MAX>>>();
.map(|vals| ListChip::new(bus_index, MAX, vals.to_vec(), Arc::clone(&range_checker)))
.collect::<Vec<ListChip>>();

let lists_traces = lists
.par_iter()
Expand Down Expand Up @@ -329,7 +329,7 @@ fn test_range_gate_chip() {
const LOG_LIST_LEN: usize = 6;
const LIST_LEN: usize = 1 << LOG_LIST_LEN;

let range_checker = RangeCheckerGateChip::<MAX>::new(bus_index);
let range_checker = RangeCheckerGateChip::new(bus_index, MAX);

// Generating random lists
let num_lists = 10;
Expand Down Expand Up @@ -387,7 +387,7 @@ fn negative_test_range_gate_chip() {
const N: usize = 3;
const MAX: u32 = 1 << N;

let range_checker = RangeCheckerGateChip::<MAX>::new(bus_index);
let range_checker = RangeCheckerGateChip::new(bus_index, MAX);

// generating a trace with a counter starting from 1
// instead of 0 to test the AIR constraints in range_checker
Expand Down
4 changes: 2 additions & 2 deletions chips/tests/list/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use p3_field::Field;
use super::columns::NUM_LIST_COLS;
use super::ListChip;

impl<F: Field, const MAX: u32> BaseAir<F> for ListChip<MAX> {
impl<F: Field> BaseAir<F> for ListChip {
fn width(&self) -> usize {
NUM_LIST_COLS
}
}

impl<AB, const MAX: u32> Air<AB> for ListChip<MAX>
impl<AB> Air<AB> for ListChip
where
AB: AirBuilder,
{
Expand Down
2 changes: 1 addition & 1 deletion chips/tests/list/chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use p3_field::PrimeField32;

use super::{columns::LIST_COL_MAP, ListChip};

impl<F: PrimeField32, const MAX: u32> Chip<F> for ListChip<MAX> {
impl<F: PrimeField32> Chip<F> for ListChip {
fn sends(&self) -> Vec<Interaction<F>> {
vec![Interaction {
fields: vec![VirtualPairCol::single_main(LIST_COL_MAP.val)],
Expand Down
11 changes: 7 additions & 4 deletions chips/tests/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,25 @@ pub mod columns;
pub mod trace;

#[derive(Default)]
pub struct ListChip<const MAX: u32> {
pub struct ListChip {
/// The index for the Range Checker bus.
bus_index: usize,
_range_max: u32,
pub vals: Vec<u32>,

range_checker: Arc<RangeCheckerChip<MAX>>,
range_checker: Arc<RangeCheckerChip>,
}

impl<const MAX: u32> ListChip<MAX> {
impl ListChip {
pub fn new(
bus_index: usize,
range_max: u32,
vals: Vec<u32>,
range_checker: Arc<RangeCheckerChip<MAX>>,
range_checker: Arc<RangeCheckerChip>,
) -> Self {
Self {
bus_index,
_range_max: range_max,
vals,
range_checker,
}
Expand Down
2 changes: 1 addition & 1 deletion chips/tests/list/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use p3_matrix::dense::RowMajorMatrix;

use super::{columns::NUM_LIST_COLS, ListChip};

impl<const MAX: u32> ListChip<MAX> {
impl ListChip {
pub fn generate_trace<F: PrimeField32>(&self) -> RowMajorMatrix<F> {
let mut rows = vec![];
for val in self.vals.iter() {
Expand Down
Loading