From 88682da87ffc9e26da5c9e4b5a4d8e62a6ee43c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20Gir=C3=A1ldez?= Date: Wed, 11 Oct 2023 20:33:42 -0400 Subject: [PATCH] feat: Save Brillig execution state in ACVM (#3026) Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com> --- acvm-repo/acvm/src/pwg/brillig.rs | 154 +++++++++++++++++++----------- acvm-repo/acvm/src/pwg/mod.rs | 73 +++++++++----- acvm-repo/brillig_vm/src/lib.rs | 22 ++++- 3 files changed, 162 insertions(+), 87 deletions(-) diff --git a/acvm-repo/acvm/src/pwg/brillig.rs b/acvm-repo/acvm/src/pwg/brillig.rs index 732d9c8c8e5..6fc54d42eab 100644 --- a/acvm-repo/acvm/src/pwg/brillig.rs +++ b/acvm-repo/acvm/src/pwg/brillig.rs @@ -14,29 +14,59 @@ use crate::{pwg::OpcodeNotSolvable, OpcodeResolutionError}; use super::{get_value, insert_value}; -pub(super) struct BrilligSolver; +pub(super) enum BrilligSolverStatus { + Finished, + InProgress, + ForeignCallWait(ForeignCallWaitInfo), +} -impl BrilligSolver { - pub(super) fn solve( - initial_witness: &mut WitnessMap, +pub(super) struct BrilligSolver<'b, B: BlackBoxFunctionSolver> { + vm: VM<'b, B>, + acir_index: usize, +} + +impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { + /// Evaluates if the Brillig block should be skipped entirely + pub(super) fn should_skip( + witness: &WitnessMap, brillig: &Brillig, - foreign_call_results: Vec, - bb_solver: &B, - acir_index: usize, - ) -> Result, OpcodeResolutionError> { - // If the predicate is `None`, then we simply return the value 1 + ) -> Result { + // If the predicate is `None`, the block should never be skipped // If the predicate is `Some` but we cannot find a value, then we return stalled - let pred_value = match &brillig.predicate { - Some(pred) => get_value(pred, initial_witness), - None => Ok(FieldElement::one()), - }?; + match &brillig.predicate { + Some(pred) => Ok(get_value(pred, witness)?.is_zero()), + None => Ok(false), + } + } - // A zero predicate indicates the oracle should be skipped, and its outputs zeroed. - if pred_value.is_zero() { - Self::zero_out_brillig_outputs(initial_witness, brillig)?; - return Ok(None); + /// Assigns the zero value to all outputs of the given [`Brillig`] bytecode. + pub(super) fn zero_out_brillig_outputs( + initial_witness: &mut WitnessMap, + brillig: &Brillig, + ) -> Result<(), OpcodeResolutionError> { + for output in &brillig.outputs { + match output { + BrilligOutputs::Simple(witness) => { + insert_value(witness, FieldElement::zero(), initial_witness)?; + } + BrilligOutputs::Array(witness_arr) => { + for witness in witness_arr { + insert_value(witness, FieldElement::zero(), initial_witness)?; + } + } + } } + Ok(()) + } + /// Constructs a solver for a Brillig block given the bytecode and initial + /// witness. + pub(super) fn new( + initial_witness: &mut WitnessMap, + brillig: &'b Brillig, + bb_solver: &'b B, + acir_index: usize, + ) -> Result { // Set input values let mut input_register_values: Vec = Vec::new(); let mut input_memory: Vec = Vec::new(); @@ -75,80 +105,92 @@ impl BrilligSolver { } // Instantiate a Brillig VM given the solved input registers and memory - // along with the Brillig bytecode, and any present foreign call results. + // along with the Brillig bytecode. let input_registers = Registers::load(input_register_values); - let mut vm = VM::new( - input_registers, - input_memory, - &brillig.bytecode, - foreign_call_results, - bb_solver, - ); + let vm = VM::new(input_registers, input_memory, &brillig.bytecode, vec![], bb_solver); + Ok(Self { vm, acir_index }) + } - // Run the Brillig VM on these inputs, bytecode, etc! - let vm_status = vm.process_opcodes(); + pub(super) fn solve(&mut self) -> Result { + let status = self.vm.process_opcodes(); + self.handle_vm_status(status) + } - // Check the status of the Brillig VM. + fn handle_vm_status( + &self, + vm_status: VMStatus, + ) -> Result { + // Check the status of the Brillig VM and return a resolution. // It may be finished, in-progress, failed, or may be waiting for results of a foreign call. // Return the "resolution" to the caller who may choose to make subsequent calls // (when it gets foreign call results for example). match vm_status { - VMStatus::Finished => { - for (i, output) in brillig.outputs.iter().enumerate() { - let register_value = vm.get_registers().get(RegisterIndex::from(i)); - match output { - BrilligOutputs::Simple(witness) => { - insert_value(witness, register_value.to_field(), initial_witness)?; - } - BrilligOutputs::Array(witness_arr) => { - // Treat the register value as a pointer to memory - for (i, witness) in witness_arr.iter().enumerate() { - let value = &vm.get_memory()[register_value.to_usize() + i]; - insert_value(witness, value.to_field(), initial_witness)?; - } - } - } - } - Ok(None) - } - VMStatus::InProgress => unreachable!("Brillig VM has not completed execution"), + VMStatus::Finished => Ok(BrilligSolverStatus::Finished), + VMStatus::InProgress => Ok(BrilligSolverStatus::InProgress), VMStatus::Failure { message, call_stack } => { Err(OpcodeResolutionError::BrilligFunctionFailed { message, call_stack: call_stack .iter() .map(|brillig_index| OpcodeLocation::Brillig { - acir_index, + acir_index: self.acir_index, brillig_index: *brillig_index, }) .collect(), }) } VMStatus::ForeignCallWait { function, inputs } => { - Ok(Some(ForeignCallWaitInfo { function, inputs })) + Ok(BrilligSolverStatus::ForeignCallWait(ForeignCallWaitInfo { function, inputs })) } } } - /// Assigns the zero value to all outputs of the given [`Brillig`] bytecode. - fn zero_out_brillig_outputs( - initial_witness: &mut WitnessMap, + pub(super) fn finalize( + self, + witness: &mut WitnessMap, brillig: &Brillig, ) -> Result<(), OpcodeResolutionError> { - for output in &brillig.outputs { + // Finish the Brillig execution by writing the outputs to the witness map + let vm_status = self.vm.get_status(); + match vm_status { + VMStatus::Finished => { + self.write_brillig_outputs(witness, brillig)?; + Ok(()) + } + _ => panic!("Brillig VM has not completed execution"), + } + } + + fn write_brillig_outputs( + &self, + witness_map: &mut WitnessMap, + brillig: &Brillig, + ) -> Result<(), OpcodeResolutionError> { + // Write VM execution results into the witness map + for (i, output) in brillig.outputs.iter().enumerate() { + let register_value = self.vm.get_registers().get(RegisterIndex::from(i)); match output { BrilligOutputs::Simple(witness) => { - insert_value(witness, FieldElement::zero(), initial_witness)?; + insert_value(witness, register_value.to_field(), witness_map)?; } BrilligOutputs::Array(witness_arr) => { - for witness in witness_arr { - insert_value(witness, FieldElement::zero(), initial_witness)?; + // Treat the register value as a pointer to memory + for (i, witness) in witness_arr.iter().enumerate() { + let value = &self.vm.get_memory()[register_value.to_usize() + i]; + insert_value(witness, value.to_field(), witness_map)?; } } } } Ok(()) } + + pub(super) fn resolve_pending_foreign_call(&mut self, foreign_call_result: ForeignCallResult) { + match self.vm.get_status() { + VMStatus::ForeignCallWait { .. } => self.vm.resolve_foreign_call(foreign_call_result), + _ => unreachable!("Brillig VM is not waiting for a foreign call"), + } + } } /// Encapsulates a request from a Brillig VM process that encounters a [foreign call opcode][acir::brillig_vm::Opcode::ForeignCall] diff --git a/acvm-repo/acvm/src/pwg/mod.rs b/acvm-repo/acvm/src/pwg/mod.rs index 4fcd6b24a7b..057597e6392 100644 --- a/acvm-repo/acvm/src/pwg/mod.rs +++ b/acvm-repo/acvm/src/pwg/mod.rs @@ -11,7 +11,9 @@ use acir::{ use acvm_blackbox_solver::BlackBoxResolutionError; use self::{ - arithmetic::ArithmeticSolver, brillig::BrilligSolver, directives::solve_directives, + arithmetic::ArithmeticSolver, + brillig::{BrilligSolver, BrilligSolverStatus}, + directives::solve_directives, memory_op::MemoryOpSolver, }; use crate::{BlackBoxFunctionSolver, Language}; @@ -141,9 +143,7 @@ pub struct ACVM<'a, B: BlackBoxFunctionSolver> { witness_map: WitnessMap, - /// Results of oracles/functions external to brillig like a database read. - // Each element of this vector corresponds to a single foreign call but may contain several values. - foreign_call_results: HashMap>, + brillig_solver: Option>, } impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { @@ -156,7 +156,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { opcodes, instruction_pointer: 0, witness_map: initial_witness, - foreign_call_results: HashMap::default(), + brillig_solver: None, } } @@ -221,10 +221,8 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { panic!("ACVM is not expecting a foreign call response as no call was made"); } - // We want to inject the foreign call result into the brillig opcode which initiated the call. - let foreign_call_results = - self.foreign_call_results.entry(self.instruction_pointer).or_default(); - foreign_call_results.push(foreign_call_result); + let brillig_solver = self.brillig_solver.as_mut().expect("No active Brillig solver"); + brillig_solver.resolve_pending_foreign_call(foreign_call_result); // Now that the foreign call has been resolved then we can resume execution. self.status(ACVMStatus::InProgress); @@ -260,23 +258,10 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { let solver = self.block_solvers.entry(*block_id).or_default(); solver.solve_memory_op(op, &mut self.witness_map, predicate) } - Opcode::Brillig(brillig) => { - let foreign_call_results = self - .foreign_call_results - .get(&self.instruction_pointer) - .cloned() - .unwrap_or_default(); - match BrilligSolver::solve( - &mut self.witness_map, - brillig, - foreign_call_results, - self.backend, - self.instruction_pointer, - ) { - Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call), - res => res.map(|_| ()), - } - } + Opcode::Brillig(_) => match self.solve_brillig_opcode() { + Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call), + res => res.map(|_| ()), + }, }; match resolution { Ok(()) => { @@ -310,6 +295,42 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { } } } + + fn solve_brillig_opcode( + &mut self, + ) -> Result, OpcodeResolutionError> { + let Opcode::Brillig(brillig) = &self.opcodes[self.instruction_pointer] else { + unreachable!("Not executing a Brillig opcode"); + }; + let witness = &mut self.witness_map; + if BrilligSolver::::should_skip(witness, brillig)? { + BrilligSolver::::zero_out_brillig_outputs(witness, brillig).map(|_| None) + } else { + // If we're resuming execution after resolving a foreign call then + // there will be a cached `BrilligSolver` to avoid recomputation. + let mut solver: BrilligSolver<'_, B> = match self.brillig_solver.take() { + Some(solver) => solver, + None => { + BrilligSolver::new(witness, brillig, self.backend, self.instruction_pointer)? + } + }; + match solver.solve()? { + BrilligSolverStatus::ForeignCallWait(foreign_call) => { + // Cache the current state of the solver + self.brillig_solver = Some(solver); + Ok(Some(foreign_call)) + } + BrilligSolverStatus::InProgress => { + unreachable!("Brillig solver still in progress") + } + BrilligSolverStatus::Finished => { + // Write execution outputs + solver.finalize(witness, brillig)?; + Ok(None) + } + } + } + } } // Returns the concrete value for a particular witness diff --git a/acvm-repo/brillig_vm/src/lib.rs b/acvm-repo/brillig_vm/src/lib.rs index 6da34c6a498..48f6bf5f1c4 100644 --- a/acvm-repo/brillig_vm/src/lib.rs +++ b/acvm-repo/brillig_vm/src/lib.rs @@ -112,6 +112,10 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { status } + pub fn get_status(&self) -> VMStatus { + self.status.clone() + } + /// Sets the current status of the VM to Finished (completed execution). fn finish(&mut self) -> VMStatus { self.status(VMStatus::Finished) @@ -127,6 +131,14 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { self.status(VMStatus::ForeignCallWait { function, inputs }) } + pub fn resolve_foreign_call(&mut self, foreign_call_result: ForeignCallResult) { + if self.foreign_call_counter < self.foreign_call_results.len() { + panic!("No unresolved foreign calls"); + } + self.foreign_call_results.push(foreign_call_result); + self.status(VMStatus::InProgress); + } + /// Sets the current status of the VM to `fail`. /// Indicating that the VM encountered a `Trap` Opcode /// or an invalid state. @@ -926,7 +938,7 @@ mod tests { ); // Push result we're waiting for - vm.foreign_call_results.push( + vm.resolve_foreign_call( Value::from(10u128).into(), // Result of doubling 5u128 ); @@ -987,7 +999,7 @@ mod tests { ); // Push result we're waiting for - vm.foreign_call_results.push(expected_result.clone().into()); + vm.resolve_foreign_call(expected_result.clone().into()); // Resume VM brillig_execute(&mut vm); @@ -1060,7 +1072,7 @@ mod tests { ); // Push result we're waiting for - vm.foreign_call_results.push(ForeignCallResult { + vm.resolve_foreign_call(ForeignCallResult { values: vec![ForeignCallParam::Array(output_string.clone())], }); @@ -1122,7 +1134,7 @@ mod tests { ); // Push result we're waiting for - vm.foreign_call_results.push(expected_result.clone().into()); + vm.resolve_foreign_call(expected_result.clone().into()); // Resume VM brillig_execute(&mut vm); @@ -1207,7 +1219,7 @@ mod tests { ); // Push result we're waiting for - vm.foreign_call_results.push(expected_result.clone().into()); + vm.resolve_foreign_call(expected_result.clone().into()); // Resume VM brillig_execute(&mut vm);