Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Save Brillig execution state in ACVM #3026

Merged
merged 9 commits into from
Oct 12, 2023
156 changes: 109 additions & 47 deletions acvm-repo/acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use acir::{
brillig::{ForeignCallParam, RegisterIndex, Value},
brillig::{ForeignCallParam, ForeignCallResult, RegisterIndex, Value},
circuit::{
brillig::{Brillig, BrilligInputs, BrilligOutputs},
OpcodeLocation,
Expand All @@ -14,28 +14,73 @@ use crate::{pwg::OpcodeNotSolvable, OpcodeResolutionError};

use super::{get_value, insert_value};

pub(super) struct BrilligSolver;
pub(super) enum BrilligSolverStatus {
Finished,
InProgress,
ForeignCallWait(ForeignCallWaitInfo),
}

impl BrilligSolver {
pub(super) fn solve<B: BlackBoxFunctionSolver>(
initial_witness: &mut WitnessMap,
brillig: &Brillig,
bb_solver: &B,
pub(super) struct BrilligSolver<'b, B: BlackBoxFunctionSolver> {
vm: VM<'b, B>,
acir_index: usize,
}

impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
/// Constructs a solver for a Brillig block given the bytecode and initial
/// witness. If the block should be skipped entirely because its predicate
/// evaluates to false, zero out the block outputs and return Ok(None).
pub(super) fn build_or_skip<'w>(
initial_witness: &'w mut WitnessMap,
brillig: &'w Brillig,
ggiraldez marked this conversation as resolved.
Show resolved Hide resolved
bb_solver: &'b B,
acir_index: usize,
) -> Result<Option<ForeignCallWaitInfo>, OpcodeResolutionError> {
) -> Result<Option<Self>, OpcodeResolutionError> {
ggiraldez marked this conversation as resolved.
Show resolved Hide resolved
if Self::should_skip(initial_witness, brillig)? {
Self::zero_out_brillig_outputs(initial_witness, brillig)?;
return Ok(None);
}

let vm = Self::build_vm(initial_witness, brillig, bb_solver)?;
Ok(Some(Self { vm, acir_index }))
}

fn should_skip(witness: &WitnessMap, brillig: &Brillig) -> Result<bool, OpcodeResolutionError> {
// If the predicate is `None`, then we simply return the value 1
// If the predicate is `Some` but we cannot find a value, then we return stalled
let pred_value = match &brillig.predicate {
Some(pred) => get_value(pred, initial_witness),
Some(pred) => get_value(pred, witness),
None => Ok(FieldElement::one()),
}?;

// A zero predicate indicates the oracle should be skipped, and its outputs zeroed.
if pred_value.is_zero() {
Self::zero_out_brillig_outputs(initial_witness, brillig)?;
return Ok(None);
Ok(pred_value.is_zero())
}

/// Assigns the zero value to all outputs of the given [`Brillig`] bytecode.
fn zero_out_brillig_outputs(
initial_witness: &mut WitnessMap,
brillig: &Brillig,
) -> Result<(), OpcodeResolutionError> {
for output in &brillig.outputs {
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, FieldElement::zero(), initial_witness)?;
}
BrilligOutputs::Array(witness_arr) => {
for witness in witness_arr {
insert_value(witness, FieldElement::zero(), initial_witness)?;
}
}
}
}
Ok(())
}

fn build_vm(
witness: &WitnessMap,
brillig: &Brillig,
bb_solver: &'b B,
) -> Result<VM<'b, B>, OpcodeResolutionError> {
// Set input values
let mut input_register_values: Vec<Value> = Vec::new();
let mut input_memory: Vec<Value> = Vec::new();
Expand All @@ -45,7 +90,7 @@ impl BrilligSolver {
// If a certain expression is not solvable, we stall the ACVM and do not proceed with Brillig VM execution.
for input in &brillig.inputs {
match input {
BrilligInputs::Single(expr) => match get_value(expr, initial_witness) {
BrilligInputs::Single(expr) => match get_value(expr, witness) {
Ok(value) => input_register_values.push(value.into()),
Err(_) => {
return Err(OpcodeResolutionError::OpcodeNotSolvable(
Expand All @@ -57,7 +102,7 @@ impl BrilligSolver {
// Attempt to fetch all array input values
let memory_pointer = input_memory.len();
for expr in expr_arr.iter() {
match get_value(expr, initial_witness) {
match get_value(expr, witness) {
Ok(value) => input_memory.push(value.into()),
Err(_) => {
return Err(OpcodeResolutionError::OpcodeNotSolvable(
Expand All @@ -76,78 +121,95 @@ impl BrilligSolver {
// Instantiate a Brillig VM given the solved input registers and memory
// along with the Brillig bytecode, and any present foreign call results.
let input_registers = Registers::load(input_register_values);
let mut vm = VM::new(
Ok(VM::new(
input_registers,
input_memory,
brillig.bytecode.clone(),
brillig.foreign_call_results.clone(),
bb_solver,
);
))
}

// Run the Brillig VM on these inputs, bytecode, etc!
let vm_status = vm.process_opcodes();
pub(super) fn solve(&mut self) -> Result<BrilligSolverStatus, OpcodeResolutionError> {
let status = self.vm.process_opcodes();
self.handle_vm_status(status)
}

// Check the status of the Brillig VM.
fn handle_vm_status(
&self,
vm_status: VMStatus,
) -> Result<BrilligSolverStatus, OpcodeResolutionError> {
// Check the status of the Brillig VM and return a resolution.
// It may be finished, in-progress, failed, or may be waiting for results of a foreign call.
// Return the "resolution" to the caller who may choose to make subsequent calls
// (when it gets foreign call results for example).
match vm_status {
VMStatus::Finished => {
for (i, output) in brillig.outputs.iter().enumerate() {
let register_value = vm.get_registers().get(RegisterIndex::from(i));
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, register_value.to_field(), initial_witness)?;
}
BrilligOutputs::Array(witness_arr) => {
// Treat the register value as a pointer to memory
for (i, witness) in witness_arr.iter().enumerate() {
let value = &vm.get_memory()[register_value.to_usize() + i];
insert_value(witness, value.to_field(), initial_witness)?;
}
}
}
}
Ok(None)
}
VMStatus::InProgress => unreachable!("Brillig VM has not completed execution"),
VMStatus::Finished => Ok(BrilligSolverStatus::Finished),
VMStatus::InProgress => Ok(BrilligSolverStatus::InProgress),
VMStatus::Failure { message, call_stack } => {
Err(OpcodeResolutionError::BrilligFunctionFailed {
message,
call_stack: call_stack
.iter()
.map(|brillig_index| OpcodeLocation::Brillig {
acir_index,
acir_index: self.acir_index,
brillig_index: *brillig_index,
})
.collect(),
})
}
VMStatus::ForeignCallWait { function, inputs } => {
Ok(Some(ForeignCallWaitInfo { function, inputs }))
Ok(BrilligSolverStatus::ForeignCallWait(ForeignCallWaitInfo { function, inputs }))
}
}
}

/// Assigns the zero value to all outputs of the given [`Brillig`] bytecode.
fn zero_out_brillig_outputs(
initial_witness: &mut WitnessMap,
pub(super) fn finalize(
self,
witness: &mut WitnessMap,
brillig: &Brillig,
) -> Result<(), OpcodeResolutionError> {
for output in &brillig.outputs {
// Finish the Brillig execution by writing the outputs to the witness map
let vm_status = self.vm.get_status();
match vm_status {
VMStatus::Finished => {
self.write_brillig_outputs(witness, brillig)?;
Ok(())
}
_ => panic!("Brillig VM has not completed execution"),
}
}

fn write_brillig_outputs(
&self,
witness_map: &mut WitnessMap,
brillig: &Brillig,
) -> Result<(), OpcodeResolutionError> {
// Write VM execution results into the witness map
for (i, output) in brillig.outputs.iter().enumerate() {
let register_value = self.vm.get_registers().get(RegisterIndex::from(i));
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, FieldElement::zero(), initial_witness)?;
insert_value(witness, register_value.to_field(), witness_map)?;
}
BrilligOutputs::Array(witness_arr) => {
for witness in witness_arr {
insert_value(witness, FieldElement::zero(), initial_witness)?;
// Treat the register value as a pointer to memory
for (i, witness) in witness_arr.iter().enumerate() {
let value = &self.vm.get_memory()[register_value.to_usize() + i];
insert_value(witness, value.to_field(), witness_map)?;
}
}
}
}
Ok(())
}

pub(super) fn resolve_pending_foreign_call(&mut self, foreign_call_result: ForeignCallResult) {
match self.vm.get_status() {
VMStatus::ForeignCallWait { .. } => self.vm.resolve_foreign_call(foreign_call_result),
_ => unreachable!("Brillig VM is not waiting for a foreign call"),
}
}
}

/// Encapsulates a request from a Brillig VM process that encounters a [foreign call opcode][acir::brillig_vm::Opcode::ForeignCall]
Expand Down
51 changes: 37 additions & 14 deletions acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use acir::{
use acvm_blackbox_solver::BlackBoxResolutionError;

use self::{
arithmetic::ArithmeticSolver, brillig::BrilligSolver, directives::solve_directives,
arithmetic::ArithmeticSolver,
brillig::{BrilligSolver, BrilligSolverStatus},
directives::solve_directives,
memory_op::MemoryOpSolver,
};
use crate::{BlackBoxFunctionSolver, Language};
Expand Down Expand Up @@ -140,6 +142,8 @@ pub struct ACVM<'backend, B: BlackBoxFunctionSolver> {
instruction_pointer: usize,

witness_map: WitnessMap,

brillig_solver: Option<BrilligSolver<'backend, B>>,
}

impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> {
Expand All @@ -152,6 +156,7 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> {
opcodes,
instruction_pointer: 0,
witness_map: initial_witness,
brillig_solver: None,
}
}

Expand Down Expand Up @@ -216,12 +221,8 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> {
panic!("ACVM is not expecting a foreign call response as no call was made");
}

// We want to inject the foreign call result into the brillig opcode which initiated the call.
let opcode = &mut self.opcodes[self.instruction_pointer];
let Opcode::Brillig(brillig) = opcode else {
unreachable!("ACVM can only enter `RequiresForeignCall` state on a Brillig opcode");
};
brillig.foreign_call_results.push(foreign_call_result);
let brillig_solver = self.brillig_solver.as_mut().expect("No active Brillig solver");
brillig_solver.resolve_pending_foreign_call(foreign_call_result);

// Now that the foreign call has been resolved then we can resume execution.
self.status(ACVMStatus::InProgress);
Expand Down Expand Up @@ -258,13 +259,35 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> {
solver.solve_memory_op(op, &mut self.witness_map, predicate)
}
Opcode::Brillig(brillig) => {
match BrilligSolver::solve(
&mut self.witness_map,
brillig,
self.backend,
self.instruction_pointer,
) {
Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call),
let witness = &mut self.witness_map;
// get the active Brillig solver, or try to build one if necessary
// (Brillig execution maybe bypassed by constraints)
let maybe_solver = match self.brillig_solver.as_mut() {
Some(solver) => Ok(Some(solver)),
None => BrilligSolver::build_or_skip(
witness,
brillig,
self.backend,
self.instruction_pointer,
)
.map(|optional_solver| {
optional_solver.map(|solver| self.brillig_solver.insert(solver))
}),
};
match maybe_solver {
Ok(Some(solver)) => match solver.solve() {
Ok(BrilligSolverStatus::ForeignCallWait(foreign_call)) => {
return self.wait_for_foreign_call(foreign_call);
}
Ok(BrilligSolverStatus::InProgress) => {
unreachable!("Brillig solver still in progress")
}
Ok(BrilligSolverStatus::Finished) => {
// clear active Brillig solver and write execution outputs
self.brillig_solver.take().unwrap().finalize(witness, brillig)
}
res => res.map(|_| ()),
},
res => res.map(|_| ()),
}
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
}
Expand Down
22 changes: 17 additions & 5 deletions acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> {
status
}

pub fn get_status(&self) -> VMStatus {
self.status.clone()
}

/// Sets the current status of the VM to Finished (completed execution).
fn finish(&mut self) -> VMStatus {
self.status(VMStatus::Finished)
Expand All @@ -127,6 +131,14 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> {
self.status(VMStatus::ForeignCallWait { function, inputs })
}

pub fn resolve_foreign_call(&mut self, foreign_call_result: ForeignCallResult) {
if self.foreign_call_counter < self.foreign_call_results.len() {
panic!("No unresolved foreign calls");
}
self.foreign_call_results.push(foreign_call_result);
self.status(VMStatus::InProgress);
}
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved

/// Sets the current status of the VM to `fail`.
/// Indicating that the VM encountered a `Trap` Opcode
/// or an invalid state.
Expand Down Expand Up @@ -929,7 +941,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(
vm.resolve_foreign_call(
Value::from(10u128).into(), // Result of doubling 5u128
);

Expand Down Expand Up @@ -990,7 +1002,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(expected_result.clone().into());
vm.resolve_foreign_call(expected_result.clone().into());

// Resume VM
brillig_execute(&mut vm);
Expand Down Expand Up @@ -1063,7 +1075,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(ForeignCallResult {
vm.resolve_foreign_call(ForeignCallResult {
values: vec![ForeignCallParam::Array(output_string.clone())],
});

Expand Down Expand Up @@ -1125,7 +1137,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(expected_result.clone().into());
vm.resolve_foreign_call(expected_result.clone().into());

// Resume VM
brillig_execute(&mut vm);
Expand Down Expand Up @@ -1210,7 +1222,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(expected_result.clone().into());
vm.resolve_foreign_call(expected_result.clone().into());

// Resume VM
brillig_execute(&mut vm);
Expand Down