Skip to content

Commit

Permalink
feat: remove LOADW2 and STOREW2 (#1160)
Browse files Browse the repository at this point in the history
* feat: remove LOADW2 and STOREW2

* chore: remove LOADW2 and STOREW2 from docs

* chore: rename eDSL to native

* fix: benchmaks do not trigger on extensions
  • Loading branch information
yi-sun committed Jan 7, 2025
1 parent 4d18009 commit 71939d1
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 314 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ on:
types: [opened, synchronize, reopened, labeled]
branches: ["**"]
paths:
- "benchmarks/**"
- "crates/circuits/**"
- "crates/vm/**"
- "crates/toolchain/**"
- "crates/vm/**"
- "extensions/**"
- "benchmarks/**"
- ".github/workflows/benchmark-call.yml"
- ".github/workflows/benchmarks.yml"
workflow_dispatch:
Expand Down
10 changes: 4 additions & 6 deletions docs/specs/ISA.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

## Instruction format

Instructions are encoded as a global opcode (field element) followed by `NUM_OPERANDS = 7` operands (field elements): `opcode, a, b, c, d, e, f, g`. An instruction does not need to use all operands, and trailing unused operands should be set to zero.
Instructions are encoded as a global opcode (field element) followed by `NUM_OPERANDS = 6` operands (field elements): `opcode, a, b, c, d, e, f`. An instruction does not need to use all operands, and trailing unused operands should be set to zero.

## Program ROM

Expand Down Expand Up @@ -431,11 +431,9 @@ In some instructions below, `W` is a generic parameter for the block size.
| -------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| LOAD\<W\> | `a,b,c,d,e` | Set `[a:W]_d = [[c]_d + b:W]_e`. Both `d, e` must be non-zero. |
| STORE\<W\> | `a,b,c,d,e` | Set `[[c]_d + b:W]_e = [a:W]_d`. Both `d, e` must be non-zero. |
| LOAD2\<W\> | `a,b,c,d,e,f,g` | Set `[a:W]_d = [[c]_d + [f]_d * g + b:W]_e`. |
| STORE2\<W\> | `a,b,c,d,e,f,g` | Set `[[c]_d + [f]_d * g + b:W]_e = [a:W]_d`. |
| JAL | `a,b,c,d` | Jump to address and link: set `[a]_d = (pc + DEFAULT_PC_STEP)` and `pc = pc + b`. Here `d` must be non-zero. |
| BEQ\<W\> | `a,b,c,d,e` | If `[a:W]_d == [b:W]_e`, then set `pc = pc + c`. |
| BNE\<W\> | `a,b,c,d,e` | If `[a:W]_d != [b:W]_e`, then set `pc = pc + c`. |
| JAL | `a,b,c,d` | Jump to address and link: set `[a]_d = (pc + DEFAULT_PC_STEP)` and `pc = pc + b`. Here `d` must be non-zero. |
| BEQ\<W\> | `a,b,c,d,e` | If `[a:W]_d == [b:W]_e`, then set `pc = pc + c`. |
| BNE\<W\> | `a,b,c,d,e` | If `[a:W]_d != [b:W]_e`, then set `pc = pc + c`. |
| HINTSTORE\<W\> | `_,b,c,d,e` | Set `[[c]_d + b:W]_e = next W elements from hint stream`. Both `d, e` must be non-zero. |
| PUBLISH | `a,b,_,d,e` | Set the user public output at index `[a]_d` to equal `[b]_e`. Invalid if `[a]_d` is greater than or equal to the configured length of user public outputs. Only valid when continuations are disabled. |
| CASTF | `a,b,_,d,e` | Cast a field element represented as `u32` into four bytes in little-endian: Set `[a:4]_d` to the unique array such that `sum_{i=0}^3 [a + i]_d * 2^{8i} = [b]_e` where `[a + i]_d < 2^8` for `i = 0..2` and `[a + 3]_d < 2^6`. This opcode constrains that `[b]_e` must be at most 30-bits. Both `d, e` must be non-zero. |
Expand Down
100 changes: 23 additions & 77 deletions extensions/native/circuit/src/adapters/loadstore_native_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ pub struct NativeLoadStoreInstruction<T> {
// Absolute opcode number
pub opcode: T,
pub is_loadw: T,
pub is_loadw2: T,
pub is_storew: T,
pub is_storew2: T,
pub is_shintw: T,
}

Expand All @@ -43,7 +41,7 @@ impl<T, const NUM_CELLS: usize> VmAdapterInterface<T>
for NativeLoadStoreAdapterInterface<T, NUM_CELLS>
{
// TODO[yi]: Fix when vectorizing
type Reads = ([T; 2], T);
type Reads = (T, T);
type Writes = [T; NUM_CELLS];
type ProcessedInstruction = NativeLoadStoreInstruction<T>;
}
Expand Down Expand Up @@ -75,8 +73,7 @@ impl<F: PrimeField32, const NUM_CELLS: usize> NativeLoadStoreAdapterChip<F, NUM_

#[derive(Clone, Debug)]
pub struct NativeLoadStoreReadRecord<F: Field, const NUM_CELLS: usize> {
pub pointer1_read: RecordId,
pub pointer2_read: Option<RecordId>,
pub pointer_read: RecordId,
pub data_read: Option<RecordId>,
pub write_as: F,
pub write_ptr: F,
Expand All @@ -86,8 +83,6 @@ pub struct NativeLoadStoreReadRecord<F: Field, const NUM_CELLS: usize> {
pub c: F,
pub d: F,
pub e: F,
pub f: F,
pub g: F,
}

#[derive(Clone, Debug)]
Expand All @@ -105,16 +100,14 @@ pub struct NativeLoadStoreAdapterCols<T, const NUM_CELLS: usize> {
pub c: T,
pub d: T,
pub e: T,
pub f: T,
pub g: T,

pub data_read_as: T,
pub data_read_pointer: T,

pub data_write_as: T,
pub data_write_pointer: T,

pub pointer_read_aux_cols: [MemoryReadOrImmediateAuxCols<T>; 2],
pub pointer_read_aux_cols: MemoryReadOrImmediateAuxCols<T>,
pub data_read_aux_cols: MemoryReadOrImmediateAuxCols<T>,
// TODO[yi]: Fix when vectorizing
// pub data_read_aux_cols: MemoryReadAuxCols<T, NUM_CELLS>,
Expand Down Expand Up @@ -154,51 +147,32 @@ impl<AB: InteractionBuilder, const NUM_CELLS: usize> VmAdapterAir<AB>
let is_valid = ctx.instruction.is_valid;
let is_loadw = ctx.instruction.is_loadw;
let is_storew = ctx.instruction.is_storew;
let is_loadw2 = ctx.instruction.is_loadw2;
let is_storew2 = ctx.instruction.is_storew2;
let is_shintw = ctx.instruction.is_shintw;

// first pointer read is always [c]_d
self.memory_bridge
.read_or_immediate(
MemoryAddress::new(cols.d, cols.c),
ctx.reads.0[0].clone(),
ctx.reads.0.clone(),
timestamp + timestamp_delta.clone(),
&cols.pointer_read_aux_cols[0],
&cols.pointer_read_aux_cols,
)
.eval(builder, is_valid.clone());
timestamp_delta += is_valid.clone();

// second pointer read is [f]_d if loadw2 or storew2, otherwise disabled
self.memory_bridge
.read_or_immediate(
MemoryAddress::new(cols.d, cols.f),
ctx.reads.0[1].clone(),
timestamp + timestamp_delta.clone(),
&cols.pointer_read_aux_cols[1],
)
.eval(
builder,
is_valid.clone() - is_shintw.clone() - is_loadw.clone() - is_storew.clone(),
);
timestamp_delta +=
is_valid.clone() - is_shintw.clone() - is_loadw.clone() - is_storew.clone();

// TODO[yi]: Remove when vectorizing
// read data, disabled if SHINTW
// data pointer = [c]_d + [f]_d * g + b, degree 2
builder
.when(is_valid.clone() - is_shintw.clone())
.assert_eq(
cols.data_read_as,
utils::select::<AB::Expr>(is_loadw.clone() + is_loadw2.clone(), cols.e, cols.d),
utils::select::<AB::Expr>(is_loadw.clone(), cols.e, cols.d),
);
// TODO[yi]: Do we need to check for overflow?
builder.assert_eq(
(is_valid.clone() - is_shintw.clone()) * cols.data_read_pointer,
(is_storew.clone() + is_storew2.clone()) * cols.a
+ (is_loadw.clone() + is_loadw2.clone())
* (ctx.reads.0[0].clone() + cols.b + ctx.reads.0[1].clone() * cols.g),
is_storew.clone() * cols.a + is_loadw.clone() * (ctx.reads.0.clone() + cols.b),
);
self.memory_bridge
.read_or_immediate(
Expand All @@ -213,14 +187,13 @@ impl<AB: InteractionBuilder, const NUM_CELLS: usize> VmAdapterAir<AB>
// data write
builder.when(is_valid.clone()).assert_eq(
cols.data_write_as,
utils::select::<AB::Expr>(is_loadw.clone() + is_loadw2.clone(), cols.d, cols.e),
utils::select::<AB::Expr>(is_loadw.clone(), cols.d, cols.e),
);
// TODO[yi]: Do we need to check for overflow?
builder.assert_eq(
is_valid.clone() * cols.data_write_pointer,
(is_loadw.clone() + is_loadw2.clone()) * cols.a
+ (is_storew.clone() + is_storew2.clone() + is_shintw.clone())
* (ctx.reads.0[0].clone() + cols.b + ctx.reads.0[1].clone() * cols.g),
is_loadw.clone() * cols.a
+ (is_storew.clone() + is_shintw.clone()) * (ctx.reads.0.clone() + cols.b),
);
self.memory_bridge
.write(
Expand All @@ -235,7 +208,7 @@ impl<AB: InteractionBuilder, const NUM_CELLS: usize> VmAdapterAir<AB>
self.execution_bridge
.execute_and_increment_or_set_pc(
ctx.instruction.opcode,
[cols.a, cols.b, cols.c, cols.d, cols.e, cols.f, cols.g],
[cols.a, cols.b, cols.c, cols.d, cols.e],
cols.from_state,
timestamp_delta.clone(),
(DEFAULT_PC_STEP, ctx.to_pc),
Expand Down Expand Up @@ -273,36 +246,25 @@ impl<F: PrimeField32, const NUM_CELLS: usize> VmAdapterChip<F>
c,
d,
e,
f,
g,
..
} = *instruction;
let local_opcode = NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset));

let read1_as = d;
let read1_ptr = c;
let read2_as = d;
let read2_ptr = f;

let read1_cell = memory.read_cell(read1_as, read1_ptr);
let read2_cell = match local_opcode {
LOADW2 | STOREW2 => Some(memory.read_cell(read2_as, read2_ptr)),
_ => None,
};
let read_as = d;
let read_ptr = c;
let read_cell = memory.read_cell(read_as, read_ptr);

let (data_read_as, data_write_as) = {
match local_opcode {
LOADW | LOADW2 => (e, d),
STOREW | STOREW2 | SHINTW => (d, e),
LOADW => (e, d),
STOREW | SHINTW => (d, e),
}
};
let (data_read_ptr, data_write_ptr) = {
match local_opcode {
LOADW => (read1_cell.1 + b, a),
LOADW2 => (read1_cell.1 + b + read2_cell.unwrap().1 * g, a),
STOREW => (a, read1_cell.1 + b),
STOREW2 => (a, read1_cell.1 + b + read2_cell.unwrap().1 * g),
SHINTW => (a, read1_cell.1 + b),
LOADW => (read_cell.1 + b, a),
STOREW => (a, read_cell.1 + b),
SHINTW => (a, read_cell.1 + b),
}
};

Expand All @@ -312,8 +274,7 @@ impl<F: PrimeField32, const NUM_CELLS: usize> VmAdapterChip<F>
_ => Some(memory.read::<1>(data_read_as, data_read_ptr)),
};
let record = NativeLoadStoreReadRecord {
pointer1_read: read1_cell.0,
pointer2_read: read2_cell.map(|x| x.0),
pointer_read: read_cell.0,
data_read: data_read.map(|x| x.0),
write_as: data_write_as,
write_ptr: data_write_ptr,
Expand All @@ -322,17 +283,9 @@ impl<F: PrimeField32, const NUM_CELLS: usize> VmAdapterChip<F>
c,
d,
e,
f,
g,
};

Ok((
(
[read1_cell.1, read2_cell.map_or(F::ZERO, |x| x.1)],
data_read.map_or(F::ZERO, |x| x.1[0]),
),
record,
))
Ok(((read_cell.1, data_read.map_or(F::ZERO, |x| x.1[0])), record))
}

fn postprocess(
Expand Down Expand Up @@ -372,8 +325,6 @@ impl<F: PrimeField32, const NUM_CELLS: usize> VmAdapterChip<F>
cols.c = read_record.c;
cols.d = read_record.d;
cols.e = read_record.e;
cols.f = read_record.f;
cols.g = read_record.g;

let data_read = read_record.data_read.map(|read| memory.record_by_id(read));
if let Some(data_read) = data_read {
Expand All @@ -388,13 +339,8 @@ impl<F: PrimeField32, const NUM_CELLS: usize> VmAdapterChip<F>
cols.data_write_as = write.address_space;
cols.data_write_pointer = write.pointer;

cols.pointer_read_aux_cols[0] = aux_cols_factory
.make_read_or_immediate_aux_cols(memory.record_by_id(read_record.pointer1_read));
cols.pointer_read_aux_cols[1] = read_record
.pointer2_read
.map_or_else(MemoryReadOrImmediateAuxCols::disabled, |read| {
aux_cols_factory.make_read_or_immediate_aux_cols(memory.record_by_id(read))
});
cols.pointer_read_aux_cols = aux_cols_factory
.make_read_or_immediate_aux_cols(memory.record_by_id(read_record.pointer_read));
cols.data_write_aux_cols = aux_cols_factory.make_write_aux_cols(write);
}

Expand Down
30 changes: 9 additions & 21 deletions extensions/native/circuit/src/loadstore/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@ use super::super::adapters::loadstore_native_adapter::NativeLoadStoreInstruction
pub struct NativeLoadStoreCoreCols<T, const NUM_CELLS: usize> {
pub is_loadw: T,
pub is_storew: T,
pub is_loadw2: T,
pub is_storew2: T,
pub is_shintw: T,

pub pointer_reads: [T; 2],
pub pointer_read: T,
pub data_read: T,
pub data_write: [T; NUM_CELLS],
}
Expand All @@ -39,7 +37,7 @@ pub struct NativeLoadStoreCoreCols<T, const NUM_CELLS: usize> {
pub struct NativeLoadStoreCoreRecord<F, const NUM_CELLS: usize> {
pub opcode: NativeLoadStoreOpcode,

pub pointer_reads: [F; 2],
pub pointer_read: F,
pub data_read: F,
pub data_write: [F; NUM_CELLS],
}
Expand All @@ -64,7 +62,7 @@ impl<AB, I, const NUM_CELLS: usize> VmCoreAir<AB, I> for NativeLoadStoreCoreAir<
where
AB: InteractionBuilder,
I: VmAdapterInterface<AB::Expr>,
I::Reads: From<([AB::Expr; 2], AB::Expr)>,
I::Reads: From<(AB::Expr, AB::Expr)>,
I::Writes: From<[AB::Expr; NUM_CELLS]>,
I::ProcessedInstruction: From<NativeLoadStoreInstruction<AB::Expr>>,
{
Expand All @@ -75,13 +73,7 @@ where
_from_pc: AB::Var,
) -> AdapterAirContext<AB::Expr, I> {
let cols: &NativeLoadStoreCoreCols<_, NUM_CELLS> = (*local_core).borrow();
let flags = [
cols.is_loadw,
cols.is_storew,
cols.is_loadw2,
cols.is_storew2,
cols.is_shintw,
];
let flags = [cols.is_loadw, cols.is_storew, cols.is_shintw];
let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
builder.assert_bool(flag);
acc + flag.into()
Expand All @@ -97,15 +89,13 @@ where

AdapterAirContext {
to_pc: None,
reads: (cols.pointer_reads.map(Into::into), cols.data_read.into()).into(),
reads: (cols.pointer_read.into(), cols.data_read.into()).into(),
writes: cols.data_write.map(Into::into).into(),
instruction: NativeLoadStoreInstruction {
is_valid,
opcode: expected_opcode,
is_loadw: cols.is_loadw.into(),
is_storew: cols.is_storew.into(),
is_loadw2: cols.is_loadw2.into(),
is_storew2: cols.is_storew2.into(),
is_shintw: cols.is_shintw.into(),
}
.into(),
Expand Down Expand Up @@ -134,7 +124,7 @@ impl<F: Field, const NUM_CELLS: usize> NativeLoadStoreCoreChip<F, NUM_CELLS> {
impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_CELLS: usize> VmCoreChip<F, I>
for NativeLoadStoreCoreChip<F, NUM_CELLS>
where
I::Reads: Into<([F; 2], F)>,
I::Reads: Into<(F, F)>,
I::Writes: From<[F; NUM_CELLS]>,
{
type Record = NativeLoadStoreCoreRecord<F, NUM_CELLS>;
Expand All @@ -149,7 +139,7 @@ where
let Instruction { opcode, .. } = *instruction;
let local_opcode =
NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
let (pointer_reads, data_read) = reads.into();
let (pointer_read, data_read) = reads.into();

let data_write = if local_opcode == NativeLoadStoreOpcode::SHINTW {
let mut streams = self.streams.get().unwrap().lock().unwrap();
Expand All @@ -164,7 +154,7 @@ where
let output = AdapterRuntimeContext::without_pc(data_write);
let record = NativeLoadStoreCoreRecord {
opcode: NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)),
pointer_reads,
pointer_read,
data_read,
data_write,
};
Expand All @@ -182,11 +172,9 @@ where
let cols: &mut NativeLoadStoreCoreCols<_, NUM_CELLS> = row_slice.borrow_mut();
cols.is_loadw = F::from_bool(record.opcode == NativeLoadStoreOpcode::LOADW);
cols.is_storew = F::from_bool(record.opcode == NativeLoadStoreOpcode::STOREW);
cols.is_loadw2 = F::from_bool(record.opcode == NativeLoadStoreOpcode::LOADW2);
cols.is_storew2 = F::from_bool(record.opcode == NativeLoadStoreOpcode::STOREW2);
cols.is_shintw = F::from_bool(record.opcode == NativeLoadStoreOpcode::SHINTW);

cols.pointer_reads = record.pointer_reads.map(Into::into);
cols.pointer_read = record.pointer_read;
cols.data_read = record.data_read;
cols.data_write = record.data_write.map(Into::into);
}
Expand Down
Loading

0 comments on commit 71939d1

Please sign in to comment.