Skip to content

Commit

Permalink
feat: Save Brillig execution state in ACVM (#3026)
Browse files Browse the repository at this point in the history
Co-authored-by: Tom French <[email protected]>
  • Loading branch information
ggiraldez and TomAFrench authored Oct 12, 2023
1 parent c0a082e commit 88682da
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 87 deletions.
154 changes: 98 additions & 56 deletions acvm-repo/acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,59 @@ use crate::{pwg::OpcodeNotSolvable, OpcodeResolutionError};

use super::{get_value, insert_value};

pub(super) struct BrilligSolver;
pub(super) enum BrilligSolverStatus {
Finished,
InProgress,
ForeignCallWait(ForeignCallWaitInfo),
}

impl BrilligSolver {
pub(super) fn solve<B: BlackBoxFunctionSolver>(
initial_witness: &mut WitnessMap,
pub(super) struct BrilligSolver<'b, B: BlackBoxFunctionSolver> {
vm: VM<'b, B>,
acir_index: usize,
}

impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
/// Evaluates if the Brillig block should be skipped entirely
pub(super) fn should_skip(
witness: &WitnessMap,
brillig: &Brillig,
foreign_call_results: Vec<ForeignCallResult>,
bb_solver: &B,
acir_index: usize,
) -> Result<Option<ForeignCallWaitInfo>, OpcodeResolutionError> {
// If the predicate is `None`, then we simply return the value 1
) -> Result<bool, OpcodeResolutionError> {
// If the predicate is `None`, the block should never be skipped
// If the predicate is `Some` but we cannot find a value, then we return stalled
let pred_value = match &brillig.predicate {
Some(pred) => get_value(pred, initial_witness),
None => Ok(FieldElement::one()),
}?;
match &brillig.predicate {
Some(pred) => Ok(get_value(pred, witness)?.is_zero()),
None => Ok(false),
}
}

// A zero predicate indicates the oracle should be skipped, and its outputs zeroed.
if pred_value.is_zero() {
Self::zero_out_brillig_outputs(initial_witness, brillig)?;
return Ok(None);
/// Assigns the zero value to all outputs of the given [`Brillig`] bytecode.
pub(super) fn zero_out_brillig_outputs(
initial_witness: &mut WitnessMap,
brillig: &Brillig,
) -> Result<(), OpcodeResolutionError> {
for output in &brillig.outputs {
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, FieldElement::zero(), initial_witness)?;
}
BrilligOutputs::Array(witness_arr) => {
for witness in witness_arr {
insert_value(witness, FieldElement::zero(), initial_witness)?;
}
}
}
}
Ok(())
}

/// Constructs a solver for a Brillig block given the bytecode and initial
/// witness.
pub(super) fn new(
initial_witness: &mut WitnessMap,
brillig: &'b Brillig,
bb_solver: &'b B,
acir_index: usize,
) -> Result<Self, OpcodeResolutionError> {
// Set input values
let mut input_register_values: Vec<Value> = Vec::new();
let mut input_memory: Vec<Value> = Vec::new();
Expand Down Expand Up @@ -75,80 +105,92 @@ impl BrilligSolver {
}

// Instantiate a Brillig VM given the solved input registers and memory
// along with the Brillig bytecode, and any present foreign call results.
// along with the Brillig bytecode.
let input_registers = Registers::load(input_register_values);
let mut vm = VM::new(
input_registers,
input_memory,
&brillig.bytecode,
foreign_call_results,
bb_solver,
);
let vm = VM::new(input_registers, input_memory, &brillig.bytecode, vec![], bb_solver);
Ok(Self { vm, acir_index })
}

// Run the Brillig VM on these inputs, bytecode, etc!
let vm_status = vm.process_opcodes();
pub(super) fn solve(&mut self) -> Result<BrilligSolverStatus, OpcodeResolutionError> {
let status = self.vm.process_opcodes();
self.handle_vm_status(status)
}

// Check the status of the Brillig VM.
fn handle_vm_status(
&self,
vm_status: VMStatus,
) -> Result<BrilligSolverStatus, OpcodeResolutionError> {
// Check the status of the Brillig VM and return a resolution.
// It may be finished, in-progress, failed, or may be waiting for results of a foreign call.
// Return the "resolution" to the caller who may choose to make subsequent calls
// (when it gets foreign call results for example).
match vm_status {
VMStatus::Finished => {
for (i, output) in brillig.outputs.iter().enumerate() {
let register_value = vm.get_registers().get(RegisterIndex::from(i));
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, register_value.to_field(), initial_witness)?;
}
BrilligOutputs::Array(witness_arr) => {
// Treat the register value as a pointer to memory
for (i, witness) in witness_arr.iter().enumerate() {
let value = &vm.get_memory()[register_value.to_usize() + i];
insert_value(witness, value.to_field(), initial_witness)?;
}
}
}
}
Ok(None)
}
VMStatus::InProgress => unreachable!("Brillig VM has not completed execution"),
VMStatus::Finished => Ok(BrilligSolverStatus::Finished),
VMStatus::InProgress => Ok(BrilligSolverStatus::InProgress),
VMStatus::Failure { message, call_stack } => {
Err(OpcodeResolutionError::BrilligFunctionFailed {
message,
call_stack: call_stack
.iter()
.map(|brillig_index| OpcodeLocation::Brillig {
acir_index,
acir_index: self.acir_index,
brillig_index: *brillig_index,
})
.collect(),
})
}
VMStatus::ForeignCallWait { function, inputs } => {
Ok(Some(ForeignCallWaitInfo { function, inputs }))
Ok(BrilligSolverStatus::ForeignCallWait(ForeignCallWaitInfo { function, inputs }))
}
}
}

/// Assigns the zero value to all outputs of the given [`Brillig`] bytecode.
fn zero_out_brillig_outputs(
initial_witness: &mut WitnessMap,
pub(super) fn finalize(
self,
witness: &mut WitnessMap,
brillig: &Brillig,
) -> Result<(), OpcodeResolutionError> {
for output in &brillig.outputs {
// Finish the Brillig execution by writing the outputs to the witness map
let vm_status = self.vm.get_status();
match vm_status {
VMStatus::Finished => {
self.write_brillig_outputs(witness, brillig)?;
Ok(())
}
_ => panic!("Brillig VM has not completed execution"),
}
}

fn write_brillig_outputs(
&self,
witness_map: &mut WitnessMap,
brillig: &Brillig,
) -> Result<(), OpcodeResolutionError> {
// Write VM execution results into the witness map
for (i, output) in brillig.outputs.iter().enumerate() {
let register_value = self.vm.get_registers().get(RegisterIndex::from(i));
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, FieldElement::zero(), initial_witness)?;
insert_value(witness, register_value.to_field(), witness_map)?;
}
BrilligOutputs::Array(witness_arr) => {
for witness in witness_arr {
insert_value(witness, FieldElement::zero(), initial_witness)?;
// Treat the register value as a pointer to memory
for (i, witness) in witness_arr.iter().enumerate() {
let value = &self.vm.get_memory()[register_value.to_usize() + i];
insert_value(witness, value.to_field(), witness_map)?;
}
}
}
}
Ok(())
}

pub(super) fn resolve_pending_foreign_call(&mut self, foreign_call_result: ForeignCallResult) {
match self.vm.get_status() {
VMStatus::ForeignCallWait { .. } => self.vm.resolve_foreign_call(foreign_call_result),
_ => unreachable!("Brillig VM is not waiting for a foreign call"),
}
}
}

/// Encapsulates a request from a Brillig VM process that encounters a [foreign call opcode][acir::brillig_vm::Opcode::ForeignCall]
Expand Down
73 changes: 47 additions & 26 deletions acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use acir::{
use acvm_blackbox_solver::BlackBoxResolutionError;

use self::{
arithmetic::ArithmeticSolver, brillig::BrilligSolver, directives::solve_directives,
arithmetic::ArithmeticSolver,
brillig::{BrilligSolver, BrilligSolverStatus},
directives::solve_directives,
memory_op::MemoryOpSolver,
};
use crate::{BlackBoxFunctionSolver, Language};
Expand Down Expand Up @@ -141,9 +143,7 @@ pub struct ACVM<'a, B: BlackBoxFunctionSolver> {

witness_map: WitnessMap,

/// Results of oracles/functions external to brillig like a database read.
// Each element of this vector corresponds to a single foreign call but may contain several values.
foreign_call_results: HashMap<usize, Vec<ForeignCallResult>>,
brillig_solver: Option<BrilligSolver<'a, B>>,
}

impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
Expand All @@ -156,7 +156,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
opcodes,
instruction_pointer: 0,
witness_map: initial_witness,
foreign_call_results: HashMap::default(),
brillig_solver: None,
}
}

Expand Down Expand Up @@ -221,10 +221,8 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
panic!("ACVM is not expecting a foreign call response as no call was made");
}

// We want to inject the foreign call result into the brillig opcode which initiated the call.
let foreign_call_results =
self.foreign_call_results.entry(self.instruction_pointer).or_default();
foreign_call_results.push(foreign_call_result);
let brillig_solver = self.brillig_solver.as_mut().expect("No active Brillig solver");
brillig_solver.resolve_pending_foreign_call(foreign_call_result);

// Now that the foreign call has been resolved then we can resume execution.
self.status(ACVMStatus::InProgress);
Expand Down Expand Up @@ -260,23 +258,10 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
let solver = self.block_solvers.entry(*block_id).or_default();
solver.solve_memory_op(op, &mut self.witness_map, predicate)
}
Opcode::Brillig(brillig) => {
let foreign_call_results = self
.foreign_call_results
.get(&self.instruction_pointer)
.cloned()
.unwrap_or_default();
match BrilligSolver::solve(
&mut self.witness_map,
brillig,
foreign_call_results,
self.backend,
self.instruction_pointer,
) {
Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call),
res => res.map(|_| ()),
}
}
Opcode::Brillig(_) => match self.solve_brillig_opcode() {
Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call),
res => res.map(|_| ()),
},
};
match resolution {
Ok(()) => {
Expand Down Expand Up @@ -310,6 +295,42 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
}
}
}

fn solve_brillig_opcode(
&mut self,
) -> Result<Option<ForeignCallWaitInfo>, OpcodeResolutionError> {
let Opcode::Brillig(brillig) = &self.opcodes[self.instruction_pointer] else {
unreachable!("Not executing a Brillig opcode");
};
let witness = &mut self.witness_map;
if BrilligSolver::<B>::should_skip(witness, brillig)? {
BrilligSolver::<B>::zero_out_brillig_outputs(witness, brillig).map(|_| None)
} else {
// If we're resuming execution after resolving a foreign call then
// there will be a cached `BrilligSolver` to avoid recomputation.
let mut solver: BrilligSolver<'_, B> = match self.brillig_solver.take() {
Some(solver) => solver,
None => {
BrilligSolver::new(witness, brillig, self.backend, self.instruction_pointer)?
}
};
match solver.solve()? {
BrilligSolverStatus::ForeignCallWait(foreign_call) => {
// Cache the current state of the solver
self.brillig_solver = Some(solver);
Ok(Some(foreign_call))
}
BrilligSolverStatus::InProgress => {
unreachable!("Brillig solver still in progress")
}
BrilligSolverStatus::Finished => {
// Write execution outputs
solver.finalize(witness, brillig)?;
Ok(None)
}
}
}
}
}

// Returns the concrete value for a particular witness
Expand Down
22 changes: 17 additions & 5 deletions acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> {
status
}

pub fn get_status(&self) -> VMStatus {
self.status.clone()
}

/// Sets the current status of the VM to Finished (completed execution).
fn finish(&mut self) -> VMStatus {
self.status(VMStatus::Finished)
Expand All @@ -127,6 +131,14 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> {
self.status(VMStatus::ForeignCallWait { function, inputs })
}

pub fn resolve_foreign_call(&mut self, foreign_call_result: ForeignCallResult) {
if self.foreign_call_counter < self.foreign_call_results.len() {
panic!("No unresolved foreign calls");
}
self.foreign_call_results.push(foreign_call_result);
self.status(VMStatus::InProgress);
}

/// Sets the current status of the VM to `fail`.
/// Indicating that the VM encountered a `Trap` Opcode
/// or an invalid state.
Expand Down Expand Up @@ -926,7 +938,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(
vm.resolve_foreign_call(
Value::from(10u128).into(), // Result of doubling 5u128
);

Expand Down Expand Up @@ -987,7 +999,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(expected_result.clone().into());
vm.resolve_foreign_call(expected_result.clone().into());

// Resume VM
brillig_execute(&mut vm);
Expand Down Expand Up @@ -1060,7 +1072,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(ForeignCallResult {
vm.resolve_foreign_call(ForeignCallResult {
values: vec![ForeignCallParam::Array(output_string.clone())],
});

Expand Down Expand Up @@ -1122,7 +1134,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(expected_result.clone().into());
vm.resolve_foreign_call(expected_result.clone().into());

// Resume VM
brillig_execute(&mut vm);
Expand Down Expand Up @@ -1207,7 +1219,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(expected_result.clone().into());
vm.resolve_foreign_call(expected_result.clone().into());

// Resume VM
brillig_execute(&mut vm);
Expand Down

0 comments on commit 88682da

Please sign in to comment.