Skip to content

Commit

Permalink
chore(acir)!: Move is_recursive flag to be part of the circuit defi…
Browse files Browse the repository at this point in the history
…nition (AztecProtocol#4221)

Resolves AztecProtocol#4222

Currently in order to specify whether we want to use a prover that
produces SNARK recursion friendly proofs, we must pass a flag from the
tooling infrastructure. This PR moves it be part of the circuit
definition itself.

The flag now lives on the Builder and is set when we call
`create_circuit` in the acir format. The proof produced when this flag
is true should be friendly for recursive verification inside of another
SNARK. For example, a recursive friendly proof may use Blake3Pedersen
for hashing in its transcript, while we still want a prove that uses
Keccak for its transcript in order to be able to verify SNARKs on
Ethereum.

However, a verifier does not need a full circuit description and should
be able to verify a proof with just the verification key and the proof.
An `is_recursive_circuit` field was thus added to the verification key
as well so that we can specify the accurate verifier to use for a given
proof without the full circuit description.

---------

Signed-off-by: kevaundray <[email protected]>
Co-authored-by: ledwards2225 <[email protected]>
Co-authored-by: kevaundray <[email protected]>
  • Loading branch information
3 people authored Feb 1, 2024
1 parent 433b9eb commit 9c965a7
Show file tree
Hide file tree
Showing 53 changed files with 249 additions and 250 deletions.
3 changes: 1 addition & 2 deletions barretenberg/acir_tests/browser-test-app/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ async function runTest(
acirComposer,
bytecode,
witness,
true
);
debug(`verifying...`);
const verified = await api.acirVerifyProof(acirComposer, proof, true);
const verified = await api.acirVerifyProof(acirComposer, proof);
debug(`verified: ${verified}`);

await api.destroy();
Expand Down
2 changes: 1 addition & 1 deletion barretenberg/acir_tests/gen_inner_proof_inputs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export BRANCH

./clone_test_vectors.sh

cd acir_tests/assert_statement
cd acir_tests/assert_statement_recursive

PROOF_DIR=$PWD/proofs
PROOF_PATH=$PROOF_DIR/$PROOF_NAME
Expand Down
30 changes: 12 additions & 18 deletions barretenberg/cpp/src/barretenberg/bb/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ acir_format::AcirFormat get_constraint_system(std::string const& bytecode_path)
* @return true if the proof is valid
* @return false if the proof is invalid
*/
bool proveAndVerify(const std::string& bytecodePath, const std::string& witnessPath, bool recursive)
bool proveAndVerify(const std::string& bytecodePath, const std::string& witnessPath)
{
auto constraint_system = get_constraint_system(bytecodePath);
auto witness = get_witness(witnessPath);
Expand All @@ -109,14 +109,14 @@ bool proveAndVerify(const std::string& bytecodePath, const std::string& witnessP
write_benchmark("subgroup_size", acir_composer.get_dyadic_circuit_size(), "acir_test", current_dir);

Timer proof_timer;
auto proof = acir_composer.create_proof(recursive);
auto proof = acir_composer.create_proof();
write_benchmark("proof_construction_time", proof_timer.milliseconds(), "acir_test", current_dir);

Timer vk_timer;
acir_composer.init_verification_key();
write_benchmark("vk_construction_time", vk_timer.milliseconds(), "acir_test", current_dir);

auto verified = acir_composer.verify_proof(proof, recursive);
auto verified = acir_composer.verify_proof(proof);

vinfo("verified: ", verified);
return verified;
Expand Down Expand Up @@ -172,9 +172,7 @@ bool accumulateAndVerifyGoblin(const std::string& bytecodePath, const std::strin
* @return true if the proof is valid
* @return false if the proof is invalid
*/
bool proveAndVerifyGoblin(const std::string& bytecodePath,
const std::string& witnessPath,
[[maybe_unused]] bool recursive)
bool proveAndVerifyGoblin(const std::string& bytecodePath, const std::string& witnessPath)
{
// Populate the acir constraint system and witness from gzipped data
auto constraint_system = get_constraint_system(bytecodePath);
Expand Down Expand Up @@ -212,10 +210,7 @@ bool proveAndVerifyGoblin(const std::string& bytecodePath,
* @param recursive Whether to use recursive proof generation of non-recursive
* @param outputPath Path to write the proof to
*/
void prove(const std::string& bytecodePath,
const std::string& witnessPath,
bool recursive,
const std::string& outputPath)
void prove(const std::string& bytecodePath, const std::string& witnessPath, const std::string& outputPath)
{
auto constraint_system = get_constraint_system(bytecodePath);
auto witness = get_witness(witnessPath);
Expand All @@ -224,7 +219,7 @@ void prove(const std::string& bytecodePath,
acir_composer.create_circuit(constraint_system, witness);
init_bn254_crs(acir_composer.get_dyadic_circuit_size());
acir_composer.init_proving_key();
auto proof = acir_composer.create_proof(recursive);
auto proof = acir_composer.create_proof();

if (outputPath == "-") {
writeRawBytesToStdout(proof);
Expand Down Expand Up @@ -270,12 +265,12 @@ void gateCount(const std::string& bytecodePath)
* @return true If the proof is valid
* @return false If the proof is invalid
*/
bool verify(const std::string& proof_path, bool recursive, const std::string& vk_path)
bool verify(const std::string& proof_path, const std::string& vk_path)
{
auto acir_composer = verifier_init();
auto vk_data = from_buffer<plonk::verification_key_data>(read_file(vk_path));
acir_composer.load_verification_key(std::move(vk_data));
auto verified = acir_composer.verify_proof(read_file(proof_path), recursive);
auto verified = acir_composer.verify_proof(read_file(proof_path));

vinfo("verified: ", verified);
return verified;
Expand Down Expand Up @@ -491,7 +486,6 @@ int main(int argc, char* argv[])
std::string vk_path = get_option(args, "-k", "./target/vk");
std::string pk_path = get_option(args, "-r", "./target/pk");
CRS_PATH = get_option(args, "-c", CRS_PATH);
bool recursive = flag_present(args, "-r") || flag_present(args, "--recursive");

// Skip CRS initialization for any command which doesn't require the CRS.
if (command == "--version") {
Expand All @@ -504,21 +498,21 @@ int main(int argc, char* argv[])
return 0;
}
if (command == "prove_and_verify") {
return proveAndVerify(bytecode_path, witness_path, recursive) ? 0 : 1;
return proveAndVerify(bytecode_path, witness_path) ? 0 : 1;
}
if (command == "accumulate_and_verify_goblin") {
return accumulateAndVerifyGoblin(bytecode_path, witness_path) ? 0 : 1;
}
if (command == "prove_and_verify_goblin") {
return proveAndVerifyGoblin(bytecode_path, witness_path, recursive) ? 0 : 1;
return proveAndVerifyGoblin(bytecode_path, witness_path) ? 0 : 1;
}
if (command == "prove") {
std::string output_path = get_option(args, "-o", "./proofs/proof");
prove(bytecode_path, witness_path, recursive, output_path);
prove(bytecode_path, witness_path, output_path);
} else if (command == "gates") {
gateCount(bytecode_path);
} else if (command == "verify") {
return verify(proof_path, recursive, vk_path) ? 0 : 1;
return verify(proof_path, vk_path) ? 0 : 1;
} else if (command == "contract") {
std::string output_path = get_option(args, "-o", "./target/contract.sol");
contract(output_path, vk_path);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ void build_constraints(Builder& builder, AcirFormat const& constraint_system, bo
template <typename Builder>
Builder create_circuit(const AcirFormat& constraint_system, size_t size_hint, WitnessVector const& witness)
{
Builder builder{ size_hint, witness, constraint_system.public_inputs, constraint_system.varnum };
Builder builder{
size_hint, witness, constraint_system.public_inputs, constraint_system.varnum, constraint_system.recursive
};

bool has_valid_witness_assignments = !witness.empty();
build_constraints(builder, constraint_system, has_valid_witness_assignments);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ namespace acir_format {
struct AcirFormat {
// The number of witnesses in the circuit
uint32_t varnum;
// Specifies whether a prover that produces SNARK recursion friendly proofs should be used.
// The proof produced when this flag is true should be friendly for recursive verification inside
// of another SNARK. For example, a recursive friendly proof may use Blake3Pedersen for
// hashing in its transcript, while we still want a prove that uses Keccak for its transcript in order
// to be able to verify SNARKs on Ethereum.
bool recursive;

std::vector<uint32_t> public_inputs;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ TEST_F(AcirFormatTests, TestASingleConstraintNoPubInputs)

AcirFormat constraint_system{
.varnum = 4,
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
Expand Down Expand Up @@ -141,6 +142,7 @@ TEST_F(AcirFormatTests, TestLogicGateFromNoirCircuit)
// EXPR [ (-1, _6) 1 ]

AcirFormat constraint_system{ .varnum = 6,
.recursive = false,
.public_inputs = { 1 },
.logic_constraints = { logic_constraint },
.range_constraints = { range_a, range_b },
Expand Down Expand Up @@ -205,6 +207,7 @@ TEST_F(AcirFormatTests, TestSchnorrVerifyPass)
.signature = signature,
};
AcirFormat constraint_system{ .varnum = 81,
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = range_constraints,
Expand Down Expand Up @@ -297,6 +300,7 @@ TEST_F(AcirFormatTests, TestSchnorrVerifySmallRange)
};
AcirFormat constraint_system{
.varnum = 81,
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = range_constraints,
Expand Down Expand Up @@ -408,6 +412,7 @@ TEST_F(AcirFormatTests, TestVarKeccak)

AcirFormat constraint_system{
.varnum = 36,
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = { range_a, range_b, range_c, range_d },
Expand Down Expand Up @@ -451,6 +456,7 @@ TEST_F(AcirFormatTests, TestKeccakPermutation)
};

AcirFormat constraint_system{ .varnum = 51,
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ AcirFormat circuit_buf_to_acir_format(std::vector<uint8_t> const& buf)
AcirFormat af;
// `varnum` is the true number of variables, thus we add one to the index which starts at zero
af.varnum = circuit.current_witness_index + 1;
af.recursive = circuit.recursive;
af.public_inputs = join({ map(circuit.public_parameters.value, [](auto e) { return e.value; }),
map(circuit.return_values.value, [](auto e) { return e.value; }) });
std::map<uint32_t, BlockConstraint> block_id_to_block_constraint;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ TEST_F(BigIntTests, TestBigIntConstraintDummy)

AcirFormat constraint_system{
.varnum = 4,
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ TEST_F(UltraPlonkRAM, TestBlockConstraint)
size_t num_variables = generate_block_constraint(block, witness_values);
AcirFormat constraint_system{
.varnum = static_cast<uint32_t>(num_variables),
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ TEST_F(EcOperations, TestECOperations)

AcirFormat constraint_system{
.varnum = static_cast<uint32_t>(num_variables + 1),
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ TEST_F(ECDSASecp256k1, TestECDSAConstraintSucceed)
size_t num_variables = generate_ecdsa_constraint(ecdsa_k1_constraint, witness_values);
AcirFormat constraint_system{
.varnum = static_cast<uint32_t>(num_variables),
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
Expand Down Expand Up @@ -134,6 +135,7 @@ TEST_F(ECDSASecp256k1, TestECDSACompilesForVerifier)
size_t num_variables = generate_ecdsa_constraint(ecdsa_k1_constraint, witness_values);
AcirFormat constraint_system{
.varnum = static_cast<uint32_t>(num_variables),
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
Expand Down Expand Up @@ -174,6 +176,7 @@ TEST_F(ECDSASecp256k1, TestECDSAConstraintFail)

AcirFormat constraint_system{
.varnum = static_cast<uint32_t>(num_variables),
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ TEST(ECDSASecp256r1, test_hardcoded)

AcirFormat constraint_system{
.varnum = static_cast<uint32_t>(num_variables),
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
Expand Down Expand Up @@ -170,6 +171,7 @@ TEST(ECDSASecp256r1, TestECDSAConstraintSucceed)
size_t num_variables = generate_ecdsa_constraint(ecdsa_r1_constraint, witness_values);
AcirFormat constraint_system{
.varnum = static_cast<uint32_t>(num_variables),
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
Expand Down Expand Up @@ -214,6 +216,7 @@ TEST(ECDSASecp256r1, TestECDSACompilesForVerifier)
size_t num_variables = generate_ecdsa_constraint(ecdsa_r1_constraint, witness_values);
AcirFormat constraint_system{
.varnum = static_cast<uint32_t>(num_variables),
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
Expand Down Expand Up @@ -253,6 +256,7 @@ TEST(ECDSASecp256r1, TestECDSAConstraintFail)

AcirFormat constraint_system{
.varnum = static_cast<uint32_t>(num_variables),
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Builder create_inner_circuit()
};

AcirFormat constraint_system{ .varnum = 6,
.recursive = true,
.public_inputs = { 1, 2 },
.logic_constraints = { logic_constraint },
.range_constraints = { range_a, range_b },
Expand Down Expand Up @@ -235,6 +236,7 @@ Builder create_outer_circuit(std::vector<Builder>& inner_circuits)
}

AcirFormat constraint_system{ .varnum = static_cast<uint32_t>(witness.size()),
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,7 @@ struct Circuit {
PublicInputs public_parameters;
PublicInputs return_values;
std::vector<std::tuple<OpcodeLocation, std::string>> assert_messages;
bool recursive;

friend bool operator==(const Circuit&, const Circuit&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -5938,6 +5939,9 @@ inline bool operator==(const Circuit& lhs, const Circuit& rhs)
if (!(lhs.assert_messages == rhs.assert_messages)) {
return false;
}
if (!(lhs.recursive == rhs.recursive)) {
return false;
}
return true;
}

Expand Down Expand Up @@ -5971,6 +5975,7 @@ void serde::Serializable<Circuit::Circuit>::serialize(const Circuit::Circuit& ob
serde::Serializable<decltype(obj.public_parameters)>::serialize(obj.public_parameters, serializer);
serde::Serializable<decltype(obj.return_values)>::serialize(obj.return_values, serializer);
serde::Serializable<decltype(obj.assert_messages)>::serialize(obj.assert_messages, serializer);
serde::Serializable<decltype(obj.recursive)>::serialize(obj.recursive, serializer);
serializer.decrease_container_depth();
}

Expand All @@ -5986,6 +5991,7 @@ Circuit::Circuit serde::Deserializable<Circuit::Circuit>::deserialize(Deserializ
obj.public_parameters = serde::Deserializable<decltype(obj.public_parameters)>::deserialize(deserializer);
obj.return_values = serde::Deserializable<decltype(obj.return_values)>::deserialize(deserializer);
obj.assert_messages = serde::Deserializable<decltype(obj.assert_messages)>::deserialize(deserializer);
obj.recursive = serde::Deserializable<decltype(obj.recursive)>::deserialize(deserializer);
deserializer.decrease_container_depth();
return obj;
}
Expand Down
Loading

0 comments on commit 9c965a7

Please sign in to comment.