Skip to content

Commit

Permalink
feat: Signed integer division and modulus in brillig gen (AztecProtoc…
Browse files Browse the repository at this point in the history
…ol/aztec-packages#5279)

Removes the signed div opcode since the AVM won't support it and
replaces it with the emulated version using unsigned division.
  • Loading branch information
AztecBot committed Mar 18, 2024
1 parent d4213a0 commit b85095b
Show file tree
Hide file tree
Showing 156 changed files with 3,664 additions and 1,843 deletions.
2 changes: 1 addition & 1 deletion .aztec-sync-commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
aa90f6ed7bfae06bdf6990816d154bbd24993689
82f8cf5eba9deacdab43ad4ef95dbf27dd1c11c7
1 change: 1 addition & 0 deletions .github/scripts/noir-wasm-build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
set -eu

.github/scripts/wasm-pack-install.sh
yarn workspace @noir-lang/types build
yarn workspace @noir-lang/noir_wasm build
2 changes: 1 addition & 1 deletion .github/scripts/wasm-bindgen-install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ cd $(dirname "$0")
./cargo-binstall-install.sh

# Install wasm-bindgen-cli.
if [ "$(wasm-bindgen --version | cut -d' ' -f2)" != "0.2.86" ]; then
if [ "$(wasm-bindgen --version &> /dev/null | cut -d' ' -f2)" != "0.2.86" ]; then
echo "Building wasm-bindgen..."
cargo binstall [email protected] --force --no-confirm
fi
Expand Down
246 changes: 191 additions & 55 deletions acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,31 @@ namespace Circuit {
static Div bincodeDeserialize(std::vector<uint8_t>);
};

struct IntegerDiv {
friend bool operator==(const IntegerDiv&, const IntegerDiv&);
std::vector<uint8_t> bincodeSerialize() const;
static IntegerDiv bincodeDeserialize(std::vector<uint8_t>);
};

struct Equals {
friend bool operator==(const Equals&, const Equals&);
std::vector<uint8_t> bincodeSerialize() const;
static Equals bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Add, Sub, Mul, Div, Equals> value;
struct LessThan {
friend bool operator==(const LessThan&, const LessThan&);
std::vector<uint8_t> bincodeSerialize() const;
static LessThan bincodeDeserialize(std::vector<uint8_t>);
};

struct LessThanEquals {
friend bool operator==(const LessThanEquals&, const LessThanEquals&);
std::vector<uint8_t> bincodeSerialize() const;
static LessThanEquals bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Add, Sub, Mul, Div, IntegerDiv, Equals, LessThan, LessThanEquals> value;

friend bool operator==(const BinaryFieldOp&, const BinaryFieldOp&);
std::vector<uint8_t> bincodeSerialize() const;
Expand All @@ -64,16 +82,10 @@ namespace Circuit {
static Mul bincodeDeserialize(std::vector<uint8_t>);
};

struct SignedDiv {
friend bool operator==(const SignedDiv&, const SignedDiv&);
std::vector<uint8_t> bincodeSerialize() const;
static SignedDiv bincodeDeserialize(std::vector<uint8_t>);
};

struct UnsignedDiv {
friend bool operator==(const UnsignedDiv&, const UnsignedDiv&);
struct Div {
friend bool operator==(const Div&, const Div&);
std::vector<uint8_t> bincodeSerialize() const;
static UnsignedDiv bincodeDeserialize(std::vector<uint8_t>);
static Div bincodeDeserialize(std::vector<uint8_t>);
};

struct Equals {
Expand Down Expand Up @@ -124,7 +136,7 @@ namespace Circuit {
static Shr bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Add, Sub, Mul, SignedDiv, UnsignedDiv, Equals, LessThan, LessThanEquals, And, Or, Xor, Shl, Shr> value;
std::variant<Add, Sub, Mul, Div, Equals, LessThan, LessThanEquals, And, Or, Xor, Shl, Shr> value;

friend bool operator==(const BinaryIntOp&, const BinaryIntOp&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -1053,7 +1065,17 @@ namespace Circuit {
static MemoryInit bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AssertZero, BlackBoxFuncCall, Directive, Brillig, MemoryOp, MemoryInit> value;
struct Call {
uint32_t id;
std::vector<Circuit::Witness> inputs;
std::vector<Circuit::Witness> outputs;

friend bool operator==(const Call&, const Call&);
std::vector<uint8_t> bincodeSerialize() const;
static Call bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AssertZero, BlackBoxFuncCall, Directive, Brillig, MemoryOp, MemoryInit, Call> value;

friend bool operator==(const Opcode&, const Opcode&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -1317,6 +1339,41 @@ Circuit::BinaryFieldOp::Div serde::Deserializable<Circuit::BinaryFieldOp::Div>::
return obj;
}

namespace Circuit {

inline bool operator==(const BinaryFieldOp::IntegerDiv &lhs, const BinaryFieldOp::IntegerDiv &rhs) {
return true;
}

inline std::vector<uint8_t> BinaryFieldOp::IntegerDiv::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BinaryFieldOp::IntegerDiv>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BinaryFieldOp::IntegerDiv BinaryFieldOp::IntegerDiv::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BinaryFieldOp::IntegerDiv>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BinaryFieldOp::IntegerDiv>::serialize(const Circuit::BinaryFieldOp::IntegerDiv &obj, Serializer &serializer) {
}

template <>
template <typename Deserializer>
Circuit::BinaryFieldOp::IntegerDiv serde::Deserializable<Circuit::BinaryFieldOp::IntegerDiv>::deserialize(Deserializer &deserializer) {
Circuit::BinaryFieldOp::IntegerDiv obj;
return obj;
}

namespace Circuit {

inline bool operator==(const BinaryFieldOp::Equals &lhs, const BinaryFieldOp::Equals &rhs) {
Expand Down Expand Up @@ -1352,6 +1409,76 @@ Circuit::BinaryFieldOp::Equals serde::Deserializable<Circuit::BinaryFieldOp::Equ
return obj;
}

namespace Circuit {

inline bool operator==(const BinaryFieldOp::LessThan &lhs, const BinaryFieldOp::LessThan &rhs) {
return true;
}

inline std::vector<uint8_t> BinaryFieldOp::LessThan::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BinaryFieldOp::LessThan>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BinaryFieldOp::LessThan BinaryFieldOp::LessThan::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BinaryFieldOp::LessThan>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BinaryFieldOp::LessThan>::serialize(const Circuit::BinaryFieldOp::LessThan &obj, Serializer &serializer) {
}

template <>
template <typename Deserializer>
Circuit::BinaryFieldOp::LessThan serde::Deserializable<Circuit::BinaryFieldOp::LessThan>::deserialize(Deserializer &deserializer) {
Circuit::BinaryFieldOp::LessThan obj;
return obj;
}

namespace Circuit {

inline bool operator==(const BinaryFieldOp::LessThanEquals &lhs, const BinaryFieldOp::LessThanEquals &rhs) {
return true;
}

inline std::vector<uint8_t> BinaryFieldOp::LessThanEquals::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BinaryFieldOp::LessThanEquals>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BinaryFieldOp::LessThanEquals BinaryFieldOp::LessThanEquals::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BinaryFieldOp::LessThanEquals>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BinaryFieldOp::LessThanEquals>::serialize(const Circuit::BinaryFieldOp::LessThanEquals &obj, Serializer &serializer) {
}

template <>
template <typename Deserializer>
Circuit::BinaryFieldOp::LessThanEquals serde::Deserializable<Circuit::BinaryFieldOp::LessThanEquals>::deserialize(Deserializer &deserializer) {
Circuit::BinaryFieldOp::LessThanEquals obj;
return obj;
}

namespace Circuit {

inline bool operator==(const BinaryIntOp &lhs, const BinaryIntOp &rhs) {
Expand Down Expand Up @@ -1501,19 +1628,19 @@ Circuit::BinaryIntOp::Mul serde::Deserializable<Circuit::BinaryIntOp::Mul>::dese

namespace Circuit {

inline bool operator==(const BinaryIntOp::SignedDiv &lhs, const BinaryIntOp::SignedDiv &rhs) {
inline bool operator==(const BinaryIntOp::Div &lhs, const BinaryIntOp::Div &rhs) {
return true;
}

inline std::vector<uint8_t> BinaryIntOp::SignedDiv::bincodeSerialize() const {
inline std::vector<uint8_t> BinaryIntOp::Div::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BinaryIntOp::SignedDiv>::serialize(*this, serializer);
serde::Serializable<BinaryIntOp::Div>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BinaryIntOp::SignedDiv BinaryIntOp::SignedDiv::bincodeDeserialize(std::vector<uint8_t> input) {
inline BinaryIntOp::Div BinaryIntOp::Div::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BinaryIntOp::SignedDiv>::deserialize(deserializer);
auto value = serde::Deserializable<BinaryIntOp::Div>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
Expand All @@ -1524,48 +1651,13 @@ namespace Circuit {

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BinaryIntOp::SignedDiv>::serialize(const Circuit::BinaryIntOp::SignedDiv &obj, Serializer &serializer) {
void serde::Serializable<Circuit::BinaryIntOp::Div>::serialize(const Circuit::BinaryIntOp::Div &obj, Serializer &serializer) {
}

template <>
template <typename Deserializer>
Circuit::BinaryIntOp::SignedDiv serde::Deserializable<Circuit::BinaryIntOp::SignedDiv>::deserialize(Deserializer &deserializer) {
Circuit::BinaryIntOp::SignedDiv obj;
return obj;
}

namespace Circuit {

inline bool operator==(const BinaryIntOp::UnsignedDiv &lhs, const BinaryIntOp::UnsignedDiv &rhs) {
return true;
}

inline std::vector<uint8_t> BinaryIntOp::UnsignedDiv::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BinaryIntOp::UnsignedDiv>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BinaryIntOp::UnsignedDiv BinaryIntOp::UnsignedDiv::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BinaryIntOp::UnsignedDiv>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BinaryIntOp::UnsignedDiv>::serialize(const Circuit::BinaryIntOp::UnsignedDiv &obj, Serializer &serializer) {
}

template <>
template <typename Deserializer>
Circuit::BinaryIntOp::UnsignedDiv serde::Deserializable<Circuit::BinaryIntOp::UnsignedDiv>::deserialize(Deserializer &deserializer) {
Circuit::BinaryIntOp::UnsignedDiv obj;
Circuit::BinaryIntOp::Div serde::Deserializable<Circuit::BinaryIntOp::Div>::deserialize(Deserializer &deserializer) {
Circuit::BinaryIntOp::Div obj;
return obj;
}

Expand Down Expand Up @@ -6012,6 +6104,50 @@ Circuit::Opcode::MemoryInit serde::Deserializable<Circuit::Opcode::MemoryInit>::
return obj;
}

namespace Circuit {

inline bool operator==(const Opcode::Call &lhs, const Opcode::Call &rhs) {
if (!(lhs.id == rhs.id)) { return false; }
if (!(lhs.inputs == rhs.inputs)) { return false; }
if (!(lhs.outputs == rhs.outputs)) { return false; }
return true;
}

inline std::vector<uint8_t> Opcode::Call::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<Opcode::Call>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline Opcode::Call Opcode::Call::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<Opcode::Call>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::Opcode::Call>::serialize(const Circuit::Opcode::Call &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.id)>::serialize(obj.id, serializer);
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
}

template <>
template <typename Deserializer>
Circuit::Opcode::Call serde::Deserializable<Circuit::Opcode::Call>::deserialize(Deserializer &deserializer) {
Circuit::Opcode::Call obj;
obj.id = serde::Deserializable<decltype(obj.id)>::deserialize(deserializer);
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const OpcodeLocation &lhs, const OpcodeLocation &rhs) {
Expand Down
16 changes: 16 additions & 0 deletions acvm-repo/acir/src/circuit/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ pub enum Opcode {
block_id: BlockId,
init: Vec<Witness>,
},
/// Calls to functions represented as a separate circuit. A call opcode allows us
/// to build a call stack when executing the outer-most circuit.
Call {
/// Id for the function being called. It is the responsibility of the executor
/// to fetch the appropriate circuit from this id.
id: u32,
/// Inputs to the function call
inputs: Vec<Witness>,
/// Outputs of the function call
outputs: Vec<Witness>,
},
}

impl std::fmt::Display for Opcode {
Expand Down Expand Up @@ -86,6 +97,11 @@ impl std::fmt::Display for Opcode {
write!(f, "INIT ")?;
write!(f, "(id: {}, len: {}) ", block_id.0, init.len())
}
Opcode::Call { id, inputs, outputs } => {
write!(f, "CALL func {}: ", id)?;
writeln!(f, "inputs: {:?}", inputs)?;
writeln!(f, "outputs: {:?}", outputs)
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions acvm-repo/acvm/src/compiler/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ pub(super) fn transform_internal(
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(opcode);
}
Opcode::Call { .. } => todo!("Handle Call opcodes in the ACVM"),
}
}

Expand Down
1 change: 1 addition & 0 deletions acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call),
res => res.map(|_| ()),
},
Opcode::Call { .. } => todo!("Handle Call opcodes in the ACVM"),
};
self.handle_opcode_resolution(resolution)
}
Expand Down
Loading

0 comments on commit b85095b

Please sign in to comment.