Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

Commit

Permalink
feat(acir)!: Add predicate to MemoryOp (#503)
Browse files Browse the repository at this point in the history
  • Loading branch information
vezenovm authored Aug 30, 2023
1 parent ae65355 commit ca9eebe
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 14 deletions.
8 changes: 7 additions & 1 deletion acir/src/circuit/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub enum Opcode {
MemoryOp {
block_id: BlockId,
op: MemOp,
/// Predicate of the memory operation - indicates if it should be skipped
predicate: Option<Expression>,
},
MemoryInit {
block_id: BlockId,
Expand Down Expand Up @@ -158,8 +160,12 @@ impl std::fmt::Display for Opcode {
writeln!(f, "outputs: {:?}", brillig.outputs)?;
writeln!(f, "{:?}", brillig.bytecode)
}
Opcode::MemoryOp { block_id, op } => {
Opcode::MemoryOp { block_id, op, predicate } => {
write!(f, "MEM ")?;
if let Some(pred) = predicate {
writeln!(f, "PREDICATE = {pred}")?;
}

let is_read = op.operation.is_zero();
let is_write = op.operation == Expression::one();
if is_read {
Expand Down
39 changes: 38 additions & 1 deletion acir/tests/test_program_serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use acir::{
circuit::{
brillig::{Brillig, BrilligInputs, BrilligOutputs},
directives::Directive,
opcodes::{BlackBoxFuncCall, FunctionInput},
opcodes::{BlackBoxFuncCall, BlockId, FunctionInput, MemOp},
Circuit, Opcode, PublicInputs,
},
native_types::{Expression, Witness},
Expand Down Expand Up @@ -340,3 +340,40 @@ fn complex_brillig_foreign_call() {

assert_eq!(bytes, expected_serialization)
}

#[test]
fn memory_op_circuit() {
let init = vec![Witness(1), Witness(2)];

let memory_init = Opcode::MemoryInit { block_id: BlockId(0), init };
let write = Opcode::MemoryOp {
block_id: BlockId(0),
op: MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
predicate: None,
};
let read = Opcode::MemoryOp {
block_id: BlockId(0),
op: MemOp::read_at_mem_index(FieldElement::one().into(), Witness(4)),
predicate: None,
};

let circuit = Circuit {
current_witness_index: 5,
opcodes: vec![memory_init, write, read],
private_parameters: BTreeSet::from([Witness(1), Witness(2), Witness(3)]),
return_values: PublicInputs([Witness(4)].into()),
..Circuit::default()
};
let mut bytes = Vec::new();
circuit.write(&mut bytes).unwrap();

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 146, 49, 14, 0, 32, 8, 3, 139, 192, 127, 240, 7,
254, 255, 85, 198, 136, 9, 131, 155, 48, 216, 165, 76, 77, 57, 80, 0, 140, 45, 117, 111,
238, 228, 179, 224, 174, 225, 110, 111, 234, 213, 185, 148, 156, 203, 121, 89, 86, 13, 215,
126, 131, 43, 153, 187, 115, 40, 185, 62, 153, 3, 136, 83, 60, 30, 96, 2, 12, 235, 225,
124, 14, 3, 0, 0,
];

assert_eq!(bytes, expected_serialization)
}
101 changes: 93 additions & 8 deletions acvm/src/pwg/memory_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::HashMap;

use acir::{
circuit::opcodes::MemOp,
native_types::{Witness, WitnessMap},
native_types::{Expression, Witness, WitnessMap},
FieldElement,
};

Expand Down Expand Up @@ -63,6 +63,7 @@ impl MemoryOpSolver {
&mut self,
op: &MemOp,
initial_witness: &mut WitnessMap,
predicate: &Option<Expression>,
) -> Result<(), OpcodeResolutionError> {
let operation = get_value(&op.operation, initial_witness)?;

Expand All @@ -79,6 +80,12 @@ impl MemoryOpSolver {
// `operation == 0` implies a read operation. (`operation == 1` implies write operation).
let is_read_operation = operation.is_zero();

// If the predicate is `None`, then we simply return the value 1
let pred_value = match predicate {
Some(pred) => get_value(pred, initial_witness),
None => Ok(FieldElement::one()),
}?;

if is_read_operation {
// `value_read = arr[memory_index]`
//
Expand All @@ -88,7 +95,13 @@ impl MemoryOpSolver {
"Memory must be read into a specified witness index, encountered an Expression",
);

let value_in_array = self.read_memory_index(memory_index)?;
// A zero predicate indicates that we should skip the read operation
// and zero out the operation's output.
let value_in_array = if pred_value.is_zero() {
FieldElement::zero()
} else {
self.read_memory_index(memory_index)?
};
insert_value(&value_read_witness, value_in_array, initial_witness)
} else {
// `arr[memory_index] = value_write`
Expand All @@ -97,9 +110,15 @@ impl MemoryOpSolver {
// into the memory block.
let value_write = value;

let value_to_write = get_value(&value_write, initial_witness)?;

self.write_memory_index(memory_index, value_to_write)
// A zero predicate indicates that we should skip the write operation.
if pred_value.is_zero() {
// We only want to write to already initialized memory.
// Do nothing if the predicate is zero.
return Ok(());
} else {
let value_to_write = get_value(&value_write, initial_witness)?;
self.write_memory_index(memory_index, value_to_write)
}
}
}
}
Expand All @@ -110,7 +129,7 @@ mod tests {

use acir::{
circuit::opcodes::MemOp,
native_types::{Witness, WitnessMap},
native_types::{Expression, Witness, WitnessMap},
FieldElement,
};

Expand All @@ -135,8 +154,9 @@ mod tests {
block_solver.init(&init, &initial_witness).unwrap();

for op in trace {
block_solver.solve_memory_op(&op, &mut initial_witness).unwrap();
block_solver.solve_memory_op(&op, &mut initial_witness, &None).unwrap();
}

assert_eq!(initial_witness[&Witness(4)], FieldElement::from(2u128));
}

Expand All @@ -159,9 +179,10 @@ mod tests {
let mut err = None;
for op in invalid_trace {
if err.is_none() {
err = block_solver.solve_memory_op(&op, &mut initial_witness).err();
err = block_solver.solve_memory_op(&op, &mut initial_witness, &None).err();
}
}

assert!(matches!(
err,
Some(crate::pwg::OpcodeResolutionError::IndexOutOfBounds {
Expand All @@ -171,4 +192,68 @@ mod tests {
})
));
}

#[test]
fn test_predicate_on_read() {
let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
(Witness(1), FieldElement::from(1u128)),
(Witness(2), FieldElement::from(1u128)),
(Witness(3), FieldElement::from(2u128)),
]));

let init = vec![Witness(1), Witness(2)];

let invalid_trace = vec![
MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
MemOp::read_at_mem_index(FieldElement::from(2u128).into(), Witness(4)),
];
let mut block_solver = MemoryOpSolver::default();
block_solver.init(&init, &initial_witness).unwrap();
let mut err = None;
for op in invalid_trace {
if err.is_none() {
err = block_solver
.solve_memory_op(&op, &mut initial_witness, &Some(Expression::zero()))
.err();
}
}

// Should have no index out of bounds error where predicate is zero
assert_eq!(err, None);
// The result of a read under a zero predicate should be zero
assert_eq!(initial_witness[&Witness(4)], FieldElement::from(0u128));
}

#[test]
fn test_predicate_on_write() {
let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
(Witness(1), FieldElement::from(1u128)),
(Witness(2), FieldElement::from(1u128)),
(Witness(3), FieldElement::from(2u128)),
]));

let init = vec![Witness(1), Witness(2)];

let invalid_trace = vec![
MemOp::write_to_mem_index(FieldElement::from(2u128).into(), Witness(3).into()),
MemOp::read_at_mem_index(FieldElement::from(0u128).into(), Witness(4).into()),
MemOp::read_at_mem_index(FieldElement::from(1u128).into(), Witness(5).into()),
];
let mut block_solver = MemoryOpSolver::default();
block_solver.init(&init, &initial_witness).unwrap();
let mut err = None;
for op in invalid_trace {
if err.is_none() {
err = block_solver
.solve_memory_op(&op, &mut initial_witness, &Some(Expression::zero()))
.err();
}
}

// Should have no index out of bounds error where predicate is zero
assert_eq!(err, None);
// The memory under a zero predicate should be zeroed out
assert_eq!(initial_witness[&Witness(4)], FieldElement::from(0u128));
assert_eq!(initial_witness[&Witness(5)], FieldElement::from(0u128));
}
}
4 changes: 2 additions & 2 deletions acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,9 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> {
let solver = self.block_solvers.entry(*block_id).or_default();
solver.init(init, &self.witness_map)
}
Opcode::MemoryOp { block_id, op } => {
Opcode::MemoryOp { block_id, op, predicate } => {
let solver = self.block_solvers.entry(*block_id).or_default();
solver.solve_memory_op(op, &mut self.witness_map)
solver.solve_memory_op(op, &mut self.witness_map, predicate)
}
Opcode::Brillig(brillig) => {
match BrilligSolver::solve(&mut self.witness_map, brillig, self.backend) {
Expand Down
7 changes: 5 additions & 2 deletions acvm/tests/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,11 @@ fn memory_operations() {

let init = Opcode::MemoryInit { block_id, init: (1..6).map(Witness).collect() };

let read_op =
Opcode::MemoryOp { block_id, op: MemOp::read_at_mem_index(Witness(6).into(), Witness(7)) };
let read_op = Opcode::MemoryOp {
block_id,
op: MemOp::read_at_mem_index(Witness(6).into(), Witness(7)),
predicate: None,
};

let expression = Opcode::Arithmetic(Expression {
mul_terms: Vec::new(),
Expand Down
16 changes: 16 additions & 0 deletions acvm_js/test/browser/execute_circuit.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,22 @@ it("successfully executes a SchnorrVerify opcode", async () => {
expect(solvedWitness).to.be.deep.eq(expectedWitnessMap);
});

it("successfully executes a MemoryOp opcode", async () => {
const { bytecode, initialWitnessMap, expectedWitnessMap } = await import(
"../shared/memory_op"
);

const solvedWitness: WitnessMap = await executeCircuit(
bytecode,
initialWitnessMap,
() => {
throw Error("unexpected oracle");
}
);

expect(solvedWitness).to.be.deep.eq(expectedWitnessMap);
});

it("successfully executes two circuits with same backend", async function () {
// chose pedersen op here because it is the one with slow initialization
// that led to the decision to pull backend initialization into a separate
Expand Down
16 changes: 16 additions & 0 deletions acvm_js/test/node/execute_circuit.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,22 @@ it("successfully executes a SchnorrVerify opcode", async () => {
expect(solvedWitness).to.be.deep.eq(expectedWitnessMap);
});

it("successfully executes a MemoryOp opcode", async () => {
const { bytecode, initialWitnessMap, expectedWitnessMap } = await import(
"../shared/memory_op"
);

const solvedWitness: WitnessMap = await executeCircuit(
bytecode,
initialWitnessMap,
() => {
throw Error("unexpected oracle");
}
);

expect(solvedWitness).to.be.deep.eq(expectedWitnessMap);
});

it("successfully executes two circuits with same backend", async function () {
this.timeout(10000);

Expand Down
21 changes: 21 additions & 0 deletions acvm_js/test/shared/memory_op.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// See `memory_op_circuit` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 146, 49, 14, 0, 32, 8, 3, 139, 192,
127, 240, 7, 254, 255, 85, 198, 136, 9, 131, 155, 48, 216, 165, 76, 77, 57,
80, 0, 140, 45, 117, 111, 238, 228, 179, 224, 174, 225, 110, 111, 234, 213,
185, 148, 156, 203, 121, 89, 86, 13, 215, 126, 131, 43, 153, 187, 115, 40,
185, 62, 153, 3, 136, 83, 60, 30, 96, 2, 12, 235, 225, 124, 14, 3, 0, 0,
]);

export const initialWitnessMap = new Map([
[1, "0x0000000000000000000000000000000000000000000000000000000000000001"],
[2, "0x0000000000000000000000000000000000000000000000000000000000000001"],
[3, "0x0000000000000000000000000000000000000000000000000000000000000002"],
]);

export const expectedWitnessMap = new Map([
[1, "0x0000000000000000000000000000000000000000000000000000000000000001"],
[2, "0x0000000000000000000000000000000000000000000000000000000000000001"],
[3, "0x0000000000000000000000000000000000000000000000000000000000000002"],
[4, "0x0000000000000000000000000000000000000000000000000000000000000002"],
]);

0 comments on commit ca9eebe

Please sign in to comment.