Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

feat!: Brillig oracle array inputs #199

Merged
merged 1 commit into from
Apr 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions acvm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -358,7 +358,9 @@ mod test {

use acir::{
brillig_bytecode,
brillig_bytecode::{BinaryOp, Comparison, RegisterIndex, RegisterMemIndex, Typ},
brillig_bytecode::{
BinaryOp, Comparison, OracleInput, RegisterIndex, RegisterMemIndex, Typ,
},
circuit::{
directives::Directive,
opcodes::{BlackBoxFuncCall, Brillig, BrilligInputs, BrilligOutputs, OracleData},
@@ -488,9 +490,14 @@ mod test {
result: RegisterIndex(3),
};

let invert_oracle_input = OracleInput {
register_mem_index: RegisterMemIndex::Register(RegisterIndex(0)),
length: 0,
};

let invert_oracle = brillig_bytecode::Opcode::Oracle(brillig_bytecode::OracleData {
name: "invert".into(),
inputs: vec![RegisterMemIndex::Register(RegisterIndex(0))],
inputs: vec![invert_oracle_input],
input_values: vec![],
output: RegisterIndex(1),
output_values: vec![],
@@ -610,9 +617,14 @@ mod test {
result: RegisterIndex(3),
};

let invert_oracle_input = OracleInput {
register_mem_index: RegisterMemIndex::Register(RegisterIndex(0)),
length: 0,
};

let invert_oracle = brillig_bytecode::Opcode::Oracle(brillig_bytecode::OracleData {
name: "invert".into(),
inputs: vec![RegisterMemIndex::Register(RegisterIndex(0))],
inputs: vec![invert_oracle_input],
input_values: vec![],
output: RegisterIndex(1),
output_values: vec![],
20 changes: 14 additions & 6 deletions acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
@@ -126,12 +126,20 @@ impl BrilligSolver {
}
};

let input_values = data
.clone()
.inputs
.into_iter()
.map(|register_mem_index| registers.get(register_mem_index).inner)
.collect::<Vec<_>>();
let mut input_values = Vec::new();
for oracle_input in data.clone().inputs {
if oracle_input.length == 0 {
let x = registers.get(oracle_input.register_mem_index).inner;
input_values.push(x);
} else {
let array_id = registers.get(oracle_input.register_mem_index);
let array = memory[&array_id].clone();
let heap_fields =
array.memory_map.into_values().map(|value| value.inner).collect::<Vec<_>>();
input_values.extend(heap_fields);
}
}

data.input_values = input_values;

return Ok(OpcodeResolution::InProgressBrillig(OracleWaitInfo {
94 changes: 86 additions & 8 deletions brillig_bytecode/src/lib.rs
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ use std::collections::BTreeMap;

use acir_field::FieldElement;
pub use opcodes::RegisterMemIndex;
pub use opcodes::{BinaryOp, Comparison, Opcode, OracleData};
pub use opcodes::{BinaryOp, Comparison, Opcode, OracleData, OracleInput};
pub use registers::{RegisterIndex, Registers};
pub use value::Typ;
pub use value::Value;
@@ -623,6 +623,8 @@ fn store_opcode() {

#[test]
fn oracle_array_output() {
use crate::opcodes::OracleInput;

let input_registers = Registers::load(vec![
Value::from(2u128),
Value::from(2u128),
@@ -633,9 +635,12 @@ fn oracle_array_output() {
Value::from(0u128),
]);

let oracle_input =
OracleInput { register_mem_index: RegisterMemIndex::Register(RegisterIndex(0)), length: 0 };

let mut oracle_data = OracleData {
name: "get_notes".to_owned(),
inputs: vec![RegisterMemIndex::Register(RegisterIndex(0))],
inputs: vec![oracle_input],
input_values: vec![],
output: RegisterIndex(3),
output_values: vec![],
@@ -650,12 +655,19 @@ fn oracle_array_output() {
let output_state = vm.process_opcodes();
assert_eq!(output_state.status, VMStatus::OracleWait);

let input_values = oracle_data
.clone()
.inputs
.into_iter()
.map(|register_mem_index| output_state.registers.get(register_mem_index).inner)
.collect::<Vec<_>>();
let mut input_values = Vec::new();
for oracle_input in oracle_data.clone().inputs {
if oracle_input.length == 0 {
let x = output_state.registers.get(oracle_input.register_mem_index).inner;
input_values.push(x);
} else {
let array_id = output_state.registers.get(oracle_input.register_mem_index);
let array = output_state.memory[&array_id].clone();
let heap_fields =
array.memory_map.into_values().map(|value| value.inner).collect::<Vec<_>>();
input_values.extend(heap_fields);
}
}

oracle_data.input_values = input_values;
oracle_data.output_values = vec![FieldElement::from(10_u128), FieldElement::from(2_u128)];
@@ -669,3 +681,69 @@ fn oracle_array_output() {
assert_eq!(mem_array.memory_map[&0], Value::from(10_u128));
assert_eq!(mem_array.memory_map[&1], Value::from(2_u128));
}

#[test]
fn oracle_array_input() {
use crate::opcodes::OracleInput;

let input_registers = Registers::load(vec![
Value::from(2u128),
Value::from(2u128),
Value::from(0u128),
Value::from(5u128),
Value::from(0u128),
Value::from(6u128),
Value::from(0u128),
]);

let oracle_input =
OracleInput { register_mem_index: RegisterMemIndex::Register(RegisterIndex(3)), length: 2 };

let mut oracle_data = OracleData {
name: "call_private_function_oracle".to_owned(),
inputs: vec![oracle_input.clone()],
input_values: vec![],
output: RegisterIndex(6),
output_values: vec![],
};

let oracle_opcode = Opcode::Oracle(oracle_data.clone());

let mut initial_memory = BTreeMap::new();
let initial_heap = ArrayHeap {
memory_map: BTreeMap::from([(0 as usize, Value::from(5u128)), (1, Value::from(6u128))]),
};
initial_memory.insert(Value::from(5u128), initial_heap);

let vm = VM::new(input_registers.clone(), initial_memory, vec![oracle_opcode]);

let output_state = vm.process_opcodes();
assert_eq!(output_state.status, VMStatus::OracleWait);

let mut input_values = Vec::new();
for oracle_input in oracle_data.clone().inputs {
if oracle_input.length == 0 {
let x = output_state.registers.get(oracle_input.register_mem_index).inner;
input_values.push(x);
} else {
let array_id = output_state.registers.get(oracle_input.register_mem_index);
let array = output_state.memory[&array_id].clone();
let heap_fields =
array.memory_map.into_values().map(|value| value.inner).collect::<Vec<_>>();
input_values.extend(heap_fields);
}
}
assert_eq!(input_values.len(), oracle_input.length);

oracle_data.input_values = input_values;
oracle_data.output_values = vec![FieldElement::from(5_u128)];
let updated_oracle_opcode = Opcode::Oracle(oracle_data);

let vm = VM::new(input_registers, output_state.memory, vec![updated_oracle_opcode]);
let output_state = vm.process_opcodes();
assert_eq!(output_state.status, VMStatus::Halted);

let mem_array = output_state.memory[&Value::from(5u128)].clone();
assert_eq!(mem_array.memory_map[&0], Value::from(5_u128));
assert_eq!(mem_array.memory_map[&1], Value::from(6_u128));
}
8 changes: 7 additions & 1 deletion brillig_bytecode/src/opcodes.rs
Original file line number Diff line number Diff line change
@@ -108,7 +108,7 @@ pub struct OracleData {
/// Name of the oracle
pub name: String,
/// Input registers
pub inputs: Vec<RegisterMemIndex>,
pub inputs: Vec<OracleInput>,
/// Input values
pub input_values: Vec<FieldElement>,
/// Output register
@@ -117,6 +117,12 @@ pub struct OracleData {
pub output_values: Vec<FieldElement>,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OracleInput {
pub register_mem_index: RegisterMemIndex,
pub length: usize,
}
Comment on lines +121 to +124
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of this representation to handle registers and arrays. A register value is not a 0-length array, it is a fundamentally different type so we should represent it as such. This should be an enum that is either Register or Array. This representation would allow easier extension in the future and prevent future bugs from forgetting to check the length field.


#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BinaryOp {
Add,