diff --git a/acvm/src/lib.rs b/acvm/src/lib.rs index bbd2ecfc2..dc6b55f95 100644 --- a/acvm/src/lib.rs +++ b/acvm/src/lib.rs @@ -9,9 +9,10 @@ pub mod pwg; use crate::pwg::{arithmetic::ArithmeticSolver, brillig::BrilligSolver, oracle::OracleSolver}; use acir::{ + brillig_bytecode, circuit::{ directives::Directive, - opcodes::{BlackBoxFuncCall, OracleData}, + opcodes::{BlackBoxFuncCall, Brillig, OracleData}, Circuit, Opcode, }, native_types::{Expression, Witness}, @@ -49,7 +50,7 @@ pub enum OpcodeResolutionError { #[error("could not satisfy all constraints")] UnsatisfiedConstrain, #[error("unexpected opcode, expected {0}, but got {1}")] - UnexpectedOpcode(&'static str, BlackBoxFunc), + UnexpectedOpcode(&'static str, &'static str), #[error("expected {0} inputs for function {1}, but got {2}")] IncorrectNumFunctionArguments(usize, BlackBoxFunc, usize), } @@ -62,6 +63,8 @@ pub enum OpcodeResolution { Stalled(OpcodeNotSolvable), /// The opcode is not solvable but could resolved some witness InProgress, + /// The brillig oracle opcode is not solved but could be resolved given some values + InProgessBrillig(brillig_bytecode::OracleData), } pub trait Backend: SmartContract + ProofSystemCompiler + PartialWitnessGenerator {} @@ -75,15 +78,17 @@ pub trait PartialWitnessGenerator { initial_witness: &mut BTreeMap, blocks: &mut Blocks, mut opcode_to_solve: Vec, - ) -> Result<(Vec, Vec), OpcodeResolutionError> { + ) -> Result { let mut unresolved_opcodes: Vec = Vec::new(); let mut unresolved_oracles: Vec = Vec::new(); + let mut unresolved_brillig_oracles: Vec = Vec::new(); while !opcode_to_solve.is_empty() || !unresolved_oracles.is_empty() { unresolved_opcodes.clear(); let mut stalled = true; let mut opcode_not_solvable = None; for opcode in &opcode_to_solve { let mut solved_oracle_data = None; + let mut solved_brillig_data = None; let resolution = match opcode { Opcode::Arithmetic(expr) => ArithmeticSolver::solve(initial_witness, expr), Opcode::BlackBoxFuncCall(bb_func) => { @@ -112,7 +117,7 @@ pub trait PartialWitnessGenerator { Opcode::Brillig(brillig) => { let mut brillig_clone = brillig.clone(); let result = BrilligSolver::solve(initial_witness, &mut brillig_clone)?; - // TODO: add oracle logic + solved_brillig_data = Some(brillig_clone); Ok(result) } }; @@ -129,6 +134,11 @@ pub trait PartialWitnessGenerator { unresolved_opcodes.push(opcode.clone()); } } + Ok(OpcodeResolution::InProgessBrillig(oracle_data)) => { + stalled = false; + // InProgressBrillig Oracles must be externally re-solved + unresolved_brillig_oracles.push(oracle_data); + } Ok(OpcodeResolution::Stalled(not_solvable)) => { if opcode_not_solvable.is_none() { // we keep track of the first unsolvable opcode @@ -141,6 +151,10 @@ pub trait PartialWitnessGenerator { Some(oracle_data) => Opcode::Oracle(oracle_data), None => opcode.clone(), }); + unresolved_opcodes.push(match solved_brillig_data { + Some(brillig) => Opcode::Brillig(brillig), + None => opcode.clone(), + }) } Err(OpcodeResolutionError::OpcodeNotSolvable(_)) => { unreachable!("ICE - Result should have been converted to GateResolution") @@ -149,8 +163,12 @@ pub trait PartialWitnessGenerator { } } // We have oracles that must be externally resolved - if !unresolved_oracles.is_empty() { - return Ok((unresolved_opcodes, unresolved_oracles)); + if !unresolved_oracles.is_empty() | !unresolved_brillig_oracles.is_empty() { + return Ok(UnresolvedData { + unresolved_opcodes, + unresolved_oracles, + unresolved_brillig_oracles, + }); } // We are stalled because of an opcode being bad if stalled && !unresolved_opcodes.is_empty() { @@ -161,7 +179,11 @@ pub trait PartialWitnessGenerator { } std::mem::swap(&mut opcode_to_solve, &mut unresolved_opcodes); } - Ok((Vec::new(), Vec::new())) + Ok(UnresolvedData { + unresolved_opcodes: Vec::new(), + unresolved_oracles: Vec::new(), + unresolved_brillig_oracles: Vec::new(), + }) } fn solve_black_box_function_call( @@ -198,6 +220,12 @@ pub trait PartialWitnessGenerator { } } +pub struct UnresolvedData { + pub unresolved_opcodes: Vec, + pub unresolved_oracles: Vec, + pub unresolved_brillig_oracles: Vec, +} + pub trait SmartContract { // TODO: Allow a backend to support multiple smart contract platforms @@ -320,9 +348,11 @@ mod test { use std::collections::BTreeMap; use acir::{ + brillig_bytecode, + brillig_bytecode::{RegisterIndex, RegisterMemIndex}, circuit::{ directives::Directive, - opcodes::{BlackBoxFuncCall, OracleData}, + opcodes::{BlackBoxFuncCall, Brillig, OracleData}, Opcode, }, native_types::{Expression, Witness}, @@ -331,6 +361,7 @@ mod test { use crate::{ pwg::block::Blocks, OpcodeResolution, OpcodeResolutionError, PartialWitnessGenerator, + UnresolvedData, }; struct StubbedPwg; @@ -395,10 +426,10 @@ mod test { (Witness(2), FieldElement::from(3u128)), ]); let mut blocks = Blocks::default(); - let (unsolved_opcodes, mut unresolved_oracles) = pwg + let UnresolvedData { unresolved_opcodes, mut unresolved_oracles, .. } = pwg .solve(&mut witness_assignments, &mut blocks, opcodes) .expect("should stall on oracle"); - assert!(unsolved_opcodes.is_empty(), "oracle should be removed"); + assert!(unresolved_opcodes.is_empty(), "oracle should be removed"); assert_eq!(unresolved_oracles.len(), 1, "should have an oracle request"); let mut oracle_data = unresolved_oracles.remove(0); assert_eq!(oracle_data.input_values.len(), 1, "Should have solved a single input"); @@ -406,11 +437,112 @@ mod test { // Filling data request and continue solving oracle_data.output_values = vec![oracle_data.input_values.last().unwrap().inverse()]; let mut next_opcodes_for_solving = vec![Opcode::Oracle(oracle_data)]; - next_opcodes_for_solving.extend_from_slice(&unsolved_opcodes[..]); - let (unsolved_opcodes, unresolved_oracles) = pwg + next_opcodes_for_solving.extend_from_slice(&unresolved_opcodes[..]); + let UnresolvedData { unresolved_opcodes, .. } = pwg .solve(&mut witness_assignments, &mut blocks, next_opcodes_for_solving) - .expect("should be solvable"); - assert!(unsolved_opcodes.is_empty(), "should be fully solved"); + .expect("should not stall on oracle"); + assert!(unresolved_opcodes.is_empty(), "should be fully solved"); assert!(unresolved_oracles.is_empty(), "should have no unresolved oracles"); } + + #[test] + fn inversion_brillig_oracle_equivalence() { + // Opcodes below describe the following: + // fn main(x : Field, y : pub Field) { + // let z = x + y; + // constrain 1/z == Oracle("inverse", x + y); + // } + let fe_0 = FieldElement::zero(); + let fe_1 = FieldElement::one(); + let w_x = Witness(1); + let w_y = Witness(2); + let w_oracle = Witness(3); + let w_z = Witness(4); + let w_z_inverse = Witness(5); + let w_x_plus_y = Witness(6); + + let brillig_bytecode = brillig_bytecode::Opcode::Oracle(brillig_bytecode::OracleData { + name: "invert".into(), + inputs: vec![RegisterMemIndex::Register(RegisterIndex(0))], + input_values: vec![], + outputs: vec![RegisterIndex(1)], + output_values: vec![], + }); + + let brillig_opcode = Opcode::Brillig(Brillig { + inputs: vec![ + Expression { + mul_terms: vec![], + linear_combinations: vec![(fe_1, w_x), (fe_1, w_y)], + q_c: fe_0, + }, + Expression::default(), + ], + outputs: vec![w_x_plus_y, w_oracle], + bytecode: vec![brillig_bytecode], + }); + + let opcodes = vec![ + brillig_opcode, + Opcode::Arithmetic(Expression { + mul_terms: vec![], + linear_combinations: vec![(fe_1, w_x), (fe_1, w_y), (-fe_1, w_z)], + q_c: fe_0, + }), + Opcode::Directive(Directive::Invert { x: w_z, result: w_z_inverse }), + Opcode::Arithmetic(Expression { + mul_terms: vec![(fe_1, w_z, w_z_inverse)], + linear_combinations: vec![], + q_c: -fe_1, + }), + Opcode::Arithmetic(Expression { + mul_terms: vec![], + linear_combinations: vec![(-fe_1, w_oracle), (fe_1, w_z_inverse)], + q_c: fe_0, + }), + ]; + + let pwg = StubbedPwg; + + let mut witness_assignments = BTreeMap::from([ + (Witness(1), FieldElement::from(2u128)), + (Witness(2), FieldElement::from(3u128)), + (Witness(6), FieldElement::from(5u128)), + ]); + let mut blocks = Blocks::default(); + let UnresolvedData { unresolved_opcodes, mut unresolved_brillig_oracles, .. } = pwg + .solve(&mut witness_assignments, &mut blocks, opcodes) + .expect("should stall on oracle"); + + assert!(unresolved_opcodes.is_empty(), "opcode should be removed"); + assert_eq!(unresolved_brillig_oracles.len(), 1, "should have a brillig oracle request"); + + let mut oracle_data = unresolved_brillig_oracles.remove(0); + assert_eq!(oracle_data.inputs.len(), 1, "Should have solved a single input"); + + // Filling data request and continue solving + oracle_data.output_values = vec![oracle_data.input_values.last().unwrap().inverse()]; + let brillig_bytecode = brillig_bytecode::Opcode::Oracle(oracle_data); + + let mut next_opcodes_for_solving = vec![Opcode::Brillig(Brillig { + inputs: vec![ + Expression { + mul_terms: vec![], + linear_combinations: vec![(fe_1, w_x), (fe_1, w_y)], + q_c: fe_0, + }, + Expression::default(), + ], + outputs: vec![w_x_plus_y, w_oracle], + bytecode: vec![brillig_bytecode], + })]; + + next_opcodes_for_solving.extend_from_slice(&unresolved_opcodes[..]); + let UnresolvedData { unresolved_opcodes, unresolved_brillig_oracles, .. } = pwg + .solve(&mut witness_assignments, &mut blocks, next_opcodes_for_solving) + .expect("should not stall on oracle"); + + assert!(unresolved_opcodes.is_empty(), "should be fully solved"); + assert!(unresolved_brillig_oracles.is_empty(), "should have no unresolved oracles"); + } } diff --git a/acvm/src/pwg/brillig.rs b/acvm/src/pwg/brillig.rs index 7e4248a99..02696413b 100644 --- a/acvm/src/pwg/brillig.rs +++ b/acvm/src/pwg/brillig.rs @@ -1,8 +1,15 @@ use std::collections::BTreeMap; -use acir::{circuit::opcodes::Brillig, native_types::Witness, FieldElement}; +use acir::{ + brillig_bytecode::{Opcode, RegisterMemIndex, Registers, VMStatus, Value, VM}, + circuit::opcodes::Brillig, + native_types::Witness, + FieldElement, +}; -use crate::{OpcodeNotSolvable, OpcodeResolution, OpcodeResolutionError}; +use crate::{ + pwg::arithmetic::ArithmeticSolver, OpcodeNotSolvable, OpcodeResolution, OpcodeResolutionError, +}; use super::{directives::insert_witness, get_value}; @@ -13,16 +20,55 @@ impl BrilligSolver { initial_witness: &mut BTreeMap, brillig: &mut Brillig, ) -> Result { - let mut input_register_values: Vec = Vec::new(); + // Set input values + let mut input_register_values: Vec = Vec::new(); for expr in &brillig.inputs { - let expr_value = get_value(expr, initial_witness)?; - input_register_values.push(expr_value.into()) + // Break from setting the inputs values if unable to solve the arithmetic expression inputs + let solve = ArithmeticSolver::evaluate(expr, initial_witness); + if let Some(value) = solve.to_const() { + input_register_values.push(value.into()) + } else { + break; + } } - let input_registers = acir::brillig_bytecode::Registers { inner: input_register_values }; - let vm = acir::brillig_bytecode::VM::new(input_registers, brillig.bytecode.clone()); + if input_register_values.len() != brillig.inputs.len() { + return Ok(OpcodeResolution::Stalled(OpcodeNotSolvable::ExpressionHasTooManyUnknowns( + brillig + .inputs + .last() + .expect("Infallible: cannot reach this point if no inputs") + .clone(), + ))); + } + + let input_registers = Registers { inner: input_register_values }; + let vm = VM::new(input_registers, brillig.bytecode.clone()); + + let (output_registers, status) = vm.clone().process_opcodes(); - let output_registers = vm.process_opcodes(); + if status == VMStatus::OracleWait { + let pc = vm.program_counter(); + let current_opcode = &brillig.bytecode[pc]; + let mut data = match current_opcode.clone() { + Opcode::Oracle(data) => data, + _ => { + return Err(OpcodeResolutionError::UnexpectedOpcode( + "brillig oracle", + current_opcode.name(), + )) + } + }; + let input_values = data + .clone() + .inputs + .into_iter() + .map(|register_mem_index| output_registers.get(register_mem_index).inner) + .collect::>(); + data.input_values = input_values; + + return Ok(OpcodeResolution::InProgessBrillig(data.clone())); + } let output_register_values: Vec = output_registers.inner.into_iter().map(|v| v.inner).collect::>(); diff --git a/acvm/src/pwg/logic.rs b/acvm/src/pwg/logic.rs index ba030901c..230fd27b9 100644 --- a/acvm/src/pwg/logic.rs +++ b/acvm/src/pwg/logic.rs @@ -10,7 +10,7 @@ pub fn solve_logic_opcode( match func_call.name { BlackBoxFunc::AND => LogicSolver::solve_and_gate(initial_witness, func_call), BlackBoxFunc::XOR => LogicSolver::solve_xor_gate(initial_witness, func_call), - _ => Err(OpcodeResolutionError::UnexpectedOpcode("logic opcode", func_call.name)), + _ => Err(OpcodeResolutionError::UnexpectedOpcode("logic opcode", func_call.name.name())), } } diff --git a/brillig_bytecode/src/lib.rs b/brillig_bytecode/src/lib.rs index 8da18099d..b47a6be89 100644 --- a/brillig_bytecode/src/lib.rs +++ b/brillig_bytecode/src/lib.rs @@ -12,7 +12,7 @@ mod value; use acir_field::FieldElement; pub use opcodes::RegisterMemIndex; -pub use opcodes::{BinaryOp, Comparison, Opcode}; +pub use opcodes::{BinaryOp, Comparison, Opcode, OracleData}; pub use registers::{RegisterIndex, Registers}; pub use value::Typ; pub use value::Value; @@ -22,6 +22,7 @@ pub enum VMStatus { Halted, InProgress, Failure, + OracleWait, } #[derive(Debug, PartialEq, Eq, Clone)] @@ -55,8 +56,11 @@ impl VM { } /// Loop over the bytecode and update the program counter - pub fn process_opcodes(mut self) -> Registers { - while !matches!(self.process_opcode(), VMStatus::Halted | VMStatus::Failure) {} + pub fn process_opcodes(mut self) -> (Registers, VMStatus) { + while !matches!( + self.process_opcode(), + VMStatus::Halted | VMStatus::Failure | VMStatus::OracleWait + ) {} self.finish() } // Process a single opcode and modify the program counter @@ -86,7 +90,17 @@ impl VM { } Opcode::Call => todo!(), Opcode::Intrinsics => todo!(), - Opcode::Oracle { inputs, destination } => todo!(), + Opcode::Oracle(data) => { + if data.outputs.len() == data.output_values.len() { + for (index, value) in data.outputs.iter().zip(data.output_values.iter()) { + self.registers.set(*index, (*value).into()) + } + } else { + self.status = VMStatus::OracleWait; + return VMStatus::OracleWait; + } + self.increment_program_counter() + } Opcode::Mov { destination, source } => { let source_value = self.registers.get(*source); @@ -106,6 +120,10 @@ impl VM { } } + pub fn program_counter(self) -> usize { + self.program_counter + } + /// Increments the program counter by 1. fn increment_program_counter(&mut self) -> VMStatus { self.set_program_counter(self.program_counter + 1) @@ -144,8 +162,8 @@ impl VM { /// Returns the state of the registers. /// This consumes ownership of the VM and is conventionally /// called when all of the bytecode has been processed. - fn finish(self) -> Registers { - self.registers + fn finish(self) -> (Registers, VMStatus) { + (self.registers, self.status) } } @@ -178,7 +196,7 @@ fn add_single_step_smoke() { // The register at index `2` should have the value of 3 since we had an // add opcode - let registers = vm.finish(); + let (registers, _) = vm.finish(); let output_value = registers.get(RegisterMemIndex::Register(RegisterIndex(2))); assert_eq!(output_value, Value::from(3u128)) @@ -270,7 +288,7 @@ fn test_jmpifnot_opcode() { assert_eq!(status, VMStatus::Failure); // The register at index `2` should have not changed as we jumped over the add opcode - let registers = vm.finish(); + let (registers, status) = vm.finish(); let output_value = registers.get(RegisterMemIndex::Register(RegisterIndex(2))); assert_eq!(output_value, Value::from(false)); } @@ -290,7 +308,7 @@ fn test_mov_opcode() { let status = vm.process_opcode(); assert_eq!(status, VMStatus::Halted); - let registers = vm.finish(); + let (registers, status) = vm.finish(); let destination_value = registers.get(RegisterMemIndex::Register(RegisterIndex(2))); assert_eq!(destination_value, Value::from(1u128)); diff --git a/brillig_bytecode/src/opcodes.rs b/brillig_bytecode/src/opcodes.rs index f1e31b125..9dc98e9f3 100644 --- a/brillig_bytecode/src/opcodes.rs +++ b/brillig_bytecode/src/opcodes.rs @@ -56,10 +56,7 @@ pub enum Opcode { // TODO:These are special functions like sha256 Intrinsics, // TODO:This will be used to get data from an outside source - Oracle { - inputs: Vec, - destination: Vec, - }, + Oracle(OracleData), Mov { destination: RegisterMemIndex, source: RegisterMemIndex, @@ -72,6 +69,37 @@ pub enum Opcode { }, } +impl Opcode { + pub fn name(&self) -> &'static str { + match self { + Opcode::BinaryOp { .. } => "binary_op", + Opcode::JMPIFNOT { .. } => "jmpifnot", + Opcode::JMPIF { .. } => "jmpif", + Opcode::JMP { .. } => "jmp", + Opcode::Call => "call", + Opcode::Intrinsics => "intrinsics", + Opcode::Oracle(_) => "oracle", + Opcode::Mov { .. } => "mov", + Opcode::Trap => "trap", + Opcode::Bootstrap { .. } => "bootstrap", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct OracleData { + /// Name of the oracle + pub name: String, + /// Inputs + pub inputs: Vec, + /// Input values + pub input_values: Vec, + /// Output witness + pub outputs: Vec, + /// Output values - they are computed by the (external) oracle once the inputs are known + pub output_values: Vec, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum BinaryOp { Add,