Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: optimize IsLessThanTupleCols #64

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions chips/src/assert_sorted/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ impl<T: Clone> AssertSortedCols<T> {
width += num_limbs + 1;
}

// for the is_equal indicators
// prods
width += key_vec_len;

// for the inverses
width += key_vec_len;

// for the cumulative is_equal and less_than
width += 2 * key_vec_len;
// for the cumulative less_than
width += key_vec_len;

width
}
Expand Down
45 changes: 17 additions & 28 deletions chips/src/is_less_than_tuple/air.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use std::borrow::Borrow;

use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::Field;
use p3_field::{AbstractField, Field};
use p3_matrix::Matrix;

use crate::{
is_equal::{
columns::{IsEqualAuxCols, IsEqualCols, IsEqualIOCols},
IsEqualAir,
},
is_less_than::columns::{IsLessThanAuxCols, IsLessThanCols, IsLessThanIOCols},
sub_chip::{AirConfig, SubAir},
};
Expand Down Expand Up @@ -85,42 +81,35 @@ impl<AB: AirBuilder> SubAir<AB> for IsLessThanTupleAir {
);
}

// here, we constrain that is_equal is the indicator for whether diff == 0, i.e. x[i] = y[i]
let prods = aux.is_equal_vec_aux.prods.clone();
let invs = aux.is_equal_vec_aux.invs.clone();

// initialize prods[0] = is_equal(x[0], y[0])
builder.assert_eq(prods[0] + (x[0] - y[0]) * invs[0], AB::Expr::one());

for i in 0..x.len() {
let is_equal = aux.is_equal[i];
let inv = aux.is_equal_aux[i].inv;

let is_equal_cols = IsEqualCols {
io: IsEqualIOCols {
x: x[i],
y: y[i],
is_equal,
},
aux: IsEqualAuxCols { inv },
};
// constrain prods[i] = 0 if x[i] != y[i]
builder.assert_zero(prods[i] * (x[i] - y[i]));
}

SubAir::eval(&IsEqualAir, builder, is_equal_cols.io, is_equal_cols.aux);
for i in 0..x.len() - 1 {
// if prod[i] == 0 all after are 0
builder.assert_eq(prods[i] * prods[i + 1], prods[i + 1]);
// prods[i] == 1 forces prods[i+1] == is_equal(x[i+1], y[i+1])
builder.assert_eq(prods[i + 1] + (x[i + 1] - y[i + 1]) * invs[i + 1], prods[i]);
}

// here, we constrain that is_equal_cumulative and less_than_cumulative are the correct values
let is_equal_cumulative = aux.is_equal_cumulative.clone();
let less_than_cumulative = aux.less_than_cumulative.clone();

builder.assert_eq(is_equal_cumulative[0], aux.is_equal[0]);
builder.assert_eq(less_than_cumulative[0], aux.less_than[0]);

for i in 1..x.len() {
// this constrains that is_equal_cumulative[i] indicates whether the first i elements of x and y are equal
builder.assert_eq(
is_equal_cumulative[i],
is_equal_cumulative[i - 1] * aux.is_equal[i],
);
// this constrains that less_than_cumulative[i] indicates whether the first i elements of x are less than
// the first i elements of y, lexicographically
// note that less_than_cumulative[i - 1] and is_equal_cumulative[i - 1] are never both 1
// note that less_than_cumulative[i - 1] and prods[i - 1] are never both 1
builder.assert_eq(
less_than_cumulative[i],
less_than_cumulative[i - 1] + aux.less_than[i] * is_equal_cumulative[i - 1],
less_than_cumulative[i - 1] + aux.less_than[i] * prods[i - 1],
);
}

Expand Down
66 changes: 24 additions & 42 deletions chips/src/is_less_than_tuple/columns.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use afs_derive::AlignedBorrow;

use crate::{is_equal::columns::IsEqualAuxCols, is_less_than::columns::IsLessThanAuxCols};
use crate::{is_equal_vec::columns::IsEqualVecAuxCols, is_less_than::columns::IsLessThanAuxCols};

#[derive(Default, AlignedBorrow)]
pub struct IsLessThanTupleIOCols<T> {
Expand Down Expand Up @@ -34,10 +34,7 @@ impl<T: Clone> IsLessThanTupleIOCols<T> {
pub struct IsLessThanTupleAuxCols<T> {
pub less_than: Vec<T>,
pub less_than_aux: Vec<IsLessThanAuxCols<T>>,
pub is_equal: Vec<T>,
pub is_equal_aux: Vec<IsEqualAuxCols<T>>,

pub is_equal_cumulative: Vec<T>,
pub is_equal_vec_aux: IsEqualVecAuxCols<T>,
pub less_than_cumulative: Vec<T>,
}

Expand Down Expand Up @@ -74,48 +71,37 @@ impl<T: Clone> IsLessThanTupleAuxCols<T> {
curr_start_idx = curr_end_idx;
curr_end_idx += tuple_len;

// get whether y[i] - x[i] == 0
let is_equal = slc[curr_start_idx..curr_end_idx].to_vec();
// generate the less_than_aux columns
let mut less_than_aux: Vec<IsLessThanAuxCols<T>> = vec![];
for i in 0..tuple_len {
let less_than_col = IsLessThanAuxCols {
lower: lower_vec[i].clone(),
lower_decomp: lower_decomp_vec[i].clone(),
};

curr_start_idx = curr_end_idx;
curr_end_idx += tuple_len;
less_than_aux.push(less_than_col);
}

// get the inverses k such that k * (diff[i] + is_zero[i]) = 1
let inverses = slc[curr_start_idx..curr_end_idx].to_vec();
// prods[i] indicates whether x[i] == y[i] up to the i-th index
let prods = slc[curr_start_idx..curr_end_idx].to_vec();

curr_start_idx = curr_end_idx;
curr_end_idx += tuple_len;

let is_equal_cumulative = slc[curr_start_idx..curr_end_idx].to_vec();
// get invs
let invs = slc[curr_start_idx..curr_end_idx].to_vec();

curr_start_idx = curr_end_idx;
curr_end_idx += tuple_len;

let less_than_cumulative = slc[curr_start_idx..curr_end_idx].to_vec();
let is_equal_vec_aux = IsEqualVecAuxCols { prods, invs };

// generate the less_than_aux and is_equal_aux columns
let mut less_than_aux: Vec<IsLessThanAuxCols<T>> = vec![];
for i in 0..tuple_len {
let less_than_col = IsLessThanAuxCols {
lower: lower_vec[i].clone(),
lower_decomp: lower_decomp_vec[i].clone(),
};

less_than_aux.push(less_than_col);
}

let mut is_equal_aux: Vec<IsEqualAuxCols<T>> = vec![];
for inv in inverses.iter() {
let is_equal_col = IsEqualAuxCols { inv: inv.clone() };
is_equal_aux.push(is_equal_col);
}
let less_than_cumulative = slc[curr_start_idx..curr_end_idx].to_vec();

Self {
less_than,
less_than_aux,
is_equal,
is_equal_aux,
is_equal_cumulative,
is_equal_vec_aux,
less_than_cumulative,
}
}
Expand All @@ -133,13 +119,9 @@ impl<T: Clone> IsLessThanTupleAuxCols<T> {
flattened.extend_from_slice(&self.less_than_aux[i].lower_decomp);
}

flattened.extend_from_slice(&self.is_equal);
flattened.extend_from_slice(&self.is_equal_vec_aux.prods);
flattened.extend_from_slice(&self.is_equal_vec_aux.invs);

for i in 0..self.is_equal_aux.len() {
flattened.push(self.is_equal_aux[i].inv.clone());
}

flattened.extend_from_slice(&self.is_equal_cumulative);
flattened.extend_from_slice(&self.less_than_cumulative);

flattened
Expand All @@ -156,12 +138,12 @@ impl<T: Clone> IsLessThanTupleAuxCols<T> {
let num_limbs = (limb_bit + decomp - 1) / decomp;
width += num_limbs + 1;
}
// for the indicator whether difference is zero
// for the prods
width += tuple_len;
// for the invs
width += tuple_len;
// for the inverses k such that k * (diff[i] + is_zero[i]) = 1
// for the cumulative less_than
width += tuple_len;
// for the cumulative is_equal and less_than
width += 2 * tuple_len;

width
}
Expand Down
52 changes: 17 additions & 35 deletions chips/src/is_less_than_tuple/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use p3_field::PrimeField64;
use p3_matrix::dense::RowMajorMatrix;

use crate::{
is_equal::columns::IsEqualAuxCols,
is_equal_vec::columns::IsEqualVecAuxCols,
is_less_than::{columns::IsLessThanAuxCols, IsLessThanChip},
range_gate::RangeCheckerGateChip,
sub_chip::LocalTraceInstructions,
Expand Down Expand Up @@ -74,17 +74,27 @@ impl<F: PrimeField64> LocalTraceInstructions<F> for IsLessThanTupleAir {
lower_decomp_vec.push(curr_less_than_row[4..].to_vec());
}

// compute is_equal_cumulative
// compute prods and invs
let mut transition_index = 0;
while transition_index < x.len() && x[transition_index] == y[transition_index] {
transition_index += 1;
}

let is_equal_cumulative = std::iter::repeat(F::one())
let prods = std::iter::repeat(F::one())
.take(transition_index)
.chain(std::iter::repeat(F::zero()).take(x.len() - transition_index))
.collect::<Vec<F>>();

let mut invs = std::iter::repeat(F::zero())
.take(x.len())
.collect::<Vec<F>>();

if transition_index != x.len() {
invs[transition_index] = (F::from_canonical_u32(x[transition_index])
- F::from_canonical_u32(y[transition_index]))
.inverse();
}

let mut less_than_cumulative: Vec<F> = vec![];

// compute less_than_cumulative
Expand All @@ -95,7 +105,7 @@ impl<F: PrimeField64> LocalTraceInstructions<F> for IsLessThanTupleAir {
F::zero()
};

if x[i] < y[i] && (i == 0 || is_equal_cumulative[i - 1] == F::one()) {
if x[i] < y[i] && (i == 0 || prods[i - 1] == F::one()) {
less_than_curr = F::one();
}

Expand All @@ -108,29 +118,7 @@ impl<F: PrimeField64> LocalTraceInstructions<F> for IsLessThanTupleAir {
less_than_cumulative.push(less_than_curr);
}

// contains indicator whether difference is zero
let mut is_equal: Vec<F> = vec![];
// contains y such that y * (i + x) = 1
let mut inverses: Vec<F> = vec![];

// we compute the indicators, which only matter if the row is not the last
for (i, &val) in x.iter().enumerate() {
let next_val = y[i];

// the difference between the two limbs
let curr_diff = F::from_canonical_u32(val) - F::from_canonical_u32(next_val);

// compute the equal indicator and inverses
if next_val == val {
is_equal.push(F::one());
inverses.push((curr_diff + F::one()).inverse());
} else {
is_equal.push(F::zero());
inverses.push(curr_diff.inverse());
}
}

// compute less_than_aux and is_equal_aux
// compute less_than_aux and is_equal_vec_aux
let mut less_than_aux: Vec<IsLessThanAuxCols<F>> = vec![];
for i in 0..x.len() {
let less_than_col = IsLessThanAuxCols {
Expand All @@ -140,11 +128,7 @@ impl<F: PrimeField64> LocalTraceInstructions<F> for IsLessThanTupleAir {
less_than_aux.push(less_than_col);
}

let mut is_equal_aux: Vec<IsEqualAuxCols<F>> = vec![];
for inverse in &inverses {
let is_equal_col = IsEqualAuxCols { inv: *inverse };
is_equal_aux.push(is_equal_col);
}
let is_equal_vec_aux = IsEqualVecAuxCols { prods, invs };

let io = IsLessThanTupleIOCols {
x: x.into_iter().map(F::from_canonical_u32).collect(),
Expand All @@ -154,9 +138,7 @@ impl<F: PrimeField64> LocalTraceInstructions<F> for IsLessThanTupleAir {
let aux = IsLessThanTupleAuxCols {
less_than,
less_than_aux,
is_equal,
is_equal_aux,
is_equal_cumulative,
is_equal_vec_aux,
less_than_cumulative,
};

Expand Down
Loading