Skip to content

Commit

Permalink
chore: optimize IsLessThanTupleCols (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfan05 authored Jun 13, 2024
1 parent 612ff0c commit 3549d03
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 108 deletions.
6 changes: 3 additions & 3 deletions chips/src/assert_sorted/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ impl<T: Clone> AssertSortedCols<T> {
width += num_limbs + 1;
}

// for the is_equal indicators
// prods
width += key_vec_len;

// for the inverses
width += key_vec_len;

// for the cumulative is_equal and less_than
width += 2 * key_vec_len;
// for the cumulative less_than
width += key_vec_len;

width
}
Expand Down
45 changes: 17 additions & 28 deletions chips/src/is_less_than_tuple/air.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use std::borrow::Borrow;

use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::Field;
use p3_field::{AbstractField, Field};
use p3_matrix::Matrix;

use crate::{
is_equal::{
columns::{IsEqualAuxCols, IsEqualCols, IsEqualIOCols},
IsEqualAir,
},
is_less_than::columns::{IsLessThanAuxCols, IsLessThanCols, IsLessThanIOCols},
sub_chip::{AirConfig, SubAir},
};
Expand Down Expand Up @@ -85,42 +81,35 @@ impl<AB: AirBuilder> SubAir<AB> for IsLessThanTupleAir {
);
}

// here, we constrain that is_equal is the indicator for whether diff == 0, i.e. x[i] = y[i]
let prods = aux.is_equal_vec_aux.prods.clone();
let invs = aux.is_equal_vec_aux.invs.clone();

// initialize prods[0] = is_equal(x[0], y[0])
builder.assert_eq(prods[0] + (x[0] - y[0]) * invs[0], AB::Expr::one());

for i in 0..x.len() {
let is_equal = aux.is_equal[i];
let inv = aux.is_equal_aux[i].inv;

let is_equal_cols = IsEqualCols {
io: IsEqualIOCols {
x: x[i],
y: y[i],
is_equal,
},
aux: IsEqualAuxCols { inv },
};
// constrain prods[i] = 0 if x[i] != y[i]
builder.assert_zero(prods[i] * (x[i] - y[i]));
}

SubAir::eval(&IsEqualAir, builder, is_equal_cols.io, is_equal_cols.aux);
for i in 0..x.len() - 1 {
// if prod[i] == 0 all after are 0
builder.assert_eq(prods[i] * prods[i + 1], prods[i + 1]);
// prods[i] == 1 forces prods[i+1] == is_equal(x[i+1], y[i+1])
builder.assert_eq(prods[i + 1] + (x[i + 1] - y[i + 1]) * invs[i + 1], prods[i]);
}

// here, we constrain that is_equal_cumulative and less_than_cumulative are the correct values
let is_equal_cumulative = aux.is_equal_cumulative.clone();
let less_than_cumulative = aux.less_than_cumulative.clone();

builder.assert_eq(is_equal_cumulative[0], aux.is_equal[0]);
builder.assert_eq(less_than_cumulative[0], aux.less_than[0]);

for i in 1..x.len() {
// this constrains that is_equal_cumulative[i] indicates whether the first i elements of x and y are equal
builder.assert_eq(
is_equal_cumulative[i],
is_equal_cumulative[i - 1] * aux.is_equal[i],
);
// this constrains that less_than_cumulative[i] indicates whether the first i elements of x are less than
// the first i elements of y, lexicographically
// note that less_than_cumulative[i - 1] and is_equal_cumulative[i - 1] are never both 1
// note that less_than_cumulative[i - 1] and prods[i - 1] are never both 1
builder.assert_eq(
less_than_cumulative[i],
less_than_cumulative[i - 1] + aux.less_than[i] * is_equal_cumulative[i - 1],
less_than_cumulative[i - 1] + aux.less_than[i] * prods[i - 1],
);
}

Expand Down
66 changes: 24 additions & 42 deletions chips/src/is_less_than_tuple/columns.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use afs_derive::AlignedBorrow;

use crate::{is_equal::columns::IsEqualAuxCols, is_less_than::columns::IsLessThanAuxCols};
use crate::{is_equal_vec::columns::IsEqualVecAuxCols, is_less_than::columns::IsLessThanAuxCols};

#[derive(Default, AlignedBorrow)]
pub struct IsLessThanTupleIOCols<T> {
Expand Down Expand Up @@ -34,10 +34,7 @@ impl<T: Clone> IsLessThanTupleIOCols<T> {
pub struct IsLessThanTupleAuxCols<T> {
pub less_than: Vec<T>,
pub less_than_aux: Vec<IsLessThanAuxCols<T>>,
pub is_equal: Vec<T>,
pub is_equal_aux: Vec<IsEqualAuxCols<T>>,

pub is_equal_cumulative: Vec<T>,
pub is_equal_vec_aux: IsEqualVecAuxCols<T>,
pub less_than_cumulative: Vec<T>,
}

Expand Down Expand Up @@ -74,48 +71,37 @@ impl<T: Clone> IsLessThanTupleAuxCols<T> {
curr_start_idx = curr_end_idx;
curr_end_idx += tuple_len;

// get whether y[i] - x[i] == 0
let is_equal = slc[curr_start_idx..curr_end_idx].to_vec();
// generate the less_than_aux columns
let mut less_than_aux: Vec<IsLessThanAuxCols<T>> = vec![];
for i in 0..tuple_len {
let less_than_col = IsLessThanAuxCols {
lower: lower_vec[i].clone(),
lower_decomp: lower_decomp_vec[i].clone(),
};

curr_start_idx = curr_end_idx;
curr_end_idx += tuple_len;
less_than_aux.push(less_than_col);
}

// get the inverses k such that k * (diff[i] + is_zero[i]) = 1
let inverses = slc[curr_start_idx..curr_end_idx].to_vec();
// prods[i] indicates whether x[i] == y[i] up to the i-th index
let prods = slc[curr_start_idx..curr_end_idx].to_vec();

curr_start_idx = curr_end_idx;
curr_end_idx += tuple_len;

let is_equal_cumulative = slc[curr_start_idx..curr_end_idx].to_vec();
// get invs
let invs = slc[curr_start_idx..curr_end_idx].to_vec();

curr_start_idx = curr_end_idx;
curr_end_idx += tuple_len;

let less_than_cumulative = slc[curr_start_idx..curr_end_idx].to_vec();
let is_equal_vec_aux = IsEqualVecAuxCols { prods, invs };

// generate the less_than_aux and is_equal_aux columns
let mut less_than_aux: Vec<IsLessThanAuxCols<T>> = vec![];
for i in 0..tuple_len {
let less_than_col = IsLessThanAuxCols {
lower: lower_vec[i].clone(),
lower_decomp: lower_decomp_vec[i].clone(),
};

less_than_aux.push(less_than_col);
}

let mut is_equal_aux: Vec<IsEqualAuxCols<T>> = vec![];
for inv in inverses.iter() {
let is_equal_col = IsEqualAuxCols { inv: inv.clone() };
is_equal_aux.push(is_equal_col);
}
let less_than_cumulative = slc[curr_start_idx..curr_end_idx].to_vec();

Self {
less_than,
less_than_aux,
is_equal,
is_equal_aux,
is_equal_cumulative,
is_equal_vec_aux,
less_than_cumulative,
}
}
Expand All @@ -133,13 +119,9 @@ impl<T: Clone> IsLessThanTupleAuxCols<T> {
flattened.extend_from_slice(&self.less_than_aux[i].lower_decomp);
}

flattened.extend_from_slice(&self.is_equal);
flattened.extend_from_slice(&self.is_equal_vec_aux.prods);
flattened.extend_from_slice(&self.is_equal_vec_aux.invs);

for i in 0..self.is_equal_aux.len() {
flattened.push(self.is_equal_aux[i].inv.clone());
}

flattened.extend_from_slice(&self.is_equal_cumulative);
flattened.extend_from_slice(&self.less_than_cumulative);

flattened
Expand All @@ -156,12 +138,12 @@ impl<T: Clone> IsLessThanTupleAuxCols<T> {
let num_limbs = (limb_bit + decomp - 1) / decomp;
width += num_limbs + 1;
}
// for the indicator whether difference is zero
// for the prods
width += tuple_len;
// for the invs
width += tuple_len;
// for the inverses k such that k * (diff[i] + is_zero[i]) = 1
// for the cumulative less_than
width += tuple_len;
// for the cumulative is_equal and less_than
width += 2 * tuple_len;

width
}
Expand Down
52 changes: 17 additions & 35 deletions chips/src/is_less_than_tuple/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use p3_field::PrimeField64;
use p3_matrix::dense::RowMajorMatrix;

use crate::{
is_equal::columns::IsEqualAuxCols,
is_equal_vec::columns::IsEqualVecAuxCols,
is_less_than::{columns::IsLessThanAuxCols, IsLessThanChip},
range_gate::RangeCheckerGateChip,
sub_chip::LocalTraceInstructions,
Expand Down Expand Up @@ -74,17 +74,27 @@ impl<F: PrimeField64> LocalTraceInstructions<F> for IsLessThanTupleAir {
lower_decomp_vec.push(curr_less_than_row[4..].to_vec());
}

// compute is_equal_cumulative
// compute prods and invs
let mut transition_index = 0;
while transition_index < x.len() && x[transition_index] == y[transition_index] {
transition_index += 1;
}

let is_equal_cumulative = std::iter::repeat(F::one())
let prods = std::iter::repeat(F::one())
.take(transition_index)
.chain(std::iter::repeat(F::zero()).take(x.len() - transition_index))
.collect::<Vec<F>>();

let mut invs = std::iter::repeat(F::zero())
.take(x.len())
.collect::<Vec<F>>();

if transition_index != x.len() {
invs[transition_index] = (F::from_canonical_u32(x[transition_index])
- F::from_canonical_u32(y[transition_index]))
.inverse();
}

let mut less_than_cumulative: Vec<F> = vec![];

// compute less_than_cumulative
Expand All @@ -95,7 +105,7 @@ impl<F: PrimeField64> LocalTraceInstructions<F> for IsLessThanTupleAir {
F::zero()
};

if x[i] < y[i] && (i == 0 || is_equal_cumulative[i - 1] == F::one()) {
if x[i] < y[i] && (i == 0 || prods[i - 1] == F::one()) {
less_than_curr = F::one();
}

Expand All @@ -108,29 +118,7 @@ impl<F: PrimeField64> LocalTraceInstructions<F> for IsLessThanTupleAir {
less_than_cumulative.push(less_than_curr);
}

// contains indicator whether difference is zero
let mut is_equal: Vec<F> = vec![];
// contains y such that y * (i + x) = 1
let mut inverses: Vec<F> = vec![];

// we compute the indicators, which only matter if the row is not the last
for (i, &val) in x.iter().enumerate() {
let next_val = y[i];

// the difference between the two limbs
let curr_diff = F::from_canonical_u32(val) - F::from_canonical_u32(next_val);

// compute the equal indicator and inverses
if next_val == val {
is_equal.push(F::one());
inverses.push((curr_diff + F::one()).inverse());
} else {
is_equal.push(F::zero());
inverses.push(curr_diff.inverse());
}
}

// compute less_than_aux and is_equal_aux
// compute less_than_aux and is_equal_vec_aux
let mut less_than_aux: Vec<IsLessThanAuxCols<F>> = vec![];
for i in 0..x.len() {
let less_than_col = IsLessThanAuxCols {
Expand All @@ -140,11 +128,7 @@ impl<F: PrimeField64> LocalTraceInstructions<F> for IsLessThanTupleAir {
less_than_aux.push(less_than_col);
}

let mut is_equal_aux: Vec<IsEqualAuxCols<F>> = vec![];
for inverse in &inverses {
let is_equal_col = IsEqualAuxCols { inv: *inverse };
is_equal_aux.push(is_equal_col);
}
let is_equal_vec_aux = IsEqualVecAuxCols { prods, invs };

let io = IsLessThanTupleIOCols {
x: x.into_iter().map(F::from_canonical_u32).collect(),
Expand All @@ -154,9 +138,7 @@ impl<F: PrimeField64> LocalTraceInstructions<F> for IsLessThanTupleAir {
let aux = IsLessThanTupleAuxCols {
less_than,
less_than_aux,
is_equal,
is_equal_aux,
is_equal_cumulative,
is_equal_vec_aux,
less_than_cumulative,
};

Expand Down

0 comments on commit 3549d03

Please sign in to comment.