diff --git a/acvm/src/pwg/brillig.rs b/acvm/src/pwg/brillig.rs index 2bc52a26a..3009bc534 100644 --- a/acvm/src/pwg/brillig.rs +++ b/acvm/src/pwg/brillig.rs @@ -142,5 +142,5 @@ pub struct ForeignCallWaitInfo { /// An identifier interpreted by the caller process pub function: String, /// Resolved inputs to a foreign call computed in the previous steps of a Brillig VM process - pub inputs: Vec, + pub inputs: Vec>, } diff --git a/acvm/src/pwg/mod.rs b/acvm/src/pwg/mod.rs index 31da6fa31..0d455bad8 100644 --- a/acvm/src/pwg/mod.rs +++ b/acvm/src/pwg/mod.rs @@ -483,8 +483,8 @@ mod tests { // Oracles are named 'foreign calls' in brillig brillig_vm::Opcode::ForeignCall { function: "invert".into(), - destination: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1)), - input: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0)), + destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1))], + inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0))], }, ], predicate: None, @@ -535,8 +535,9 @@ mod tests { "Should be waiting for a single input" ); // As caller of VM, need to resolve foreign calls - let foreign_call_result = - vec![Value::from(foreign_call.foreign_call_wait_info.inputs[0].to_field().inverse())]; + let foreign_call_result = vec![Value::from( + foreign_call.foreign_call_wait_info.inputs[0][0].to_field().inverse(), + )]; // Alter Brillig oracle opcode with foreign call resolution let brillig: Brillig = foreign_call.resolve(foreign_call_result.into()); let mut next_opcodes_for_solving = vec![Opcode::Brillig(brillig)]; @@ -610,13 +611,13 @@ mod tests { // Oracles are named 'foreign calls' in brillig brillig_vm::Opcode::ForeignCall { function: "invert".into(), - destination: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1)), - input: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0)), + destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1))], + inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0))], }, brillig_vm::Opcode::ForeignCall { function: "invert".into(), - destination: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(3)), - input: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(2)), + destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(3))], + inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(2))], }, ], predicate: None, @@ -669,7 +670,8 @@ mod tests { "Should be waiting for a single input" ); - let x_plus_y_inverse = foreign_call.foreign_call_wait_info.inputs[0].to_field().inverse(); + let x_plus_y_inverse = + foreign_call.foreign_call_wait_info.inputs[0][0].to_field().inverse(); // Alter Brillig oracle opcode let brillig: Brillig = foreign_call.resolve(vec![Value::from(x_plus_y_inverse)].into()); @@ -693,7 +695,8 @@ mod tests { "Should be waiting for a single input" ); - let i_plus_j_inverse = foreign_call.foreign_call_wait_info.inputs[0].to_field().inverse(); + let i_plus_j_inverse = + foreign_call.foreign_call_wait_info.inputs[0][0].to_field().inverse(); assert_ne!(x_plus_y_inverse, i_plus_j_inverse); // Alter Brillig oracle opcode let brillig = foreign_call.resolve(vec![Value::from(i_plus_j_inverse)].into()); @@ -756,8 +759,8 @@ mod tests { // Oracles are named 'foreign calls' in brillig brillig_vm::Opcode::ForeignCall { function: "invert".into(), - destination: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1)), - input: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0)), + destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1))], + inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0))], }, ], predicate: Some(Expression::default()), diff --git a/brillig_vm/src/lib.rs b/brillig_vm/src/lib.rs index 4f95b5e21..7712c7331 100644 --- a/brillig_vm/src/lib.rs +++ b/brillig_vm/src/lib.rs @@ -34,7 +34,8 @@ pub enum VMStatus { /// Interpreted by simulator context function: String, /// Input values - inputs: Vec, + /// Each input is a list of values as an input can be either a single value or a memory pointer + inputs: Vec>, }, } @@ -44,11 +45,18 @@ pub enum VMStatus { #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)] pub struct ForeignCallResult { /// Resolved output values of the foreign call. - pub values: Vec, + /// Each output is its own list of values as an output can be either a single value or a memory pointer + pub values: Vec>, } impl From> for ForeignCallResult { fn from(values: Vec) -> Self { + ForeignCallResult { values: vec![values] } + } +} + +impl From>> for ForeignCallResult { + fn from(values: Vec>) -> Self { ForeignCallResult { values } } } @@ -110,7 +118,7 @@ impl VM { /// Sets the status of the VM to `ForeignCallWait`. /// Indicating that the VM is now waiting for a foreign call to be resolved. - fn wait_for_foreign_call(&mut self, function: String, inputs: Vec) -> VMStatus { + fn wait_for_foreign_call(&mut self, function: String, inputs: Vec>) -> VMStatus { self.status(VMStatus::ForeignCallWait { function, inputs }) } @@ -176,7 +184,7 @@ impl VM { self.fail("return opcode hit, but callstack already empty".to_string()) } } - Opcode::ForeignCall { function, destination, input } => { + Opcode::ForeignCall { function, destinations, inputs } => { if self.foreign_call_counter >= self.foreign_call_results.len() { // When this opcode is called, it is possible that the results of a foreign call are // not yet known (not enough entries in `foreign_call_results`). @@ -185,33 +193,45 @@ impl VM { // resolved inputs back to the caller. Once the caller pushes to `foreign_call_results`, // they can then make another call to the VM that starts at this opcode // but has the necessary results to proceed with execution. - let resolved_inputs = self.get_register_value_or_memory_values(*input); + let resolved_inputs = inputs + .iter() + .map(|input| self.get_register_value_or_memory_values(*input)) + .collect::>(); return self.wait_for_foreign_call(function.clone(), resolved_inputs); } let ForeignCallResult { values } = &self.foreign_call_results[self.foreign_call_counter]; - match destination { - RegisterValueOrArray::RegisterIndex(index) => { - assert_eq!( - values.len(), - 1, - "Function result size does not match brillig bytecode" - ); - self.registers.set(*index, values[0]) - } - RegisterValueOrArray::HeapArray(index, size) => { - let destination_value = self.registers.get(*index); - assert_eq!( - values.len(), - *size, - "Function result size does not match brillig bytecode" - ); - for (i, value) in values.iter().enumerate() { - self.memory[destination_value.to_usize() + i] = *value; + + for (destination, values) in destinations.iter().zip(values) { + match destination { + RegisterValueOrArray::RegisterIndex(index) => { + assert_eq!( + values.len(), + 1, + "Function result size does not match brillig bytecode" + ); + self.registers.set(*index, values[0]) + } + RegisterValueOrArray::HeapArray(index, size) => { + let destination_value = self.registers.get(*index); + assert_eq!( + values.len(), + *size, + "Function result size does not match brillig bytecode" + ); + for (i, value) in values.iter().enumerate() { + self.memory[destination_value.to_usize() + i] = *value; + } } } } + + // This check must come after resolving the foreign call outputs as `fail` uses a mutable reference + if destinations.len() != values.len() { + self.fail(format!("{} output values were provided as a foreign call result for {} destination slots", values.len(), destinations.len())); + } + self.foreign_call_counter += 1; self.increment_program_counter() } @@ -804,8 +824,8 @@ mod tests { // Call foreign function "double" with the input register Opcode::ForeignCall { function: "double".into(), - destination: RegisterValueOrArray::RegisterIndex(r_result), - input: RegisterValueOrArray::RegisterIndex(r_input), + destinations: vec![RegisterValueOrArray::RegisterIndex(r_result)], + inputs: vec![RegisterValueOrArray::RegisterIndex(r_input)], }, ]; @@ -816,13 +836,13 @@ mod tests { vm.status, VMStatus::ForeignCallWait { function: "double".into(), - inputs: vec![Value::from(5u128)] + inputs: vec![vec![Value::from(5u128)]] } ); // Push result we're waiting for vm.foreign_call_results.push(ForeignCallResult { - values: vec![Value::from(10u128)], // Result of doubling 5u128 + values: vec![vec![Value::from(10u128)]], // Result of doubling 5u128 }); // Resume VM @@ -859,8 +879,8 @@ mod tests { // *output = matrix_2x2_transpose(*input) Opcode::ForeignCall { function: "matrix_2x2_transpose".into(), - destination: RegisterValueOrArray::HeapArray(r_output, initial_matrix.len()), - input: RegisterValueOrArray::HeapArray(r_input, initial_matrix.len()), + destinations: vec![RegisterValueOrArray::HeapArray(r_output, initial_matrix.len())], + inputs: vec![RegisterValueOrArray::HeapArray(r_input, initial_matrix.len())], }, ]; @@ -871,12 +891,84 @@ mod tests { vm.status, VMStatus::ForeignCallWait { function: "matrix_2x2_transpose".into(), - inputs: initial_matrix + inputs: vec![initial_matrix] + } + ); + + // Push result we're waiting for + vm.foreign_call_results.push(ForeignCallResult { values: vec![expected_result.clone()] }); + + // Resume VM + brillig_execute(&mut vm); + + // Check that VM finished once resumed + assert_eq!(vm.status, VMStatus::Finished); + + // Check result in memory + let result_values = vm.memory[0..4].to_vec(); + assert_eq!(result_values, expected_result); + + // Ensure the foreign call counter has been incremented + assert_eq!(vm.foreign_call_counter, 1); + } + + #[test] + fn foreign_call_opcode_multiple_array_inputs_result() { + let r_input_a = RegisterIndex::from(0); + let r_input_b = RegisterIndex::from(1); + let r_output = RegisterIndex::from(2); + + // Define a simple 2x2 matrix in memory + let matrix_a = + vec![Value::from(1u128), Value::from(2u128), Value::from(3u128), Value::from(4u128)]; + + let matrix_b = vec![ + Value::from(10u128), + Value::from(11u128), + Value::from(12u128), + Value::from(13u128), + ]; + + // Transpose of the matrix (but arbitrary for this test, the 'correct value') + let expected_result = vec![ + Value::from(34u128), + Value::from(37u128), + Value::from(78u128), + Value::from(85u128), + ]; + + let matrix_mul_program = vec![ + // input = 0 + Opcode::Const { destination: r_input_a, value: Value::from(0u128) }, + // input = 0 + Opcode::Const { destination: r_input_b, value: Value::from(4u128) }, + // output = 0 + Opcode::Const { destination: r_output, value: Value::from(0u128) }, + // *output = matrix_2x2_transpose(*input) + Opcode::ForeignCall { + function: "matrix_2x2_transpose".into(), + destinations: vec![RegisterValueOrArray::HeapArray(r_output, matrix_a.len())], + inputs: vec![ + RegisterValueOrArray::HeapArray(r_input_a, matrix_a.len()), + RegisterValueOrArray::HeapArray(r_input_b, matrix_b.len()), + ], + }, + ]; + let mut initial_memory = matrix_a.clone(); + initial_memory.extend(matrix_b.clone()); + let mut vm = brillig_execute_and_get_vm(initial_memory, matrix_mul_program); + + // Check that VM is waiting + assert_eq!( + vm.status, + VMStatus::ForeignCallWait { + function: "matrix_2x2_transpose".into(), + inputs: vec![matrix_a, matrix_b] } ); // Push result we're waiting for - vm.foreign_call_results.push(ForeignCallResult { values: expected_result.clone() }); + vm.foreign_call_results.push(ForeignCallResult { values: vec![expected_result.clone()] }); // Resume VM brillig_execute(&mut vm); diff --git a/brillig_vm/src/opcodes.rs b/brillig_vm/src/opcodes.rs index 548b74975..565f48beb 100644 --- a/brillig_vm/src/opcodes.rs +++ b/brillig_vm/src/opcodes.rs @@ -63,10 +63,10 @@ pub enum Opcode { /// Interpreted by caller context, ie this will have different meanings depending on /// who the caller is. function: String, - /// Destination register (may be a memory pointer). - destination: RegisterValueOrArray, - /// Input register (may be a memory pointer). - input: RegisterValueOrArray, + /// Destination registers (may be single values or memory pointers). + destinations: Vec, + /// Input registers (may be single values or memory pointers). + inputs: Vec, }, Mov { destination: RegisterIndex,