Skip to content

Commit

Permalink
feat: new test programs for wasm benchmarking (#8389)
Browse files Browse the repository at this point in the history
This PR creates new test programs: bench_2_to_17, which is just a
program constructed out of poseidon hashes with around 2^17 - epsilon
gates, fold_2_to_17, which is a program that has 2 2^17-eps sized
circuits that would get folded together, and single_verify_proof, which
is a plonk recursive verifier which now has 2^18+eps gates.

It also reworks some of the interfaces in main.ts. In particular, it
gets the foldAndVerifyProgram flow working with fold_basic (and
SMALL_TEST execution trace structure).

I used the fold_2_to_17 test program to benchmark memory usage of
ClientIVC in WASM, which came out to be 700MiB. Note that this was only
possible by turning off the structure, because it would fail otherwise
from too many poseidon gates.

Running ClientIVC on the fold_basic test program using
TraceStructure::SMALL_TEST (which ends up with dyadic size of 2^18)
gives 1166.56MiB in WASM. Note that the builder memory is not an actual
full 2^18 gate circuit because the circuits only had 22, 4539, 16432
gates, so it should actually be close to 1250MiB or so if we had full
2^18 gate circuits.
  • Loading branch information
lucasxia01 authored Sep 12, 2024
1 parent 7065962 commit 0b46e96
Show file tree
Hide file tree
Showing 19 changed files with 148 additions and 37 deletions.
2 changes: 2 additions & 0 deletions barretenberg/Earthfile
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ barretenberg-acir-tests-bb.js:
RUN BIN=../ts/dest/node/main.js FLOW=prove_then_verify_ultra_honk ./run_acir_tests.sh 6_array assert_statement
# Run a single arbitrary test not involving recursion through bb.js for MegaHonk
RUN BIN=../ts/dest/node/main.js FLOW=prove_and_verify_mega_honk ./run_acir_tests.sh 6_array
# Run fold_basic test through bb.js which runs ClientIVC on fold basic
RUN BIN=../ts/dest/node/main.js FLOW=fold_and_verify_program ./run_acir_tests.sh fold_basic
# Run 1_mul through bb.js build, all_cmds flow, to test all cli args.
RUN BIN=../ts/dest/node/main.js FLOW=all_cmds ./run_acir_tests.sh 1_mul
# TODO(https://github.com/AztecProtocol/aztec-packages/issues/6672)
Expand Down
4 changes: 2 additions & 2 deletions barretenberg/cpp/src/barretenberg/bb/get_bn254_crs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ std::vector<g1::affine_element> get_bn254_g1_data(const std::filesystem::path& p
size_t g1_file_size = get_file_size(g1_path);

if (g1_file_size >= num_points * 64 && g1_file_size % 64 == 0) {
vinfo("using cached crs of size ", std::to_string(g1_file_size / 64), " at ", g1_path);
vinfo("using cached bn254 crs of size ", std::to_string(g1_file_size / 64), " at ", g1_path);
auto data = read_file(g1_path, g1_file_size);
auto points = std::vector<g1::affine_element>(num_points);
for (size_t i = 0; i < num_points; ++i) {
Expand All @@ -47,7 +47,7 @@ std::vector<g1::affine_element> get_bn254_g1_data(const std::filesystem::path& p
return points;
}

vinfo("downloading crs...");
vinfo("downloading bn254 crs...");
auto data = download_bn254_g1_data(num_points);
write_file(g1_path, data);

Expand Down
2 changes: 1 addition & 1 deletion barretenberg/cpp/src/barretenberg/bb/get_grumpkin_crs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ std::vector<curve::Grumpkin::AffineElement> get_grumpkin_g1_data(const std::file
}
if (size >= num_points) {
auto file = path / "grumpkin_g1.dat";
vinfo("using cached crs at: ", file);
vinfo("using cached grumpkin crs of size ", size, " at: ", file);
auto data = read_file(file, 28 + num_points * 64);
auto points = std::vector<curve::Grumpkin::AffineElement>(num_points);
auto size_of_points_in_bytes = num_points * 64;
Expand Down
6 changes: 3 additions & 3 deletions barretenberg/cpp/src/barretenberg/bb/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,10 +456,10 @@ bool foldAndVerifyProgram(const std::string& bytecodePath, const std::string& wi
auto stack_item = program_stack.back();

// Construct a bberg circuit from the acir representation
auto circuit = acir_format::create_circuit<Builder>(
stack_item.constraints, 0, stack_item.witness, false, ivc.goblin.op_queue);
auto builder = acir_format::create_circuit<Builder>(
stack_item.constraints, 0, stack_item.witness, /*honk_recursion=*/false, ivc.goblin.op_queue);

ivc.accumulate(circuit);
ivc.accumulate(builder);

program_stack.pop_back();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ WASM_EXPORT void acir_fold_and_verify_program_stack(uint8_t const* acir_vec, uin
program_stack.pop_back();
}
*result = ivc.prove_and_verify();
info("acir_fold_and_verify_program_stack result: ", *result);
}

WASM_EXPORT void acir_prove_and_verify_mega_honk(uint8_t const* acir_vec, uint8_t const* witness_vec, bool* result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ TEST(reference_string, mem_bn254_file_consistency)
0);
}

TEST(reference_string, DISABLED_mem_grumpkin_file_consistency)
TEST(reference_string, mem_grumpkin_file_consistency)
{
// Load 1024 from file.
auto file_crs = FileCrsFactory<Grumpkin>("../srs_db/grumpkin", 1024);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,21 @@ MemGrumpkinCrsFactory::MemGrumpkinCrsFactory(std::vector<Grumpkin::AffineElement
std::shared_ptr<bb::srs::factories::ProverCrs<Grumpkin>> MemGrumpkinCrsFactory::get_prover_crs(size_t degree)
{
if (prover_crs_->get_monomial_size() < degree) {
throw_or_abort("prover trying to get too many points in MemGrumpkinCrsFactory!");
throw_or_abort(format("prover trying to get too many points in MemGrumpkinCrsFactory - ",
degree,
" is more than ",
prover_crs_->get_monomial_size()));
}
return prover_crs_;
}

std::shared_ptr<bb::srs::factories::VerifierCrs<Grumpkin>> MemGrumpkinCrsFactory::get_verifier_crs(size_t degree)
{
if (prover_crs_->get_monomial_size() < degree) {
throw_or_abort("verifier trying to get too many points in MemGrumpkinCrsFactory!");
throw_or_abort(format("verifier trying to get too many points in MemGrumpkinCrsFactory - ",
degree,
" is more than ",
prover_crs_->get_monomial_size()));
}
return verifier_crs_;
}
Expand Down
4 changes: 2 additions & 2 deletions barretenberg/cpp/src/barretenberg/srs/global_crs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ void init_grumpkin_crs_factory(std::string crs_path)
std::shared_ptr<factories::CrsFactory<curve::BN254>> get_bn254_crs_factory()
{
if (!crs_factory) {
throw_or_abort("You need to initalize the global CRS with a call to init_crs_factory(...)!");
throw_or_abort("You need to initialize the global CRS with a call to init_crs_factory(...)!");
}
return crs_factory;
}

std::shared_ptr<factories::CrsFactory<curve::Grumpkin>> get_grumpkin_crs_factory()
{
if (!grumpkin_crs_factory) {
throw_or_abort("You need to initalize the global CRS with a call to init_grumpkin_crs_factory(...)!");
throw_or_abort("You need to initialize the global CRS with a call to init_grumpkin_crs_factory(...)!");
}
return grumpkin_crs_factory;
}
Expand Down
83 changes: 59 additions & 24 deletions barretenberg/ts/src/main.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env node
import { Crs, Barretenberg, RawBuffer } from './index.js';
import { GrumpkinCrs } from './crs/node/index.js';
import createDebug from 'debug';
import { readFileSync, writeFileSync } from 'fs';
import { gunzipSync } from 'zlib';
Expand All @@ -9,13 +10,13 @@ import path from 'path';
createDebug.log = console.error.bind(console);
const debug = createDebug('bb.js');

// Maximum we support in node and the browser is 2^19.
// This is because both node and browser use barretenberg.wasm.
// Maximum circuit size for plonk we support in node and the browser is 2^19.
// This is because both node and browser use barretenberg.wasm which has a 4GB memory limit.
//
// This is not a restriction in the bb binary and one should be
// aware of this discrepancy, when creating proofs in bb versus
// creating the same proofs in the node CLI.
const MAX_CIRCUIT_SIZE = 2 ** 19;
const MAX_ULTRAPLONK_CIRCUIT_SIZE_IN_WASM = 2 ** 19;
const threads = +process.env.HARDWARE_CONCURRENCY! || undefined;

function getBytecode(bytecodePath: string) {
Expand All @@ -32,7 +33,7 @@ function getBytecode(bytecodePath: string) {
return decompressed;
}

async function getGates(bytecodePath: string, honkRecursion: boolean, api: Barretenberg) {
async function getGatesUltra(bytecodePath: string, honkRecursion: boolean, api: Barretenberg) {
const { total } = await computeCircuitSize(bytecodePath, honkRecursion, api);
return total;
}
Expand All @@ -50,33 +51,66 @@ async function computeCircuitSize(bytecodePath: string, honkRecursion: boolean,
return { exact, total, subgroup };
}

async function init(bytecodePath: string, crsPath: string, subgroupSizeOverride = -1, honkRecursion = false) {
async function initUltraPlonk(bytecodePath: string, crsPath: string, subgroupSizeOverride = -1, honkRecursion = false) {
const api = await Barretenberg.new({ threads });

const circuitSize = await getGates(bytecodePath, honkRecursion, api);
const circuitSize = await getGatesUltra(bytecodePath, honkRecursion, api);
// TODO(https://github.com/AztecProtocol/barretenberg/issues/811): remove subgroupSizeOverride hack for goblin
const subgroupSize = Math.max(subgroupSizeOverride, Math.pow(2, Math.ceil(Math.log2(circuitSize))));
if (subgroupSize > MAX_CIRCUIT_SIZE) {
throw new Error(`Circuit size of ${subgroupSize} exceeds max supported of ${MAX_CIRCUIT_SIZE}`);
}

if (subgroupSize > MAX_ULTRAPLONK_CIRCUIT_SIZE_IN_WASM) {
throw new Error(`Circuit size of ${subgroupSize} exceeds max supported of ${MAX_ULTRAPLONK_CIRCUIT_SIZE_IN_WASM}`);
}
debug(`circuit size: ${circuitSize}`);
debug(`subgroup size: ${subgroupSize}`);
debug('loading crs...');
// Plus 1 needed! (Move +1 into Crs?)
const crs = await Crs.new(subgroupSize + 1, crsPath);

// Important to init slab allocator as first thing, to ensure maximum memory efficiency.
// Important to init slab allocator as first thing, to ensure maximum memory efficiency for Plonk.
await api.commonInitSlabAllocator(subgroupSize);

// Load CRS into wasm global CRS state.
// TODO: Make RawBuffer be default behavior, and have a specific Vector type for when wanting length prefixed.
await api.srsInitSrs(new RawBuffer(crs.getG1Data()), crs.numPoints, new RawBuffer(crs.getG2Data()));

const acirComposer = await api.acirNewAcirComposer(subgroupSize);
return { api, acirComposer, circuitSize, subgroupSize };
}

async function initUltraHonk(bytecodePath: string, crsPath: string) {
const api = await Barretenberg.new({ threads });

const circuitSize = await getGatesUltra(bytecodePath, /*honkRecursion=*/ true, api);
// TODO(https://github.com/AztecProtocol/barretenberg/issues/811): remove subgroupSizeOverride hack for goblin
const dyadicCircuitSize = Math.pow(2, Math.ceil(Math.log2(circuitSize)));

debug(`circuit size: ${circuitSize}`);
debug(`dyadic circuit size size: ${dyadicCircuitSize}`);
debug('loading crs...');
// Plus 1 needed! (Move +1 into Crs?)
const crs = await Crs.new(dyadicCircuitSize + 1, crsPath);

// Load CRS into wasm global CRS state.
// TODO: Make RawBuffer be default behavior, and have a specific Vector type for when wanting length prefixed.
await api.srsInitSrs(new RawBuffer(crs.getG1Data()), crs.numPoints, new RawBuffer(crs.getG2Data()));
return { api, circuitSize, dyadicCircuitSize };
}

async function initClientIVC(bytecodePath: string, crsPath: string) {
const api = await Barretenberg.new({ threads });

debug('loading BN254 and Grumpkin crs...');
// Plus 1 needed! (Move +1 into Crs?)
const crs = await Crs.new(2 ** 18 + 1, crsPath);
const grumpkinCrs = await GrumpkinCrs.new(8192 + 1, crsPath);

// Load CRS into wasm global CRS state.
// TODO: Make RawBuffer be default behavior, and have a specific Vector type for when wanting length prefixed.
await api.srsInitSrs(new RawBuffer(crs.getG1Data()), crs.numPoints, new RawBuffer(crs.getG2Data()));
await api.srsInitGrumpkinSrs(new RawBuffer(grumpkinCrs.getG1Data()), grumpkinCrs.numPoints);
return { api };
}

async function initLite() {
const api = await Barretenberg.new({ threads: 1 });

Expand All @@ -94,7 +128,7 @@ export async function proveAndVerify(bytecodePath: string, witnessPath: string,
/* eslint-disable camelcase */
const acir_test = path.basename(process.cwd());

const { api, acirComposer, circuitSize, subgroupSize } = await init(bytecodePath, crsPath);
const { api, acirComposer, circuitSize, subgroupSize } = await initUltraPlonk(bytecodePath, crsPath);
try {
debug(`creating proof...`);
const bytecode = getBytecode(bytecodePath);
Expand Down Expand Up @@ -122,7 +156,7 @@ export async function proveAndVerify(bytecodePath: string, witnessPath: string,

export async function proveAndVerifyUltraHonk(bytecodePath: string, witnessPath: string, crsPath: string) {
/* eslint-disable camelcase */
const { api } = await init(bytecodePath, crsPath, -1, true);
const { api } = await initUltraHonk(bytecodePath, crsPath);
try {
const bytecode = getBytecode(bytecodePath);
const witness = getWitness(witnessPath);
Expand All @@ -137,7 +171,7 @@ export async function proveAndVerifyUltraHonk(bytecodePath: string, witnessPath:

export async function proveAndVerifyMegaHonk(bytecodePath: string, witnessPath: string, crsPath: string) {
/* eslint-disable camelcase */
const { api } = await init(bytecodePath, crsPath);
const { api } = await initUltraPlonk(bytecodePath, crsPath);
try {
const bytecode = getBytecode(bytecodePath);
const witness = getWitness(witnessPath);
Expand All @@ -152,12 +186,13 @@ export async function proveAndVerifyMegaHonk(bytecodePath: string, witnessPath:

export async function foldAndVerifyProgram(bytecodePath: string, witnessPath: string, crsPath: string) {
/* eslint-disable camelcase */
const { api } = await init(bytecodePath, crsPath);
const { api } = await initClientIVC(bytecodePath, crsPath);
try {
const bytecode = getBytecode(bytecodePath);
const witness = getWitness(witnessPath);

const verified = await api.acirFoldAndVerifyProgramStack(bytecode, witness);
debug(`verified: ${verified}`);
return verified;
} finally {
await api.destroy();
Expand All @@ -166,7 +201,7 @@ export async function foldAndVerifyProgram(bytecodePath: string, witnessPath: st
}

export async function prove(bytecodePath: string, witnessPath: string, crsPath: string, outputPath: string) {
const { api, acirComposer } = await init(bytecodePath, crsPath);
const { api, acirComposer } = await initUltraPlonk(bytecodePath, crsPath);
try {
debug(`creating proof...`);
const bytecode = getBytecode(bytecodePath);
Expand All @@ -186,10 +221,10 @@ export async function prove(bytecodePath: string, witnessPath: string, crsPath:
}
}

export async function gateCount(bytecodePath: string, honkRecursion: boolean) {
export async function gateCountUltra(bytecodePath: string, honkRecursion: boolean) {
const api = await Barretenberg.new({ threads: 1 });
try {
const numberOfGates = await getGates(bytecodePath, honkRecursion, api);
const numberOfGates = await getGatesUltra(bytecodePath, honkRecursion, api);
debug(`number of gates: : ${numberOfGates}`);
// Create an 8-byte buffer and write the number into it.
// Writing number directly to stdout will result in a variable sized
Expand Down Expand Up @@ -234,7 +269,7 @@ export async function contract(outputPath: string, vkPath: string) {
}

export async function writeVk(bytecodePath: string, crsPath: string, outputPath: string) {
const { api, acirComposer } = await init(bytecodePath, crsPath);
const { api, acirComposer } = await initUltraPlonk(bytecodePath, crsPath);
try {
debug('initing proving key...');
const bytecode = getBytecode(bytecodePath);
Expand All @@ -256,7 +291,7 @@ export async function writeVk(bytecodePath: string, crsPath: string, outputPath:
}

export async function writePk(bytecodePath: string, crsPath: string, outputPath: string) {
const { api, acirComposer } = await init(bytecodePath, crsPath);
const { api, acirComposer } = await initUltraPlonk(bytecodePath, crsPath);
try {
debug('initing proving key...');
const bytecode = getBytecode(bytecodePath);
Expand Down Expand Up @@ -326,7 +361,7 @@ export async function vkAsFields(vkPath: string, vkeyOutputPath: string) {
}

export async function proveUltraHonk(bytecodePath: string, witnessPath: string, crsPath: string, outputPath: string) {
const { api } = await init(bytecodePath, crsPath, -1, /* honkRecursion= */ true);
const { api } = await initUltraHonk(bytecodePath, crsPath);
try {
debug(`creating proof...`);
const bytecode = getBytecode(bytecodePath);
Expand All @@ -347,7 +382,7 @@ export async function proveUltraHonk(bytecodePath: string, witnessPath: string,
}

export async function writeVkUltraHonk(bytecodePath: string, crsPath: string, outputPath: string) {
const { api } = await init(bytecodePath, crsPath, -1, true);
const { api } = await initUltraHonk(bytecodePath, crsPath);
try {
const bytecode = getBytecode(bytecodePath);
debug('initing verification key...');
Expand Down Expand Up @@ -487,12 +522,12 @@ program

program
.command('gates')
.description('Print gate count to standard output.')
.description('Print Ultra Builder gate count to standard output.')
.option('-b, --bytecode-path <path>', 'Specify the bytecode path', './target/program.json')
.option('-hr, --honk-recursion', 'Specify whether to use UltraHonk recursion', false)
.action(async ({ bytecodePath: bytecodePath, honkRecursion: honkRecursion }) => {
handleGlobalOptions();
await gateCount(bytecodePath, honkRecursion);
await gateCountUltra(bytecodePath, honkRecursion);
});

program
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "bench_2_to_17"
type = "bin"
authors = [""]
compiler_version = ">=0.33.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x = "3"
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use std::hash::poseidon2;

global len = 2450 * 2;
fn main(x: Field) {
let ped_input = [x; len];
let mut val = poseidon2::Poseidon2::hash(ped_input, len);
assert(val != 0);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "fold_2_to_17"
type = "bin"
authors = [""]
compiler_version = ">=0.25.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x = "2"
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use std::hash::poseidon2;

global len = 2450 * 2 - 240; // for just under 2^17 gates
fn main(x: Field) {
let ped_input = [x; len];
let mut val = poseidon2::Poseidon2::hash(ped_input, len);
let z = foo(x);
assert(val == z);
}

#[fold]
fn foo(x: Field) -> Field {
let ped_input = [x; len];
let mut val = poseidon2::Poseidon2::hash(ped_input, len);
val
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "single_verify_proof"
type = "bin"
authors = [""]
compiler_version = ">=0.24.0"

[dependencies]
Loading

0 comments on commit 0b46e96

Please sign in to comment.