diff --git a/acir/src/circuit/mod.rs b/acir/src/circuit/mod.rs index b3e22462c..3c5710d2a 100644 --- a/acir/src/circuit/mod.rs +++ b/acir/src/circuit/mod.rs @@ -181,10 +181,10 @@ mod test { use std::collections::BTreeSet; use super::{ - opcodes::{BlackBoxFuncCall, FunctionInput}, + opcodes::{BlackBoxFuncCall, FunctionInput, OracleData}, Circuit, Opcode, PublicInputs, }; - use crate::native_types::Witness; + use crate::native_types::{Expression, Witness}; use acir_field::FieldElement; fn and_opcode() -> Opcode { @@ -204,12 +204,25 @@ mod test { outputs: vec![], }) } + fn oracle_opcode() -> Opcode { + Opcode::Oracle(OracleData { + name: String::from("oracle-name"), + inputs: vec![Expression { + mul_terms: vec![(FieldElement::from(123u128), Witness(1), Witness(2))], + linear_combinations: vec![(FieldElement::from(456u128), Witness(34))], + q_c: FieldElement::from(12345678u128), + }], + input_values: vec![], + outputs: vec![Witness(1), Witness(2), Witness(3)], + output_values: vec![], + }) + } #[test] fn serialization_roundtrip() { let circuit = Circuit { current_witness_index: 5, - opcodes: vec![and_opcode(), range_opcode()], + opcodes: vec![and_opcode(), range_opcode(), oracle_opcode()], public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2), Witness(12)])), return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(4), Witness(12)])), }; @@ -237,6 +250,7 @@ mod test { }), range_opcode(), and_opcode(), + oracle_opcode(), ], public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2)])), return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(2)])), @@ -261,6 +275,7 @@ mod test { }), range_opcode(), and_opcode(), + oracle_opcode(), ], public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2)])), return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(2)])), diff --git a/acir/src/circuit/opcodes.rs b/acir/src/circuit/opcodes.rs index c9273b4ff..12ee35c55 100644 --- a/acir/src/circuit/opcodes.rs +++ b/acir/src/circuit/opcodes.rs @@ -2,7 +2,9 @@ use std::io::{Read, Write}; use super::directives::{Directive, LogInfo}; use crate::native_types::{Expression, Witness}; -use crate::serialization::{read_n, read_u16, read_u32, write_bytes, write_u16, write_u32}; +use crate::serialization::{ + read_bytes, read_field_element, read_n, read_u16, read_u32, write_bytes, write_u16, write_u32, +}; use crate::BlackBoxFunc; use acir_field::FieldElement; use serde::{Deserialize, Serialize}; @@ -85,6 +87,114 @@ impl MemoryBlock { } } +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct OracleData { + /// Name of the oracle + pub name: String, + /// Inputs + pub inputs: Vec, + /// Input values - they are progressively computed by the pwg + pub input_values: Vec, + /// Output witness + pub outputs: Vec, + /// Output values - they are computed by the (external) oracle once the input_values are known + pub output_values: Vec, +} + +impl OracleData { + pub(crate) fn write(&self, mut writer: W) -> std::io::Result<()> { + let name_as_bytes = self.name.as_bytes(); + let name_len = name_as_bytes.len(); + write_u32(&mut writer, name_len as u32)?; + write_bytes(&mut writer, name_as_bytes)?; + + let inputs_len = self.inputs.len() as u32; + write_u32(&mut writer, inputs_len)?; + for input in &self.inputs { + input.write(&mut writer)? + } + + let outputs_len = self.outputs.len() as u32; + write_u32(&mut writer, outputs_len)?; + for output in &self.outputs { + write_u32(&mut writer, output.witness_index())?; + } + + let inputs_len = self.input_values.len() as u32; + write_u32(&mut writer, inputs_len)?; + for input in &self.input_values { + write_bytes(&mut writer, &input.to_be_bytes())?; + } + + let outputs_len = self.output_values.len() as u32; + write_u32(&mut writer, outputs_len)?; + for output in &self.output_values { + write_bytes(&mut writer, &output.to_be_bytes())?; + } + Ok(()) + } + + pub(crate) fn read(mut reader: R) -> std::io::Result { + let name_len = read_u32(&mut reader)?; + let name_as_bytes = read_bytes(&mut reader, name_len as usize)?; + let name: String = String::from_utf8(name_as_bytes) + .map_err(|_| std::io::Error::from(std::io::ErrorKind::InvalidData))?; + + let inputs_len = read_u32(&mut reader)?; + let mut inputs = Vec::with_capacity(inputs_len as usize); + for _ in 0..inputs_len { + let input = Expression::read(&mut reader)?; + inputs.push(input); + } + + let outputs_len = read_u32(&mut reader)?; + let mut outputs = Vec::with_capacity(outputs_len as usize); + for _ in 0..outputs_len { + let witness_index = read_u32(&mut reader)?; + outputs.push(Witness(witness_index)); + } + + const FIELD_ELEMENT_NUM_BYTES: usize = FieldElement::max_num_bytes() as usize; + let inputs_len = read_u32(&mut reader)?; + let mut input_values = Vec::with_capacity(inputs_len as usize); + for _ in 0..inputs_len { + let value = read_field_element::(&mut reader)?; + input_values.push(value); + } + + let outputs_len = read_u32(&mut reader)?; + let mut output_values = Vec::with_capacity(outputs_len as usize); + for _ in 0..outputs_len { + let value = read_field_element::(&mut reader)?; + output_values.push(value); + } + + Ok(OracleData { name, inputs, outputs, input_values, output_values }) + } +} + +impl std::fmt::Display for OracleData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ORACLE: {}", self.name)?; + let solved = if self.input_values.len() == self.inputs.len() { "solved" } else { "" }; + + write!( + f, + "Inputs: _{}..._{}{solved}", + self.inputs.first().unwrap(), + self.inputs.last().unwrap() + )?; + + let solved = if self.output_values.len() == self.outputs.len() { "solved" } else { "" }; + write!( + f, + "Outputs: _{}..._{}{solved}", + self.outputs.first().unwrap().witness_index(), + self.outputs.last().unwrap().witness_index() + ) + } +} + #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum Opcode { Arithmetic(Expression), @@ -104,6 +214,7 @@ pub enum Opcode { /// - after MemoryBlock.len, all operations are constant expressions (0 or 1) /// RAM is required for Aztec Backend as dynamic memory implementation in Barrentenberg requires an intialisation phase and can only handle constant values for operations. RAM(MemoryBlock), + Oracle(OracleData), } impl Opcode { @@ -117,6 +228,7 @@ impl Opcode { Opcode::Block(_) => "block", Opcode::RAM(_) => "ram", Opcode::ROM(_) => "rom", + Opcode::Oracle(data) => &data.name, } } @@ -130,6 +242,7 @@ impl Opcode { Opcode::Block(_) => 3, Opcode::ROM(_) => 4, Opcode::RAM(_) => 5, + Opcode::Oracle { .. } => 6, } } @@ -154,6 +267,7 @@ impl Opcode { Opcode::Block(mem_block) | Opcode::ROM(mem_block) | Opcode::RAM(mem_block) => { mem_block.write(writer) } + Opcode::Oracle(data) => data.write(writer), } } pub fn read(mut reader: R) -> std::io::Result { @@ -187,6 +301,10 @@ impl Opcode { let block = MemoryBlock::read(reader)?; Ok(Opcode::RAM(block)) } + 6 => { + let data = OracleData::read(reader)?; + Ok(Opcode::Oracle(data)) + } _ => Err(std::io::ErrorKind::InvalidData.into()), } } @@ -273,6 +391,10 @@ impl std::fmt::Display for Opcode { write!(f, "RAM ")?; write!(f, "(id: {}, len: {}) ", block.id.0, block.trace.len()) } + Opcode::Oracle(data) => { + write!(f, "ORACLE: ")?; + write!(f, "{data}") + } } } } diff --git a/acir/src/serialization.rs b/acir/src/serialization.rs index 4c948e09c..01abcff49 100644 --- a/acir/src/serialization.rs +++ b/acir/src/serialization.rs @@ -20,6 +20,11 @@ pub(crate) fn write_n( pub(crate) fn write_bytes(mut w: W, bytes: &[u8]) -> std::io::Result { w.write(bytes) } +pub(crate) fn read_bytes(mut r: R, num_bytes: usize) -> std::io::Result> { + let mut bytes = vec![0u8; num_bytes]; + r.read_exact(&mut bytes[..])?; + Ok(bytes) +} pub(crate) fn write_u16(w: W, num: u16) -> std::io::Result { let bytes = num.to_le_bytes(); diff --git a/acvm/src/compiler/transformers/fallback.rs b/acvm/src/compiler/transformers/fallback.rs index a2550fbc0..33c8c031e 100644 --- a/acvm/src/compiler/transformers/fallback.rs +++ b/acvm/src/compiler/transformers/fallback.rs @@ -26,8 +26,10 @@ impl FallbackTransformer { | Opcode::Directive(_) | Opcode::Block(_) | Opcode::ROM(_) - | Opcode::RAM(_) => { - // directive, arithmetic expression or blocks are handled by acvm + | Opcode::RAM(_) + | Opcode::Oracle { .. } => { + // directive, arithmetic expression or block are handled by acvm + // The oracle opcode is assumed to be supported. acir_supported_opcodes.push(opcode); continue; } diff --git a/acvm/src/lib.rs b/acvm/src/lib.rs index ac31e02b6..e25d96939 100644 --- a/acvm/src/lib.rs +++ b/acvm/src/lib.rs @@ -7,9 +7,13 @@ pub mod compiler; pub mod pwg; -use crate::pwg::arithmetic::ArithmeticSolver; +use crate::pwg::{arithmetic::ArithmeticSolver, oracle::OracleSolver}; use acir::{ - circuit::{directives::Directive, opcodes::BlackBoxFuncCall, Circuit, Opcode}, + circuit::{ + directives::Directive, + opcodes::{BlackBoxFuncCall, OracleData}, + Circuit, Opcode, + }, native_types::{Expression, Witness}, BlackBoxFunc, }; @@ -70,15 +74,17 @@ pub trait PartialWitnessGenerator { fn solve( &self, initial_witness: &mut BTreeMap, + blocks: &mut Blocks, mut opcode_to_solve: Vec, - ) -> Result<(), OpcodeResolutionError> { + ) -> Result<(Vec, Vec), OpcodeResolutionError> { let mut unresolved_opcodes: Vec = Vec::new(); - let mut blocks = Blocks::default(); - while !opcode_to_solve.is_empty() { + let mut unresolved_oracles: Vec = Vec::new(); + while !opcode_to_solve.is_empty() || !unresolved_oracles.is_empty() { unresolved_opcodes.clear(); let mut stalled = true; let mut opcode_not_solvable = None; for opcode in &opcode_to_solve { + let mut solved_oracle_data = None; let resolution = match opcode { Opcode::Arithmetic(expr) => ArithmeticSolver::solve(initial_witness, expr), Opcode::BlackBoxFuncCall(bb_func) => { @@ -90,6 +96,12 @@ pub trait PartialWitnessGenerator { Opcode::Block(block) | Opcode::ROM(block) | Opcode::RAM(block) => { blocks.solve(block.id, &block.trace, initial_witness) } + Opcode::Oracle(data) => { + let mut data_clone = data.clone(); + let result = OracleSolver::solve(initial_witness, &mut data_clone)?; + solved_oracle_data = Some(data_clone); + Ok(result) + } }; match resolution { Ok(OpcodeResolution::Solved) => { @@ -97,7 +109,12 @@ pub trait PartialWitnessGenerator { } Ok(OpcodeResolution::InProgress) => { stalled = false; - unresolved_opcodes.push(opcode.clone()); + // InProgress Oracles must be externally re-solved + if let Some(oracle) = solved_oracle_data { + unresolved_oracles.push(oracle); + } else { + unresolved_opcodes.push(opcode.clone()); + } } Ok(OpcodeResolution::Stalled(not_solvable)) => { if opcode_not_solvable.is_none() { @@ -107,7 +124,10 @@ pub trait PartialWitnessGenerator { // We push those opcodes not solvable to the back as // it could be because the opcodes are out of order, i.e. this assignment // relies on a later opcodes' results - unresolved_opcodes.push(opcode.clone()); + unresolved_opcodes.push(match solved_oracle_data { + Some(oracle_data) => Opcode::Oracle(oracle_data), + None => opcode.clone(), + }); } Err(OpcodeResolutionError::OpcodeNotSolvable(_)) => { unreachable!("ICE - Result should have been converted to GateResolution") @@ -115,6 +135,11 @@ pub trait PartialWitnessGenerator { Err(err) => return Err(err), } } + // We have oracles that must be externally resolved + if !unresolved_oracles.is_empty() { + return Ok((unresolved_opcodes, unresolved_oracles)); + } + // We are stalled because of an opcode being bad if stalled && !unresolved_opcodes.is_empty() { return Err(OpcodeResolutionError::OpcodeNotSolvable( opcode_not_solvable @@ -123,7 +148,7 @@ pub trait PartialWitnessGenerator { } std::mem::swap(&mut opcode_to_solve, &mut unresolved_opcodes); } - Ok(()) + Ok((Vec::new(), Vec::new())) } fn solve_black_box_function_call( @@ -272,3 +297,103 @@ pub fn default_is_opcode_supported( Language::PLONKCSat { .. } => plonk_is_supported, } } + +#[cfg(test)] +mod test { + use std::collections::BTreeMap; + + use acir::{ + circuit::{ + directives::Directive, + opcodes::{BlackBoxFuncCall, OracleData}, + Opcode, + }, + native_types::{Expression, Witness}, + FieldElement, + }; + + use crate::{ + pwg::block::Blocks, OpcodeResolution, OpcodeResolutionError, PartialWitnessGenerator, + }; + + struct StubbedPwg; + + impl PartialWitnessGenerator for StubbedPwg { + fn solve_black_box_function_call( + _initial_witness: &mut BTreeMap, + _func_call: &BlackBoxFuncCall, + ) -> Result { + panic!("Path not trodden by this test") + } + } + + #[test] + fn inversion_oracle_equivalence() { + // Opcodes below describe the following: + // fn main(x : Field, y : pub Field) { + // let z = x + y; + // constrain 1/z == Oracle("inverse", x + y); + // } + let fe_0 = FieldElement::zero(); + let fe_1 = FieldElement::one(); + let w_x = Witness(1); + let w_y = Witness(2); + let w_oracle = Witness(3); + let w_z = Witness(4); + let w_z_inverse = Witness(5); + let opcodes = vec![ + Opcode::Oracle(OracleData { + name: "invert".into(), + inputs: vec![Expression { + mul_terms: vec![], + linear_combinations: vec![(fe_1, w_x), (fe_1, w_y)], + q_c: fe_0, + }], + input_values: vec![], + outputs: vec![w_oracle], + output_values: vec![], + }), + Opcode::Arithmetic(Expression { + mul_terms: vec![], + linear_combinations: vec![(fe_1, w_x), (fe_1, w_y), (-fe_1, w_z)], + q_c: fe_0, + }), + Opcode::Directive(Directive::Invert { x: w_z, result: w_z_inverse }), + Opcode::Arithmetic(Expression { + mul_terms: vec![(fe_1, w_z, w_z_inverse)], + linear_combinations: vec![], + q_c: -fe_1, + }), + Opcode::Arithmetic(Expression { + mul_terms: vec![], + linear_combinations: vec![(-fe_1, w_oracle), (fe_1, w_z_inverse)], + q_c: fe_0, + }), + ]; + + let pwg = StubbedPwg; + + let mut witness_assignments = BTreeMap::from([ + (Witness(1), FieldElement::from(2u128)), + (Witness(2), FieldElement::from(3u128)), + ]); + let mut blocks = Blocks::default(); + let (unsolved_opcodes, mut unresolved_oracles) = pwg + .solve(&mut witness_assignments, &mut blocks, opcodes) + .expect("should stall on oracle"); + assert!(unsolved_opcodes.is_empty(), "oracle should be removed"); + assert_eq!(unresolved_oracles.len(), 1, "should have an oracle request"); + let mut oracle_data = unresolved_oracles.remove(0); + assert_eq!(oracle_data.input_values.len(), 1, "Should have solved a single input"); + + // Filling data request and continue solving + oracle_data.output_values = vec![oracle_data.input_values.last().unwrap().inverse()]; + let mut next_opcodes_for_solving = vec![Opcode::Oracle(oracle_data)]; + next_opcodes_for_solving.extend_from_slice(&unsolved_opcodes[..]); + let (unsolved_opcodes, unresolved_oracles) = pwg + .solve(&mut witness_assignments, &mut blocks, next_opcodes_for_solving) + .expect("should be solvable"); + assert!(unsolved_opcodes.is_empty(), "should be fully solved"); + assert!(unresolved_oracles.is_empty(), "should have no unresolved oracles"); + } +} diff --git a/acvm/src/pwg.rs b/acvm/src/pwg.rs index 1d8c66ffa..bf7bbdab5 100644 --- a/acvm/src/pwg.rs +++ b/acvm/src/pwg.rs @@ -17,6 +17,7 @@ pub mod directives; pub mod block; pub mod hash; pub mod logic; +pub mod oracle; pub mod range; pub mod signature; pub mod sorting; diff --git a/acvm/src/pwg/oracle.rs b/acvm/src/pwg/oracle.rs new file mode 100644 index 000000000..7d427aac7 --- /dev/null +++ b/acvm/src/pwg/oracle.rs @@ -0,0 +1,47 @@ +use std::collections::BTreeMap; + +use acir::{circuit::opcodes::OracleData, native_types::Witness, FieldElement}; + +use crate::{OpcodeNotSolvable, OpcodeResolution, OpcodeResolutionError}; + +use super::{arithmetic::ArithmeticSolver, directives::insert_witness}; + +pub struct OracleSolver; + +impl OracleSolver { + /// Derives the rest of the witness based on the initial low level variables + pub fn solve( + initial_witness: &mut BTreeMap, + data: &mut OracleData, + ) -> Result { + // Set input values + for input in data.inputs.iter().skip(data.input_values.len()) { + let solve = ArithmeticSolver::evaluate(input, initial_witness); + if let Some(value) = solve.to_const() { + data.input_values.push(value); + } else { + break; + } + } + + // If all of the inputs to the oracle have assignments + if data.input_values.len() == data.inputs.len() { + if data.output_values.len() == data.outputs.len() { + for (out, value) in data.outputs.iter().zip(data.output_values.iter()) { + insert_witness(*out, *value, initial_witness)?; + } + Ok(OpcodeResolution::Solved) + } else { + // Missing output values + Ok(OpcodeResolution::InProgress) + } + } else { + Ok(OpcodeResolution::Stalled(OpcodeNotSolvable::ExpressionHasTooManyUnknowns( + data.inputs + .last() + .expect("Infallible: cannot reach this point if no inputs") + .clone(), + ))) + } + } +}