Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Signed integer division and modulus in brillig gen #5279

Merged
merged 9 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub fn brillig_to_avm(brillig: &Brillig) -> Vec<u8> {
BinaryIntOp::Add => AvmOpcode::ADD,
BinaryIntOp::Sub => AvmOpcode::SUB,
BinaryIntOp::Mul => AvmOpcode::MUL,
BinaryIntOp::UnsignedDiv => AvmOpcode::DIV,
BinaryIntOp::Div => AvmOpcode::DIV,
BinaryIntOp::Equals => AvmOpcode::EQ,
BinaryIntOp::LessThan => AvmOpcode::LT,
BinaryIntOp::LessThanEquals => AvmOpcode::LTE,
Expand Down
74 changes: 13 additions & 61 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,10 @@ struct BinaryIntOp {
static Mul bincodeDeserialize(std::vector<uint8_t>);
};

struct SignedDiv {
friend bool operator==(const SignedDiv&, const SignedDiv&);
std::vector<uint8_t> bincodeSerialize() const;
static SignedDiv bincodeDeserialize(std::vector<uint8_t>);
};

struct UnsignedDiv {
friend bool operator==(const UnsignedDiv&, const UnsignedDiv&);
struct Div {
friend bool operator==(const Div&, const Div&);
std::vector<uint8_t> bincodeSerialize() const;
static UnsignedDiv bincodeDeserialize(std::vector<uint8_t>);
static Div bincodeDeserialize(std::vector<uint8_t>);
};

struct Equals {
Expand Down Expand Up @@ -142,7 +136,7 @@ struct BinaryIntOp {
static Shr bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Add, Sub, Mul, SignedDiv, UnsignedDiv, Equals, LessThan, LessThanEquals, And, Or, Xor, Shl, Shr> value;
std::variant<Add, Sub, Mul, Div, Equals, LessThan, LessThanEquals, And, Or, Xor, Shl, Shr> value;

friend bool operator==(const BinaryIntOp&, const BinaryIntOp&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -1768,63 +1762,22 @@ Circuit::BinaryIntOp::Mul serde::Deserializable<Circuit::BinaryIntOp::Mul>::dese

namespace Circuit {

inline bool operator==(const BinaryIntOp::SignedDiv& lhs, const BinaryIntOp::SignedDiv& rhs)
{
return true;
}

inline std::vector<uint8_t> BinaryIntOp::SignedDiv::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BinaryIntOp::SignedDiv>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BinaryIntOp::SignedDiv BinaryIntOp::SignedDiv::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BinaryIntOp::SignedDiv>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BinaryIntOp::SignedDiv>::serialize(const Circuit::BinaryIntOp::SignedDiv& obj,
Serializer& serializer)
{}

template <>
template <typename Deserializer>
Circuit::BinaryIntOp::SignedDiv serde::Deserializable<Circuit::BinaryIntOp::SignedDiv>::deserialize(
Deserializer& deserializer)
{
Circuit::BinaryIntOp::SignedDiv obj;
return obj;
}

namespace Circuit {

inline bool operator==(const BinaryIntOp::UnsignedDiv& lhs, const BinaryIntOp::UnsignedDiv& rhs)
inline bool operator==(const BinaryIntOp::Div& lhs, const BinaryIntOp::Div& rhs)
{
return true;
}

inline std::vector<uint8_t> BinaryIntOp::UnsignedDiv::bincodeSerialize() const
inline std::vector<uint8_t> BinaryIntOp::Div::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BinaryIntOp::UnsignedDiv>::serialize(*this, serializer);
serde::Serializable<BinaryIntOp::Div>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BinaryIntOp::UnsignedDiv BinaryIntOp::UnsignedDiv::bincodeDeserialize(std::vector<uint8_t> input)
inline BinaryIntOp::Div BinaryIntOp::Div::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BinaryIntOp::UnsignedDiv>::deserialize(deserializer);
auto value = serde::Deserializable<BinaryIntOp::Div>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
Expand All @@ -1835,16 +1788,15 @@ inline BinaryIntOp::UnsignedDiv BinaryIntOp::UnsignedDiv::bincodeDeserialize(std

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BinaryIntOp::UnsignedDiv>::serialize(const Circuit::BinaryIntOp::UnsignedDiv& obj,
Serializer& serializer)
void serde::Serializable<Circuit::BinaryIntOp::Div>::serialize(const Circuit::BinaryIntOp::Div& obj,
Serializer& serializer)
{}

template <>
template <typename Deserializer>
Circuit::BinaryIntOp::UnsignedDiv serde::Deserializable<Circuit::BinaryIntOp::UnsignedDiv>::deserialize(
Deserializer& deserializer)
Circuit::BinaryIntOp::Div serde::Deserializable<Circuit::BinaryIntOp::Div>::deserialize(Deserializer& deserializer)
{
Circuit::BinaryIntOp::UnsignedDiv obj;
Circuit::BinaryIntOp::Div obj;
return obj;
}

Expand Down
65 changes: 12 additions & 53 deletions noir/noir-repo/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,10 @@ namespace Circuit {
static Mul bincodeDeserialize(std::vector<uint8_t>);
};

struct SignedDiv {
friend bool operator==(const SignedDiv&, const SignedDiv&);
std::vector<uint8_t> bincodeSerialize() const;
static SignedDiv bincodeDeserialize(std::vector<uint8_t>);
};

struct UnsignedDiv {
friend bool operator==(const UnsignedDiv&, const UnsignedDiv&);
struct Div {
friend bool operator==(const Div&, const Div&);
std::vector<uint8_t> bincodeSerialize() const;
static UnsignedDiv bincodeDeserialize(std::vector<uint8_t>);
static Div bincodeDeserialize(std::vector<uint8_t>);
};

struct Equals {
Expand Down Expand Up @@ -142,7 +136,7 @@ namespace Circuit {
static Shr bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Add, Sub, Mul, SignedDiv, UnsignedDiv, Equals, LessThan, LessThanEquals, And, Or, Xor, Shl, Shr> value;
std::variant<Add, Sub, Mul, Div, Equals, LessThan, LessThanEquals, And, Or, Xor, Shl, Shr> value;

friend bool operator==(const BinaryIntOp&, const BinaryIntOp&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -1634,54 +1628,19 @@ Circuit::BinaryIntOp::Mul serde::Deserializable<Circuit::BinaryIntOp::Mul>::dese

namespace Circuit {

inline bool operator==(const BinaryIntOp::SignedDiv &lhs, const BinaryIntOp::SignedDiv &rhs) {
return true;
}

inline std::vector<uint8_t> BinaryIntOp::SignedDiv::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BinaryIntOp::SignedDiv>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BinaryIntOp::SignedDiv BinaryIntOp::SignedDiv::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BinaryIntOp::SignedDiv>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BinaryIntOp::SignedDiv>::serialize(const Circuit::BinaryIntOp::SignedDiv &obj, Serializer &serializer) {
}

template <>
template <typename Deserializer>
Circuit::BinaryIntOp::SignedDiv serde::Deserializable<Circuit::BinaryIntOp::SignedDiv>::deserialize(Deserializer &deserializer) {
Circuit::BinaryIntOp::SignedDiv obj;
return obj;
}

namespace Circuit {

inline bool operator==(const BinaryIntOp::UnsignedDiv &lhs, const BinaryIntOp::UnsignedDiv &rhs) {
inline bool operator==(const BinaryIntOp::Div &lhs, const BinaryIntOp::Div &rhs) {
return true;
}

inline std::vector<uint8_t> BinaryIntOp::UnsignedDiv::bincodeSerialize() const {
inline std::vector<uint8_t> BinaryIntOp::Div::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BinaryIntOp::UnsignedDiv>::serialize(*this, serializer);
serde::Serializable<BinaryIntOp::Div>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BinaryIntOp::UnsignedDiv BinaryIntOp::UnsignedDiv::bincodeDeserialize(std::vector<uint8_t> input) {
inline BinaryIntOp::Div BinaryIntOp::Div::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BinaryIntOp::UnsignedDiv>::deserialize(deserializer);
auto value = serde::Deserializable<BinaryIntOp::Div>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
Expand All @@ -1692,13 +1651,13 @@ namespace Circuit {

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BinaryIntOp::UnsignedDiv>::serialize(const Circuit::BinaryIntOp::UnsignedDiv &obj, Serializer &serializer) {
void serde::Serializable<Circuit::BinaryIntOp::Div>::serialize(const Circuit::BinaryIntOp::Div &obj, Serializer &serializer) {
}

template <>
template <typename Deserializer>
Circuit::BinaryIntOp::UnsignedDiv serde::Deserializable<Circuit::BinaryIntOp::UnsignedDiv>::deserialize(Deserializer &deserializer) {
Circuit::BinaryIntOp::UnsignedDiv obj;
Circuit::BinaryIntOp::Div serde::Deserializable<Circuit::BinaryIntOp::Div>::deserialize(Deserializer &deserializer) {
Circuit::BinaryIntOp::Div obj;
return obj;
}

Expand Down
3 changes: 1 addition & 2 deletions noir/noir-repo/acvm-repo/brillig/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,7 @@ pub enum BinaryIntOp {
Add,
Sub,
Mul,
SignedDiv,
UnsignedDiv,
Div,
/// (==) equal
Equals,
/// (<) Field less than
Expand Down
71 changes: 3 additions & 68 deletions noir/noir-repo/acvm-repo/brillig_vm/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use acir::brillig::{BinaryFieldOp, BinaryIntOp};
use acir::FieldElement;
use num_bigint::{BigInt, BigUint};
use num_bigint::BigUint;
use num_traits::{One, ToPrimitive, Zero};

/// Evaluate a binary operation on two FieldElements and return the result as a FieldElement.
Expand Down Expand Up @@ -42,24 +42,14 @@ pub(crate) fn evaluate_binary_bigint_op(
BinaryIntOp::Sub => (bit_modulo + a - b) % bit_modulo,
BinaryIntOp::Mul => (a * b) % bit_modulo,
// Perform unsigned division using the modulo operation on a and b.
BinaryIntOp::UnsignedDiv => {
BinaryIntOp::Div => {
let b_mod = b % bit_modulo;
if b_mod.is_zero() {
BigUint::zero()
} else {
(a % bit_modulo) / b_mod
}
}
// Perform signed division by first converting a and b to signed integers and then back to unsigned after the operation.
BinaryIntOp::SignedDiv => {
let b_signed = to_big_signed(b, bit_size);
if b_signed.is_zero() {
BigUint::zero()
} else {
let signed_div = to_big_signed(a, bit_size) / b_signed;
to_big_unsigned(signed_div, bit_size)
}
}
// Perform a == operation, returning 0 or 1
BinaryIntOp::Equals => {
if (a % bit_modulo) == (b % bit_modulo) {
Expand Down Expand Up @@ -103,23 +93,6 @@ pub(crate) fn evaluate_binary_bigint_op(
Ok(result)
}

fn to_big_signed(a: BigUint, bit_size: u32) -> BigInt {
let pow_2 = BigUint::from(2_u32).pow(bit_size - 1);
if a < pow_2 {
BigInt::from(a)
} else {
BigInt::from(a) - 2 * BigInt::from(pow_2)
}
}

fn to_big_unsigned(a: BigInt, bit_size: u32) -> BigUint {
if a >= BigInt::zero() {
BigUint::from_bytes_le(&a.to_bytes_le().1)
} else {
BigUint::from(2_u32).pow(bit_size) - BigUint::from_bytes_le(&a.to_bytes_le().1)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -139,24 +112,6 @@ mod tests {
result_value.to_u128().unwrap()
}

fn to_signed(a: u128, bit_size: u32) -> i128 {
assert!(bit_size < 128);
let pow_2 = 2_u128.pow(bit_size - 1);
if a < pow_2 {
a as i128
} else {
(a.wrapping_sub(2 * pow_2)) as i128
}
}

fn to_unsigned(a: i128, bit_size: u32) -> u128 {
if a >= 0 {
a as u128
} else {
(a + 2_i128.pow(bit_size)) as u128
}
}

fn to_negative(a: u128, bit_size: u32) -> u128 {
assert!(a > 0);
let two_pow = 2_u128.pow(bit_size);
Expand Down Expand Up @@ -233,26 +188,6 @@ mod tests {
let test_ops =
vec![TestParams { a: 5, b: 3, result: 1 }, TestParams { a: 5, b: 10, result: 0 }];

evaluate_int_ops(test_ops, BinaryIntOp::UnsignedDiv, bit_size);
}

#[test]
fn to_signed_roundtrip() {
let bit_size = 32;
let minus_one = 2_u128.pow(bit_size) - 1;
assert_eq!(to_unsigned(to_signed(minus_one, bit_size), bit_size), minus_one);
}

#[test]
fn signed_div_test() {
let bit_size = 32;

let test_ops = vec![
TestParams { a: 5, b: to_negative(10, bit_size), result: 0 },
TestParams { a: 5, b: to_negative(1, bit_size), result: to_negative(5, bit_size) },
TestParams { a: to_negative(5, bit_size), b: to_negative(1, bit_size), result: 5 },
];

evaluate_int_ops(test_ops, BinaryIntOp::SignedDiv, bit_size);
evaluate_int_ops(test_ops, BinaryIntOp::Div, bit_size);
}
}
Loading