diff --git a/brillig_vm/Cargo.toml b/brillig_vm/Cargo.toml index 38974eff6..3b774ba06 100644 --- a/brillig_vm/Cargo.toml +++ b/brillig_vm/Cargo.toml @@ -13,6 +13,8 @@ repository.workspace = true [dependencies] acir.workspace = true blackbox_solver.workspace = true +num-bigint.workspace = true +num-traits.workspace = true [features] default = ["bn254"] diff --git a/brillig_vm/src/arithmetic.rs b/brillig_vm/src/arithmetic.rs index 5f8cb3092..51ab86604 100644 --- a/brillig_vm/src/arithmetic.rs +++ b/brillig_vm/src/arithmetic.rs @@ -1,5 +1,7 @@ use acir::brillig::{BinaryFieldOp, BinaryIntOp}; use acir::FieldElement; +use num_bigint::{BigInt, BigUint}; +use num_traits::{One, ToPrimitive, Zero}; /// Evaluate a binary operation on two FieldElements and return the result as a FieldElement. pub(crate) fn evaluate_binary_field_op( @@ -17,50 +19,81 @@ pub(crate) fn evaluate_binary_field_op( } } -/// Evaluate a binary operation on two unsigned integers (u128) with a given bit size and return the result as a u128. -pub(crate) fn evaluate_binary_int_op(op: &BinaryIntOp, a: u128, b: u128, bit_size: u32) -> u128 { - let bit_modulo = 1_u128 << bit_size; +/// Evaluate a binary operation on two unsigned big integers with a given bit size and return the result as a big integer. +pub(crate) fn evaluate_binary_bigint_op( + op: &BinaryIntOp, + a: BigUint, + b: BigUint, + bit_size: u32, +) -> BigUint { + let bit_modulo = &(BigUint::one() << bit_size); match op { // Perform addition, subtraction, and multiplication, applying a modulo operation to keep the result within the bit size. - BinaryIntOp::Add => a.wrapping_add(b) % bit_modulo, - BinaryIntOp::Sub => a.wrapping_sub(b) % bit_modulo, - BinaryIntOp::Mul => a.wrapping_mul(b) % bit_modulo, + BinaryIntOp::Add => (a + b) % bit_modulo, + BinaryIntOp::Sub => (bit_modulo + a - b) % bit_modulo, + BinaryIntOp::Mul => (a * b) % bit_modulo, // Perform unsigned division using the modulo operation on a and b. BinaryIntOp::UnsignedDiv => (a % bit_modulo) / (b % bit_modulo), // Perform signed division by first converting a and b to signed integers and then back to unsigned after the operation. BinaryIntOp::SignedDiv => { - to_unsigned(to_signed(a, bit_size) / to_signed(b, bit_size), bit_size) + let signed_div = to_big_signed(a, bit_size) / to_big_signed(b, bit_size); + to_big_unsigned(signed_div, bit_size) } // Perform a == operation, returning 0 or 1 - BinaryIntOp::Equals => ((a % bit_modulo) == (b % bit_modulo)).into(), + BinaryIntOp::Equals => { + if (a % bit_modulo) == (b % bit_modulo) { + BigUint::one() + } else { + BigUint::zero() + } + } // Perform a < operation, returning 0 or 1 - BinaryIntOp::LessThan => ((a % bit_modulo) < (b % bit_modulo)).into(), + BinaryIntOp::LessThan => { + if (a % bit_modulo) < (b % bit_modulo) { + BigUint::one() + } else { + BigUint::zero() + } + } // Perform a <= operation, returning 0 or 1 - BinaryIntOp::LessThanEquals => ((a % bit_modulo) <= (b % bit_modulo)).into(), + BinaryIntOp::LessThanEquals => { + if (a % bit_modulo) <= (b % bit_modulo) { + BigUint::one() + } else { + BigUint::zero() + } + } // Perform bitwise AND, OR, XOR, left shift, and right shift operations, applying a modulo operation to keep the result within the bit size. BinaryIntOp::And => (a & b) % bit_modulo, BinaryIntOp::Or => (a | b) % bit_modulo, BinaryIntOp::Xor => (a ^ b) % bit_modulo, - BinaryIntOp::Shl => (a << b) % bit_modulo, - BinaryIntOp::Shr => (a >> b) % bit_modulo, + BinaryIntOp::Shl => { + assert!(bit_size <= 128, "unsupported bit size for right shift"); + let b = b.to_u128().unwrap(); + (a << b) % bit_modulo + } + BinaryIntOp::Shr => { + assert!(bit_size <= 128, "unsupported bit size for right shift"); + let b = b.to_u128().unwrap(); + (a >> b) % bit_modulo + } } } -fn to_signed(a: u128, bit_size: u32) -> i128 { - assert!(bit_size < 128); - let pow_2 = 2_u128.pow(bit_size - 1); +fn to_big_signed(a: BigUint, bit_size: u32) -> BigInt { + let pow_2 = BigUint::from(2_u32).pow(bit_size - 1); if a < pow_2 { - a as i128 + BigInt::from(a) } else { - (a.wrapping_sub(2 * pow_2)) as i128 + BigInt::from(a) - 2 * BigInt::from(pow_2) } } -fn to_unsigned(a: i128, bit_size: u32) -> u128 { - if a >= 0 { - a as u128 +fn to_big_unsigned(a: BigInt, bit_size: u32) -> BigUint { + if a >= BigInt::zero() { + BigUint::from_bytes_le(&a.to_bytes_le().1) } else { - (a + 2_i128.pow(bit_size)) as u128 + BigUint::from(2_u32).pow(bit_size) - BigUint::from_bytes_le(&a.to_bytes_le().1) } } @@ -74,6 +107,33 @@ mod tests { result: u128, } + fn evaluate_u128(op: &BinaryIntOp, a: u128, b: u128, bit_size: u32) -> u128 { + // Convert to big integers + let lhs_big = BigUint::from(a); + let rhs_big = BigUint::from(b); + let result_value = evaluate_binary_bigint_op(op, lhs_big, rhs_big, bit_size); + // Convert back to u128 + result_value.to_u128().unwrap() + } + + fn to_signed(a: u128, bit_size: u32) -> i128 { + assert!(bit_size < 128); + let pow_2 = 2_u128.pow(bit_size - 1); + if a < pow_2 { + a as i128 + } else { + (a.wrapping_sub(2 * pow_2)) as i128 + } + } + + fn to_unsigned(a: i128, bit_size: u32) -> u128 { + if a >= 0 { + a as u128 + } else { + (a + 2_i128.pow(bit_size)) as u128 + } + } + fn to_negative(a: u128, bit_size: u32) -> u128 { assert!(a > 0); let two_pow = 2_u128.pow(bit_size); @@ -82,7 +142,7 @@ mod tests { fn evaluate_int_ops(test_params: Vec, op: BinaryIntOp, bit_size: u32) { for test in test_params { - assert_eq!(evaluate_binary_int_op(&op, test.a, test.b, bit_size), test.result); + assert_eq!(evaluate_u128(&op, test.a, test.b, bit_size), test.result); } } @@ -140,7 +200,7 @@ mod tests { let b = 3; // ( 2**(n-1) - 1 ) * 3 = 2*2**(n-1) - 2 + (2**(n-1) - 1) => wraps to (2**(n-1) - 1) - 2 - assert_eq!(evaluate_binary_int_op(&BinaryIntOp::Mul, a, b, bit_size), a - 2); + assert_eq!(evaluate_u128(&BinaryIntOp::Mul, a, b, bit_size), a - 2); } #[test] diff --git a/brillig_vm/src/lib.rs b/brillig_vm/src/lib.rs index 19d0630e2..d0eb6ac25 100644 --- a/brillig_vm/src/lib.rs +++ b/brillig_vm/src/lib.rs @@ -22,11 +22,12 @@ mod black_box; mod memory; mod registers; -use arithmetic::{evaluate_binary_field_op, evaluate_binary_int_op}; +use arithmetic::{evaluate_binary_bigint_op, evaluate_binary_field_op}; use black_box::evaluate_black_box; use blackbox_solver::{BlackBoxFunctionSolver, BlackBoxResolutionError}; pub use memory::Memory; +use num_bigint::BigUint; pub use registers::Registers; #[derive(Debug, PartialEq, Eq, Clone)] @@ -371,9 +372,13 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> { let lhs_value = self.registers.get(lhs); let rhs_value = self.registers.get(rhs); - let result_value = - evaluate_binary_int_op(&op, lhs_value.to_u128(), rhs_value.to_u128(), bit_size); - self.registers.set(result, result_value.into()); + // Convert to big integers + let lhs_big = BigUint::from_bytes_be(&lhs_value.to_field().to_be_bytes()); + let rhs_big = BigUint::from_bytes_be(&rhs_value.to_field().to_be_bytes()); + let result_value = evaluate_binary_bigint_op(&op, lhs_big, rhs_big, bit_size); + // Convert back to field element + self.registers + .set(result, FieldElement::from_be_bytes_reduce(&result_value.to_bytes_be()).into()); } }