From 7ba98fcdab10c182a8d90d4e91af9de736fa1b36 Mon Sep 17 00:00:00 2001 From: jfecher Date: Mon, 17 Apr 2023 13:59:14 -0500 Subject: [PATCH] fix: Change OracleInput to an enum (#200) * Fix OracleInput representation * Fix tests --- acvm/src/lib.rs | 12 +- acvm/src/pwg/brillig.rs | 29 +- brillig_bytecode/src/lib.rs | 963 ++++++++++++++++---------------- brillig_bytecode/src/memory.rs | 1 + brillig_bytecode/src/opcodes.rs | 6 +- brillig_bytecode/src/value.rs | 4 +- 6 files changed, 506 insertions(+), 509 deletions(-) diff --git a/acvm/src/lib.rs b/acvm/src/lib.rs index 5a0f49f76..1096779ef 100644 --- a/acvm/src/lib.rs +++ b/acvm/src/lib.rs @@ -490,10 +490,8 @@ mod test { result: RegisterIndex(3), }; - let invert_oracle_input = OracleInput { - register_mem_index: RegisterMemIndex::Register(RegisterIndex(0)), - length: 0, - }; + let invert_oracle_input = + OracleInput::RegisterMemIndex(RegisterMemIndex::Register(RegisterIndex(0))); let invert_oracle = brillig_bytecode::Opcode::Oracle(brillig_bytecode::OracleData { name: "invert".into(), @@ -617,10 +615,8 @@ mod test { result: RegisterIndex(3), }; - let invert_oracle_input = OracleInput { - register_mem_index: RegisterMemIndex::Register(RegisterIndex(0)), - length: 0, - }; + let invert_oracle_input = + OracleInput::RegisterMemIndex(RegisterMemIndex::Register(RegisterIndex(0))); let invert_oracle = brillig_bytecode::Opcode::Oracle(brillig_bytecode::OracleData { name: "invert".into(), diff --git a/acvm/src/pwg/brillig.rs b/acvm/src/pwg/brillig.rs index f37fe16d7..135e5e10d 100644 --- a/acvm/src/pwg/brillig.rs +++ b/acvm/src/pwg/brillig.rs @@ -1,9 +1,7 @@ use std::collections::BTreeMap; use acir::{ - brillig_bytecode::{ - ArrayHeap, Opcode, OracleData, Registers, Typ, VMOutputState, VMStatus, Value, VM, - }, + brillig_bytecode::{ArrayHeap, Opcode, OracleData, Registers, Typ, VMStatus, Value, VM}, circuit::opcodes::{Brillig, BrilligInputs, BrilligOutputs}, native_types::Witness, FieldElement, @@ -112,9 +110,11 @@ impl BrilligSolver { let input_registers = Registers { inner: input_register_values }; let vm = VM::new(input_registers, input_memory, brillig.bytecode.clone()); - let VMOutputState { registers, program_counter, status, memory } = vm.process_opcodes(); + let vm_output = vm.process_opcodes(); + + if vm_output.status == VMStatus::OracleWait { + let program_counter = vm_output.program_counter; - if status == VMStatus::OracleWait { let current_opcode = &brillig.bytecode[program_counter]; let mut data = match current_opcode.clone() { Opcode::Oracle(data) => data, @@ -126,20 +126,7 @@ impl BrilligSolver { } }; - let mut input_values = Vec::new(); - for oracle_input in data.clone().inputs { - if oracle_input.length == 0 { - let x = registers.get(oracle_input.register_mem_index).inner; - input_values.push(x); - } else { - let array_id = registers.get(oracle_input.register_mem_index); - let array = memory[&array_id].clone(); - let heap_fields = - array.memory_map.into_values().map(|value| value.inner).collect::>(); - input_values.extend(heap_fields); - } - } - + let input_values = vm_output.map_input_values(&data); data.input_values = input_values; return Ok(OpcodeResolution::InProgressBrillig(OracleWaitInfo { @@ -148,13 +135,13 @@ impl BrilligSolver { })); } - for (output, register_value) in brillig.outputs.iter().zip(registers) { + for (output, register_value) in brillig.outputs.iter().zip(vm_output.registers) { match output { BrilligOutputs::Simple(witness) => { insert_witness(*witness, register_value.inner, initial_witness)?; } BrilligOutputs::Array(witness_arr) => { - let array = memory[®ister_value].memory_map.values(); + let array = vm_output.memory[®ister_value].memory_map.values(); for (witness, value) in witness_arr.iter().zip(array) { insert_witness(*witness, value.inner, initial_witness)?; } diff --git a/brillig_bytecode/src/lib.rs b/brillig_bytecode/src/lib.rs index fefebd12b..3c1882826 100644 --- a/brillig_bytecode/src/lib.rs +++ b/brillig_bytecode/src/lib.rs @@ -210,7 +210,7 @@ impl VM { lhs: RegisterMemIndex, rhs: RegisterMemIndex, result: RegisterIndex, - result_type: Typ, + _result_type: Typ, ) { let lhs_value = self.registers.get(lhs); let rhs_value = self.registers.get(rhs); @@ -240,510 +240,523 @@ pub struct VMOutputState { pub memory: BTreeMap, } -#[test] -fn add_single_step_smoke() { - // Load values into registers and initialize the registers that - // will be used during bytecode processing - let input_registers = - Registers::load(vec![Value::from(1u128), Value::from(2u128), Value::from(0u128)]); - - // Add opcode to add the value in register `0` and `1` - // and place the output in register `2` - let opcode = Opcode::BinaryOp { - op: BinaryOp::Add, - lhs: RegisterMemIndex::Register(RegisterIndex(0)), - rhs: RegisterMemIndex::Register(RegisterIndex(1)), - result: RegisterIndex(2), - result_type: Typ::Field, - }; - - // Start VM - let mut vm = VM::new(input_registers, BTreeMap::new(), vec![opcode]); - - // Process a single VM opcode - // - // After processing a single opcode, we should have - // the vm status as halted since there is only one opcode - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::Halted); - - // The register at index `2` should have the value of 3 since we had an - // add opcode - let VMOutputState { registers, .. } = vm.finish(); - let output_value = registers.get(RegisterMemIndex::Register(RegisterIndex(2))); - - assert_eq!(output_value, Value::from(3u128)) +impl VMOutputState { + pub fn map_input_values(&self, oracle_data: &OracleData) -> Vec { + let mut input_values = vec![]; + for oracle_input in &oracle_data.inputs { + match oracle_input { + OracleInput::RegisterMemIndex(register_index) => { + let register = self.registers.get(*register_index); + input_values.push(register.inner); + } + OracleInput::Array { start, length } => { + let array_id = self.registers.get(*start); + let array = &self.memory[&array_id]; + let heap_fields = array.memory_map.values().map(|value| value.inner.clone()); + + assert_eq!(heap_fields.len(), *length); + input_values.extend(heap_fields); + } + } + } + input_values + } } -#[test] -fn jmpif_opcode() { - let input_registers = - Registers::load(vec![Value::from(2u128), Value::from(2u128), Value::from(0u128)]); +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn add_single_step_smoke() { + // Load values into registers and initialize the registers that + // will be used during bytecode processing + let input_registers = + Registers::load(vec![Value::from(1u128), Value::from(2u128), Value::from(0u128)]); + + // Add opcode to add the value in register `0` and `1` + // and place the output in register `2` + let opcode = Opcode::BinaryOp { + op: BinaryOp::Add, + lhs: RegisterMemIndex::Register(RegisterIndex(0)), + rhs: RegisterMemIndex::Register(RegisterIndex(1)), + result: RegisterIndex(2), + result_type: Typ::Field, + }; - let equal_cmp_opcode = Opcode::BinaryOp { - result_type: Typ::Field, - op: BinaryOp::Cmp(Comparison::Eq), - lhs: RegisterMemIndex::Register(RegisterIndex(0)), - rhs: RegisterMemIndex::Register(RegisterIndex(1)), - result: RegisterIndex(2), - }; + // Start VM + let mut vm = VM::new(input_registers, BTreeMap::new(), vec![opcode]); - let jump_opcode = Opcode::JMP { destination: 2 }; + // Process a single VM opcode + // + // After processing a single opcode, we should have + // the vm status as halted since there is only one opcode + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::Halted); - let jump_if_opcode = - Opcode::JMPIF { condition: RegisterMemIndex::Register(RegisterIndex(2)), destination: 3 }; + // The register at index `2` should have the value of 3 since we had an + // add opcode + let VMOutputState { registers, .. } = vm.finish(); + let output_value = registers.get(RegisterMemIndex::Register(RegisterIndex(2))); - let mut vm = VM::new( - input_registers, - BTreeMap::new(), - vec![equal_cmp_opcode, jump_opcode, jump_if_opcode], - ); + assert_eq!(output_value, Value::from(3u128)) + } - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); + #[test] + fn jmpif_opcode() { + let input_registers = + Registers::load(vec![Value::from(2u128), Value::from(2u128), Value::from(0u128)]); + + let equal_cmp_opcode = Opcode::BinaryOp { + result_type: Typ::Field, + op: BinaryOp::Cmp(Comparison::Eq), + lhs: RegisterMemIndex::Register(RegisterIndex(0)), + rhs: RegisterMemIndex::Register(RegisterIndex(1)), + result: RegisterIndex(2), + }; - let output_cmp_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); - assert_eq!(output_cmp_value, Value::from(true)); + let jump_opcode = Opcode::JMP { destination: 2 }; - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); + let jump_if_opcode = Opcode::JMPIF { + condition: RegisterMemIndex::Register(RegisterIndex(2)), + destination: 3, + }; - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::Halted); + let mut vm = VM::new( + input_registers, + BTreeMap::new(), + vec![equal_cmp_opcode, jump_opcode, jump_if_opcode], + ); - vm.finish(); -} + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); -#[test] -fn jmpifnot_opcode() { - let input_registers = - Registers::load(vec![Value::from(1u128), Value::from(2u128), Value::from(0u128)]); - - let trap_opcode = Opcode::Trap; - - let not_equal_cmp_opcode = Opcode::BinaryOp { - result_type: Typ::Field, - op: BinaryOp::Cmp(Comparison::Eq), - lhs: RegisterMemIndex::Register(RegisterIndex(0)), - rhs: RegisterMemIndex::Register(RegisterIndex(1)), - result: RegisterIndex(2), - }; - - let jump_opcode = Opcode::JMP { destination: 2 }; - - let jump_if_not_opcode = Opcode::JMPIFNOT { - condition: RegisterMemIndex::Register(RegisterIndex(2)), - destination: 1, - }; - - let add_opcode = Opcode::BinaryOp { - op: BinaryOp::Add, - lhs: RegisterMemIndex::Register(RegisterIndex(0)), - rhs: RegisterMemIndex::Register(RegisterIndex(1)), - result: RegisterIndex(2), - result_type: Typ::Field, - }; - - let mut vm = VM::new( - input_registers, - BTreeMap::new(), - vec![jump_opcode, trap_opcode, not_equal_cmp_opcode, jump_if_not_opcode, add_opcode], - ); - - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); - - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); - - let output_cmp_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); - assert_eq!(output_cmp_value, Value::from(false)); - - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); - - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::Failure); - - // The register at index `2` should have not changed as we jumped over the add opcode - let VMOutputState { registers, .. } = vm.finish(); - let output_value = registers.get(RegisterMemIndex::Register(RegisterIndex(2))); - assert_eq!(output_value, Value::from(false)); -} + let output_cmp_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); + assert_eq!(output_cmp_value, Value::from(true)); -#[test] -fn mov_opcode() { - let input_registers = - Registers::load(vec![Value::from(1u128), Value::from(2u128), Value::from(3u128)]); + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); - let mov_opcode = Opcode::Mov { - destination: RegisterMemIndex::Register(RegisterIndex(2)), - source: RegisterMemIndex::Register(RegisterIndex(0)), - }; + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::Halted); - let mut vm = VM::new(input_registers, BTreeMap::new(), vec![mov_opcode]); + vm.finish(); + } - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::Halted); + #[test] + fn jmpifnot_opcode() { + let input_registers = + Registers::load(vec![Value::from(1u128), Value::from(2u128), Value::from(0u128)]); - let VMOutputState { registers, .. } = vm.finish(); + let trap_opcode = Opcode::Trap; - let destination_value = registers.get(RegisterMemIndex::Register(RegisterIndex(2))); - assert_eq!(destination_value, Value::from(1u128)); + let not_equal_cmp_opcode = Opcode::BinaryOp { + result_type: Typ::Field, + op: BinaryOp::Cmp(Comparison::Eq), + lhs: RegisterMemIndex::Register(RegisterIndex(0)), + rhs: RegisterMemIndex::Register(RegisterIndex(1)), + result: RegisterIndex(2), + }; - let source_value = registers.get(RegisterMemIndex::Register(RegisterIndex(0))); - assert_eq!(source_value, Value::from(1u128)); -} + let jump_opcode = Opcode::JMP { destination: 2 }; -#[test] -fn cmp_binary_ops() { - let input_registers = Registers::load(vec![ - Value::from(2u128), - Value::from(2u128), - Value::from(0u128), - Value::from(5u128), - Value::from(6u128), - ]); - - let equal_opcode = Opcode::BinaryOp { - result_type: Typ::Field, - op: BinaryOp::Cmp(Comparison::Eq), - lhs: RegisterMemIndex::Register(RegisterIndex(0)), - rhs: RegisterMemIndex::Register(RegisterIndex(1)), - result: RegisterIndex(2), - }; - - let not_equal_opcode = Opcode::BinaryOp { - result_type: Typ::Field, - op: BinaryOp::Cmp(Comparison::Eq), - lhs: RegisterMemIndex::Register(RegisterIndex(0)), - rhs: RegisterMemIndex::Register(RegisterIndex(3)), - result: RegisterIndex(2), - }; - - let less_than_opcode = Opcode::BinaryOp { - result_type: Typ::Field, - op: BinaryOp::Cmp(Comparison::Lt), - lhs: RegisterMemIndex::Register(RegisterIndex(3)), - rhs: RegisterMemIndex::Register(RegisterIndex(4)), - result: RegisterIndex(2), - }; - - let less_than_equal_opcode = Opcode::BinaryOp { - result_type: Typ::Field, - op: BinaryOp::Cmp(Comparison::Lte), - lhs: RegisterMemIndex::Register(RegisterIndex(3)), - rhs: RegisterMemIndex::Register(RegisterIndex(4)), - result: RegisterIndex(2), - }; - - let mut vm = VM::new( - input_registers, - BTreeMap::new(), - vec![equal_opcode, not_equal_opcode, less_than_opcode, less_than_equal_opcode], - ); - - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); - - let output_eq_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); - assert_eq!(output_eq_value, Value::from(true)); - - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); - - let output_neq_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); - assert_eq!(output_neq_value, Value::from(false)); - - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); - - let lt_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); - assert_eq!(lt_value, Value::from(true)); - - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::Halted); - - let lte_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); - assert_eq!(lte_value, Value::from(true)); - - vm.finish(); -} + let jump_if_not_opcode = Opcode::JMPIFNOT { + condition: RegisterMemIndex::Register(RegisterIndex(2)), + destination: 1, + }; -#[test] -fn load_opcode() { - let input_registers = Registers::load(vec![ - Value::from(2u128), - Value::from(2u128), - Value::from(0u128), - Value::from(5u128), - Value::from(0u128), - Value::from(6u128), - Value::from(0u128), - ]); - - let equal_cmp_opcode = Opcode::BinaryOp { - result_type: Typ::Field, - op: BinaryOp::Cmp(Comparison::Eq), - lhs: RegisterMemIndex::Register(RegisterIndex(0)), - rhs: RegisterMemIndex::Register(RegisterIndex(1)), - result: RegisterIndex(2), - }; - - let jump_opcode = Opcode::JMP { destination: 3 }; - - let jump_if_opcode = - Opcode::JMPIF { condition: RegisterMemIndex::Register(RegisterIndex(2)), destination: 10 }; - - let load_opcode = Opcode::Load { - destination: RegisterMemIndex::Register(RegisterIndex(4)), - array_id_reg: RegisterMemIndex::Register(RegisterIndex(3)), - index: RegisterMemIndex::Register(RegisterIndex(2)), - }; - - let mem_equal_opcode = Opcode::BinaryOp { - result_type: Typ::Field, - op: BinaryOp::Cmp(Comparison::Eq), - lhs: RegisterMemIndex::Register(RegisterIndex(4)), - rhs: RegisterMemIndex::Register(RegisterIndex(5)), - result: RegisterIndex(6), - }; - - let mut initial_memory = BTreeMap::new(); - let initial_heap = ArrayHeap { - memory_map: BTreeMap::from([(0 as usize, Value::from(5u128)), (1, Value::from(6u128))]), - }; - initial_memory.insert(Value::from(5u128), initial_heap); - - let mut vm = VM::new( - input_registers, - initial_memory, - vec![equal_cmp_opcode, load_opcode, jump_opcode, mem_equal_opcode, jump_if_opcode], - ); - - // equal_cmp_opcode - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); - - let output_cmp_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); - assert_eq!(output_cmp_value, Value::from(true)); - - // load_opcode - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); - - let output_cmp_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(4))); - assert_eq!(output_cmp_value, Value::from(6u128)); - - // jump_opcode - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); - - // mem_equal_opcode - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); - - let output_cmp_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(6))); - assert_eq!(output_cmp_value, Value::from(true)); - - // jump_if_opcode - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::Halted); - - vm.finish(); -} + let add_opcode = Opcode::BinaryOp { + op: BinaryOp::Add, + lhs: RegisterMemIndex::Register(RegisterIndex(0)), + rhs: RegisterMemIndex::Register(RegisterIndex(1)), + result: RegisterIndex(2), + result_type: Typ::Field, + }; -#[test] -fn store_opcode() { - let input_registers = Registers::load(vec![ - Value::from(2u128), - Value::from(2u128), - Value::from(0u128), - Value::from(5u128), - Value::from(0u128), - Value::from(6u128), - Value::from(0u128), - ]); - - let equal_cmp_opcode = Opcode::BinaryOp { - result_type: Typ::Field, - op: BinaryOp::Cmp(Comparison::Eq), - lhs: RegisterMemIndex::Register(RegisterIndex(0)), - rhs: RegisterMemIndex::Register(RegisterIndex(1)), - result: RegisterIndex(2), - }; - - let jump_opcode = Opcode::JMP { destination: 3 }; - - let jump_if_opcode = - Opcode::JMPIF { condition: RegisterMemIndex::Register(RegisterIndex(2)), destination: 10 }; - - let store_opcode = Opcode::Store { - source: RegisterMemIndex::Register(RegisterIndex(2)), - array_id_reg: RegisterMemIndex::Register(RegisterIndex(3)), - index: RegisterMemIndex::Constant(FieldElement::from(3_u128)), - }; - - let mut initial_memory = BTreeMap::new(); - let initial_heap = ArrayHeap { - memory_map: BTreeMap::from([(0 as usize, Value::from(5u128)), (1, Value::from(6u128))]), - }; - initial_memory.insert(Value::from(5u128), initial_heap); - - let mut vm = VM::new( - input_registers, - initial_memory, - vec![equal_cmp_opcode, store_opcode, jump_opcode, jump_if_opcode], - ); - - // equal_cmp_opcode - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); - - let output_cmp_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); - assert_eq!(output_cmp_value, Value::from(true)); - - // store_opcode - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); - - let mem_array = vm.memory[&Value::from(5u128)].clone(); - assert_eq!(mem_array.memory_map[&3], Value::from(true)); - - // jump_opcode - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::InProgress); - - // jump_if_opcode - let status = vm.process_opcode(); - assert_eq!(status, VMStatus::Halted); - - vm.finish(); -} + let mut vm = VM::new( + input_registers, + BTreeMap::new(), + vec![jump_opcode, trap_opcode, not_equal_cmp_opcode, jump_if_not_opcode, add_opcode], + ); -#[test] -fn oracle_array_output() { - use crate::opcodes::OracleInput; - - let input_registers = Registers::load(vec![ - Value::from(2u128), - Value::from(2u128), - Value::from(0u128), - Value::from(5u128), - Value::from(0u128), - Value::from(6u128), - Value::from(0u128), - ]); - - let oracle_input = - OracleInput { register_mem_index: RegisterMemIndex::Register(RegisterIndex(0)), length: 0 }; - - let mut oracle_data = OracleData { - name: "get_notes".to_owned(), - inputs: vec![oracle_input], - input_values: vec![], - output: RegisterIndex(3), - output_values: vec![], - }; - - let oracle_opcode = Opcode::Oracle(oracle_data.clone()); - - let initial_memory = BTreeMap::new(); - - let vm = VM::new(input_registers.clone(), initial_memory, vec![oracle_opcode]); - - let output_state = vm.process_opcodes(); - assert_eq!(output_state.status, VMStatus::OracleWait); - - let mut input_values = Vec::new(); - for oracle_input in oracle_data.clone().inputs { - if oracle_input.length == 0 { - let x = output_state.registers.get(oracle_input.register_mem_index).inner; - input_values.push(x); - } else { - let array_id = output_state.registers.get(oracle_input.register_mem_index); - let array = output_state.memory[&array_id].clone(); - let heap_fields = - array.memory_map.into_values().map(|value| value.inner).collect::>(); - input_values.extend(heap_fields); - } + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let output_cmp_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); + assert_eq!(output_cmp_value, Value::from(false)); + + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::Failure); + + // The register at index `2` should have not changed as we jumped over the add opcode + let VMOutputState { registers, .. } = vm.finish(); + let output_value = registers.get(RegisterMemIndex::Register(RegisterIndex(2))); + assert_eq!(output_value, Value::from(false)); } - oracle_data.input_values = input_values; - oracle_data.output_values = vec![FieldElement::from(10_u128), FieldElement::from(2_u128)]; - let updated_oracle_opcode = Opcode::Oracle(oracle_data); + #[test] + fn mov_opcode() { + let input_registers = + Registers::load(vec![Value::from(1u128), Value::from(2u128), Value::from(3u128)]); - let vm = VM::new(input_registers, output_state.memory, vec![updated_oracle_opcode]); - let output_state = vm.process_opcodes(); - assert_eq!(output_state.status, VMStatus::Halted); + let mov_opcode = Opcode::Mov { + destination: RegisterMemIndex::Register(RegisterIndex(2)), + source: RegisterMemIndex::Register(RegisterIndex(0)), + }; - let mem_array = output_state.memory[&Value::from(5u128)].clone(); - assert_eq!(mem_array.memory_map[&0], Value::from(10_u128)); - assert_eq!(mem_array.memory_map[&1], Value::from(2_u128)); -} + let mut vm = VM::new(input_registers, BTreeMap::new(), vec![mov_opcode]); -#[test] -fn oracle_array_input() { - use crate::opcodes::OracleInput; - - let input_registers = Registers::load(vec![ - Value::from(2u128), - Value::from(2u128), - Value::from(0u128), - Value::from(5u128), - Value::from(0u128), - Value::from(6u128), - Value::from(0u128), - ]); - - let oracle_input = - OracleInput { register_mem_index: RegisterMemIndex::Register(RegisterIndex(3)), length: 2 }; - - let mut oracle_data = OracleData { - name: "call_private_function_oracle".to_owned(), - inputs: vec![oracle_input.clone()], - input_values: vec![], - output: RegisterIndex(6), - output_values: vec![], - }; - - let oracle_opcode = Opcode::Oracle(oracle_data.clone()); - - let mut initial_memory = BTreeMap::new(); - let initial_heap = ArrayHeap { - memory_map: BTreeMap::from([(0 as usize, Value::from(5u128)), (1, Value::from(6u128))]), - }; - initial_memory.insert(Value::from(5u128), initial_heap); - - let vm = VM::new(input_registers.clone(), initial_memory, vec![oracle_opcode]); - - let output_state = vm.process_opcodes(); - assert_eq!(output_state.status, VMStatus::OracleWait); - - let mut input_values = Vec::new(); - for oracle_input in oracle_data.clone().inputs { - if oracle_input.length == 0 { - let x = output_state.registers.get(oracle_input.register_mem_index).inner; - input_values.push(x); - } else { - let array_id = output_state.registers.get(oracle_input.register_mem_index); - let array = output_state.memory[&array_id].clone(); - let heap_fields = - array.memory_map.into_values().map(|value| value.inner).collect::>(); - input_values.extend(heap_fields); - } + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::Halted); + + let VMOutputState { registers, .. } = vm.finish(); + + let destination_value = registers.get(RegisterMemIndex::Register(RegisterIndex(2))); + assert_eq!(destination_value, Value::from(1u128)); + + let source_value = registers.get(RegisterMemIndex::Register(RegisterIndex(0))); + assert_eq!(source_value, Value::from(1u128)); + } + + #[test] + fn cmp_binary_ops() { + let input_registers = Registers::load(vec![ + Value::from(2u128), + Value::from(2u128), + Value::from(0u128), + Value::from(5u128), + Value::from(6u128), + ]); + + let equal_opcode = Opcode::BinaryOp { + result_type: Typ::Field, + op: BinaryOp::Cmp(Comparison::Eq), + lhs: RegisterMemIndex::Register(RegisterIndex(0)), + rhs: RegisterMemIndex::Register(RegisterIndex(1)), + result: RegisterIndex(2), + }; + + let not_equal_opcode = Opcode::BinaryOp { + result_type: Typ::Field, + op: BinaryOp::Cmp(Comparison::Eq), + lhs: RegisterMemIndex::Register(RegisterIndex(0)), + rhs: RegisterMemIndex::Register(RegisterIndex(3)), + result: RegisterIndex(2), + }; + + let less_than_opcode = Opcode::BinaryOp { + result_type: Typ::Field, + op: BinaryOp::Cmp(Comparison::Lt), + lhs: RegisterMemIndex::Register(RegisterIndex(3)), + rhs: RegisterMemIndex::Register(RegisterIndex(4)), + result: RegisterIndex(2), + }; + + let less_than_equal_opcode = Opcode::BinaryOp { + result_type: Typ::Field, + op: BinaryOp::Cmp(Comparison::Lte), + lhs: RegisterMemIndex::Register(RegisterIndex(3)), + rhs: RegisterMemIndex::Register(RegisterIndex(4)), + result: RegisterIndex(2), + }; + + let mut vm = VM::new( + input_registers, + BTreeMap::new(), + vec![equal_opcode, not_equal_opcode, less_than_opcode, less_than_equal_opcode], + ); + + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let output_eq_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); + assert_eq!(output_eq_value, Value::from(true)); + + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let output_neq_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); + assert_eq!(output_neq_value, Value::from(false)); + + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let lt_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); + assert_eq!(lt_value, Value::from(true)); + + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::Halted); + + let lte_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); + assert_eq!(lte_value, Value::from(true)); + + vm.finish(); } - assert_eq!(input_values.len(), oracle_input.length); - oracle_data.input_values = input_values; - oracle_data.output_values = vec![FieldElement::from(5_u128)]; - let updated_oracle_opcode = Opcode::Oracle(oracle_data); + #[test] + fn load_opcode() { + let input_registers = Registers::load(vec![ + Value::from(2u128), + Value::from(2u128), + Value::from(0u128), + Value::from(5u128), + Value::from(0u128), + Value::from(6u128), + Value::from(0u128), + ]); + + let equal_cmp_opcode = Opcode::BinaryOp { + result_type: Typ::Field, + op: BinaryOp::Cmp(Comparison::Eq), + lhs: RegisterMemIndex::Register(RegisterIndex(0)), + rhs: RegisterMemIndex::Register(RegisterIndex(1)), + result: RegisterIndex(2), + }; + + let jump_opcode = Opcode::JMP { destination: 3 }; - let vm = VM::new(input_registers, output_state.memory, vec![updated_oracle_opcode]); - let output_state = vm.process_opcodes(); - assert_eq!(output_state.status, VMStatus::Halted); + let jump_if_opcode = Opcode::JMPIF { + condition: RegisterMemIndex::Register(RegisterIndex(2)), + destination: 10, + }; - let mem_array = output_state.memory[&Value::from(5u128)].clone(); - assert_eq!(mem_array.memory_map[&0], Value::from(5_u128)); - assert_eq!(mem_array.memory_map[&1], Value::from(6_u128)); + let load_opcode = Opcode::Load { + destination: RegisterMemIndex::Register(RegisterIndex(4)), + array_id_reg: RegisterMemIndex::Register(RegisterIndex(3)), + index: RegisterMemIndex::Register(RegisterIndex(2)), + }; + + let mem_equal_opcode = Opcode::BinaryOp { + result_type: Typ::Field, + op: BinaryOp::Cmp(Comparison::Eq), + lhs: RegisterMemIndex::Register(RegisterIndex(4)), + rhs: RegisterMemIndex::Register(RegisterIndex(5)), + result: RegisterIndex(6), + }; + + let mut initial_memory = BTreeMap::new(); + let initial_heap = ArrayHeap { + memory_map: BTreeMap::from([(0 as usize, Value::from(5u128)), (1, Value::from(6u128))]), + }; + initial_memory.insert(Value::from(5u128), initial_heap); + + let mut vm = VM::new( + input_registers, + initial_memory, + vec![equal_cmp_opcode, load_opcode, jump_opcode, mem_equal_opcode, jump_if_opcode], + ); + + // equal_cmp_opcode + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let output_cmp_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); + assert_eq!(output_cmp_value, Value::from(true)); + + // load_opcode + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let output_cmp_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(4))); + assert_eq!(output_cmp_value, Value::from(6u128)); + + // jump_opcode + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + // mem_equal_opcode + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let output_cmp_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(6))); + assert_eq!(output_cmp_value, Value::from(true)); + + // jump_if_opcode + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::Halted); + + vm.finish(); + } + + #[test] + fn store_opcode() { + let input_registers = Registers::load(vec![ + Value::from(2u128), + Value::from(2u128), + Value::from(0u128), + Value::from(5u128), + Value::from(0u128), + Value::from(6u128), + Value::from(0u128), + ]); + + let equal_cmp_opcode = Opcode::BinaryOp { + result_type: Typ::Field, + op: BinaryOp::Cmp(Comparison::Eq), + lhs: RegisterMemIndex::Register(RegisterIndex(0)), + rhs: RegisterMemIndex::Register(RegisterIndex(1)), + result: RegisterIndex(2), + }; + + let jump_opcode = Opcode::JMP { destination: 3 }; + + let jump_if_opcode = Opcode::JMPIF { + condition: RegisterMemIndex::Register(RegisterIndex(2)), + destination: 10, + }; + + let store_opcode = Opcode::Store { + source: RegisterMemIndex::Register(RegisterIndex(2)), + array_id_reg: RegisterMemIndex::Register(RegisterIndex(3)), + index: RegisterMemIndex::Constant(FieldElement::from(3_u128)), + }; + + let mut initial_memory = BTreeMap::new(); + let initial_heap = ArrayHeap { + memory_map: BTreeMap::from([(0 as usize, Value::from(5u128)), (1, Value::from(6u128))]), + }; + initial_memory.insert(Value::from(5u128), initial_heap); + + let mut vm = VM::new( + input_registers, + initial_memory, + vec![equal_cmp_opcode, store_opcode, jump_opcode, jump_if_opcode], + ); + + // equal_cmp_opcode + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let output_cmp_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); + assert_eq!(output_cmp_value, Value::from(true)); + + // store_opcode + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let mem_array = vm.memory[&Value::from(5u128)].clone(); + assert_eq!(mem_array.memory_map[&3], Value::from(true)); + + // jump_opcode + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + // jump_if_opcode + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::Halted); + + vm.finish(); + } + + #[test] + fn oracle_array_output() { + use crate::opcodes::OracleInput; + + let input_registers = Registers::load(vec![ + Value::from(2u128), + Value::from(2u128), + Value::from(0u128), + Value::from(5u128), + Value::from(0u128), + Value::from(6u128), + Value::from(0u128), + ]); + + let oracle_input = + OracleInput::RegisterMemIndex(RegisterMemIndex::Register(RegisterIndex(0))); + + let mut oracle_data = OracleData { + name: "get_notes".to_owned(), + inputs: vec![oracle_input], + input_values: vec![], + output: RegisterIndex(3), + output_values: vec![], + }; + + let oracle_opcode = Opcode::Oracle(oracle_data.clone()); + + let initial_memory = BTreeMap::new(); + + let vm = VM::new(input_registers.clone(), initial_memory, vec![oracle_opcode]); + + let output_state = vm.process_opcodes(); + assert_eq!(output_state.status, VMStatus::OracleWait); + + let input_values = output_state.map_input_values(&oracle_data); + + oracle_data.input_values = input_values; + oracle_data.output_values = vec![FieldElement::from(10_u128), FieldElement::from(2_u128)]; + let updated_oracle_opcode = Opcode::Oracle(oracle_data); + + let vm = VM::new(input_registers, output_state.memory, vec![updated_oracle_opcode]); + let output_state = vm.process_opcodes(); + assert_eq!(output_state.status, VMStatus::Halted); + + let mem_array = output_state.memory[&Value::from(5u128)].clone(); + assert_eq!(mem_array.memory_map[&0], Value::from(10_u128)); + assert_eq!(mem_array.memory_map[&1], Value::from(2_u128)); + } + + #[test] + fn oracle_array_input() { + use crate::opcodes::OracleInput; + + let input_registers = Registers::load(vec![ + Value::from(2u128), + Value::from(2u128), + Value::from(0u128), + Value::from(5u128), + Value::from(0u128), + Value::from(6u128), + Value::from(0u128), + ]); + + let expected_length = 2; + let oracle_input = OracleInput::Array { + start: RegisterMemIndex::Register(RegisterIndex(3)), + length: expected_length, + }; + + let mut oracle_data = OracleData { + name: "call_private_function_oracle".to_owned(), + inputs: vec![oracle_input.clone()], + input_values: vec![], + output: RegisterIndex(6), + output_values: vec![], + }; + + let oracle_opcode = Opcode::Oracle(oracle_data.clone()); + + let mut initial_memory = BTreeMap::new(); + let initial_heap = ArrayHeap { + memory_map: BTreeMap::from([(0 as usize, Value::from(5u128)), (1, Value::from(6u128))]), + }; + initial_memory.insert(Value::from(5u128), initial_heap); + + let vm = VM::new(input_registers.clone(), initial_memory, vec![oracle_opcode]); + + let output_state = vm.process_opcodes(); + assert_eq!(output_state.status, VMStatus::OracleWait); + + let input_values = output_state.map_input_values(&oracle_data); + assert_eq!(input_values.len(), expected_length); + + oracle_data.input_values = input_values; + oracle_data.output_values = vec![FieldElement::from(5_u128)]; + let updated_oracle_opcode = Opcode::Oracle(oracle_data); + + let vm = VM::new(input_registers, output_state.memory, vec![updated_oracle_opcode]); + let output_state = vm.process_opcodes(); + assert_eq!(output_state.status, VMStatus::Halted); + + let mem_array = output_state.memory[&Value::from(5u128)].clone(); + assert_eq!(mem_array.memory_map[&0], Value::from(5_u128)); + assert_eq!(mem_array.memory_map[&1], Value::from(6_u128)); + } } diff --git a/brillig_bytecode/src/memory.rs b/brillig_bytecode/src/memory.rs index 78d216404..24b218dac 100644 --- a/brillig_bytecode/src/memory.rs +++ b/brillig_bytecode/src/memory.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; + /// Memory in the VM is used for storing arrays /// /// ArrayIndex will be used to reference an Array element. diff --git a/brillig_bytecode/src/opcodes.rs b/brillig_bytecode/src/opcodes.rs index 5bee76a4d..f206f2f6d 100644 --- a/brillig_bytecode/src/opcodes.rs +++ b/brillig_bytecode/src/opcodes.rs @@ -118,9 +118,9 @@ pub struct OracleData { } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct OracleInput { - pub register_mem_index: RegisterMemIndex, - pub length: usize, +pub enum OracleInput { + RegisterMemIndex(RegisterMemIndex), + Array { start: RegisterMemIndex, length: usize }, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] diff --git a/brillig_bytecode/src/value.rs b/brillig_bytecode/src/value.rs index d66313876..a59e8b02b 100644 --- a/brillig_bytecode/src/value.rs +++ b/brillig_bytecode/src/value.rs @@ -26,10 +26,10 @@ impl Value { pub fn inverse(&self) -> Value { let value = match self.typ { Typ::Field => self.inner.inverse(), - Typ::Unsigned { bit_size } => { + Typ::Unsigned { bit_size: _ } => { todo!("TODO") } - Typ::Signed { bit_size } => todo!("TODO"), + Typ::Signed { bit_size: _ } => todo!("TODO"), }; Value { typ: self.typ, inner: value } }