Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: move implementation of bitwise operations into blackbox_solver #5209

Merged
merged 2 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 0 additions & 81 deletions acvm-repo/acir_field/src/field_element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,6 @@ impl<F: PrimeField> FieldElement<F> {
Some(FieldElement(fr))
}

// mask_to methods will not remove any bytes from the field
// they are simply zeroed out
// Whereas truncate_to will remove those bits and make the byte array smaller
fn mask_to_be_bytes(&self, num_bits: u32) -> Vec<u8> {
let mut bytes = self.to_be_bytes();
mask_vector_le(&mut bytes, num_bits as usize);
bytes
}

fn bits(&self) -> Vec<bool> {
fn byte_to_bit(byte: u8) -> Vec<bool> {
let mut bits = Vec::with_capacity(8);
Expand All @@ -220,29 +211,6 @@ impl<F: PrimeField> FieldElement<F> {
}
bits
}

fn and_xor(&self, rhs: &FieldElement<F>, num_bits: u32, is_xor: bool) -> FieldElement<F> {
// XXX: Gadgets like SHA256 need to have their input be a multiple of 8
// This is not a restriction caused by SHA256, as it works on bits
// but most backends assume bytes.
// We could implicitly pad, however this may not be intuitive for users.
// assert!(
// num_bits % 8 == 0,
// "num_bits is not a multiple of 8, it is {}",
// num_bits
// );

let lhs_bytes = self.mask_to_be_bytes(num_bits);
let rhs_bytes = rhs.mask_to_be_bytes(num_bits);

let and_byte_arr: Vec<_> = lhs_bytes
.into_iter()
.zip(rhs_bytes)
.map(|(lhs, rhs)| if is_xor { lhs ^ rhs } else { lhs & rhs })
.collect();

FieldElement::from_be_bytes_reduce(&and_byte_arr)
}
}

impl<F: PrimeField> AcirField for FieldElement<F> {
Expand Down Expand Up @@ -376,13 +344,6 @@ impl<F: PrimeField> AcirField for FieldElement<F> {

bytes[0..num_elements].to_vec()
}

fn and(&self, rhs: &FieldElement<F>, num_bits: u32) -> FieldElement<F> {
self.and_xor(rhs, num_bits, false)
}
fn xor(&self, rhs: &FieldElement<F>, num_bits: u32) -> FieldElement<F> {
self.and_xor(rhs, num_bits, true)
}
}

impl<F: PrimeField> Neg for FieldElement<F> {
Expand Down Expand Up @@ -433,35 +394,6 @@ impl<F: PrimeField> SubAssign for FieldElement<F> {
}
}

fn mask_vector_le(bytes: &mut [u8], num_bits: usize) {
// reverse to big endian format
bytes.reverse();

let mask_power = num_bits % 8;
let array_mask_index = num_bits / 8;

for (index, byte) in bytes.iter_mut().enumerate() {
match index.cmp(&array_mask_index) {
std::cmp::Ordering::Less => {
// do nothing if the current index is less than
// the array index.
}
std::cmp::Ordering::Equal => {
let mask = 2u8.pow(mask_power as u32) - 1;
// mask the byte
*byte &= mask;
}
std::cmp::Ordering::Greater => {
// Anything greater than the array index
// will be set to zero
*byte = 0;
}
}
}
// reverse back to little endian
bytes.reverse();
}

// For pretty printing powers
fn superscript(n: u64) -> String {
if n == 0 {
Expand Down Expand Up @@ -495,19 +427,6 @@ fn superscript(n: u64) -> String {
mod tests {
use super::{AcirField, FieldElement};

#[test]
fn and() {
let max = 10_000u32;

let num_bits = (std::mem::size_of::<u32>() * 8) as u32 - max.leading_zeros();

for x in 0..max {
let x = FieldElement::<ark_bn254::Fr>::from(x as i128);
let res = x.and(&x, num_bits);
assert_eq!(res.to_be_bytes(), x.to_be_bytes());
}
}

#[test]
fn serialize_fixed_test_vectors() {
// Serialized field elements from of 0, -1, -2, -3
Expand Down
3 changes: 0 additions & 3 deletions acvm-repo/acir_field/src/generic_ark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,4 @@ pub trait AcirField:
/// Returns the closest number of bytes to the bits specified
/// This method truncates
fn fetch_nearest_bytes(&self, num_bits: usize) -> Vec<u8>;

fn and(&self, rhs: &Self, num_bits: u32) -> Self;
fn xor(&self, rhs: &Self, num_bits: u32) -> Self;
}
9 changes: 5 additions & 4 deletions acvm-repo/acvm/src/pwg/blackbox/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use acir::{
native_types::{Witness, WitnessMap},
AcirField,
};
use acvm_blackbox_solver::{bit_and, bit_xor};

/// Solves a [`BlackBoxFunc::And`][acir::circuit::black_box_functions::BlackBoxFunc::AND] opcode and inserts
/// the result into the supplied witness map
Expand All @@ -19,7 +20,7 @@ pub(super) fn and<F: AcirField>(
"number of bits specified for each input must be the same"
);
solve_logic_opcode(initial_witness, &lhs.witness, &rhs.witness, *output, |left, right| {
left.and(right, lhs.num_bits)
bit_and(left, right, lhs.num_bits)
})
}

Expand All @@ -36,7 +37,7 @@ pub(super) fn xor<F: AcirField>(
"number of bits specified for each input must be the same"
);
solve_logic_opcode(initial_witness, &lhs.witness, &rhs.witness, *output, |left, right| {
left.xor(right, lhs.num_bits)
bit_xor(left, right, lhs.num_bits)
})
}

Expand All @@ -46,11 +47,11 @@ fn solve_logic_opcode<F: AcirField>(
a: &Witness,
b: &Witness,
result: Witness,
logic_op: impl Fn(&F, &F) -> F,
logic_op: impl Fn(F, F) -> F,
) -> Result<(), OpcodeResolutionError<F>> {
let w_l_value = witness_to_value(initial_witness, *a)?;
let w_r_value = witness_to_value(initial_witness, *b)?;
let assignment = logic_op(w_l_value, w_r_value);
let assignment = logic_op(*w_l_value, *w_r_value);

insert_value(&result, assignment, initial_witness)
}
4 changes: 2 additions & 2 deletions acvm-repo/acvm_js/src/black_box_solvers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use acvm::{acir::AcirField, FieldElement};
pub fn and(lhs: JsString, rhs: JsString) -> JsString {
let lhs = js_value_to_field_element(lhs.into()).unwrap();
let rhs = js_value_to_field_element(rhs.into()).unwrap();
let result = lhs.and(&rhs, FieldElement::max_num_bits());
let result = acvm::blackbox_solver::bit_and(lhs, rhs, FieldElement::max_num_bits());
field_element_to_js_string(&result)
}

Expand All @@ -18,7 +18,7 @@ pub fn and(lhs: JsString, rhs: JsString) -> JsString {
pub fn xor(lhs: JsString, rhs: JsString) -> JsString {
let lhs = js_value_to_field_element(lhs.into()).unwrap();
let rhs = js_value_to_field_element(rhs.into()).unwrap();
let result = lhs.xor(&rhs, FieldElement::max_num_bits());
let result = acvm::blackbox_solver::bit_xor(lhs, rhs, FieldElement::max_num_bits());
field_element_to_js_string(&result)
}

Expand Down
2 changes: 2 additions & 0 deletions acvm-repo/blackbox_solver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ mod bigint;
mod curve_specific_solver;
mod ecdsa;
mod hash;
mod logic;

pub use aes128::aes128_encrypt;
pub use bigint::BigIntSolver;
pub use curve_specific_solver::{BlackBoxFunctionSolver, StubbedBlackBoxSolver};
pub use ecdsa::{ecdsa_secp256k1_verify, ecdsa_secp256r1_verify};
pub use hash::{blake2s, blake3, keccak256, keccakf1600, sha256, sha256compression};
pub use logic::{bit_and, bit_xor};

#[derive(Clone, PartialEq, Eq, Debug, Error)]
pub enum BlackBoxResolutionError {
Expand Down
87 changes: 87 additions & 0 deletions acvm-repo/blackbox_solver/src/logic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use acir::AcirField;

pub fn bit_and<F: AcirField>(lhs: F, rhs: F, num_bits: u32) -> F {
bitwise_op(lhs, rhs, num_bits, |lhs_byte, rhs_byte| lhs_byte & rhs_byte)
}

pub fn bit_xor<F: AcirField>(lhs: F, rhs: F, num_bits: u32) -> F {
bitwise_op(lhs, rhs, num_bits, |lhs_byte, rhs_byte| lhs_byte ^ rhs_byte)
}

fn bitwise_op<F: AcirField>(lhs: F, rhs: F, num_bits: u32, op: fn(u8, u8) -> u8) -> F {
// XXX: Gadgets like SHA256 need to have their input be a multiple of 8
// This is not a restriction caused by SHA256, as it works on bits
// but most backends assume bytes.
// We could implicitly pad, however this may not be intuitive for users.
// assert!(
// num_bits % 8 == 0,
// "num_bits is not a multiple of 8, it is {}",
// num_bits
// );

let lhs_bytes = mask_to_be_bytes(lhs, num_bits);
let rhs_bytes = mask_to_be_bytes(rhs, num_bits);

let and_byte_arr: Vec<_> =
lhs_bytes.into_iter().zip(rhs_bytes).map(|(left, right)| op(left, right)).collect();

F::from_be_bytes_reduce(&and_byte_arr)
}

// mask_to methods will not remove any bytes from the field
// they are simply zeroed out
// Whereas truncate_to will remove those bits and make the byte array smaller
fn mask_to_be_bytes<F: AcirField>(field: F, num_bits: u32) -> Vec<u8> {
let mut bytes = field.to_be_bytes();
mask_vector_le(&mut bytes, num_bits as usize);
bytes
}

fn mask_vector_le(bytes: &mut [u8], num_bits: usize) {
// reverse to big endian format
bytes.reverse();

let mask_power = num_bits % 8;
let array_mask_index = num_bits / 8;

for (index, byte) in bytes.iter_mut().enumerate() {
match index.cmp(&array_mask_index) {
std::cmp::Ordering::Less => {
// do nothing if the current index is less than
// the array index.
}
std::cmp::Ordering::Equal => {
let mask = 2u8.pow(mask_power as u32) - 1;
// mask the byte
*byte &= mask;
}
std::cmp::Ordering::Greater => {
// Anything greater than the array index
// will be set to zero
*byte = 0;
}
}
}
// reverse back to little endian
bytes.reverse();
}

#[cfg(test)]
mod tests {
use acir::{AcirField, FieldElement};

use crate::bit_and;

#[test]
fn and() {
let max = 10_000u32;

let num_bits = (std::mem::size_of::<u32>() * 8) as u32 - max.leading_zeros();

for x in 0..max {
let x = FieldElement::from(x as i128);
let res = bit_and(x, x, num_bits);
assert_eq!(res.to_be_bytes(), x.to_be_bytes());
}
}
}
Loading