Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

chore: Enforce proper conversion of memory into fixed length array #163

Merged
merged 4 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/acvm_interop/smart_contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,10 @@ impl SmartContract for Barretenberg {

// We then need to read the pointer at `contract_ptr_ptr` to get the smart contract's location
// and then slice memory again at `contract_ptr_ptr` to get the smart contract string.
let contract_ptr = self.slice_memory(contract_ptr_ptr, POINTER_BYTES);
let contract_ptr: usize =
u32::from_le_bytes(contract_ptr[0..POINTER_BYTES].try_into().unwrap()) as usize;
let contract_ptr: [u8; POINTER_BYTES] = self.read_memory(contract_ptr_ptr);
let contract_ptr: usize = u32::from_le_bytes(contract_ptr) as usize;

let sc_as_bytes = self.slice_memory(contract_ptr, contract_size);
let sc_as_bytes = self.read_memory_variable_length(contract_ptr, contract_size);

let verification_key_library: String = sc_as_bytes.iter().map(|b| *b as char).collect();
format!("{verification_key_library}{ULTRA_VERIFIER_CONTRACT}")
Expand Down
21 changes: 9 additions & 12 deletions src/composer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,10 @@ impl Composer for Barretenberg {

// We then need to read the pointer at `pk_ptr_ptr` to get the key's location
// and then slice memory again at `pk_ptr` to get the proving key.
let pk_ptr = self.slice_memory(pk_ptr_ptr, POINTER_BYTES);
let pk_ptr: usize =
u32::from_le_bytes(pk_ptr[0..POINTER_BYTES].try_into().unwrap()) as usize;
let pk_ptr: [u8; POINTER_BYTES] = self.read_memory(pk_ptr_ptr);
let pk_ptr: usize = u32::from_le_bytes(pk_ptr) as usize;

self.slice_memory(pk_ptr, pk_size)
self.read_memory_variable_length(pk_ptr, pk_size)
}

fn compute_verification_key(
Expand Down Expand Up @@ -290,11 +289,10 @@ impl Composer for Barretenberg {

// We then need to read the pointer at `vk_ptr_ptr` to get the key's location
// and then slice memory again at `vk_ptr` to get the verification key.
let vk_ptr = self.slice_memory(vk_ptr_ptr, POINTER_BYTES);
let vk_ptr: usize =
u32::from_le_bytes(vk_ptr[0..POINTER_BYTES].try_into().unwrap()) as usize;
let vk_ptr: [u8; POINTER_BYTES] = self.read_memory(vk_ptr_ptr);
let vk_ptr: usize = u32::from_le_bytes(vk_ptr) as usize;

self.slice_memory(vk_ptr, vk_size)
self.read_memory_variable_length(vk_ptr, vk_size)
}

fn create_proof_with_pk(
Expand Down Expand Up @@ -339,11 +337,10 @@ impl Composer for Barretenberg {

// We then need to read the pointer at `proof_ptr_ptr` to get the proof's location
// and then slice memory again at `proof_ptr` to get the proof data.
let proof_ptr = self.slice_memory(proof_ptr_ptr, POINTER_BYTES);
let proof_ptr: usize =
u32::from_le_bytes(proof_ptr[0..POINTER_BYTES].try_into().unwrap()) as usize;
let proof_ptr: [u8; POINTER_BYTES] = self.read_memory(proof_ptr_ptr);
let proof_ptr: usize = u32::from_le_bytes(proof_ptr) as usize;

let result = self.slice_memory(proof_ptr, proof_size);
let result = self.read_memory_variable_length(proof_ptr, proof_size);

// Barretenberg returns proofs which are prepended with the public inputs.
// This behavior is nonstandard so we strip the public inputs from the proof.
Expand Down
11 changes: 8 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,18 @@ mod wasm {
}
}

// XXX: change to read_mem
pub(super) fn slice_memory(&self, start: usize, length: usize) -> Vec<u8> {
pub(super) fn read_memory<const SIZE: usize>(&self, start: usize) -> [u8; SIZE] {
self.read_memory_variable_length(start, SIZE)
.try_into()
.expect("Read memory should be of the specified length")
}

pub(super) fn read_memory_variable_length(&self, start: usize, length: usize) -> Vec<u8> {
let memory = &self.memory;
let end = start + length;

#[cfg(feature = "js")]
return memory.uint8view().to_vec()[start as usize..end].to_vec();
return memory.uint8view().to_vec()[start..end].to_vec();

#[cfg(not(feature = "js"))]
return memory.view()[start..end]
Expand Down
6 changes: 3 additions & 3 deletions src/pedersen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl Pedersen for Barretenberg {
vec![&lhs_ptr.into(), &rhs_ptr.into(), &result_ptr.into()],
);

let result_bytes = self.slice_memory(result_ptr, FIELD_BYTES);
let result_bytes: [u8; FIELD_BYTES] = self.read_memory(result_ptr);
FieldElement::from_be_bytes_reduce(&result_bytes)
}

Expand All @@ -83,7 +83,7 @@ impl Pedersen for Barretenberg {
vec![&input_ptr, &result_ptr.into()],
);

let result_bytes = self.slice_memory(result_ptr, FIELD_BYTES);
let result_bytes: [u8; FIELD_BYTES] = self.read_memory(result_ptr);
FieldElement::from_be_bytes_reduce(&result_bytes)
}

Expand All @@ -100,7 +100,7 @@ impl Pedersen for Barretenberg {
vec![&input_ptr, &result_ptr.into()],
);

let result_bytes = self.slice_memory(result_ptr, 2 * FIELD_BYTES);
let result_bytes: [u8; 2 * FIELD_BYTES] = self.read_memory(result_ptr);
let (point_x_bytes, point_y_bytes) = result_bytes.split_at(FIELD_BYTES);

let point_x = FieldElement::from_be_bytes_reduce(point_x_bytes);
Expand Down
2 changes: 1 addition & 1 deletion src/scalar_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl ScalarMul for Barretenberg {
vec![&lhs_ptr.into(), &result_ptr.into()],
);

let result_bytes = self.slice_memory(result_ptr, 2 * FIELD_BYTES);
let result_bytes: [u8; 2 * FIELD_BYTES] = self.read_memory(result_ptr);
let (pubkey_x_bytes, pubkey_y_bytes) = result_bytes.split_at(FIELD_BYTES);

assert!(pubkey_x_bytes.len() == FIELD_BYTES);
Expand Down
10 changes: 3 additions & 7 deletions src/schnorr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@ impl SchnorrSig for Barretenberg {
],
);

let sig_s_bytes = self.slice_memory(sig_s_ptr, FIELD_BYTES);
let sig_e_bytes = self.slice_memory(sig_e_ptr, FIELD_BYTES);
let sig_s: [u8; 32] = sig_s_bytes.try_into().unwrap();
let sig_e: [u8; 32] = sig_e_bytes.try_into().unwrap();
let sig_s: [u8; FIELD_BYTES] = self.read_memory(sig_s_ptr);
let sig_e: [u8; FIELD_BYTES] = self.read_memory(sig_e_ptr);

let sig_bytes: [u8; 64] = [sig_s, sig_e].concat().try_into().unwrap();
sig_bytes
Expand All @@ -82,9 +80,7 @@ impl SchnorrSig for Barretenberg {
vec![&private_key_ptr.into(), &result_ptr.into()],
);

self.slice_memory(result_ptr, 2 * FIELD_BYTES)
.try_into()
.unwrap()
self.read_memory(result_ptr)
}

fn verify_signature(&self, pub_key: [u8; 64], sig: [u8; 64], message: &[u8]) -> bool {
Expand Down