Skip to content

Commit

Permalink
nonce translation
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-barrett committed Nov 11, 2024
1 parent da7e74a commit b05f909
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 47 deletions.
6 changes: 4 additions & 2 deletions src/air/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,17 @@ impl<AB: AirBuilder + MessageBuilder<AirInteraction<AB::Expr>>> LookupBuilder fo
pub struct Record {
pub nonce: u32,
pub count: u32,
pub query_index: u32,
// Original query that did the lookup. `None` is for the root lookup
pub query_index: Option<usize>,
}

impl Record {
/// Updates the provide record and returns the require record
pub fn new_lookup(&mut self, nonce: u32) -> Record {
pub fn new_lookup(&mut self, nonce: u32, query_index: usize) -> Record {
let require = *self;
self.nonce = nonce;
self.count += 1;
self.query_index = Some(query_index);
require
}

Expand Down
3 changes: 2 additions & 1 deletion src/core/big_num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ impl<F: PrimeField32> Chipset<F> for BigNum {
&self,
input: &[F],
nonce: u32,
query_index: usize,
queries: &mut QueryRecord<F>,
requires: &mut Vec<Record>,
) -> Vec<F> {
let in1: [F; 8] = input[0..8].try_into().unwrap();
let in2: [F; 8] = input[8..16].try_into().unwrap();
let bytes = &mut queries.bytes.context(nonce, requires);
let bytes = &mut queries.bytes.context(nonce, query_index, requires);
match self {
BigNum::LessThan => {
let mut witness = BigNumCompareWitness::<F>::default();
Expand Down
17 changes: 12 additions & 5 deletions src/core/chipset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,22 @@ impl Chipset<BabyBear> for LurkChip {
&self,
input: &[BabyBear],
nonce: u32,
query_index: usize,
queries: &mut QueryRecord<BabyBear>,
requires: &mut Vec<Record>,
) -> Vec<BabyBear> {
match self {
LurkChip::Hasher3(hasher) => hasher.execute(input, nonce, queries, requires),
LurkChip::Hasher4(hasher) => hasher.execute(input, nonce, queries, requires),
LurkChip::Hasher5(hasher) => hasher.execute(input, nonce, queries, requires),
LurkChip::U64(op) => op.execute(input, nonce, queries, requires),
LurkChip::BigNum(op) => op.execute(input, nonce, queries, requires),
LurkChip::Hasher3(hasher) => {
hasher.execute(input, nonce, query_index, queries, requires)
}
LurkChip::Hasher4(hasher) => {
hasher.execute(input, nonce, query_index, queries, requires)
}
LurkChip::Hasher5(hasher) => {
hasher.execute(input, nonce, query_index, queries, requires)
}
LurkChip::U64(op) => op.execute(input, nonce, query_index, queries, requires),
LurkChip::BigNum(op) => op.execute(input, nonce, query_index, queries, requires),
}
}

Expand Down
1 change: 1 addition & 0 deletions src/core/cli/repl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ impl<F: PrimeField32, C1: Chipset<F>, C2: Chipset<F>> Repl<F, C1, C2> {
bytes: Default::default(),
emitted: Default::default(),
debug_data: Default::default(),
nonce_map: Default::default(),
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/core/u64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl<F: PrimeField32> Chipset<F> for U64 {
&self,
input: &[F],
nonce: u32,
query_index: usize,
queries: &mut QueryRecord<F>,
requires: &mut Vec<Record>,
) -> Vec<F> {
Expand All @@ -93,7 +94,7 @@ impl<F: PrimeField32> Chipset<F> for U64 {
U64::IsZero => 0, // unused
_ => into_u64(&input[8..16]),
};
let bytes = &mut queries.bytes.context(nonce, requires);
let bytes = &mut queries.bytes.context(nonce, query_index, requires);
match self {
U64::Add => {
let mut witness = Sum64::<F>::default();
Expand Down
21 changes: 15 additions & 6 deletions src/gadgets/bytes/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub struct ByteRecordWithContext<'a> {
pub nonce: u32,
pub requires: &'a mut Vec<Record>,
pub record: &'a mut BytesRecord,
pub query_index: usize,
}

/// For a given input byte pair, this structure records the nonce and count of the accesses to
Expand Down Expand Up @@ -76,12 +77,14 @@ impl BytesRecord {
pub fn context<'a>(
&'a mut self,
nonce: u32,
query_index: usize,
requires: &'a mut Vec<Record>,
) -> ByteRecordWithContext<'a> {
ByteRecordWithContext {
nonce,
record: self,
requires,
query_index,
}
}

Expand Down Expand Up @@ -112,46 +115,52 @@ impl BytesRecord {
impl ByteRecord for ByteRecordWithContext<'_> {
fn range_check_u8_pair(&mut self, i1: u8, i2: u8) {
let input = ByteInput::from_u8_pair(i1, i2);
let index = self.query_index;
let range_u8 = &mut self.record.get_mut(input).range_u8;
let require = range_u8.new_lookup(self.nonce);
let require = range_u8.new_lookup(self.nonce, index);
self.requires.push(require);
}

fn range_check_u16(&mut self, i: u16) {
let input = ByteInput::from_u16(i);
let index = self.query_index;
let range_u16 = &mut self.record.get_mut(input).range_u16;
let require = range_u16.new_lookup(self.nonce);
let require = range_u16.new_lookup(self.nonce, index);
self.requires.push(require);
}

fn less_than(&mut self, i1: u8, i2: u8) -> bool {
let input = ByteInput::from_u8_pair(i1, i2);
let index = self.query_index;
let less_than = &mut self.record.get_mut(input).less_than;
let require = less_than.new_lookup(self.nonce);
let require = less_than.new_lookup(self.nonce, index);
self.requires.push(require);
input.less_than()
}

fn and(&mut self, i1: u8, i2: u8) -> u8 {
let input = ByteInput::from_u8_pair(i1, i2);
let index = self.query_index;
let and = &mut self.record.get_mut(input).and;
let require = and.new_lookup(self.nonce);
let require = and.new_lookup(self.nonce, index);
self.requires.push(require);
input.and()
}

fn xor(&mut self, i1: u8, i2: u8) -> u8 {
let input = ByteInput::from_u8_pair(i1, i2);
let index = self.query_index;
let xor = &mut self.record.get_mut(input).xor;
let require = xor.new_lookup(self.nonce);
let require = xor.new_lookup(self.nonce, index);
self.requires.push(require);
input.xor()
}

fn or(&mut self, i1: u8, i2: u8) -> u8 {
let input = ByteInput::from_u8_pair(i1, i2);
let index = self.query_index;
let or = &mut self.record.get_mut(input).or;
let require = or.new_lookup(self.nonce);
let require = or.new_lookup(self.nonce, index);
self.requires.push(require);
input.or()
}
Expand Down
6 changes: 4 additions & 2 deletions src/lair/chipset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub trait Chipset<F>: Sync {
&self,
input: &[F],
_nonce: u32,
_query_index: usize,
_queries: &mut QueryRecord<F>,
_requires: &mut Vec<Record>,
) -> Vec<F>
Expand Down Expand Up @@ -65,15 +66,16 @@ impl<F, C1: Chipset<F>, C2: Chipset<F>> Chipset<F> for &Either<C1, C2> {
&self,
input: &[F],
nonce: u32,
query_index: usize,
queries: &mut QueryRecord<F>,
requires: &mut Vec<Record>,
) -> Vec<F>
where
F: PrimeField32,
{
match self {
Either::Left(c) => c.execute(input, nonce, queries, requires),
Either::Right(c) => c.execute(input, nonce, queries, requires),
Either::Left(c) => c.execute(input, nonce, query_index, queries, requires),
Either::Right(c) => c.execute(input, nonce, query_index, queries, requires),
}
}

Expand Down
50 changes: 43 additions & 7 deletions src/lair/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ use super::{
chipset::Chipset,
expr::ReturnGroup,
func_chip::FuncChip,
toplevel::Toplevel,
toplevel::{FuncStruct, Toplevel},
FxIndexMap, List,
};

type QueryMap<F> = FxIndexMap<List<F>, QueryResult<F>>;
type InvQueryMap<F> = FxHashMap<List<F>, List<F>>;
pub(crate) type MemMap<F> = FxIndexMap<List<F>, QueryResult<F>>;
type NonceMap = FxHashMap<(usize, u32), u32>;

#[derive(Clone, Debug, Eq, PartialEq, Default)]
pub struct QueryResult<F> {
Expand All @@ -47,9 +48,8 @@ impl<F: PrimeField32> QueryResult<F> {
nonce: usize,
caller_requires: &mut Vec<Record>,
) {
let mut lookup = self.provide.new_lookup(nonce as u32);
lookup.query_index = index as u32;
caller_requires.push(lookup);
let old_lookup = self.provide.new_lookup(nonce as u32, index);
caller_requires.push(old_lookup);
}
}

Expand Down Expand Up @@ -82,6 +82,7 @@ pub struct QueryRecord<F: PrimeField32> {
pub(crate) bytes: BytesRecord,
pub(crate) emitted: Vec<List<F>>,
pub(crate) debug_data: DebugData,
pub(crate) nonce_map: NonceMap,
}

#[derive(Default, Clone, Debug, Eq, PartialEq)]
Expand Down Expand Up @@ -287,6 +288,7 @@ impl<F: PrimeField32> QueryRecord<F> {
bytes: BytesRecord::default(),
emitted: vec![],
debug_data: DebugData::default(),
nonce_map: FxHashMap::default(),
}
}

Expand Down Expand Up @@ -379,6 +381,28 @@ impl<F: PrimeField32> QueryRecord<F> {
pub fn expect_public_values(&self) -> &[F] {
self.public_values.as_ref().expect("Public values not set")
}

pub fn populate_nonce_map<C1: Chipset<F>, C2: Chipset<F>>(
&mut self,
toplevel: &Toplevel<F, C1, C2>,
) {
for FuncStruct {
full_func,
split_funcs,
} in toplevel.func_map.values()
{
let index = full_func.index;
let max = *split_funcs.keys().max().unwrap() + 1;
let mut nonces = vec![0; max as usize];
let queries = &self.func_queries[index];
for (i, result) in queries.values().enumerate() {
let nonce = nonces[result.return_group as usize];
println!("{:?}", (index, i as u32));
self.nonce_map.insert((index, i as u32), nonce);
nonces[result.return_group as usize] += 1
}
}
}
}

impl<F: PrimeField32, C1: Chipset<F>, C2: Chipset<F>> FuncChip<'_, F, C1, C2> {
Expand Down Expand Up @@ -412,6 +436,7 @@ impl<F: PrimeField32, C1: Chipset<F>, C2: Chipset<F>> Toplevel<F, C1, C2> {
public_values.extend(depth.to_le_bytes().map(F::from_canonical_u8));
}
queries.public_values = Some(public_values);
queries.populate_nonce_map(self);
Ok(out)
}

Expand Down Expand Up @@ -689,13 +714,21 @@ impl<F: PrimeField32> Func<F> {
ExecEntry::Op(Op::ExternCall(chip_idx, input)) => {
let input: List<_> = input.iter().map(|a| map[*a]).collect();
let chip = toplevel.chip_by_index(*chip_idx);
map.extend(chip.execute(&input, nonce as u32, queries, &mut requires));
map.extend(chip.execute(
&input,
nonce as u32,
func_index,
queries,
&mut requires,
));
}
ExecEntry::Op(Op::Emit(xs)) => {
queries.emitted.push(xs.iter().map(|a| map[*a]).collect())
}
ExecEntry::Op(Op::RangeU8(xs)) => {
let mut bytes = queries.bytes.context(nonce as u32, &mut requires);
let mut bytes = queries
.bytes
.context(nonce as u32, func_index, &mut requires);
let xs = xs.iter().map(|x| {
map[*x]
.as_canonical_u32()
Expand Down Expand Up @@ -724,7 +757,10 @@ impl<F: PrimeField32> Func<F> {
inv_map.insert(out_list.clone(), inp.clone());
}
if partial {
let mut bytes = queries.bytes.context(nonce as u32, &mut depth_requires);
let mut bytes =
queries
.bytes
.context(nonce as u32, func_index, &mut depth_requires);
let depth = depths.iter().map(|&a| a + 1).max().unwrap_or(0);
bytes.range_check_u8_iter(depth.to_le_bytes());
for dep_depth in depths.iter() {
Expand Down
Loading

0 comments on commit b05f909

Please sign in to comment.