Skip to content

Commit

Permalink
chore: move implementation of bitwise operations into `blackbox_solve…
Browse files Browse the repository at this point in the history
…r` (#5209)

# Description

## Problem\*

Related to #5055 
## Summary\*

We don't need the field element to know about AND and XOR as we just
need to be able to convert into byte arrays and back. I've then moved
this into the blackbox solver crate (the only place we were actually
using this) to simplify the `AcirField` trait.

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
TomAFrench authored Jun 10, 2024
1 parent 8a32299 commit b1298b8
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 90 deletions.
81 changes: 0 additions & 81 deletions acvm-repo/acir_field/src/field_element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,6 @@ impl<F: PrimeField> FieldElement<F> {
Some(FieldElement(fr))
}

// mask_to methods will not remove any bytes from the field
// they are simply zeroed out
// Whereas truncate_to will remove those bits and make the byte array smaller
fn mask_to_be_bytes(&self, num_bits: u32) -> Vec<u8> {
let mut bytes = self.to_be_bytes();
mask_vector_le(&mut bytes, num_bits as usize);
bytes
}

fn bits(&self) -> Vec<bool> {
fn byte_to_bit(byte: u8) -> Vec<bool> {
let mut bits = Vec::with_capacity(8);
Expand All @@ -220,29 +211,6 @@ impl<F: PrimeField> FieldElement<F> {
}
bits
}

fn and_xor(&self, rhs: &FieldElement<F>, num_bits: u32, is_xor: bool) -> FieldElement<F> {
// XXX: Gadgets like SHA256 need to have their input be a multiple of 8
// This is not a restriction caused by SHA256, as it works on bits
// but most backends assume bytes.
// We could implicitly pad, however this may not be intuitive for users.
// assert!(
// num_bits % 8 == 0,
// "num_bits is not a multiple of 8, it is {}",
// num_bits
// );

let lhs_bytes = self.mask_to_be_bytes(num_bits);
let rhs_bytes = rhs.mask_to_be_bytes(num_bits);

let and_byte_arr: Vec<_> = lhs_bytes
.into_iter()
.zip(rhs_bytes)
.map(|(lhs, rhs)| if is_xor { lhs ^ rhs } else { lhs & rhs })
.collect();

FieldElement::from_be_bytes_reduce(&and_byte_arr)
}
}

impl<F: PrimeField> AcirField for FieldElement<F> {
Expand Down Expand Up @@ -376,13 +344,6 @@ impl<F: PrimeField> AcirField for FieldElement<F> {

bytes[0..num_elements].to_vec()
}

fn and(&self, rhs: &FieldElement<F>, num_bits: u32) -> FieldElement<F> {
self.and_xor(rhs, num_bits, false)
}
fn xor(&self, rhs: &FieldElement<F>, num_bits: u32) -> FieldElement<F> {
self.and_xor(rhs, num_bits, true)
}
}

impl<F: PrimeField> Neg for FieldElement<F> {
Expand Down Expand Up @@ -433,35 +394,6 @@ impl<F: PrimeField> SubAssign for FieldElement<F> {
}
}

fn mask_vector_le(bytes: &mut [u8], num_bits: usize) {
// reverse to big endian format
bytes.reverse();

let mask_power = num_bits % 8;
let array_mask_index = num_bits / 8;

for (index, byte) in bytes.iter_mut().enumerate() {
match index.cmp(&array_mask_index) {
std::cmp::Ordering::Less => {
// do nothing if the current index is less than
// the array index.
}
std::cmp::Ordering::Equal => {
let mask = 2u8.pow(mask_power as u32) - 1;
// mask the byte
*byte &= mask;
}
std::cmp::Ordering::Greater => {
// Anything greater than the array index
// will be set to zero
*byte = 0;
}
}
}
// reverse back to little endian
bytes.reverse();
}

// For pretty printing powers
fn superscript(n: u64) -> String {
if n == 0 {
Expand Down Expand Up @@ -495,19 +427,6 @@ fn superscript(n: u64) -> String {
mod tests {
use super::{AcirField, FieldElement};

#[test]
fn and() {
let max = 10_000u32;

let num_bits = (std::mem::size_of::<u32>() * 8) as u32 - max.leading_zeros();

for x in 0..max {
let x = FieldElement::<ark_bn254::Fr>::from(x as i128);
let res = x.and(&x, num_bits);
assert_eq!(res.to_be_bytes(), x.to_be_bytes());
}
}

#[test]
fn serialize_fixed_test_vectors() {
// Serialized field elements from of 0, -1, -2, -3
Expand Down
3 changes: 0 additions & 3 deletions acvm-repo/acir_field/src/generic_ark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,4 @@ pub trait AcirField:
/// Returns the closest number of bytes to the bits specified
/// This method truncates
fn fetch_nearest_bytes(&self, num_bits: usize) -> Vec<u8>;

fn and(&self, rhs: &Self, num_bits: u32) -> Self;
fn xor(&self, rhs: &Self, num_bits: u32) -> Self;
}
9 changes: 5 additions & 4 deletions acvm-repo/acvm/src/pwg/blackbox/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use acir::{
native_types::{Witness, WitnessMap},
AcirField,
};
use acvm_blackbox_solver::{bit_and, bit_xor};

/// Solves a [`BlackBoxFunc::And`][acir::circuit::black_box_functions::BlackBoxFunc::AND] opcode and inserts
/// the result into the supplied witness map
Expand All @@ -19,7 +20,7 @@ pub(super) fn and<F: AcirField>(
"number of bits specified for each input must be the same"
);
solve_logic_opcode(initial_witness, &lhs.witness, &rhs.witness, *output, |left, right| {
left.and(right, lhs.num_bits)
bit_and(left, right, lhs.num_bits)
})
}

Expand All @@ -36,7 +37,7 @@ pub(super) fn xor<F: AcirField>(
"number of bits specified for each input must be the same"
);
solve_logic_opcode(initial_witness, &lhs.witness, &rhs.witness, *output, |left, right| {
left.xor(right, lhs.num_bits)
bit_xor(left, right, lhs.num_bits)
})
}

Expand All @@ -46,11 +47,11 @@ fn solve_logic_opcode<F: AcirField>(
a: &Witness,
b: &Witness,
result: Witness,
logic_op: impl Fn(&F, &F) -> F,
logic_op: impl Fn(F, F) -> F,
) -> Result<(), OpcodeResolutionError<F>> {
let w_l_value = witness_to_value(initial_witness, *a)?;
let w_r_value = witness_to_value(initial_witness, *b)?;
let assignment = logic_op(w_l_value, w_r_value);
let assignment = logic_op(*w_l_value, *w_r_value);

insert_value(&result, assignment, initial_witness)
}
4 changes: 2 additions & 2 deletions acvm-repo/acvm_js/src/black_box_solvers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use acvm::{acir::AcirField, FieldElement};
pub fn and(lhs: JsString, rhs: JsString) -> JsString {
let lhs = js_value_to_field_element(lhs.into()).unwrap();
let rhs = js_value_to_field_element(rhs.into()).unwrap();
let result = lhs.and(&rhs, FieldElement::max_num_bits());
let result = acvm::blackbox_solver::bit_and(lhs, rhs, FieldElement::max_num_bits());
field_element_to_js_string(&result)
}

Expand All @@ -18,7 +18,7 @@ pub fn and(lhs: JsString, rhs: JsString) -> JsString {
pub fn xor(lhs: JsString, rhs: JsString) -> JsString {
let lhs = js_value_to_field_element(lhs.into()).unwrap();
let rhs = js_value_to_field_element(rhs.into()).unwrap();
let result = lhs.xor(&rhs, FieldElement::max_num_bits());
let result = acvm::blackbox_solver::bit_xor(lhs, rhs, FieldElement::max_num_bits());
field_element_to_js_string(&result)
}

Expand Down
2 changes: 2 additions & 0 deletions acvm-repo/blackbox_solver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ mod bigint;
mod curve_specific_solver;
mod ecdsa;
mod hash;
mod logic;

pub use aes128::aes128_encrypt;
pub use bigint::BigIntSolver;
pub use curve_specific_solver::{BlackBoxFunctionSolver, StubbedBlackBoxSolver};
pub use ecdsa::{ecdsa_secp256k1_verify, ecdsa_secp256r1_verify};
pub use hash::{blake2s, blake3, keccak256, keccakf1600, sha256, sha256compression};
pub use logic::{bit_and, bit_xor};

#[derive(Clone, PartialEq, Eq, Debug, Error)]
pub enum BlackBoxResolutionError {
Expand Down
87 changes: 87 additions & 0 deletions acvm-repo/blackbox_solver/src/logic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use acir::AcirField;

pub fn bit_and<F: AcirField>(lhs: F, rhs: F, num_bits: u32) -> F {
bitwise_op(lhs, rhs, num_bits, |lhs_byte, rhs_byte| lhs_byte & rhs_byte)
}

pub fn bit_xor<F: AcirField>(lhs: F, rhs: F, num_bits: u32) -> F {
bitwise_op(lhs, rhs, num_bits, |lhs_byte, rhs_byte| lhs_byte ^ rhs_byte)
}

fn bitwise_op<F: AcirField>(lhs: F, rhs: F, num_bits: u32, op: fn(u8, u8) -> u8) -> F {
// XXX: Gadgets like SHA256 need to have their input be a multiple of 8
// This is not a restriction caused by SHA256, as it works on bits
// but most backends assume bytes.
// We could implicitly pad, however this may not be intuitive for users.
// assert!(
// num_bits % 8 == 0,
// "num_bits is not a multiple of 8, it is {}",
// num_bits
// );

let lhs_bytes = mask_to_be_bytes(lhs, num_bits);
let rhs_bytes = mask_to_be_bytes(rhs, num_bits);

let and_byte_arr: Vec<_> =
lhs_bytes.into_iter().zip(rhs_bytes).map(|(left, right)| op(left, right)).collect();

F::from_be_bytes_reduce(&and_byte_arr)
}

// mask_to methods will not remove any bytes from the field
// they are simply zeroed out
// Whereas truncate_to will remove those bits and make the byte array smaller
fn mask_to_be_bytes<F: AcirField>(field: F, num_bits: u32) -> Vec<u8> {
let mut bytes = field.to_be_bytes();
mask_vector_le(&mut bytes, num_bits as usize);
bytes
}

fn mask_vector_le(bytes: &mut [u8], num_bits: usize) {
// reverse to big endian format
bytes.reverse();

let mask_power = num_bits % 8;
let array_mask_index = num_bits / 8;

for (index, byte) in bytes.iter_mut().enumerate() {
match index.cmp(&array_mask_index) {
std::cmp::Ordering::Less => {
// do nothing if the current index is less than
// the array index.
}
std::cmp::Ordering::Equal => {
let mask = 2u8.pow(mask_power as u32) - 1;
// mask the byte
*byte &= mask;
}
std::cmp::Ordering::Greater => {
// Anything greater than the array index
// will be set to zero
*byte = 0;
}
}
}
// reverse back to little endian
bytes.reverse();
}

#[cfg(test)]
mod tests {
use acir::{AcirField, FieldElement};

use crate::bit_and;

#[test]
fn and() {
let max = 10_000u32;

let num_bits = (std::mem::size_of::<u32>() * 8) as u32 - max.leading_zeros();

for x in 0..max {
let x = FieldElement::from(x as i128);
let res = bit_and(x, x, num_bits);
assert_eq!(res.to_be_bytes(), x.to_be_bytes());
}
}
}

0 comments on commit b1298b8

Please sign in to comment.