From f228d238f68fba75ac63c162a26768a444907e1c Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Sun, 18 Dec 2022 04:01:37 +0000 Subject: [PATCH] feat: KZG commitments with single strategy for proving (#81) --- .github/workflows/rust.yml | 76 +++++++++- .gitignore | 2 + Cargo.toml | 3 +- src/bin/ezkl.rs | 192 ++++++++++++++---------- src/circuit/eltwise.rs | 6 +- src/commands.rs | 95 +++--------- src/graph/mod.rs | 11 +- src/graph/model.rs | 33 ++--- src/pfsys/{kzg => }/aggregation.rs | 181 ++++++++++++++++++++++- src/pfsys/ipa.rs | 128 ---------------- src/pfsys/kzg/mod.rs | 187 ------------------------ src/pfsys/mod.rs | 225 ++++++++++++++++++++++++++++- tests/integration_tests.rs | 211 +++++++++++++++++---------- 13 files changed, 754 insertions(+), 596 deletions(-) rename src/pfsys/{kzg => }/aggregation.rs (61%) delete mode 100644 src/pfsys/ipa.rs delete mode 100644 src/pfsys/kzg/mod.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 9746b93b3..eea63e505 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -16,8 +16,6 @@ jobs: steps: - uses: actions/checkout@v3 - with: - submodules: recursive - uses: actions-rs/toolchain@v1 with: toolchain: nightly @@ -25,5 +23,75 @@ jobs: components: rustfmt, clippy - name: Build run: cargo build --verbose - - name: Run tests - run: cargo test --verbose + + library-tests: + + runs-on: self-hosted + + steps: + - uses: actions/checkout@v3 + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + override: true + components: rustfmt, clippy + - name: Doc tests + run: cargo test --doc --verbose + - name: Library tests + run: cargo test --lib --verbose + + mock-proving-tests: + + runs-on: self-hosted + steps: + - uses: actions/checkout@v3 + with: + submodules: recursive + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + override: true + components: rustfmt, clippy + + - name: Mock proving tests (public outputs) + run: cargo test --release --verbose tests::mock_public_outputs_ + - name: Mock proving tests (public inputs) + run: cargo test --release --verbose tests::mock_public_inputs_ + - name: Mock proving tests (public params) + run: cargo test --release --verbose tests::mock_public_params_ + + full-proving-tests: + + runs-on: self-hosted + steps: + - uses: actions/checkout@v3 + with: + submodules: recursive + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + override: true + components: rustfmt, clippy + + - name: IPA full-prove tests + run: cargo test --release --verbose tests::ipa_fullprove_ -- --test-threads 1 + - name: KZG full-prove tests + run: cargo test --release --verbose tests::kzg_fullprove_ -- --test-threads 1 + + prove-and-verify-tests: + + runs-on: self-hosted + steps: + - uses: actions/checkout@v3 + with: + submodules: recursive + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + override: true + components: rustfmt, clippy + + - name: IPA prove and verify tests + run: cargo test --release --verbose tests::ipa_prove_and_verify_ -- --test-threads 1 + - name: KZG prove and verify tests + run: cargo test --release --verbose tests::kzg_prove_and_verify_ -- --test-threads 1 diff --git a/.gitignore b/.gitignore index 1120eafc9..3e58434eb 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,7 @@ target Cargo.lock data *.pf +*.vk +*.params *~ \#*\# diff --git a/Cargo.toml b/Cargo.toml index b218c6302..7332d8254 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2", tag = "v2022_10_22"} +halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2", rev = "50ee8ad"} halo2curves = { git = 'https://github.com/privacy-scaling-explorations/halo2curves', tag = "0.3.0" } # nalgebra = "0.31" rand = "0.8" @@ -35,6 +35,7 @@ plonk_verifier = { git = "https://github.com/privacy-scaling-explorations/plonk criterion = {version = "0.3", features = ["html_reports"]} seq-macro = "0.3.1" test-case = "2.2.2" +ctor = "0.1.26" [[bench]] name = "affine" diff --git a/src/bin/ezkl.rs b/src/bin/ezkl.rs index 6cddd2c78..074c7bc1a 100644 --- a/src/bin/ezkl.rs +++ b/src/bin/ezkl.rs @@ -3,38 +3,42 @@ use ezkl::abort; use ezkl::commands::{Cli, Commands, ProofSystem}; use ezkl::fieldutils::i32_to_felt; use ezkl::graph::Model; -use ezkl::pfsys::ipa::{create_ipa_proof, verify_ipa_proof}; #[cfg(feature = "evm")] -use ezkl::pfsys::kzg::{ +use ezkl::pfsys::kzg::aggregation::{ aggregation::AggregationCircuit, evm_verify, gen_aggregation_evm_verifier, gen_application_snark, gen_kzg_proof, gen_pk, gen_srs, }; -use ezkl::pfsys::Proof; -use ezkl::pfsys::{parse_prover_errors, prepare_circuit_and_public_input, prepare_data}; +use ezkl::pfsys::{create_keys, load_params, load_vk, Proof}; +#[cfg(not(feature = "evm"))] +use ezkl::pfsys::{ + create_proof_model, parse_prover_errors, prepare_circuit_and_public_input, prepare_data, + save_params, save_vk, verify_proof_model, +}; #[cfg(feature = "evm")] use halo2_proofs::poly::commitment::Params; +use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme; +use halo2_proofs::poly::ipa::multiopen::ProverIPA; +use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme; +use halo2_proofs::poly::kzg::multiopen::ProverGWC; +#[cfg(not(feature = "evm"))] +use halo2_proofs::poly::kzg::{ + commitment::ParamsKZG, multiopen::VerifierGWC, strategy::SingleStrategy as KZGSingleStrategy, +}; use halo2_proofs::{ dev::MockProver, - plonk::verify_proof, poly::{ commitment::ParamsProver, - ipa::{commitment::ParamsIPA, strategy::SingleStrategy}, + ipa::{commitment::ParamsIPA, strategy::SingleStrategy as IPASingleStrategy}, VerificationStrategy, }, - transcript::{Blake2bRead, Challenge255, TranscriptReadBuffer}, }; -#[cfg(feature = "evm")] -use halo2curves::bn256::G1Affine; +use halo2curves::bn256::{Bn256, Fr}; use halo2curves::pasta::vesta; use halo2curves::pasta::Fp; use log::{error, info, trace}; #[cfg(feature = "evm")] use plonk_verifier::system::halo2::transcript::evm::EvmTranscript; use rand::seq::SliceRandom; -use std::fs::File; -use std::io::{Read, Write}; -use std::ops::Deref; -use std::time::Instant; use tabled::Table; pub fn main() { @@ -81,46 +85,59 @@ pub fn main() { model: _, pfsys, } => { + // A direct proof let args = Cli::parse(); let data = prepare_data(data); - let (circuit, public_inputs) = prepare_circuit_and_public_input(&data); - info!("full proof with {}", pfsys); match pfsys { ProofSystem::IPA => { - // A direct proof + let (circuit, public_inputs) = prepare_circuit_and_public_input::(&data); + info!("full proof with {}", pfsys); + let params: ParamsIPA = ParamsIPA::new(args.logrows); + let pk = create_keys::, Fp>(&circuit, ¶ms); + let strategy = IPASingleStrategy::new(¶ms); trace!("params computed"); - let (pk, proof, _dims) = - create_ipa_proof(circuit, public_inputs.clone(), ¶ms); + let (proof, _dims) = create_proof_model::< + IPACommitmentScheme<_>, + Fp, + ProverIPA<_>, + >( + &circuit, &public_inputs, ¶ms, &pk + ); + + assert!(verify_proof_model(proof, ¶ms, pk.get_vk(), strategy)); + } + #[cfg(not(feature = "evm"))] + ProofSystem::KZG => { + // A direct proof + let (circuit, public_inputs) = prepare_circuit_and_public_input::(&data); + let params: ParamsKZG = ParamsKZG::new(args.logrows); + let pk = create_keys::, Fr>(&circuit, ¶ms); + let strategy = KZGSingleStrategy::new(¶ms); + trace!("params computed"); - let pi_inner: Vec> = public_inputs - .iter() - .map(|i| i.iter().map(|e| i32_to_felt::(*e)).collect::>()) - .collect::>>(); - let pi_inner = pi_inner.iter().map(|e| e.deref()).collect::>(); - let pi_for_real_prover: &[&[&[Fp]]] = &[&pi_inner]; + let (proof, _dims) = create_proof_model::< + KZGCommitmentScheme<_>, + Fr, + ProverGWC<_>, + >( + &circuit, &public_inputs, ¶ms, &pk + ); - let now = Instant::now(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof( + assert!(verify_proof_model::<_, VerifierGWC<'_, Bn256>, _, _>( + proof, ¶ms, pk.get_vk(), - strategy, - pi_for_real_prover, - &mut transcript - ) - .is_ok()); - info!("verify took {}", now.elapsed().as_secs()); + strategy + )); } - #[cfg(not(feature = "evm"))] - ProofSystem::KZG => todo!(), #[cfg(feature = "evm")] ProofSystem::KZG => { // We will need aggregator k > application k > bits // let application_logrows = args.logrows; //bits + 1; + let (circuit, public_inputs) = prepare_circuit_and_public_input(&data); let aggregation_logrows = args.logrows + 6; let params = gen_srs(aggregation_logrows); @@ -159,63 +176,84 @@ pub fn main() { Commands::Prove { data, model: _, - output, + proof_path, + vk_path, + params_path, pfsys, } => { let args = Cli::parse(); let data = prepare_data(data); - let (circuit, public_inputs) = prepare_circuit_and_public_input(&data); - info!("proof with {}", pfsys); - let params: ParamsIPA = ParamsIPA::new(args.logrows); - trace!("params computed"); - let (_pk, proof, _input_dims) = - create_ipa_proof(circuit.clone(), public_inputs.clone(), ¶ms); + match pfsys { + ProofSystem::IPA => { + info!("proof with {}", pfsys); + let (circuit, public_inputs) = prepare_circuit_and_public_input::(&data); + let params: ParamsIPA = ParamsIPA::new(args.logrows); + let pk = create_keys::, Fp>(&circuit, ¶ms); + trace!("params computed"); + + let (proof, _) = create_proof_model::, Fp, ProverIPA<_>>( + &circuit, + &public_inputs, + ¶ms, + &pk, + ); - let pi: Vec<_> = public_inputs - .into_iter() - .map(|i| i.into_iter().collect()) - .collect(); + proof.save(&proof_path); + save_params::>(¶ms_path, ¶ms); + save_vk::>(&vk_path, pk.get_vk()); + } + ProofSystem::KZG => { + info!("proof with {}", pfsys); + let (circuit, public_inputs) = prepare_circuit_and_public_input(&data); + let params: ParamsKZG = ParamsKZG::new(args.logrows); + let pk = create_keys::, Fr>(&circuit, ¶ms); + trace!("params computed"); - let checkable_pf = Proof { - input_shapes: circuit.inputs.iter().map(|i| i.dims().to_vec()).collect(), - public_inputs: pi, - proof, - }; + let (proof, _input_dims) = create_proof_model::< + KZGCommitmentScheme, + Fr, + ProverGWC<'_, Bn256>, + >( + &circuit, &public_inputs, ¶ms, &pk + ); - let serialized = match serde_json::to_string(&checkable_pf) { - Ok(s) => s, - Err(e) => { - abort!("failed to convert proof json to string {:?}", e); + proof.save(&proof_path); + save_params::>(¶ms_path, ¶ms); + save_vk::>(&vk_path, pk.get_vk()); } }; - - let mut file = std::fs::File::create(output).expect("create failed"); - file.write_all(serialized.as_bytes()).expect("write failed"); } Commands::Verify { model: _, - proof, - pfsys: _, + proof_path, + vk_path, + params_path, + pfsys, } => { - let mut file = match File::open(proof) { - Ok(f) => f, - Err(e) => { - abort!("failed to open proof file {:?}", e); + let proof = Proof::load(&proof_path); + match pfsys { + ProofSystem::IPA => { + let params: ParamsIPA = + load_params::>(params_path); + let strategy = IPASingleStrategy::new(¶ms); + let vk = load_vk::, Fp>(vk_path, ¶ms); + let result = verify_proof_model(proof, ¶ms, &vk, strategy); + info!("verified: {}", result); + assert!(result); } - }; - let mut data = String::new(); - match file.read_to_string(&mut data) { - Ok(_) => {} - Err(e) => { - abort!("failed to read file {:?}", e); + ProofSystem::KZG => { + let params: ParamsKZG = + load_params::>(params_path); + let strategy = KZGSingleStrategy::new(¶ms); + let vk = load_vk::, Fr>(vk_path, ¶ms); + let result = verify_proof_model::<_, VerifierGWC<'_, Bn256>, _, _>( + proof, ¶ms, &vk, strategy, + ); + info!("verified: {}", result); + assert!(result); } - }; - let proof: Proof = serde_json::from_str(&data).expect("JSON was not well-formatted"); - - let result = verify_ipa_proof(proof); - info!("verified: {}", result); - assert!(result); + } } } } diff --git a/src/circuit/eltwise.rs b/src/circuit/eltwise.rs index b85c28bd5..5f566c7de 100644 --- a/src/circuit/eltwise.rs +++ b/src/circuit/eltwise.rs @@ -358,7 +358,7 @@ mod tests { for i in -127..127 { let r = as Nonlinearity>::nonlinearity(i, &[1]); if i <= 0 { - assert!(r == F::from(0 as u64)) + assert!(r == F::from(0_u64)) } else { assert!(r == F::from(i as u64)) } @@ -390,7 +390,7 @@ mod tests { #[test] fn relucircuit() { let input: Tensor> = - Tensor::new(Some(&[Value::::known(F::from(1 as u64))]), &[1]).unwrap(); + Tensor::new(Some(&[Value::::known(F::from(1_u64))]), &[1]).unwrap(); let assigned: Nonlin1d> = Nonlin1d { input: ValTensor::from(input.clone()), output: ValTensor::from(input), @@ -402,7 +402,7 @@ mod tests { _marker: PhantomData, }; - let prover = MockProver::run(4 as u32, &circuit, vec![]).unwrap(); + let prover = MockProver::run(4_u32, &circuit, vec![]).unwrap(); prover.assert_satisfied(); } } diff --git a/src/commands.rs b/src/commands.rs index 3715e25a4..c221f5e7d 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -2,7 +2,7 @@ use clap::{Parser, Subcommand, ValueEnum}; use log::info; use std::io::{stdin, stdout, Write}; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -95,10 +95,16 @@ pub enum Commands { /// The path to the .onnx model file #[arg(short = 'M', long, default_value = "")] - model: String, + model: PathBuf, /// The path to the desired output file #[arg(short = 'O', long, default_value = "")] - output: PathBuf, + proof_path: PathBuf, + /// The path to output to the desired verfication key file (optional) + #[arg(long, default_value = "")] + vk_path: PathBuf, + /// The path to output to the desired verfication key file (optional) + #[arg(long, default_value = "")] + params_path: PathBuf, // /// The path to the Params for the proof system // #[arg(short = 'P', long, default_value = "")] @@ -119,11 +125,17 @@ pub enum Commands { Verify { /// The path to the .onnx model file #[arg(short = 'M', long, default_value = "")] - model: String, + model: PathBuf, /// The path to the proof file #[arg(short = 'P', long, default_value = "")] - proof: PathBuf, + proof_path: PathBuf, + /// The path to output to the desired verfication key file (optional) + #[arg(long, default_value = "")] + vk_path: PathBuf, + /// The path to output to the desired verfication key file (optional) + #[arg(long, default_value = "")] + params_path: PathBuf, // /// The path to the Params for the proof system // #[arg(short = 'P', long, default_value = "")] @@ -138,80 +150,7 @@ pub enum Commands { value_enum )] pfsys: ProofSystem, - // todo, allow optional vkey and params when applicable }, - // Awaiting PR to stabilize VK and PK formats - // /// Loads model and prepares verification key, saving in --output - // Vkey { - // /// The path to the .onnx model file - // #[arg(short = 'M', long, default_value = "")] - // model: String, - // /// The path to the desired output file - // #[arg(short = 'O', long, default_value = "")] - // output: PathBuf, - - // /// The path to the Params for the proof system - // #[arg(short = 'P', long, default_value = "")] - // params: PathBuf, - - // #[arg( - // long, - // short = 'B', - // require_equals = true, - // num_args = 0..=1, - // default_value_t = ProofSystem::IPA, - // default_missing_value = "always", - // value_enum - // )] - // pfsys: ProofSystem, - // }, - - // /// Loads model and prepares verification and proving key, saving proving key in --output - // Pkey { - // /// The path to the .onnx model file - // #[arg(short = 'M', long, default_value = "")] - // model: String, - // /// The path to the desired output file - // #[arg(short = 'O', long, default_value = "")] - // output: PathBuf, - - // /// The path to the Params for the proof system - // #[arg(short = 'P', long, default_value = "")] - // params: PathBuf, - - // #[arg( - // long, - // short = 'B', - // require_equals = true, - // num_args = 0..=1, - // default_value_t = ProofSystem::IPA, - // default_missing_value = "always", - // value_enum - // )] - // pfsys: ProofSystem, - // // todo, optionally allow supplying verification key - // }, -} - -pub fn model_path(model: String) -> PathBuf { - let mut s = String::new(); - let model_path = match model.is_empty() { - false => { - info!("loading model from {}", model.clone()); - Path::new(&model) - } - true => { - info!("please enter a path to a .onnx file containing a model: "); - let _ = stdout().flush(); - let _ = &stdin() - .read_line(&mut s) - .expect("did not enter a correct string"); - s.truncate(s.len() - 1); - Path::new(&s) - } - }; - assert!(model_path.exists()); - model_path.into() } pub fn data_path(data: String) -> PathBuf { diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 4dbe3d91d..cf9de7dcb 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -34,23 +34,22 @@ impl Circuit for ModelCircuit { fn configure(cs: &mut ConstraintSystem) -> Self::Config { let model = Model::from_arg(); - let num_advice: usize; let mut num_fixed = 0; let row_cap = model.max_node_size(); // TODO: extract max number of params in a given fused layer - if model.visibility.params.is_public() { + let num_advice: usize = if model.visibility.params.is_public() { num_fixed += model.max_node_params(); // this is the maximum of variables in non-fused layer, and the maximum of variables (non-params) in fused layers - num_advice = max(model.max_node_vars_non_fused(), model.max_node_vars_fused()); + max(model.max_node_vars_non_fused(), model.max_node_vars_fused()) } else { // this is the maximum of variables in non-fused layer, and the maximum of variables (non-params) in fused layers // + the max number of params in a fused layer - num_advice = max( + max( model.max_node_vars_non_fused(), model.max_node_params() + model.max_node_vars_fused(), - ); - } + ) + }; // for now the number of instances corresponds to the number of graph / model outputs let mut num_instances = 0; let mut instance_shapes = vec![]; diff --git a/src/graph/model.rs b/src/graph/model.rs index 11f71b2d3..8f3dd63d3 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -4,7 +4,7 @@ use super::vars::*; use crate::circuit::eltwise::{DivideBy, EltwiseConfig, ReLu, Sigmoid}; use crate::circuit::fused::*; use crate::circuit::range::*; -use crate::commands::{model_path, Cli, Commands}; +use crate::commands::{Cli, Commands}; use crate::tensor::TensorType; use crate::tensor::{Tensor, ValTensor, VarTensor}; use anyhow::{Context, Result}; @@ -108,7 +108,7 @@ impl Model { let visibility = VarVisibility::from_args(); match args.command { Commands::Table { model } => Model::new( - model_path(model), + model, args.scale, args.bits, args.logrows, @@ -116,8 +116,8 @@ impl Model { Mode::Table, visibility, ), - Commands::Mock { data: _, model } => Model::new( - model_path(model), + Commands::Mock { model, .. } => Model::new( + model, args.scale, args.bits, args.logrows, @@ -125,12 +125,8 @@ impl Model { Mode::Mock, visibility, ), - Commands::Fullprove { - data: _, + Commands::Fullprove { model, .. } => Model::new( model, - pfsys: _, - } => Model::new( - model_path(model), args.scale, args.bits, args.logrows, @@ -138,13 +134,8 @@ impl Model { Mode::FullProve, visibility, ), - Commands::Prove { - data: _, + Commands::Prove { model, .. } => Model::new( model, - output: _, - pfsys: _, - } => Model::new( - model_path(model), args.scale, args.bits, args.logrows, @@ -152,12 +143,8 @@ impl Model { Mode::Prove, visibility, ), - Commands::Verify { + Commands::Verify { model, .. } => Model::new( model, - proof: _, - pfsys: _, - } => Model::new( - model_path(model), args.scale, args.bits, args.logrows, @@ -309,7 +296,7 @@ impl Model { .filter(|i| !nodes.contains_key(&i.idx) && seen.insert(i.idx)) .map(|f| { let s = f.out_dims.clone(); - let a = if f.opkind.is_const() && self.visibility.params.is_public() { + if f.opkind.is_const() && self.visibility.params.is_public() { let vars = (f.idx, vars.fixed[fixed_idx].reshape(&s)); fixed_idx += 1; vars @@ -317,9 +304,7 @@ impl Model { let vars = (f.idx, vars.advices[advice_idx].reshape(&s)); advice_idx += 1; vars - }; - - a + } }) .collect_vec() }) diff --git a/src/pfsys/kzg/aggregation.rs b/src/pfsys/aggregation.rs similarity index 61% rename from src/pfsys/kzg/aggregation.rs rename to src/pfsys/aggregation.rs index 350f5cf52..6133c27f2 100644 --- a/src/pfsys/kzg/aggregation.rs +++ b/src/pfsys/aggregation.rs @@ -1,7 +1,29 @@ +use super::super::prepare_circuit_and_public_input; +use super::super::ModelInput; +use crate::fieldutils::i32_to_felt; +#[cfg(feature = "evm")] +use ethereum_types::Address; +#[cfg(feature = "evm")] +use foundry_evm::executor::{fork::MultiFork, Backend, ExecutorBuilder}; +#[cfg(feature = "evm")] +use halo2_proofs::plonk::VerifyingKey; use halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner, Value}, - plonk::{self, Circuit, ConstraintSystem}, - poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, + dev::MockProver, + plonk::{ + self, create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ConstraintSystem, + ProvingKey, + }, + poly::{ + commitment::{Params, ParamsProver}, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::{ProverGWC, VerifierGWC}, + strategy::AccumulatorStrategy, + }, + VerificationStrategy, + }, + transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, }; use halo2_wrong_ecc::{ integer::rns::Rns, @@ -11,11 +33,23 @@ use halo2_wrong_ecc::{ }, EccConfig, }; +#[cfg(feature = "evm")] +use halo2curves::bn256::Fq; use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}; use itertools::Itertools; use log::trace; +#[cfg(feature = "evm")] +use plonk_verifier::{ + loader::evm::{encode_calldata, EvmLoader}, + system::halo2::transcript::evm::EvmTranscript, + verifier::PlonkVerifier, +}; use plonk_verifier::{ - loader::{self, native::NativeLoader}, + loader::native::NativeLoader, + system::halo2::{compile, Config}, +}; +use plonk_verifier::{ + loader::{self}, pcs::{ kzg::{Gwc19, Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding}, AccumulationScheme, AccumulationSchemeProver, @@ -26,6 +60,9 @@ use plonk_verifier::{ Protocol, }; use rand::rngs::OsRng; +use std::io::Cursor; +#[cfg(feature = "evm")] +use std::rc::Rc; use std::{iter, rc::Rc}; const LIMBS: usize = 4; @@ -304,3 +341,141 @@ impl Circuit for AggregationCircuit { Ok(()) } } + +pub fn gen_application_snark(params: &ParamsKZG, data: &ModelInput) -> Snark { + let (circuit, public_inputs) = prepare_circuit_and_public_input::(data); + + let pk = gen_pk(params, &circuit); + let number_instance = public_inputs[0].len(); + trace!("number_instance {:?}", number_instance); + let protocol = compile( + params, + pk.get_vk(), + Config::kzg().with_num_instance(vec![number_instance]), + ); + let pi_inner: Vec> = public_inputs + .iter() + .map(|i| i.iter().map(|e| i32_to_felt::(*e)).collect::>()) + .collect::>>(); + // let pi_inner = pi_inner.iter().map(|e| e.deref()).collect::>(); + trace!("pi_inner {:?}", pi_inner); + let proof = gen_kzg_proof::< + _, + _, + PoseidonTranscript, + PoseidonTranscript, + >(params, &pk, circuit, pi_inner.clone()); + Snark::new(protocol, pi_inner, proof) +} + +#[cfg(feature = "evm")] +pub fn gen_aggregation_evm_verifier( + params: &ParamsKZG, + vk: &VerifyingKey, + num_instance: Vec, + accumulator_indices: Vec<(usize, usize)>, +) -> Vec { + let svk = params.get_g()[0].into(); + let dk = (params.g2(), params.s_g2()).into(); + let protocol = compile( + params, + vk, + Config::kzg() + .with_num_instance(num_instance.clone()) + .with_accumulator_indices(accumulator_indices), + ); + + let loader = EvmLoader::new::(); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + + let instances = transcript.load_instances(num_instance); + let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); + + loader.deployment_code() +} + +#[cfg(feature = "evm")] +pub fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { + let calldata = encode_calldata(&instances, &proof); + let success = { + let mut evm = ExecutorBuilder::default() + .with_gas_limit(u64::MAX.into()) + .build(Backend::new(MultiFork::new().0, None)); + + let caller = Address::from_low_u64_be(0xfe); + let verifier = evm + .deploy(caller, deployment_code.into(), 0.into(), None) + .unwrap() + .address; + let result = evm + .call_raw(caller, verifier, calldata.into(), 0.into()) + .unwrap(); + + dbg!(result.gas_used); + + !result.reverted + }; + assert!(success); +} + +pub fn gen_srs(k: u32) -> ParamsKZG { + ParamsKZG::::setup(k, OsRng) +} + +pub fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { + let vk = keygen_vk(params, circuit).unwrap(); + keygen_pk(params, vk, circuit).unwrap() +} + +/// Generates proof for either application circuit (model) or aggregation circuit. +pub fn gen_kzg_proof< + C: Circuit, + E: EncodedChallenge, + TR: TranscriptReadBuffer>, G1Affine, E>, + TW: TranscriptWriterBuffer, G1Affine, E>, +>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, +) -> Vec { + MockProver::run(params.k(), &circuit, instances.clone()) + .unwrap() + .assert_satisfied(); + + let instances = instances + .iter() + .map(|instances| instances.as_slice()) + .collect_vec(); + let proof = { + let mut transcript = TW::init(Vec::new()); + create_proof::, ProverGWC<_>, _, _, TW, _>( + params, + pk, + &[circuit], + &[instances.as_slice()], + OsRng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let accept = { + let mut transcript = TR::init(Cursor::new(proof.clone())); + VerificationStrategy::<_, VerifierGWC<_>>::finalize( + verify_proof::<_, VerifierGWC<_>, _, TR, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }; + assert!(accept); + + proof +} diff --git a/src/pfsys/ipa.rs b/src/pfsys/ipa.rs deleted file mode 100644 index 302531668..000000000 --- a/src/pfsys/ipa.rs +++ /dev/null @@ -1,128 +0,0 @@ -use super::Proof; -use crate::abort; -use crate::commands::Cli; -use crate::fieldutils::i32_to_felt; -use crate::graph::ModelCircuit; -use crate::tensor::Tensor; -use clap::Parser; -use halo2_proofs::{ - // arithmetic::FieldExt, - // dev::{MockProver, VerifyFailure}, - plonk::{create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey}, - poly::{ - commitment::ParamsProver, - ipa::{ - commitment::{IPACommitmentScheme, ParamsIPA}, - multiopen::ProverIPA, - strategy::SingleStrategy, - }, - VerificationStrategy, - }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, - }, -}; -use halo2curves::pasta::vesta; -use halo2curves::pasta::Fp; -use halo2curves::pasta::{EqAffine, Fp as F}; -use log::{error, info, trace}; -use rand::rngs::OsRng; -use std::marker::PhantomData; -use std::ops::Deref; -use std::time::Instant; - -pub fn create_ipa_proof( - circuit: ModelCircuit, - public_inputs: Vec>, - params: &ParamsIPA, -) -> (ProvingKey, Vec, Vec>) { - // Real proof - let empty_circuit = circuit.without_witnesses(); - - // Initialize the proving key - let now = Instant::now(); - trace!("preparing VK"); - let vk = keygen_vk(params, &empty_circuit).expect("keygen_vk should not fail"); - info!("VK took {}", now.elapsed().as_secs()); - let now = Instant::now(); - let pk = keygen_pk(params, vk, &empty_circuit).expect("keygen_pk should not fail"); - info!("PK took {}", now.elapsed().as_secs()); - let now = Instant::now(); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - let mut rng = OsRng; - - let pi_inner: Vec> = public_inputs - .iter() - .map(|i| i.iter().map(|e| i32_to_felt::(*e)).collect::>()) - .collect::>>(); - let pi_inner = pi_inner.iter().map(|e| e.deref()).collect::>(); - let pi_for_real_prover: &[&[&[F]]] = &[&pi_inner]; - trace!("pi for real prover {:?}", pi_for_real_prover); - - let dims = circuit.inputs.iter().map(|i| i.dims().to_vec()).collect(); - - create_proof::, ProverIPA<_>, _, _, _, _>( - params, - &pk, - &[circuit], - pi_for_real_prover, - &mut rng, - &mut transcript, - ) - .expect("proof generation should not fail"); - let proof = transcript.finalize(); - info!("Proof took {}", now.elapsed().as_secs()); - - (pk, proof, dims) -} - -pub fn verify_ipa_proof(proof: Proof) -> bool { - let args = Cli::parse(); - let params: ParamsIPA = ParamsIPA::new(args.logrows); - - let inputs = proof - .input_shapes - .iter() - .map( - |s| match Tensor::new(Some(&vec![0; s.iter().product()]), s) { - Ok(t) => t, - Err(e) => { - abort!("failed to initialize tensor {:?}", e); - } - }, - ) - .collect(); - let circuit = ModelCircuit:: { - inputs, - _marker: PhantomData, - }; - let empty_circuit = circuit.without_witnesses(); - let vk = keygen_vk(¶ms, &empty_circuit).expect("keygen_vk should not fail"); - let pk = keygen_pk(¶ms, vk, &empty_circuit).expect("keygen_pk should not fail"); - - let pi_inner: Vec> = proof - .public_inputs - .iter() - .map(|i| i.iter().map(|e| i32_to_felt::(*e)).collect::>()) - .collect::>>(); - let pi_inner = pi_inner.iter().map(|e| e.deref()).collect::>(); - let pi_for_real_prover: &[&[&[F]]] = &[&pi_inner]; - trace!("pi for real prover {:?}", pi_for_real_prover); - - let now = Instant::now(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof.proof[..]); - - trace!("params computed"); - - let result = verify_proof( - ¶ms, - pk.get_vk(), - strategy, - pi_for_real_prover, - &mut transcript, - ) - .is_ok(); - info!("verify took {}", now.elapsed().as_secs()); - result -} diff --git a/src/pfsys/kzg/mod.rs b/src/pfsys/kzg/mod.rs deleted file mode 100644 index 483c16397..000000000 --- a/src/pfsys/kzg/mod.rs +++ /dev/null @@ -1,187 +0,0 @@ -/// Aggregation circuit -pub mod aggregation; - -use super::prepare_circuit_and_public_input; -use super::ModelInput; -use crate::fieldutils::i32_to_felt; -#[cfg(feature = "evm")] -use aggregation::Plonk; -use aggregation::{PoseidonTranscript, Snark}; -#[cfg(feature = "evm")] -use ethereum_types::Address; -#[cfg(feature = "evm")] -use foundry_evm::executor::{fork::MultiFork, Backend, ExecutorBuilder}; -#[cfg(feature = "evm")] -use halo2_proofs::plonk::VerifyingKey; -use halo2_proofs::{ - dev::MockProver, - plonk::{create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey}, - poly::{ - commitment::{Params, ParamsProver}, - kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::{ProverGWC, VerifierGWC}, - strategy::AccumulatorStrategy, - }, - VerificationStrategy, - }, - transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, -}; -#[cfg(feature = "evm")] -use halo2curves::bn256::Fq; -use halo2curves::bn256::{Bn256, Fr, G1Affine}; -use itertools::Itertools; -use log::trace; -#[cfg(feature = "evm")] -use plonk_verifier::{ - loader::evm::{encode_calldata, EvmLoader}, - system::halo2::transcript::evm::EvmTranscript, - verifier::PlonkVerifier, -}; -use plonk_verifier::{ - loader::native::NativeLoader, - system::halo2::{compile, Config}, -}; - -use rand::rngs::OsRng; -use std::io::Cursor; -#[cfg(feature = "evm")] -use std::rc::Rc; - -pub fn gen_application_snark(params: &ParamsKZG, data: &ModelInput) -> Snark { - let (circuit, public_inputs) = prepare_circuit_and_public_input::(data); - - let pk = gen_pk(params, &circuit); - let number_instance = public_inputs[0].len(); - trace!("number_instance {:?}", number_instance); - let protocol = compile( - params, - pk.get_vk(), - Config::kzg().with_num_instance(vec![number_instance]), - ); - let pi_inner: Vec> = public_inputs - .iter() - .map(|i| i.iter().map(|e| i32_to_felt::(*e)).collect::>()) - .collect::>>(); - // let pi_inner = pi_inner.iter().map(|e| e.deref()).collect::>(); - trace!("pi_inner {:?}", pi_inner); - let proof = gen_kzg_proof::< - _, - _, - PoseidonTranscript, - PoseidonTranscript, - >(params, &pk, circuit, pi_inner.clone()); - Snark::new(protocol, pi_inner, proof) -} - -#[cfg(feature = "evm")] -pub fn gen_aggregation_evm_verifier( - params: &ParamsKZG, - vk: &VerifyingKey, - num_instance: Vec, - accumulator_indices: Vec<(usize, usize)>, -) -> Vec { - let svk = params.get_g()[0].into(); - let dk = (params.g2(), params.s_g2()).into(); - let protocol = compile( - params, - vk, - Config::kzg() - .with_num_instance(num_instance.clone()) - .with_accumulator_indices(accumulator_indices), - ); - - let loader = EvmLoader::new::(); - let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); - - let instances = transcript.load_instances(num_instance); - let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); - Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); - - loader.deployment_code() -} - -#[cfg(feature = "evm")] -pub fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { - let calldata = encode_calldata(&instances, &proof); - let success = { - let mut evm = ExecutorBuilder::default() - .with_gas_limit(u64::MAX.into()) - .build(Backend::new(MultiFork::new().0, None)); - - let caller = Address::from_low_u64_be(0xfe); - let verifier = evm - .deploy(caller, deployment_code.into(), 0.into(), None) - .unwrap() - .address; - let result = evm - .call_raw(caller, verifier, calldata.into(), 0.into()) - .unwrap(); - - dbg!(result.gas_used); - - !result.reverted - }; - assert!(success); -} - -pub fn gen_srs(k: u32) -> ParamsKZG { - ParamsKZG::::setup(k, OsRng) -} - -pub fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { - let vk = keygen_vk(params, circuit).unwrap(); - keygen_pk(params, vk, circuit).unwrap() -} - -/// Generates proof for either application circuit (model) or aggregation circuit. -pub fn gen_kzg_proof< - C: Circuit, - E: EncodedChallenge, - TR: TranscriptReadBuffer>, G1Affine, E>, - TW: TranscriptWriterBuffer, G1Affine, E>, ->( - params: &ParamsKZG, - pk: &ProvingKey, - circuit: C, - instances: Vec>, -) -> Vec { - MockProver::run(params.k(), &circuit, instances.clone()) - .unwrap() - .assert_satisfied(); - - let instances = instances - .iter() - .map(|instances| instances.as_slice()) - .collect_vec(); - let proof = { - let mut transcript = TW::init(Vec::new()); - create_proof::, ProverGWC<_>, _, _, TW, _>( - params, - pk, - &[circuit], - &[instances.as_slice()], - OsRng, - &mut transcript, - ) - .unwrap(); - transcript.finalize() - }; - - let accept = { - let mut transcript = TR::init(Cursor::new(proof.clone())); - VerificationStrategy::<_, VerifierGWC<_>>::finalize( - verify_proof::<_, VerifierGWC<_>, _, TR, _>( - params.verifier_params(), - pk.get_vk(), - AccumulatorStrategy::new(params.verifier_params()), - &[instances.as_slice()], - &mut transcript, - ) - .unwrap(), - ) - }; - assert!(accept); - - proof -} diff --git a/src/pfsys/mod.rs b/src/pfsys/mod.rs index 0c0f9c700..1c55f5b1d 100644 --- a/src/pfsys/mod.rs +++ b/src/pfsys/mod.rs @@ -1,18 +1,31 @@ -// Really these are halo2 plonkish IOP + commitment scheme, but we only support Plonkish IOP so far, so there is no ambiguity -pub mod ipa; -pub mod kzg; +/// Aggregation circuit +#[cfg(feature = "evm")] +pub mod aggregation; use crate::abort; use crate::commands::{data_path, Cli}; +use crate::fieldutils::i32_to_felt; use crate::graph::{utilities::vector_to_quantized, Model, ModelCircuit}; -use crate::tensor::Tensor; +use crate::tensor::{Tensor, TensorType}; use clap::Parser; +use halo2_proofs::plonk::{ + create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey, +}; +use halo2_proofs::poly::commitment::{CommitmentScheme, Params, Prover, Verifier}; +use halo2_proofs::poly::VerificationStrategy; +use halo2_proofs::transcript::{ + Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, +}; use halo2_proofs::{arithmetic::FieldExt, dev::VerifyFailure}; use log::{error, info, trace}; +use rand::rngs::OsRng; use serde::{Deserialize, Serialize}; use std::fs::File; -use std::io::Read; +use std::io::{BufReader, BufWriter, Read, Write}; use std::marker::PhantomData; +use std::ops::Deref; +use std::path::PathBuf; +use std::time::Instant; #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelInput { @@ -28,6 +41,37 @@ pub struct Proof { pub proof: Vec, } +impl Proof { + pub fn save(&self, proof_path: &PathBuf) { + let serialized = match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => { + abort!("failed to convert proof json to string {:?}", e); + } + }; + + let mut file = std::fs::File::create(proof_path).expect("create failed"); + file.write_all(serialized.as_bytes()).expect("write failed"); + } + + pub fn load(proof_path: &PathBuf) -> Self { + let mut file = match File::open(proof_path) { + Ok(f) => f, + Err(e) => { + abort!("failed to open proof file {:?}", e); + } + }; + let mut data = String::new(); + match file.read_to_string(&mut data) { + Ok(_) => {} + Err(e) => { + abort!("failed to read file {:?}", e); + } + }; + serde_json::from_str(&data).expect("JSON was not well-formatted") + } +} + /// Helper function to print helpful error messages after verification has failed. pub fn parse_prover_errors(f: &VerifyFailure) { match f { @@ -164,3 +208,174 @@ pub fn prepare_data(datapath: String) -> ModelInput { data } + +pub fn create_keys( + circuit: &ModelCircuit, + params: &'_ Scheme::ParamsProver, +) -> ProvingKey +where + ModelCircuit: Circuit, +{ + // Real proof + let empty_circuit = circuit.without_witnesses(); + + // Initialize the proving key + let now = Instant::now(); + trace!("preparing VK"); + let vk = keygen_vk(params, &empty_circuit).expect("keygen_vk should not fail"); + info!("VK took {}", now.elapsed().as_secs()); + let now = Instant::now(); + let pk = keygen_pk(params, vk, &empty_circuit).expect("keygen_pk should not fail"); + info!("PK took {}", now.elapsed().as_secs()); + pk +} + +/// a wrapper around halo2's create_proof +pub fn create_proof_model< + 'params, + Scheme: CommitmentScheme, + F: FieldExt + TensorType, + P: Prover<'params, Scheme>, +>( + circuit: &ModelCircuit, + public_inputs: &[Tensor], + params: &'params Scheme::ParamsProver, + pk: &ProvingKey, +) -> (Proof, Vec>) +where + ModelCircuit: Circuit, +{ + let now = Instant::now(); + let mut transcript = Blake2bWrite::<_, Scheme::Curve, Challenge255<_>>::init(vec![]); + let mut rng = OsRng; + let pi_inner: Vec> = public_inputs + .iter() + .map(|i| { + i.iter() + .map(|e| i32_to_felt::(*e)) + .collect::>() + }) + .collect::>>(); + let pi_inner = pi_inner + .iter() + .map(|e| e.deref()) + .collect::>(); + let instances: &[&[&[Scheme::Scalar]]] = &[&pi_inner]; + trace!("instances {:?}", instances); + + let dims = circuit.inputs.iter().map(|i| i.dims().to_vec()).collect(); + + create_proof::( + params, + pk, + &[circuit.clone()], + instances, + &mut rng, + &mut transcript, + ) + .expect("proof generation should not fail"); + let proof = transcript.finalize(); + info!("Proof took {}", now.elapsed().as_secs()); + + let checkable_pf = Proof { + input_shapes: circuit.inputs.iter().map(|i| i.dims().to_vec()).collect(), + public_inputs: public_inputs + .iter() + .map(|i| i.clone().into_iter().collect()) + .collect(), + proof, + }; + + (checkable_pf, dims) +} + +/// a wrapper around halo2's verify_proof +pub fn verify_proof_model< + 'params, + F: FieldExt, + V: Verifier<'params, Scheme>, + Scheme: CommitmentScheme, + Strategy: VerificationStrategy<'params, Scheme, V>, +>( + proof: Proof, + params: &'params Scheme::ParamsVerifier, + vk: &VerifyingKey, + strategy: Strategy, +) -> bool +where + ModelCircuit: Circuit, +{ + let pi_inner: Vec> = proof + .public_inputs + .iter() + .map(|i| { + i.iter() + .map(|e| i32_to_felt::(*e)) + .collect::>() + }) + .collect::>>(); + let pi_inner = pi_inner + .iter() + .map(|e| e.deref()) + .collect::>(); + let instances: &[&[&[Scheme::Scalar]]] = &[&pi_inner]; + trace!("instances {:?}", instances); + + let now = Instant::now(); + let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof.proof[..]); + + let result = + verify_proof::(params, vk, strategy, instances, &mut transcript) + .is_ok(); + info!("verify took {}", now.elapsed().as_secs()); + result +} + +pub fn load_vk( + path: PathBuf, + params: &'_ Scheme::ParamsVerifier, +) -> VerifyingKey +where + ModelCircuit: Circuit, +{ + info!("loading verification key from {:?}", path); + let f = match File::open(path) { + Ok(f) => f, + Err(e) => { + abort!("failed to load vk {}", e); + } + }; + let mut reader = BufReader::new(f); + VerifyingKey::::read::<_, ModelCircuit>(&mut reader, params).unwrap() +} + +pub fn load_params(path: PathBuf) -> Scheme::ParamsVerifier { + info!("loading params from {:?}", path); + let f = match File::open(path) { + Ok(f) => f, + Err(e) => { + abort!("failed to load params {}", e); + } + }; + let mut reader = BufReader::new(f); + Params::<'_, Scheme::Curve>::read(&mut reader).unwrap() +} + +pub fn save_vk(vk_path: &PathBuf, vk: &VerifyingKey) { + info!("saving verification key 💾"); + let f = File::create(vk_path).unwrap(); + let mut writer = BufWriter::new(f); + vk.write(&mut writer).unwrap(); + writer.flush().unwrap(); +} + +pub fn save_params( + params_path: &PathBuf, + params: &'_ Scheme::ParamsVerifier, +) { + info!("saving parameters 💾"); + let f = File::create(params_path).unwrap(); + let mut writer = BufWriter::new(f); + params.write(&mut writer).unwrap(); + writer.flush().unwrap(); +} diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 7951528f6..cfde190f9 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -1,5 +1,11 @@ use std::process::Command; +#[cfg(test)] +#[ctor::ctor] +fn init() { + build_ezkl(); +} + const TESTS: [&str; 12] = [ "1l_mlp", "1l_flatten", @@ -25,13 +31,15 @@ macro_rules! test_func { use crate::mock; use crate::mock_public_inputs; use crate::mock_public_params; - use crate::fullprove; - use crate::prove_and_verify; + use crate::ipa_fullprove; + use crate::ipa_prove_and_verify; use crate::kzg_fullprove; + use crate::kzg_prove_and_verify; + use crate::kzg_evm_fullprove; seq!(N in 0..=11 { #(#[test_case(TESTS[N])])* - fn mock_(test: &str) { + fn mock_public_outputs_(test: &str) { mock(test.to_string()); } @@ -46,21 +54,32 @@ macro_rules! test_func { } #(#[test_case(TESTS[N])])* - fn fullprove_(test: &str) { - fullprove(test.to_string()); + fn ipa_fullprove_(test: &str) { + ipa_fullprove(test.to_string()); } #(#[test_case(TESTS[N])])* - fn prove_and_verify_(test: &str) { - prove_and_verify(test.to_string()); + fn ipa_prove_and_verify_(test: &str) { + ipa_prove_and_verify(test.to_string()); } - // these take a particularly long time to run #(#[test_case(TESTS[N])])* - #[ignore] fn kzg_fullprove_(test: &str) { kzg_fullprove(test.to_string()); } + + #(#[test_case(TESTS[N])])* + // #[ignore] + fn kzg_prove_and_verify_(test: &str) { + kzg_prove_and_verify(test.to_string()); + } + + // these take a particularly long time to run + #(#[test_case(TESTS[N])])* + #[ignore] + fn kzg_evm_fullprove_(test: &str) { + kzg_evm_fullprove(test.to_string()); + } }); } }; @@ -70,17 +89,10 @@ test_func!(); // Mock prove (fast, but does not cover some potential issues) fn mock(example_name: String) { - let status = Command::new("cargo") + let status = Command::new("target/release/ezkl") .args([ - "run", - "--release", - "--bin", - "ezkl", - "--", - "--bits", - "16", - "-K", - "17", + "--bits=16", + "-K=17", "mock", "-D", format!("./examples/onnx/examples/{}/input.json", example_name).as_str(), @@ -96,18 +108,11 @@ fn mock(example_name: String) { // Mock prove (fast, but does not cover some potential issues) fn mock_public_inputs(example_name: String) { - let status = Command::new("cargo") + let status = Command::new("target/release/ezkl") .args([ - "run", - "--release", - "--bin", - "ezkl", - "--", "--public-inputs", - "--bits", - "16", - "-K", - "17", + "--bits=16", + "-K=17", "mock", "-D", format!("./examples/onnx/examples/{}/input.json", example_name).as_str(), @@ -123,18 +128,11 @@ fn mock_public_inputs(example_name: String) { // Mock prove (fast, but does not cover some potential issues) fn mock_public_params(example_name: String) { - let status = Command::new("cargo") + let status = Command::new("target/release/ezkl") .args([ - "run", - "--release", - "--bin", - "ezkl", - "--", "--public-params", - "--bits", - "16", - "-K", - "17", + "--bits=16", + "-K=17", "mock", "-D", format!("./examples/onnx/examples/{}/input.json", example_name).as_str(), @@ -149,18 +147,11 @@ fn mock_public_params(example_name: String) { } // full prove (slower, covers more, but still reuses the pk) -fn fullprove(example_name: String) { - let status = Command::new("cargo") +fn ipa_fullprove(example_name: String) { + let status = Command::new("target/release/ezkl") .args([ - "run", - "--release", - "--bin", - "ezkl", - "--", - "--bits", - "16", - "-K", - "17", + "--bits=16", + "-K=17", "fullprove", "-D", format!("./examples/onnx/examples/{}/input.json", example_name).as_str(), @@ -175,54 +166,109 @@ fn fullprove(example_name: String) { } // prove-serialize-verify, the usual full path -fn prove_and_verify(example_name: String) { - let status = Command::new("cargo") +fn ipa_prove_and_verify(example_name: String) { + let status = Command::new("target/release/ezkl") .args([ - "run", - "--release", - "--bin", - "ezkl", - "--", - "--bits", - "16", - "-K", - "17", + "--bits=16", + "-K=17", "prove", "-D", format!("./examples/onnx/examples/{}/input.json", example_name).as_str(), "-M", format!("./examples/onnx/examples/{}/network.onnx", example_name).as_str(), "-O", - format!("pav_{}.pf", example_name).as_str(), + format!("ipa_{}.pf", example_name).as_str(), + "--vk-path", + format!("ipa_{}.vk", example_name).as_str(), + "--params-path", + format!("ipa_{}.params", example_name).as_str(), ]) .status() .expect("failed to execute process"); assert!(status.success()); - let status = Command::new("cargo") + let status = Command::new("target/release/ezkl") .args([ - "run", - "--release", - "--bin", - "ezkl", - "--", - "--bits", - "16", - "-K", - "17", + "--bits=16", + "-K=17", "verify", "-M", format!("./examples/onnx/examples/{}/network.onnx", example_name).as_str(), "-P", - format!("pav_{}.pf", example_name).as_str(), + format!("ipa_{}.pf", example_name).as_str(), + "--vk-path", + format!("ipa_{}.vk", example_name).as_str(), + "--params-path", + format!("ipa_{}.params", example_name).as_str(), ]) .status() .expect("failed to execute process"); assert!(status.success()); } -// KZG / EVM tests +// prove-serialize-verify, the usual full path +fn kzg_prove_and_verify(example_name: String) { + let status = Command::new("target/release/ezkl") + .args([ + "--bits=16", + "-K=17", + "prove", + "--pfsys=kzg", + "-D", + format!("./examples/onnx/examples/{}/input.json", example_name).as_str(), + "-M", + format!("./examples/onnx/examples/{}/network.onnx", example_name).as_str(), + "-O", + format!("kzg_{}.pf", example_name).as_str(), + "--vk-path", + format!("kzg_{}.vk", example_name).as_str(), + "--params-path", + format!("kzg_{}.params", example_name).as_str(), + ]) + .status() + .expect("failed to execute process"); + assert!(status.success()); + let status = Command::new("target/release/ezkl") + .args([ + "--bits=16", + "-K=17", + "verify", + "--pfsys=kzg", + "-M", + format!("./examples/onnx/examples/{}/network.onnx", example_name).as_str(), + "-P", + format!("kzg_{}.pf", example_name).as_str(), + "--vk-path", + format!("kzg_{}.vk", example_name).as_str(), + "--params-path", + format!("kzg_{}.params", example_name).as_str(), + ]) + .status() + .expect("failed to execute process"); + assert!(status.success()); +} + +// KZG tests // full prove (slower, covers more, but still reuses the pk) fn kzg_fullprove(example_name: String) { + let status = Command::new("target/release/ezkl") + .args([ + "--bits=16", + "-K=17", + "fullprove", + "--pfsys=kzg", + "-D", + format!("./examples/onnx/examples/{}/input.json", example_name).as_str(), + "-M", + format!("./examples/onnx/examples/{}/network.onnx", example_name).as_str(), + ]) + .status() + .expect("failed to execute process"); + assert!(status.success()); +} + +// KZG / EVM tests +// full prove (slower, covers more, but still reuses the pk) +fn kzg_evm_fullprove(example_name: String) { let status = Command::new("cargo") .args([ "run", @@ -232,19 +278,24 @@ fn kzg_fullprove(example_name: String) { "--bin", "ezkl", "--", - "--bits", - "16", - "-K", - "17", + "--bits=16", + "-K=17", "fullprove", + "--pfsys=kzg", "-D", format!("./examples/onnx/examples/{}/input.json", example_name).as_str(), "-M", format!("./examples/onnx/examples/{}/network.onnx", example_name).as_str(), - "--pfsys", - "kzg", ]) .status() .expect("failed to execute process"); assert!(status.success()); } + +fn build_ezkl() { + let status = Command::new("cargo") + .args(["build", "--release", "--bin", "ezkl"]) + .status() + .expect("failed to execute process"); + assert!(status.success()); +}