Skip to content

Commit

Permalink
feat: honk flows exposed through wasm (AztecProtocol#6096)
Browse files Browse the repository at this point in the history
Adds the Honk prove, verify, and store_vk flows for WASM (through
bb.js).

Closes AztecProtocol/barretenberg#970.
  • Loading branch information
lucasxia01 authored May 3, 2024
1 parent ed84fe3 commit c9b3206
Show file tree
Hide file tree
Showing 16 changed files with 243 additions and 34 deletions.
6 changes: 5 additions & 1 deletion barretenberg/acir_tests/Dockerfile.bb.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ ENV VERBOSE=1
# Run double_verify_proof through bb.js on node to check 512k support.
RUN BIN=../ts/dest/node/main.js FLOW=prove_then_verify ./run_acir_tests.sh double_verify_proof
# Run a single arbitrary test not involving recursion through bb.js for UltraHonk
RUN BIN=../ts/dest/node/main.js FLOW=prove_and_verify_ultra_honk ./run_acir_tests.sh 6_array
RUN BIN=../ts/dest/node/main.js FLOW=prove_then_verify_ultra_honk ./run_acir_tests.sh nested_array_dynamic
# Run a single arbitrary test not involving recursion through bb.js for Plonk
RUN BIN=../ts/dest/node/main.js FLOW=prove_and_verify ./run_acir_tests.sh poseidon_bn254_hash
# Run a single arbitrary test not involving recursion through bb.js for GoblinUltraHonk
RUN BIN=../ts/dest/node/main.js FLOW=prove_and_verify_ultra_honk ./run_acir_tests.sh closures_mut_ref
# Run a single arbitrary test for separate prove and verify for UltraHonk
RUN BIN=../ts/dest/node/main.js FLOW=prove_and_verify_goblin_ultra_honk ./run_acir_tests.sh 6_array
# Run a single arbitrary test not involving recursion through bb.js for full Goblin
RUN BIN=../ts/dest/node/main.js FLOW=prove_and_verify_goblin ./run_acir_tests.sh 6_array
Expand Down
40 changes: 40 additions & 0 deletions barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,43 @@ WASM_EXPORT void acir_serialize_verification_key_into_fields(in_ptr acir_compose
*out_vkey = to_heap_buffer(vkey_as_fields);
write(out_key_hash, vk_hash);
}

WASM_EXPORT void acir_prove_ultra_honk(uint8_t const* acir_vec, uint8_t const* witness_vec, uint8_t** out)
{
auto constraint_system = acir_format::circuit_buf_to_acir_format(from_buffer<std::vector<uint8_t>>(acir_vec));
auto witness = acir_format::witness_buf_to_witness_data(from_buffer<std::vector<uint8_t>>(witness_vec));

auto builder = acir_format::create_circuit<UltraCircuitBuilder>(constraint_system, 0, witness);

UltraProver prover{ builder };
auto proof = prover.construct_proof();
*out = to_heap_buffer(to_buffer</*include_size=*/true>(proof));
}

WASM_EXPORT void acir_verify_ultra_honk(uint8_t const* proof_buf, uint8_t const* vk_buf, bool* result)
{
using VerificationKey = UltraFlavor::VerificationKey;
using VerifierCommitmentKey = bb::VerifierCommitmentKey<curve::BN254>;
using Verifier = UltraVerifier_<UltraFlavor>;

auto proof = from_buffer<std::vector<bb::fr>>(from_buffer<std::vector<uint8_t>>(proof_buf));
auto verification_key = std::make_shared<VerificationKey>(from_buffer<VerificationKey>(vk_buf));
verification_key->pcs_verification_key = std::make_shared<VerifierCommitmentKey>();

Verifier verifier{ verification_key };

*result = verifier.verify_proof(proof);
}

WASM_EXPORT void acir_write_vk_ultra_honk(uint8_t const* acir_vec, uint8_t** out)
{
using ProverInstance = ProverInstance_<UltraFlavor>;
using VerificationKey = UltraFlavor::VerificationKey;

auto constraint_system = acir_format::circuit_buf_to_acir_format(from_buffer<std::vector<uint8_t>>(acir_vec));
auto builder = acir_format::create_circuit<UltraCircuitBuilder>(constraint_system, 0, {});

ProverInstance prover_inst(builder);
VerificationKey vk(prover_inst.proving_key);
*out = to_heap_buffer(to_buffer(vk));
}
4 changes: 3 additions & 1 deletion barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,6 @@ WASM_EXPORT void acir_serialize_proof_into_fields(in_ptr acir_composer_ptr,

WASM_EXPORT void acir_serialize_verification_key_into_fields(in_ptr acir_composer_ptr,
fr::vec_out_buf out_vkey,
fr::out_buf out_key_hash);
fr::out_buf out_key_hash);

WASM_EXPORT void acir_prove_ultra_honk(uint8_t const* acir_vec, uint8_t const* witness_vec, uint8_t** out);
2 changes: 1 addition & 1 deletion barretenberg/cpp/src/barretenberg/eccvm/eccvm_verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ bool ECCVMVerifier::verify_proof(const HonkProof& proof)
const size_t log_circuit_size = numeric::get_msb(circuit_size);
auto sumcheck = SumcheckVerifier<Flavor>(log_circuit_size, transcript);
FF alpha = transcript->template get_challenge<FF>("Sumcheck:alpha");
std::vector<FF> gate_challenges(numeric::get_msb(key->circuit_size));
std::vector<FF> gate_challenges(static_cast<size_t>(numeric::get_msb(key->circuit_size)));
for (size_t idx = 0; idx < gate_challenges.size(); idx++) {
gate_challenges[idx] = transcript->template get_challenge<FF>("Sumcheck:gate_challenge_" + std::to_string(idx));
}
Expand Down
8 changes: 4 additions & 4 deletions barretenberg/cpp/src/barretenberg/flavor/flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ namespace bb {
*/
class PrecomputedEntitiesBase {
public:
size_t circuit_size;
size_t log_circuit_size;
size_t num_public_inputs;
uint64_t circuit_size;
uint64_t log_circuit_size;
uint64_t num_public_inputs;
CircuitType circuit_type; // TODO(#392)
};

Expand Down Expand Up @@ -181,7 +181,7 @@ template <typename PrecomputedCommitments, typename VerifierCommitmentKey>
class VerificationKey_ : public PrecomputedCommitments {
public:
std::shared_ptr<VerifierCommitmentKey> pcs_verification_key;
size_t pub_inputs_offset = 0;
uint64_t pub_inputs_offset = 0;

VerificationKey_() = default;
VerificationKey_(const size_t circuit_size, const size_t num_public_inputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ template <typename Flavor> bool DeciderVerifier_<Flavor>::verify_proof(const Hon

VerifierCommitments commitments{ accumulator->verification_key, accumulator->witness_commitments };

auto sumcheck =
SumcheckVerifier<Flavor>(accumulator->verification_key->log_circuit_size, transcript, accumulator->target_sum);
auto sumcheck = SumcheckVerifier<Flavor>(
static_cast<size_t>(accumulator->verification_key->log_circuit_size), transcript, accumulator->target_sum);

auto [multivariate_challenge, claimed_evaluations, sumcheck_verified] =
sumcheck.verify(accumulator->relation_parameters, accumulator->alphas, accumulator->gate_challenges);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void ProtoGalaxyVerifier_<VerifierInstances>::prepare_for_folding(const std::vec
if (!inst->is_accumulator) {
receive_and_finalise_instance(inst, domain_separator);
inst->target_sum = 0;
inst->gate_challenges = std::vector<FF>(inst->verification_key->log_circuit_size, 0);
inst->gate_challenges = std::vector<FF>(static_cast<size_t>(inst->verification_key->log_circuit_size), 0);
}
index++;

Expand All @@ -45,11 +45,12 @@ std::shared_ptr<typename VerifierInstances::Instance> ProtoGalaxyVerifier_<Verif

auto delta = transcript->template get_challenge<FF>("delta");
auto accumulator = get_accumulator();
auto deltas = compute_round_challenge_pows(accumulator->verification_key->log_circuit_size, delta);
auto deltas =
compute_round_challenge_pows(static_cast<size_t>(accumulator->verification_key->log_circuit_size), delta);

std::vector<FF> perturbator_coeffs(accumulator->verification_key->log_circuit_size + 1, 0);
std::vector<FF> perturbator_coeffs(static_cast<size_t>(accumulator->verification_key->log_circuit_size) + 1, 0);
if (accumulator->is_accumulator) {
for (size_t idx = 1; idx <= accumulator->verification_key->log_circuit_size; idx++) {
for (size_t idx = 1; idx <= static_cast<size_t>(accumulator->verification_key->log_circuit_size); idx++) {
perturbator_coeffs[idx] =
transcript->template receive_from_prover<FF>("perturbator_" + std::to_string(idx));
}
Expand Down Expand Up @@ -112,7 +113,8 @@ std::shared_ptr<typename VerifierInstances::Instance> ProtoGalaxyVerifier_<Verif
vk_idx++;
}
next_accumulator->verification_key->num_public_inputs = accumulator->verification_key->num_public_inputs;
next_accumulator->public_inputs = std::vector<FF>(next_accumulator->verification_key->num_public_inputs, 0);
next_accumulator->public_inputs =
std::vector<FF>(static_cast<size_t>(next_accumulator->verification_key->num_public_inputs), 0);
size_t public_input_idx = 0;
for (auto& public_input : next_accumulator->public_inputs) {
size_t inst = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ std::array<typename Flavor::GroupElement, 2> DeciderRecursiveVerifier_<Flavor>::

VerifierCommitments commitments{ accumulator->verification_key, accumulator->witness_commitments };

auto sumcheck = Sumcheck(accumulator->verification_key->log_circuit_size, transcript, accumulator->target_sum);
auto sumcheck = Sumcheck(
static_cast<size_t>(accumulator->verification_key->log_circuit_size), transcript, accumulator->target_sum);

auto [multivariate_challenge, claimed_evaluations, sumcheck_verified] =
sumcheck.verify(accumulator->relation_parameters, accumulator->alphas, accumulator->gate_challenges);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,12 @@ void ProtoGalaxyRecursiveVerifier_<VerifierInstances>::receive_and_finalise_inst
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.z_lookup);

// Compute correction terms for grand products
const FF public_input_delta = compute_public_input_delta<Flavor>(inst->public_inputs,
beta,
gamma,
inst->verification_key->circuit_size,
inst->verification_key->pub_inputs_offset);
const FF public_input_delta =
compute_public_input_delta<Flavor>(inst->public_inputs,
beta,
gamma,
inst->verification_key->circuit_size,
static_cast<size_t>(inst->verification_key->pub_inputs_offset));
const FF lookup_grand_product_delta =
compute_lookup_grand_product_delta<FF>(beta, gamma, inst->verification_key->circuit_size);
inst->relation_parameters =
Expand All @@ -105,7 +106,7 @@ template <class VerifierInstances> void ProtoGalaxyRecursiveVerifier_<VerifierIn
if (!inst->is_accumulator) {
receive_and_finalise_instance(inst, domain_separator);
inst->target_sum = 0;
inst->gate_challenges = std::vector<FF>(inst->verification_key->log_circuit_size, 0);
inst->gate_challenges = std::vector<FF>(static_cast<size_t>(inst->verification_key->log_circuit_size), 0);
}
index++;

Expand All @@ -128,11 +129,12 @@ std::shared_ptr<typename VerifierInstances::Instance> ProtoGalaxyRecursiveVerifi

auto delta = transcript->template get_challenge<FF>("delta");
auto accumulator = get_accumulator();
auto deltas = compute_round_challenge_pows(accumulator->verification_key->log_circuit_size, delta);
auto deltas =
compute_round_challenge_pows(static_cast<size_t>(accumulator->verification_key->log_circuit_size), delta);

std::vector<FF> perturbator_coeffs(accumulator->verification_key->log_circuit_size + 1, 0);
std::vector<FF> perturbator_coeffs(static_cast<size_t>(accumulator->verification_key->log_circuit_size) + 1, 0);
if (accumulator->is_accumulator) {
for (size_t idx = 1; idx <= accumulator->verification_key->log_circuit_size; idx++) {
for (size_t idx = 1; idx <= static_cast<size_t>(accumulator->verification_key->log_circuit_size); idx++) {
perturbator_coeffs[idx] =
transcript->template receive_from_prover<FF>("perturbator_" + std::to_string(idx));
}
Expand Down Expand Up @@ -200,7 +202,8 @@ std::shared_ptr<typename VerifierInstances::Instance> ProtoGalaxyRecursiveVerifi
comm_idx++;
}

next_accumulator->public_inputs = std::vector<FF>(next_accumulator->verification_key->num_public_inputs, 0);
next_accumulator->public_inputs =
std::vector<FF>(static_cast<size_t>(next_accumulator->verification_key->num_public_inputs), 0);
size_t public_input_idx = 0;
for (auto& public_input : next_accumulator->public_inputs) {
size_t inst = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ template <IsRecursiveFlavor Flavor> class RecursiveVerifierInstance_ {
: verification_key(std::make_shared<VerificationKey>(instance->verification_key->circuit_size,
instance->verification_key->num_public_inputs))
, is_accumulator(bool(instance->is_accumulator))
, public_inputs(std::vector<FF>(instance->verification_key->num_public_inputs))
, public_inputs(std::vector<FF>(static_cast<size_t>(instance->verification_key->num_public_inputs)))
{

verification_key->pub_inputs_offset = instance->verification_key->pub_inputs_offset;
Expand Down Expand Up @@ -113,7 +113,7 @@ template <IsRecursiveFlavor Flavor> class RecursiveVerifierInstance_ {
VerifierInstance inst(inst_verification_key);
inst.is_accumulator = is_accumulator;

inst.public_inputs = std::vector<NativeFF>(verification_key->num_public_inputs);
inst.public_inputs = std::vector<NativeFF>(static_cast<size_t>(verification_key->num_public_inputs));
for (auto [public_input, inst_public_input] : zip_view(public_inputs, inst.public_inputs)) {
inst_public_input = public_input.get_value();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,9 +459,9 @@ class UltraFlavor {
}
}
// TODO(https://github.com/AztecProtocol/barretenberg/issues/964): Clean the boilerplate up.
VerificationKey(const size_t circuit_size,
const size_t num_public_inputs,
const size_t pub_inputs_offset,
VerificationKey(const uint64_t circuit_size,
const uint64_t num_public_inputs,
const uint64_t pub_inputs_offset,
const Commitment& q_m,
const Commitment& q_c,
const Commitment& q_l,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ bool GoblinTranslatorVerifier::verify_proof(const HonkProof& proof)
const size_t log_circuit_size = numeric::get_msb(circuit_size);
auto sumcheck = SumcheckVerifier<Flavor>(log_circuit_size, transcript);
FF alpha = transcript->template get_challenge<FF>("Sumcheck:alpha");
std::vector<FF> gate_challenges(numeric::get_msb(key->circuit_size));
std::vector<FF> gate_challenges(static_cast<size_t>(numeric::get_msb(key->circuit_size)));
for (size_t idx = 0; idx < gate_challenges.size(); idx++) {
gate_challenges[idx] = transcript->template get_challenge<FF>("Sumcheck:gate_challenge_" + std::to_string(idx));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,11 @@ template <IsUltraFlavor Flavor> void OinkVerifier<Flavor>::execute_log_derivativ
*/
template <IsUltraFlavor Flavor> void OinkVerifier<Flavor>::execute_grand_product_computation_round()
{
const FF public_input_delta = compute_public_input_delta<Flavor>(
public_inputs, relation_parameters.beta, relation_parameters.gamma, key->circuit_size, key->pub_inputs_offset);
const FF public_input_delta = compute_public_input_delta<Flavor>(public_inputs,
relation_parameters.beta,
relation_parameters.gamma,
key->circuit_size,
static_cast<size_t>(key->pub_inputs_offset));
const FF lookup_grand_product_delta =
compute_lookup_grand_product_delta<FF>(relation_parameters.beta, relation_parameters.gamma, key->circuit_size);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ template <typename Flavor> bool UltraVerifier_<Flavor>::verify_proof(const HonkP
}

// Execute Sumcheck Verifier
const size_t log_circuit_size = numeric::get_msb(key->circuit_size);
const size_t log_circuit_size = static_cast<size_t>(numeric::get_msb(key->circuit_size));
auto sumcheck = SumcheckVerifier<Flavor>(log_circuit_size, transcript);

auto gate_challenges = std::vector<FF>(log_circuit_size);
Expand Down
72 changes: 72 additions & 0 deletions barretenberg/ts/src/barretenberg_api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,42 @@ export class BarretenbergApi {
const out = result.map((r, i) => outTypes[i].fromBuffer(r));
return out as any;
}

async acirProveUltraHonk(constraintSystemBuf: Uint8Array, witnessBuf: Uint8Array): Promise<Uint8Array> {
const inArgs = [constraintSystemBuf, witnessBuf].map(serializeBufferable);
const outTypes: OutputType[] = [BufferDeserializer()];
const result = await this.wasm.callWasmExport(
'acir_prove_ultra_honk',
inArgs,
outTypes.map(t => t.SIZE_IN_BYTES),
);
const out = result.map((r, i) => outTypes[i].fromBuffer(r));
return out[0];
}

async acirVerifyUltraHonk(proofBuf: Uint8Array, vkBuf: Uint8Array): Promise<boolean> {
const inArgs = [proofBuf, vkBuf].map(serializeBufferable);
const outTypes: OutputType[] = [BoolDeserializer()];
const result = await this.wasm.callWasmExport(
'acir_verify_ultra_honk',
inArgs,
outTypes.map(t => t.SIZE_IN_BYTES),
);
const out = result.map((r, i) => outTypes[i].fromBuffer(r));
return out[0];
}

async acirWriteVkUltraHonk(constraintSystemBuf: Uint8Array): Promise<Uint8Array> {
const inArgs = [constraintSystemBuf].map(serializeBufferable);
const outTypes: OutputType[] = [BufferDeserializer()];
const result = await this.wasm.callWasmExport(
'acir_write_vk_ultra_honk',
inArgs,
outTypes.map(t => t.SIZE_IN_BYTES),
);
const out = result.map((r, i) => outTypes[i].fromBuffer(r));
return out[0];
}
}
export class BarretenbergApiSync {
constructor(protected wasm: BarretenbergWasm) {}
Expand Down Expand Up @@ -1111,4 +1147,40 @@ export class BarretenbergApiSync {
const out = result.map((r, i) => outTypes[i].fromBuffer(r));
return out as any;
}

acirUltraHonkProve(constraintSystemBuf: Uint8Array, witnessBuf: Uint8Array): Uint8Array {
const inArgs = [constraintSystemBuf, witnessBuf].map(serializeBufferable);
const outTypes: OutputType[] = [BufferDeserializer()];
const result = this.wasm.callWasmExport(
'acir_prove_ultra_honk',
inArgs,
outTypes.map(t => t.SIZE_IN_BYTES),
);
const out = result.map((r, i) => outTypes[i].fromBuffer(r));
return out[0];
}

acirVerifyUltraHonk(proofBuf: Uint8Array, vkBuf: Uint8Array): boolean {
const inArgs = [proofBuf, vkBuf].map(serializeBufferable);
const outTypes: OutputType[] = [BoolDeserializer()];
const result = this.wasm.callWasmExport(
'acir_verify_ultra_honk',
inArgs,
outTypes.map(t => t.SIZE_IN_BYTES),
);
const out = result.map((r, i) => outTypes[i].fromBuffer(r));
return out[0];
}

acirWriteVkUltraHonk(constraintSystemBuf: Uint8Array): Uint8Array {
const inArgs = [constraintSystemBuf].map(serializeBufferable);
const outTypes: OutputType[] = [BufferDeserializer()];
const result = this.wasm.callWasmExport(
'acir_write_vk_ultra_honk',
inArgs,
outTypes.map(t => t.SIZE_IN_BYTES),
);
const out = result.map((r, i) => outTypes[i].fromBuffer(r));
return out[0];
}
}
Loading

0 comments on commit c9b3206

Please sign in to comment.