Skip to content

Commit

Permalink
feat!: Pass ACIR to ACVM by reference rather than passing ownership (n…
Browse files Browse the repository at this point in the history
…oir-lang#2872)

Co-authored-by: Tom French <[email protected]>
Co-authored-by: kevaundray <[email protected]>
  • Loading branch information
3 people authored and Sakapoi committed Oct 19, 2023
1 parent f4bdbb2 commit 87f9f86
Show file tree
Hide file tree
Showing 130 changed files with 183 additions and 391 deletions.
197 changes: 0 additions & 197 deletions acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,43 +674,9 @@ namespace Circuit {
static BrilligOutputs bincodeDeserialize(std::vector<uint8_t>);
};

struct ForeignCallParam {

struct Single {
Circuit::Value value;

friend bool operator==(const Single&, const Single&);
std::vector<uint8_t> bincodeSerialize() const;
static Single bincodeDeserialize(std::vector<uint8_t>);
};

struct Array {
std::vector<Circuit::Value> value;

friend bool operator==(const Array&, const Array&);
std::vector<uint8_t> bincodeSerialize() const;
static Array bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Single, Array> value;

friend bool operator==(const ForeignCallParam&, const ForeignCallParam&);
std::vector<uint8_t> bincodeSerialize() const;
static ForeignCallParam bincodeDeserialize(std::vector<uint8_t>);
};

struct ForeignCallResult {
std::vector<Circuit::ForeignCallParam> values;

friend bool operator==(const ForeignCallResult&, const ForeignCallResult&);
std::vector<uint8_t> bincodeSerialize() const;
static ForeignCallResult bincodeDeserialize(std::vector<uint8_t>);
};

struct Brillig {
std::vector<Circuit::BrilligInputs> inputs;
std::vector<Circuit::BrilligOutputs> outputs;
std::vector<Circuit::ForeignCallResult> foreign_call_results;
std::vector<Circuit::BrilligOpcode> bytecode;
std::optional<Circuit::Expression> predicate;

Expand Down Expand Up @@ -2761,7 +2727,6 @@ namespace Circuit {
inline bool operator==(const Brillig &lhs, const Brillig &rhs) {
if (!(lhs.inputs == rhs.inputs)) { return false; }
if (!(lhs.outputs == rhs.outputs)) { return false; }
if (!(lhs.foreign_call_results == rhs.foreign_call_results)) { return false; }
if (!(lhs.bytecode == rhs.bytecode)) { return false; }
if (!(lhs.predicate == rhs.predicate)) { return false; }
return true;
Expand Down Expand Up @@ -2790,7 +2755,6 @@ void serde::Serializable<Circuit::Brillig>::serialize(const Circuit::Brillig &ob
serializer.increase_container_depth();
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
serde::Serializable<decltype(obj.foreign_call_results)>::serialize(obj.foreign_call_results, serializer);
serde::Serializable<decltype(obj.bytecode)>::serialize(obj.bytecode, serializer);
serde::Serializable<decltype(obj.predicate)>::serialize(obj.predicate, serializer);
serializer.decrease_container_depth();
Expand All @@ -2803,7 +2767,6 @@ Circuit::Brillig serde::Deserializable<Circuit::Brillig>::deserialize(Deserializ
Circuit::Brillig obj;
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
obj.foreign_call_results = serde::Deserializable<decltype(obj.foreign_call_results)>::deserialize(deserializer);
obj.bytecode = serde::Deserializable<decltype(obj.bytecode)>::deserialize(deserializer);
obj.predicate = serde::Deserializable<decltype(obj.predicate)>::deserialize(deserializer);
deserializer.decrease_container_depth();
Expand Down Expand Up @@ -3970,166 +3933,6 @@ Circuit::Expression serde::Deserializable<Circuit::Expression>::deserialize(Dese
return obj;
}

namespace Circuit {

inline bool operator==(const ForeignCallParam &lhs, const ForeignCallParam &rhs) {
if (!(lhs.value == rhs.value)) { return false; }
return true;
}

inline std::vector<uint8_t> ForeignCallParam::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<ForeignCallParam>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline ForeignCallParam ForeignCallParam::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<ForeignCallParam>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::ForeignCallParam>::serialize(const Circuit::ForeignCallParam &obj, Serializer &serializer) {
serializer.increase_container_depth();
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
serializer.decrease_container_depth();
}

template <>
template <typename Deserializer>
Circuit::ForeignCallParam serde::Deserializable<Circuit::ForeignCallParam>::deserialize(Deserializer &deserializer) {
deserializer.increase_container_depth();
Circuit::ForeignCallParam obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
deserializer.decrease_container_depth();
return obj;
}

namespace Circuit {

inline bool operator==(const ForeignCallParam::Single &lhs, const ForeignCallParam::Single &rhs) {
if (!(lhs.value == rhs.value)) { return false; }
return true;
}

inline std::vector<uint8_t> ForeignCallParam::Single::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<ForeignCallParam::Single>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline ForeignCallParam::Single ForeignCallParam::Single::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<ForeignCallParam::Single>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::ForeignCallParam::Single>::serialize(const Circuit::ForeignCallParam::Single &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
}

template <>
template <typename Deserializer>
Circuit::ForeignCallParam::Single serde::Deserializable<Circuit::ForeignCallParam::Single>::deserialize(Deserializer &deserializer) {
Circuit::ForeignCallParam::Single obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const ForeignCallParam::Array &lhs, const ForeignCallParam::Array &rhs) {
if (!(lhs.value == rhs.value)) { return false; }
return true;
}

inline std::vector<uint8_t> ForeignCallParam::Array::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<ForeignCallParam::Array>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline ForeignCallParam::Array ForeignCallParam::Array::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<ForeignCallParam::Array>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::ForeignCallParam::Array>::serialize(const Circuit::ForeignCallParam::Array &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
}

template <>
template <typename Deserializer>
Circuit::ForeignCallParam::Array serde::Deserializable<Circuit::ForeignCallParam::Array>::deserialize(Deserializer &deserializer) {
Circuit::ForeignCallParam::Array obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const ForeignCallResult &lhs, const ForeignCallResult &rhs) {
if (!(lhs.values == rhs.values)) { return false; }
return true;
}

inline std::vector<uint8_t> ForeignCallResult::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<ForeignCallResult>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline ForeignCallResult ForeignCallResult::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<ForeignCallResult>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::ForeignCallResult>::serialize(const Circuit::ForeignCallResult &obj, Serializer &serializer) {
serializer.increase_container_depth();
serde::Serializable<decltype(obj.values)>::serialize(obj.values, serializer);
serializer.decrease_container_depth();
}

template <>
template <typename Deserializer>
Circuit::ForeignCallResult serde::Deserializable<Circuit::ForeignCallResult>::deserialize(Deserializer &deserializer) {
deserializer.increase_container_depth();
Circuit::ForeignCallResult obj;
obj.values = serde::Deserializable<decltype(obj.values)>::deserialize(deserializer);
deserializer.decrease_container_depth();
return obj;
}

namespace Circuit {

inline bool operator==(const FunctionInput &lhs, const FunctionInput &rhs) {
Expand Down
4 changes: 0 additions & 4 deletions acvm-repo/acir/src/circuit/brillig.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::native_types::{Expression, Witness};
use brillig::ForeignCallResult;
use brillig::Opcode as BrilligOpcode;
use serde::{Deserialize, Serialize};

Expand All @@ -23,9 +22,6 @@ pub enum BrilligOutputs {
pub struct Brillig {
pub inputs: Vec<BrilligInputs>,
pub outputs: Vec<BrilligOutputs>,
/// Results of oracles/functions external to brillig like a database read.
// Each element of this vector corresponds to a single foreign call but may contain several values.
pub foreign_call_results: Vec<ForeignCallResult>,
/// The Brillig VM bytecode to be executed by this ACIR opcode.
pub bytecode: Vec<BrilligOpcode>,
/// Predicate of the Brillig execution - indicates if it should be skipped
Expand Down
5 changes: 1 addition & 4 deletions acvm-repo/acir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ mod reflection {
};

use brillig::{
BinaryFieldOp, BinaryIntOp, BlackBoxOp, ForeignCallParam, ForeignCallResult,
Opcode as BrilligOpcode, RegisterOrMemory,
BinaryFieldOp, BinaryIntOp, BlackBoxOp, Opcode as BrilligOpcode, RegisterOrMemory,
};
use serde_reflection::{Tracer, TracerConfig};

Expand Down Expand Up @@ -70,8 +69,6 @@ mod reflection {
tracer.trace_simple_type::<BinaryIntOp>().unwrap();
tracer.trace_simple_type::<BlackBoxOp>().unwrap();
tracer.trace_simple_type::<Directive>().unwrap();
tracer.trace_simple_type::<ForeignCallParam>().unwrap();
tracer.trace_simple_type::<ForeignCallResult>().unwrap();
tracer.trace_simple_type::<RegisterOrMemory>().unwrap();

let registry = tracer.registry().unwrap();
Expand Down
27 changes: 11 additions & 16 deletions acvm-repo/acir/tests/test_program_serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,6 @@ fn simple_brillig_foreign_call() {
outputs: vec![
BrilligOutputs::Simple(w_inverted), // Output Register 1
],
// stack of foreign call/oracle resolutions, starts empty
foreign_call_results: vec![],
bytecode: vec![brillig::Opcode::ForeignCall {
function: "invert".into(),
destinations: vec![RegisterOrMemory::RegisterIndex(RegisterIndex::from(0))],
Expand All @@ -203,11 +201,10 @@ fn simple_brillig_foreign_call() {
circuit.write(&mut bytes).unwrap();

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 81, 10, 0, 16, 16, 68, 199, 42, 57, 14, 55,
112, 25, 31, 126, 124, 72, 206, 79, 161, 86, 225, 135, 87, 219, 78, 187, 53, 205, 104, 0,
2, 29, 201, 52, 103, 222, 220, 216, 230, 13, 43, 254, 121, 25, 158, 151, 54, 153, 117, 27,
53, 116, 136, 197, 167, 124, 107, 184, 64, 236, 73, 56, 83, 1, 18, 139, 122, 157, 67, 1, 0,
0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 49, 10, 64, 33, 12, 67, 99, 63, 124, 60, 142,
222, 192, 203, 56, 184, 56, 136, 120, 126, 5, 21, 226, 160, 139, 62, 40, 13, 45, 132, 68,
3, 80, 232, 124, 164, 153, 121, 115, 99, 155, 59, 172, 122, 231, 101, 56, 175, 80, 86, 221,
230, 31, 58, 196, 226, 83, 62, 53, 91, 16, 122, 10, 246, 84, 99, 243, 0, 30, 59, 1, 0, 0,
];

assert_eq!(bytes, expected_serialization)
Expand Down Expand Up @@ -248,8 +245,6 @@ fn complex_brillig_foreign_call() {
BrilligOutputs::Simple(a_plus_b_plus_c), // Output Register 1
BrilligOutputs::Simple(a_plus_b_plus_c_times_2), // Output Register 2
],
// stack of foreign call/oracle resolutions, starts empty
foreign_call_results: vec![],
bytecode: vec![
// Oracles are named 'foreign calls' in brillig
brillig::Opcode::ForeignCall {
Expand Down Expand Up @@ -280,13 +275,13 @@ fn complex_brillig_foreign_call() {
circuit.write(&mut bytes).unwrap();

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 83, 219, 10, 128, 48, 8, 245, 210, 101, 159, 179,
254, 160, 127, 137, 222, 138, 122, 236, 243, 27, 228, 64, 44, 232, 33, 7, 237, 128, 56,
157, 147, 131, 103, 6, 0, 64, 184, 192, 201, 72, 206, 40, 177, 70, 174, 27, 197, 199, 111,
24, 208, 175, 87, 44, 197, 145, 42, 224, 200, 5, 56, 230, 255, 240, 83, 189, 61, 117, 113,
157, 31, 63, 236, 79, 147, 172, 77, 214, 73, 220, 139, 15, 106, 214, 168, 114, 249, 126,
218, 214, 125, 153, 15, 54, 37, 90, 26, 155, 39, 227, 95, 223, 232, 230, 4, 247, 157, 215,
56, 1, 153, 86, 63, 138, 44, 4, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 83, 219, 10, 128, 48, 8, 117, 174, 139, 159, 179,
254, 160, 127, 137, 222, 138, 122, 236, 243, 19, 114, 32, 22, 244, 144, 131, 118, 64, 156,
178, 29, 14, 59, 74, 0, 16, 224, 66, 228, 64, 57, 7, 169, 53, 242, 189, 81, 114, 250, 134,
33, 248, 113, 165, 82, 26, 177, 2, 141, 177, 128, 198, 60, 15, 63, 245, 219, 211, 23, 215,
255, 139, 15, 251, 211, 112, 180, 28, 157, 212, 189, 100, 82, 179, 64, 170, 63, 109, 235,
190, 204, 135, 166, 178, 150, 216, 62, 154, 252, 250, 70, 147, 35, 220, 119, 93, 227, 4,
182, 131, 81, 25, 36, 4, 0, 0,
];

assert_eq!(bytes, expected_serialization)
Expand Down
5 changes: 3 additions & 2 deletions acvm-repo/acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use acir::{
brillig::{ForeignCallParam, RegisterIndex, Value},
brillig::{ForeignCallParam, ForeignCallResult, RegisterIndex, Value},
circuit::{
brillig::{Brillig, BrilligInputs, BrilligOutputs},
OpcodeLocation,
Expand All @@ -20,6 +20,7 @@ impl BrilligSolver {
pub(super) fn solve<B: BlackBoxFunctionSolver>(
initial_witness: &mut WitnessMap,
brillig: &Brillig,
foreign_call_results: Vec<ForeignCallResult>,
bb_solver: &B,
acir_index: usize,
) -> Result<Option<ForeignCallWaitInfo>, OpcodeResolutionError> {
Expand Down Expand Up @@ -80,7 +81,7 @@ impl BrilligSolver {
input_registers,
input_memory,
&brillig.bytecode,
brillig.foreign_call_results.clone(),
foreign_call_results,
bb_solver,
);

Expand Down
Loading

0 comments on commit 87f9f86

Please sign in to comment.