From 34f01a8f4b3084f5f039a6c4972e31f2964b347a Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 7 Aug 2023 13:50:07 -0700 Subject: [PATCH 1/3] feat: allow circuit size (number of rows) to be loaded as witness --- .../{standard_plonk.rs => vkey_as_witness.rs} | 24 ++------------ snark-verifier-sdk/src/halo2/aggregation.rs | 31 ++++++++++++------- snark-verifier/examples/recursion.rs | 2 +- snark-verifier/src/loader.rs | 6 ++++ snark-verifier/src/loader/evm/loader.rs | 4 +++ snark-verifier/src/loader/halo2/loader.rs | 11 +++++++ snark-verifier/src/loader/halo2/shim.rs | 19 ++++++++++++ snark-verifier/src/loader/native.rs | 7 ++++- snark-verifier/src/system/halo2.rs | 1 + snark-verifier/src/verifier/plonk.rs | 8 +++-- snark-verifier/src/verifier/plonk/protocol.rs | 30 ++++++++++++++++-- 11 files changed, 103 insertions(+), 40 deletions(-) rename snark-verifier-sdk/examples/{standard_plonk.rs => vkey_as_witness.rs} (89%) diff --git a/snark-verifier-sdk/examples/standard_plonk.rs b/snark-verifier-sdk/examples/vkey_as_witness.rs similarity index 89% rename from snark-verifier-sdk/examples/standard_plonk.rs rename to snark-verifier-sdk/examples/vkey_as_witness.rs index 35bb573b..9f715910 100644 --- a/snark-verifier-sdk/examples/standard_plonk.rs +++ b/snark-verifier-sdk/examples/vkey_as_witness.rs @@ -153,7 +153,7 @@ fn main() { let params_app = gen_srs(8); let dummy_snark = gen_application_snark(¶ms_app, ComputeFlag::All); - let k = 22u32; + let k = 15u32; let params = gen_srs(k); let lookup_bits = k as usize - 1; BASE_CONFIG_PARAMS.with(|config| { @@ -170,9 +170,7 @@ fn main() { ); agg_circuit.config(k, Some(10)); - let start0 = start_timer!(|| "gen vk & pk"); - let pk = gen_pk(¶ms, &agg_circuit, Some(Path::new("./examples/agg.pk"))); - end_timer!(start0); + let pk = gen_pk(¶ms, &agg_circuit, None); let break_points = agg_circuit.break_points(); let snarks = [ComputeFlag::All, ComputeFlag::SkipFixed, ComputeFlag::SkipCopy] @@ -189,22 +187,4 @@ fn main() { let _snark = gen_snark_shplonk(¶ms, &pk, agg_circuit, None::<&str>); println!("snark {i} success"); } - - /* - #[cfg(feature = "loader_evm")] - { - // do one more time to verify - let num_instances = agg_circuit.num_instance(); - let instances = agg_circuit.instances(); - let proof_calldata = gen_evm_proof_shplonk(¶ms, &pk, agg_circuit, instances.clone()); - - let deployment_code = gen_evm_verifier_shplonk::>( - ¶ms, - pk.get_vk(), - num_instances, - Some(Path::new("./examples/standard_plonk.yul")), - ); - evm_verify(deployment_code, instances, proof_calldata); - } - */ } diff --git a/snark-verifier-sdk/src/halo2/aggregation.rs b/snark-verifier-sdk/src/halo2/aggregation.rs index b6bbabf1..6232f7d0 100644 --- a/snark-verifier-sdk/src/halo2/aggregation.rs +++ b/snark-verifier-sdk/src/halo2/aggregation.rs @@ -54,7 +54,7 @@ pub struct SnarkAggregationWitness<'a> { /// These can then be exposed as public instances. /// /// This is only useful if preprocessed digest is loaded as witness (i.e., `preprocessed_as_witness` is true in `aggregate`), so we set it to `None` otherwise. - pub preprocessed_digest: Option>>>, + pub preprocessed_digests: Option>>>, } #[allow(clippy::type_complexity)] @@ -99,7 +99,7 @@ where }; let mut previous_instances = Vec::with_capacity(snarks.len()); - let mut preprocessed_digest = Vec::with_capacity(snarks.len()); + let mut preprocessed_digests = Vec::with_capacity(snarks.len()); // to avoid re-loading the spec each time, we create one transcript and clear the stream let mut transcript = PoseidonTranscript::>, &[u8]>::from_spec( loader, @@ -109,9 +109,10 @@ where let mut accumulators = snarks .iter() - .flat_map(|snark| { + .flat_map(|snark: &Snark| { let protocol = if preprocessed_as_witness { - snark.protocol.loaded_preprocessed_as_witness(loader) + // always load `domain.n` as witness if vkey is witness + snark.protocol.loaded_preprocessed_as_witness(loader, true) } else { snark.protocol.loaded(loader) }; @@ -128,6 +129,7 @@ where .chain( protocol.transcript_initial_state.clone().map(|scalar| scalar.into_assigned()), ) + .chain(protocol.n_as_witness.clone().map(|n| n.into_assigned())) // If `n` is witness, add it as part of input .collect_vec(); let instances = assign_instances(&snark.instances); @@ -147,7 +149,7 @@ where previous_instances.push( instances.into_iter().flatten().map(|scalar| scalar.into_assigned()).collect(), ); - preprocessed_digest.push(inputs); + preprocessed_digests.push(inputs); accumulator }) @@ -166,9 +168,9 @@ where } else { accumulators.pop().unwrap() }; - let preprocessed_digest = preprocessed_as_witness.then(|| preprocessed_digest); + let preprocessed_digests = preprocessed_as_witness.then_some(preprocessed_digests); - SnarkAggregationWitness { previous_instances, accumulator, preprocessed_digest } + SnarkAggregationWitness { previous_instances, accumulator, preprocessed_digests } } /// Same as `FlexGateConfigParams` except we assume a single Phase and default 'Vertical' strategy. @@ -208,6 +210,11 @@ pub struct AggregationCircuit { // the public instances from previous snarks that were aggregated, now collected as PRIVATE assigned values // the user can optionally append these to `inner.assigned_instances` to expose them pub previous_instances: Vec>>, + /// This returns the assigned `preprocessed` and `transcript_initial_state` and `domain.n` values as a vector of assigned values, one for each aggregated snark. + /// These can then be exposed as public instances. + /// + /// This is only useful if preprocessed digest is loaded as witness (i.e., `preprocessed_as_witness` is true in `aggregate`), so we set it to `None` otherwise. + pub preprocessed_digests: Option>>>, // accumulation scheme proof, private input pub as_proof: Vec, // not sure this needs to be stored, keeping for now } @@ -239,9 +246,11 @@ impl AggregationCircuit { /// Given snarks, this creates a circuit and runs the `GateThreadBuilder` to verify all the snarks. /// By default, the returned circuit has public instances equal to the limbs of the pair of elliptic curve points, referred to as the `accumulator`, that need to be verified in a final pairing check. /// - /// The user can optionally modify the circuit after calling this function to add more instances to `assigned_instances` to expose. + /// - If `preprocessed_as_witness` is set to `true`, then the verifying keys of each snark in `snarks` is loaded as a witness in the circuit. Moreover, the number of rows `n` of each snark in `snarks` is also loaded as a witness. By default, these witnesses are _private_ and returned in `self.preprocessed_ + /// - The user can optionally modify the circuit after calling this function to add more instances to `assigned_instances` to expose. /// - /// Warning: will fail silently if `snarks` were created using a different multi-open scheme than `AS` + /// # Warning + /// Will fail silently if `snarks` were created using a different multi-open scheme than `AS` /// where `AS` can be either [`crate::SHPLONK`] or [`crate::GWC`] (for original PLONK multi-open scheme) pub fn new( stage: CircuitBuilderStage, @@ -300,7 +309,7 @@ impl AggregationCircuit { let ecc_chip = BaseFieldEccChip::new(&fp_chip); let loader = Halo2Loader::new(ecc_chip, builder); - let SnarkAggregationWitness { previous_instances, accumulator, preprocessed_digest } = + let SnarkAggregationWitness { previous_instances, accumulator, preprocessed_digests } = aggregate::(&svk, &loader, &snarks, as_proof.as_slice(), preprocessed_as_witness); let lhs = accumulator.lhs.assigned(); let rhs = accumulator.rhs.assigned(); @@ -333,7 +342,7 @@ impl AggregationCircuit { } }; let inner = RangeWithInstanceCircuitBuilder::new(circuit, assigned_instances); - Self { inner, previous_instances, as_proof } + Self { inner, previous_instances, preprocessed_digests, as_proof } } pub fn public( diff --git a/snark-verifier/examples/recursion.rs b/snark-verifier/examples/recursion.rs index 7415e1ab..4138e372 100644 --- a/snark-verifier/examples/recursion.rs +++ b/snark-verifier/examples/recursion.rs @@ -358,7 +358,7 @@ mod recursion { ) -> (Vec>>, Vec>>>) { let protocol = if let Some(preprocessed_digest) = preprocessed_digest { let preprocessed_digest = loader.scalar_from_assigned(preprocessed_digest); - let protocol = snark.protocol.loaded_preprocessed_as_witness(loader); + let protocol = snark.protocol.loaded_preprocessed_as_witness(loader, false); let inputs = protocol .preprocessed .iter() diff --git a/snark-verifier/src/loader.rs b/snark-verifier/src/loader.rs index 77a8f54b..0a4e8f38 100644 --- a/snark-verifier/src/loader.rs +++ b/snark-verifier/src/loader.rs @@ -67,6 +67,12 @@ pub trait LoadedScalar: Clone + Debug + PartialEq + FieldOps { acc } + /// Returns power to exponent, where exponent is also [`LoadedScalar`]. + /// If `Loader` is for Halo2, then `exp` must have at most `exp_max_bits` bits (otherwise constraints will fail). + /// + /// Currently **unimplemented** for EvmLoader + fn pow_var(&self, exp: &Self, exp_max_bits: usize) -> Self; + /// Returns powers up to exponent `n-1`. fn powers(&self, n: usize) -> Vec { iter::once(self.loader().load_one()) diff --git a/snark-verifier/src/loader/evm/loader.rs b/snark-verifier/src/loader/evm/loader.rs index 98ca5ca4..5aa3fecf 100644 --- a/snark-verifier/src/loader/evm/loader.rs +++ b/snark-verifier/src/loader/evm/loader.rs @@ -632,6 +632,10 @@ impl> LoadedScalar for Scalar { fn loader(&self) -> &Self::Loader { &self.loader } + + fn pow_var(&self, _exp: &Self, _exp_max_bits: usize) -> Self { + todo!() + } } impl EcPointLoader for Rc diff --git a/snark-verifier/src/loader/halo2/loader.rs b/snark-verifier/src/loader/halo2/loader.rs index 31be9841..1f1bb16d 100644 --- a/snark-verifier/src/loader/halo2/loader.rs +++ b/snark-verifier/src/loader/halo2/loader.rs @@ -306,6 +306,17 @@ impl> LoadedScalar for Sc fn loader(&self) -> &Self::Loader { &self.loader } + + fn pow_var(&self, exp: &Self, max_bits: usize) -> Self { + let loader = self.loader(); + let res = loader.scalar_chip().pow_var( + &mut loader.ctx_mut(), + &self.assigned(), + &exp.assigned(), + max_bits, + ); + loader.scalar_from_assigned(res) + } } impl> Debug for Scalar { diff --git a/snark-verifier/src/loader/halo2/shim.rs b/snark-verifier/src/loader/halo2/shim.rs index 790c9e22..1196f0a1 100644 --- a/snark-verifier/src/loader/halo2/shim.rs +++ b/snark-verifier/src/loader/halo2/shim.rs @@ -65,6 +65,15 @@ pub trait IntegerInstructions: Clone + Debug { lhs: &Self::AssignedInteger, rhs: &Self::AssignedInteger, ); + + /// Returns `base^exponent` and constrains that `exponent` has at most `max_bits` bits. + fn pow_var( + &self, + ctx: &mut Self::Context, + base: &Self::AssignedInteger, + exponent: &Self::AssignedInteger, + max_bits: usize, + ) -> Self::AssignedInteger; } /// Instructions to handle elliptic curve point operations. @@ -232,6 +241,16 @@ mod halo2_lib { ) { ctx.main(0).constrain_equal(a, b); } + + fn pow_var( + &self, + ctx: &mut Self::Context, + base: &Self::AssignedInteger, + exponent: &Self::AssignedInteger, + max_bits: usize, + ) -> Self::AssignedInteger { + GateInstructions::pow_var(self, ctx.main(0), *base, *exponent, max_bits) + } } impl<'chip, C: CurveAffineExt> EccInstructions for BaseFieldEccChip<'chip, C> diff --git a/snark-verifier/src/loader/native.rs b/snark-verifier/src/loader/native.rs index 783aaa89..a9aa86ff 100644 --- a/snark-verifier/src/loader/native.rs +++ b/snark-verifier/src/loader/native.rs @@ -2,7 +2,7 @@ use crate::{ loader::{EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, - util::arithmetic::{Curve, CurveAffine, FieldOps, PrimeField}, + util::arithmetic::{fe_to_big, Curve, CurveAffine, FieldOps, PrimeField}, Error, }; use lazy_static::lazy_static; @@ -38,6 +38,11 @@ impl LoadedScalar for F { fn loader(&self) -> &NativeLoader { &LOADER } + + fn pow_var(&self, exp: &Self, _: usize) -> Self { + let exp = fe_to_big(*exp).to_u64_digits(); + self.pow_vartime(exp) + } } impl EcPointLoader for NativeLoader { diff --git a/snark-verifier/src/system/halo2.rs b/snark-verifier/src/system/halo2.rs index 98f4488c..9715d996 100644 --- a/snark-verifier/src/system/halo2.rs +++ b/snark-verifier/src/system/halo2.rs @@ -141,6 +141,7 @@ pub fn compile<'a, C: CurveAffine, P: Params<'a, C>>( PlonkProtocol { domain, + n_as_witness: None, preprocessed, num_instance: polynomials.num_instance(), num_witness: polynomials.num_witness(), diff --git a/snark-verifier/src/verifier/plonk.rs b/snark-verifier/src/verifier/plonk.rs index d5937ab8..54fbaf74 100644 --- a/snark-verifier/src/verifier/plonk.rs +++ b/snark-verifier/src/verifier/plonk.rs @@ -62,8 +62,12 @@ where proof: &Self::Proof, ) -> Result { let common_poly_eval = { - let mut common_poly_eval = - CommonPolynomialEvaluation::new(&protocol.domain, protocol.langranges(), &proof.z); + let mut common_poly_eval = CommonPolynomialEvaluation::new( + &protocol.domain, + protocol.langranges(), + &proof.z, + &protocol.n_as_witness, + ); L::batch_invert(common_poly_eval.denoms()); common_poly_eval.evaluate(); diff --git a/snark-verifier/src/verifier/plonk/protocol.rs b/snark-verifier/src/verifier/plonk/protocol.rs index a3a84346..6256d522 100644 --- a/snark-verifier/src/verifier/plonk/protocol.rs +++ b/snark-verifier/src/verifier/plonk/protocol.rs @@ -1,7 +1,7 @@ use crate::{ loader::{native::NativeLoader, LoadedScalar, Loader}, util::{ - arithmetic::{CurveAffine, Domain, Field, Fraction, Rotation}, + arithmetic::{CurveAffine, Domain, Field, Fraction, PrimeField, Rotation}, Itertools, }, }; @@ -29,6 +29,14 @@ where ))] /// Working domain. pub domain: Domain, + + #[serde(bound( + serialize = "L::LoadedScalar: Serialize", + deserialize = "L::LoadedScalar: Deserialize<'de>" + ))] + /// Optional: load `domain.n` as a witness + pub n_as_witness: Option, + #[serde(bound( serialize = "L::LoadedEcPoint: Serialize", deserialize = "L::LoadedEcPoint: Deserialize<'de>" @@ -115,6 +123,7 @@ where .map(|transcript_initial_state| loader.load_const(transcript_initial_state)); PlonkProtocol { domain: self.domain.clone(), + n_as_witness: None, preprocessed, num_instance: self.num_instance.clone(), num_witness: self.num_witness.clone(), @@ -149,7 +158,10 @@ mod halo2 { pub fn loaded_preprocessed_as_witness>( &self, loader: &Rc>, + load_n_as_witness: bool, ) -> PlonkProtocol>> { + let n_as_witness = load_n_as_witness + .then(|| loader.assign_scalar(C::Scalar::from(self.domain.n as u64))); let preprocessed = self .preprocessed .iter() @@ -161,6 +173,7 @@ mod halo2 { .map(|transcript_initial_state| loader.assign_scalar(*transcript_initial_state)); PlonkProtocol { domain: self.domain.clone(), + n_as_witness, preprocessed, num_instance: self.num_instance.clone(), num_witness: self.num_witness.clone(), @@ -201,21 +214,32 @@ where C: CurveAffine, L: Loader, { + // if `n_as_witness` is Some, then we assume `n_as_witness` has value equal to `domain.n` (i.e., number of rows in the circuit) + // and is loaded as a witness instead of a constant. pub fn new( domain: &Domain, langranges: impl IntoIterator, z: &L::LoadedScalar, + n_as_witness: &Option, ) -> Self { let loader = z.loader(); - let zn = z.pow_const(domain.n as u64); + let zn = if let Some(n) = n_as_witness.as_ref() { + z.pow_var(n, C::Scalar::S as usize + 1) + } else { + z.pow_const(domain.n as u64) + }; let langranges = langranges.into_iter().sorted().dedup().collect_vec(); let one = loader.load_one(); let zn_minus_one = zn.clone() - &one; let zn_minus_one_inv = Fraction::one_over(zn_minus_one.clone()); - let n_inv = loader.load_const(&domain.n_inv); + let n_inv = if let Some(n) = n_as_witness.as_ref() { + n.invert().expect("n is not zero") + } else { + loader.load_const(&domain.n_inv) + }; let numer = zn_minus_one.clone() * &n_inv; let omegas = langranges .iter() From debda9e72727bc63a4e587c4ba6bc9ea16bfc0c6 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 7 Aug 2023 13:51:32 -0700 Subject: [PATCH 2/3] chore: clippy fix --- snark-verifier-sdk/examples/vkey_as_witness.rs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/snark-verifier-sdk/examples/vkey_as_witness.rs b/snark-verifier-sdk/examples/vkey_as_witness.rs index 9f715910..73d1cc70 100644 --- a/snark-verifier-sdk/examples/vkey_as_witness.rs +++ b/snark-verifier-sdk/examples/vkey_as_witness.rs @@ -1,23 +1,20 @@ use application::ComputeFlag; -use ark_std::{end_timer, start_timer}; -use halo2_base::gates::builder::{set_lookup_bits, CircuitBuilderStage, BASE_CONFIG_PARAMS}; + +use halo2_base::gates::builder::{CircuitBuilderStage, BASE_CONFIG_PARAMS}; use halo2_base::halo2_proofs; use halo2_base::halo2_proofs::arithmetic::Field; use halo2_base::halo2_proofs::halo2curves::bn256::Fr; -use halo2_base::halo2_proofs::poly::commitment::Params; use halo2_base::utils::fs::gen_srs; use halo2_proofs::halo2curves as halo2_curves; -use halo2_proofs::plonk::Circuit; + use halo2_proofs::{halo2curves::bn256::Bn256, poly::kzg::commitment::ParamsKZG}; use rand::rngs::OsRng; +use snark_verifier_sdk::SHPLONK; use snark_verifier_sdk::{ - evm::{evm_verify, gen_evm_proof_shplonk, gen_evm_verifier_shplonk}, gen_pk, halo2::{aggregation::AggregationCircuit, gen_snark_shplonk}, Snark, }; -use snark_verifier_sdk::{CircuitExt, SHPLONK}; -use std::path::Path; mod application { use super::halo2_curves::bn256::Fr; @@ -26,7 +23,7 @@ mod application { plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, poly::Rotation, }; - use rand::RngCore; + use snark_verifier_sdk::CircuitExt; #[derive(Clone, Copy)] From 253b5b5ea961ea733f95adc2afcc30a1db88bbab Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 10 Aug 2023 00:46:04 -0600 Subject: [PATCH 3/3] fix(n_as_witness): computation of shifts depends on `omega` `omega` which changes when `k` changes, so all shift computations need to be done as witness. Current implementation is likely not the most optimal. Instead of storing `shift` as `omega^i`, we store just `Rotation(i)`. We de-duplicate when possible using `BTreeMap` of `Rotation`. Note you must use `Rotation` instead of `F` for `BTreeMap` because the ordering of `omega^i` may change depending on `omega`. --- snark-verifier-sdk/benches/standard_plonk.rs | 6 +- snark-verifier-sdk/examples/n_as_witness.rs | 190 ++++++++++++++++++ .../examples/vkey_as_witness.rs | 5 +- snark-verifier-sdk/src/halo2.rs | 34 ++-- snark-verifier-sdk/src/halo2/aggregation.rs | 65 +++++- snark-verifier/src/loader/halo2/loader.rs | 9 +- snark-verifier/src/pcs.rs | 24 ++- snark-verifier/src/pcs/ipa/multiopen/bgh19.rs | 85 ++++---- .../src/pcs/kzg/multiopen/bdfg21.rs | 183 +++++++---------- snark-verifier/src/pcs/kzg/multiopen/gwc19.rs | 51 +++-- snark-verifier/src/system/halo2.rs | 2 +- snark-verifier/src/verifier/plonk.rs | 11 +- snark-verifier/src/verifier/plonk/proof.rs | 43 ++-- snark-verifier/src/verifier/plonk/protocol.rs | 93 ++++++--- 14 files changed, 543 insertions(+), 258 deletions(-) create mode 100644 snark-verifier-sdk/examples/n_as_witness.rs diff --git a/snark-verifier-sdk/benches/standard_plonk.rs b/snark-verifier-sdk/benches/standard_plonk.rs index 24e3ae73..643254c7 100644 --- a/snark-verifier-sdk/benches/standard_plonk.rs +++ b/snark-verifier-sdk/benches/standard_plonk.rs @@ -12,7 +12,7 @@ use halo2_proofs::{ use pprof::criterion::{Output, PProfProfiler}; use rand::rngs::OsRng; use snark_verifier_sdk::evm::{evm_verify, gen_evm_proof_shplonk, gen_evm_verifier_shplonk}; -use snark_verifier_sdk::halo2::aggregation::AggregationConfigParams; +use snark_verifier_sdk::halo2::aggregation::{AggregationConfigParams, VerifierUniversality}; use snark_verifier_sdk::{ gen_pk, halo2::{aggregation::AggregationCircuit, gen_proof_shplonk, gen_snark_shplonk}, @@ -209,7 +209,7 @@ fn bench(c: &mut Criterion) { lookup_bits, params, snarks.clone(), - false, + VerifierUniversality::None, ); let instances = agg_circuit.instances(); gen_proof_shplonk(params, pk, agg_circuit, instances, None) @@ -227,7 +227,7 @@ fn bench(c: &mut Criterion) { lookup_bits, ¶ms, snarks.clone(), - false, + VerifierUniversality::None, ); let num_instances = agg_circuit.num_instance(); let instances = agg_circuit.instances(); diff --git a/snark-verifier-sdk/examples/n_as_witness.rs b/snark-verifier-sdk/examples/n_as_witness.rs new file mode 100644 index 00000000..e0970174 --- /dev/null +++ b/snark-verifier-sdk/examples/n_as_witness.rs @@ -0,0 +1,190 @@ +use halo2_base::gates::builder::{CircuitBuilderStage, BASE_CONFIG_PARAMS}; +use halo2_base::halo2_proofs; +use halo2_base::halo2_proofs::arithmetic::Field; +use halo2_base::halo2_proofs::halo2curves::bn256::Fr; +use halo2_base::halo2_proofs::poly::commitment::Params; +use halo2_base::utils::fs::gen_srs; +use halo2_proofs::halo2curves as halo2_curves; + +use rand::rngs::OsRng; +use snark_verifier_sdk::halo2::aggregation::VerifierUniversality; +use snark_verifier_sdk::SHPLONK; +use snark_verifier_sdk::{ + gen_pk, + halo2::{aggregation::AggregationCircuit, gen_snark_shplonk}, + Snark, +}; + +mod application { + use super::halo2_curves::bn256::Fr; + use super::halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, + poly::Rotation, + }; + + use snark_verifier_sdk::CircuitExt; + + #[derive(Clone, Copy)] + pub struct StandardPlonkConfig { + a: Column, + b: Column, + c: Column, + q_a: Column, + q_b: Column, + q_c: Column, + q_ab: Column, + constant: Column, + #[allow(dead_code)] + instance: Column, + } + + impl StandardPlonkConfig { + fn configure(meta: &mut ConstraintSystem) -> Self { + let [a, b, c] = [(); 3].map(|_| meta.advice_column()); + let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); + let instance = meta.instance_column(); + + [a, b, c].map(|column| meta.enable_equality(column)); + + meta.create_gate( + "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", + |meta| { + let [a, b, c] = + [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); + let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] + .map(|column| meta.query_fixed(column, Rotation::cur())); + let instance = meta.query_instance(instance, Rotation::cur()); + Some( + q_a * a.clone() + + q_b * b.clone() + + q_c * c + + q_ab * a * b + + constant + + instance, + ) + }, + ); + + StandardPlonkConfig { a, b, c, q_a, q_b, q_c, q_ab, constant, instance } + } + } + + #[derive(Clone)] + pub struct StandardPlonk(pub Fr, pub usize); + + impl CircuitExt for StandardPlonk { + fn num_instance(&self) -> Vec { + vec![1] + } + + fn instances(&self) -> Vec> { + vec![vec![self.0]] + } + } + + impl Circuit for StandardPlonk { + type Config = StandardPlonkConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self(Fr::zero(), self.1) + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + meta.set_minimum_degree(4); + StandardPlonkConfig::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "", + |mut region| { + region.assign_advice(config.a, 0, Value::known(self.0)); + region.assign_fixed(config.q_a, 0, -Fr::one()); + region.assign_advice(config.a, 1, Value::known(-Fr::from(5u64))); + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed(column, 1, Fr::from(idx as u64)); + } + let a = region.assign_advice(config.a, 2, Value::known(Fr::one())); + a.copy_advice(&mut region, config.b, 3); + a.copy_advice(&mut region, config.c, 4); + + // assuming <= 10 blinding factors + // fill in most of circuit with a computation + /*let n = self.1; + for offset in 5..n - 10 { + region.assign_advice(config.a, offset, Value::known(-Fr::from(5u64))); + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed(column, offset, Fr::from(idx as u64)); + } + }*/ + + Ok(()) + }, + ) + } + } +} + +fn gen_application_snark(k: u32) -> Snark { + let params = gen_srs(k); + let circuit = application::StandardPlonk(Fr::random(OsRng), params.n() as usize); + + let pk = gen_pk(¶ms, &circuit, None); + gen_snark_shplonk(¶ms, &pk, circuit, None::<&str>) +} + +fn main() { + let dummy_snark = gen_application_snark(8); + + let k = 15u32; + let params = gen_srs(k); + let lookup_bits = k as usize - 1; + BASE_CONFIG_PARAMS.with(|config| { + config.borrow_mut().lookup_bits = Some(lookup_bits); + config.borrow_mut().k = k as usize; + }); + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Keygen, + None, + lookup_bits, + ¶ms, + vec![dummy_snark], + VerifierUniversality::Full, + ); + agg_circuit.config(k, Some(10)); + + let pk = gen_pk(¶ms, &agg_circuit, None); + let break_points = agg_circuit.break_points(); + + let snarks = [8, 12, 15, 20].map(|k| (k, gen_application_snark(k))); + for (k, snark) in snarks { + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Prover, + Some(break_points.clone()), + lookup_bits, + ¶ms, + vec![snark], + VerifierUniversality::Full, + ); + let _snark = gen_snark_shplonk(¶ms, &pk, agg_circuit, None::<&str>); + println!("snark with k = {k} success"); + } +} diff --git a/snark-verifier-sdk/examples/vkey_as_witness.rs b/snark-verifier-sdk/examples/vkey_as_witness.rs index 73d1cc70..ea76d63f 100644 --- a/snark-verifier-sdk/examples/vkey_as_witness.rs +++ b/snark-verifier-sdk/examples/vkey_as_witness.rs @@ -9,6 +9,7 @@ use halo2_proofs::halo2curves as halo2_curves; use halo2_proofs::{halo2curves::bn256::Bn256, poly::kzg::commitment::ParamsKZG}; use rand::rngs::OsRng; +use snark_verifier_sdk::halo2::aggregation::VerifierUniversality; use snark_verifier_sdk::SHPLONK; use snark_verifier_sdk::{ gen_pk, @@ -163,7 +164,7 @@ fn main() { lookup_bits, ¶ms, vec![dummy_snark], - true, + VerifierUniversality::PreprocessedAsWitness, ); agg_circuit.config(k, Some(10)); @@ -179,7 +180,7 @@ fn main() { lookup_bits, ¶ms, vec![snark], - true, + VerifierUniversality::PreprocessedAsWitness, ); let _snark = gen_snark_shplonk(¶ms, &pk, agg_circuit, None::<&str>); println!("snark {i} success"); diff --git a/snark-verifier-sdk/src/halo2.rs b/snark-verifier-sdk/src/halo2.rs index b0710230..06d5e622 100644 --- a/snark-verifier-sdk/src/halo2.rs +++ b/snark-verifier-sdk/src/halo2.rs @@ -35,6 +35,7 @@ use snark_verifier::{ AccumulationScheme, PolynomialCommitmentScheme, Query, }, system::halo2::{compile, Config}, + util::arithmetic::Rotation, util::transcript::TranscriptWrite, verifier::plonk::PlonkProof, }; @@ -122,20 +123,25 @@ where end_timer!(proof_time); // validate proof before caching - assert!({ - let mut transcript_read = - PoseidonTranscript::::from_spec(&proof[..], POSEIDON_SPEC.clone()); - VerificationStrategy::<_, V>::finalize( - verify_proof::<_, V, _, _, _>( - params.verifier_params(), - pk.get_vk(), - AccumulatorStrategy::new(params.verifier_params()), - &[instances.as_slice()], - &mut transcript_read, + assert!( + { + let mut transcript_read = PoseidonTranscript::::from_spec( + &proof[..], + POSEIDON_SPEC.clone(), + ); + VerificationStrategy::<_, V>::finalize( + verify_proof::<_, V, _, _, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript_read, + ) + .unwrap(), ) - .unwrap(), - ) - }); + }, + "SNARK proof failed to verify" + ); if let Some((instance_path, proof_path)) = path { write_instances(&instances, instance_path); @@ -286,7 +292,7 @@ where NativeLoader, Accumulator = KzgAccumulator, VerifyingKey = KzgAsVerifyingKey, - > + CostEstimation>>, + > + CostEstimation>>, { struct CsProxy(PhantomData<(F, C)>); diff --git a/snark-verifier-sdk/src/halo2/aggregation.rs b/snark-verifier-sdk/src/halo2/aggregation.rs index 6232f7d0..fbebeb23 100644 --- a/snark-verifier-sdk/src/halo2/aggregation.rs +++ b/snark-verifier-sdk/src/halo2/aggregation.rs @@ -57,6 +57,27 @@ pub struct SnarkAggregationWitness<'a> { pub preprocessed_digests: Option>>>, } +/// Different possible stages of universality the aggregation circuit can support +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub enum VerifierUniversality { + /// Default: verifier is specific to a single circuit + None, + /// Preprocessed digest (commitments to fixed columns) is loaded as witness + PreprocessedAsWitness, + /// Preprocessed as witness and number of rows in the circuit `n` loaded as witness + Full, +} + +impl VerifierUniversality { + pub fn preprocessed_as_witness(&self) -> bool { + self != &VerifierUniversality::None + } + + pub fn n_as_witness(&self) -> bool { + self == &VerifierUniversality::Full + } +} + #[allow(clippy::type_complexity)] /// Core function used in `synthesize` to aggregate multiple `snarks`. /// @@ -65,6 +86,7 @@ pub struct SnarkAggregationWitness<'a> { /// one vector per snark, for convenience. /// /// - `preprocessed_as_witness`: flag for whether preprocessed digest (i.e., verifying key) should be loaded as witness (if false, loaded as constant) +/// - If `preprocessed_as_witness` is true, number of circuit rows `domain.n` is also loaded as a witness /// /// # Assumptions /// * `snarks` is not empty @@ -73,7 +95,7 @@ pub fn aggregate<'a, AS>( loader: &Rc>, snarks: &[Snark], as_proof: &[u8], - preprocessed_as_witness: bool, + universality: VerifierUniversality, ) -> SnarkAggregationWitness<'a> where AS: PolynomialCommitmentScheme< @@ -107,12 +129,13 @@ where POSEIDON_SPEC.clone(), ); + let preprocessed_as_witness = universality.preprocessed_as_witness(); let mut accumulators = snarks .iter() .flat_map(|snark: &Snark| { let protocol = if preprocessed_as_witness { // always load `domain.n` as witness if vkey is witness - snark.protocol.loaded_preprocessed_as_witness(loader, true) + snark.protocol.loaded_preprocessed_as_witness(loader, universality.n_as_witness()) } else { snark.protocol.loaded(loader) }; @@ -129,7 +152,18 @@ where .chain( protocol.transcript_initial_state.clone().map(|scalar| scalar.into_assigned()), ) - .chain(protocol.n_as_witness.clone().map(|n| n.into_assigned())) // If `n` is witness, add it as part of input + .chain( + protocol + .domain_as_witness + .as_ref() + .map(|domain| domain.n.clone().into_assigned()), + ) // If `n` is witness, add it as part of input + .chain( + protocol + .domain_as_witness + .as_ref() + .map(|domain| domain.gen.clone().into_assigned()), + ) // If `n` is witness, add the generator of the order `n` subgroup as part of input .collect_vec(); let instances = assign_instances(&snark.instances); @@ -258,7 +292,7 @@ impl AggregationCircuit { lookup_bits: usize, params: &ParamsKZG, snarks: impl IntoIterator, - preprocessed_as_witness: bool, + universality: VerifierUniversality, ) -> Self where AS: for<'a> Halo2KzgAccumulationScheme<'a>, @@ -310,7 +344,7 @@ impl AggregationCircuit { let loader = Halo2Loader::new(ecc_chip, builder); let SnarkAggregationWitness { previous_instances, accumulator, preprocessed_digests } = - aggregate::(&svk, &loader, &snarks, as_proof.as_slice(), preprocessed_as_witness); + aggregate::(&svk, &loader, &snarks, as_proof.as_slice(), universality); let lhs = accumulator.lhs.assigned(); let rhs = accumulator.rhs.assigned(); let assigned_instances = lhs @@ -356,7 +390,14 @@ impl AggregationCircuit { where AS: for<'a> Halo2KzgAccumulationScheme<'a>, { - let mut private = Self::new::(stage, break_points, lookup_bits, params, snarks, false); + let mut private = Self::new::( + stage, + break_points, + lookup_bits, + params, + snarks, + VerifierUniversality::None, + ); private.expose_previous_instances(has_prev_accumulator); private } @@ -370,8 +411,14 @@ impl AggregationCircuit { let lookup_bits = BASE_CONFIG_PARAMS .with(|conf| conf.borrow().lookup_bits) .unwrap_or(params.k() as usize - 1); - let circuit = - Self::new::(CircuitBuilderStage::Keygen, None, lookup_bits, params, snarks, false); + let circuit = Self::new::( + CircuitBuilderStage::Keygen, + None, + lookup_bits, + params, + snarks, + VerifierUniversality::None, + ); circuit.config(params.k(), Some(10)); circuit } @@ -394,7 +441,7 @@ impl AggregationCircuit { lookup_bits, params, snarks, - false, + VerifierUniversality::None, ) } diff --git a/snark-verifier/src/loader/halo2/loader.rs b/snark-verifier/src/loader/halo2/loader.rs index 1f1bb16d..0ea79f88 100644 --- a/snark-verifier/src/loader/halo2/loader.rs +++ b/snark-verifier/src/loader/halo2/loader.rs @@ -309,12 +309,9 @@ impl> LoadedScalar for Sc fn pow_var(&self, exp: &Self, max_bits: usize) -> Self { let loader = self.loader(); - let res = loader.scalar_chip().pow_var( - &mut loader.ctx_mut(), - &self.assigned(), - &exp.assigned(), - max_bits, - ); + let base = self.clone().into_assigned(); + let exp = exp.clone().into_assigned(); + let res = loader.scalar_chip().pow_var(&mut loader.ctx_mut(), &base, &exp, max_bits); loader.scalar_from_assigned(res) } } diff --git a/snark-verifier/src/pcs.rs b/snark-verifier/src/pcs.rs index 65b1325b..1ca9eedc 100644 --- a/snark-verifier/src/pcs.rs +++ b/snark-verifier/src/pcs.rs @@ -3,7 +3,7 @@ use crate::{ loader::{native::NativeLoader, Loader}, util::{ - arithmetic::{CurveAffine, PrimeField}, + arithmetic::{CurveAffine, Rotation}, msm::Msm, transcript::{TranscriptRead, TranscriptWrite}, }, @@ -18,24 +18,26 @@ pub mod kzg; /// Query to an oracle. /// It assumes all queries are based on the same point, but with some `shift`. #[derive(Clone, Debug)] -pub struct Query { +pub struct Query { /// Index of polynomial to query pub poly: usize, /// Shift of the query point. - pub shift: F, + pub shift: S, + /// Shift loaded as either constant or witness. It is user's job to ensure this is correctly constrained to have value equal to `shift` + pub loaded_shift: T, /// Evaluation read from transcript. pub eval: T, } -impl Query { +impl Query { /// Initialize [`Query`] without evaluation. - pub fn new(poly: usize, shift: F) -> Self { - Self { poly, shift, eval: () } + pub fn new(poly: usize, shift: S) -> Self { + Self { poly, shift, loaded_shift: (), eval: () } } - /// Returns [`Query`] with evaluation. - pub fn with_evaluation(self, eval: T) -> Query { - Query { poly: self.poly, shift: self.shift, eval } + /// Returns [`Query`] with evaluation and optionally the shift are loaded as. + pub fn with_evaluation(self, loaded_shift: T, eval: T) -> Query { + Query { poly: self.poly, shift: self.shift, loaded_shift, eval } } } @@ -55,7 +57,7 @@ where /// Read [`PolynomialCommitmentScheme::Proof`] from transcript. fn read_proof( vk: &Self::VerifyingKey, - queries: &[Query], + queries: &[Query], transcript: &mut T, ) -> Result where @@ -66,7 +68,7 @@ where vk: &Self::VerifyingKey, commitments: &[Msm], point: &L::LoadedScalar, - queries: &[Query], + queries: &[Query], proof: &Self::Proof, ) -> Result; } diff --git a/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs b/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs index cae77a5f..87455114 100644 --- a/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs +++ b/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs @@ -5,7 +5,7 @@ use crate::{ PolynomialCommitmentScheme, Query, }, util::{ - arithmetic::{CurveAffine, FieldExt, Fraction}, + arithmetic::{CurveAffine, FieldExt, Fraction, Rotation}, msm::Msm, transcript::TranscriptRead, Itertools, @@ -35,7 +35,7 @@ where fn read_proof( svk: &Self::VerifyingKey, - queries: &[Query], + queries: &[Query], transcript: &mut T, ) -> Result where @@ -48,7 +48,7 @@ where svk: &Self::VerifyingKey, commitments: &[Msm], x: &L::LoadedScalar, - queries: &[Query], + queries: &[Query], proof: &Self::Proof, ) -> Result { let loader = x.loader(); @@ -119,7 +119,7 @@ where { fn read>( svk: &IpaSuccinctVerifyingKey, - queries: &[Query], + queries: &[Query], transcript: &mut T, ) -> Result { // Multiopen @@ -157,28 +157,33 @@ where } } -fn query_sets(queries: &[Query]) -> Vec> +fn query_sets(queries: &[Query]) -> Vec> where - F: FieldExt, + S: PartialEq + Ord + Copy, T: Clone, { let poly_shifts = - queries.iter().fold(Vec::<(usize, Vec, Vec<&T>)>::new(), |mut poly_shifts, query| { + queries.iter().fold(Vec::<(usize, Vec<_>, Vec<&T>)>::new(), |mut poly_shifts, query| { if let Some(pos) = poly_shifts.iter().position(|(poly, _, _)| *poly == query.poly) { let (_, shifts, evals) = &mut poly_shifts[pos]; - if !shifts.contains(&query.shift) { - shifts.push(query.shift); + if !shifts.iter().map(|(shift, _)| shift).contains(&query.shift) { + shifts.push((query.shift, query.loaded_shift.clone())); evals.push(&query.eval); } } else { - poly_shifts.push((query.poly, vec![query.shift], vec![&query.eval])); + poly_shifts.push(( + query.poly, + vec![(query.shift, query.loaded_shift.clone())], + vec![&query.eval], + )); } poly_shifts }); - poly_shifts.into_iter().fold(Vec::>::new(), |mut sets, (poly, shifts, evals)| { + poly_shifts.into_iter().fold(Vec::>::new(), |mut sets, (poly, shifts, evals)| { if let Some(pos) = sets.iter().position(|set| { - BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) + BTreeSet::from_iter(set.shifts.iter().map(|(shift, _)| shift)) + == BTreeSet::from_iter(shifts.iter().map(|(shift, _)| shift)) }) { let set = &mut sets[pos]; if !set.polys.contains(&poly) { @@ -187,7 +192,7 @@ where set.shifts .iter() .map(|lhs| { - let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); + let idx = shifts.iter().position(|rhs| lhs.0 == rhs.0).unwrap(); evals[idx] }) .collect(), @@ -201,18 +206,23 @@ where }) } -fn query_set_coeffs(sets: &[QuerySet], x: &T, x_3: &T) -> Vec> +fn query_set_coeffs( + sets: &[QuerySet], + x: &T, + x_3: &T, +) -> Vec> where F: FieldExt, T: LoadedScalar, { - let loader = x.loader(); - let superset = sets.iter().flat_map(|set| set.shifts.clone()).sorted().dedup(); + let superset = BTreeMap::from_iter(sets.iter().flat_map(|set| set.shifts.clone())); let size = sets.iter().map(|set| set.shifts.len()).chain(Some(2)).max().unwrap(); let powers_of_x = x.powers(size); let x_3_minus_x_shift_i = BTreeMap::from_iter( - superset.map(|shift| (shift, x_3.clone() - x.clone() * loader.load_const(&shift))), + superset + .into_iter() + .map(|(shift, loaded_shift)| (shift, x_3.clone() - x.clone() * loaded_shift)), ); let mut coeffs = sets @@ -228,23 +238,22 @@ where } #[derive(Clone, Debug)] -struct QuerySet<'a, F, T> { - shifts: Vec, +struct QuerySet<'a, S, T> { + shifts: Vec<(S, T)>, polys: Vec, evals: Vec>, } -impl<'a, F, T> QuerySet<'a, F, T> -where - F: FieldExt, - T: LoadedScalar, -{ +impl<'a, S, T> QuerySet<'a, S, T> { fn msm>( &self, commitments: &[Msm<'a, C, L>], q_eval: &T, powers_of_x_1: &[T], - ) -> Msm { + ) -> Msm + where + T: LoadedScalar, + { self.polys .iter() .rev() @@ -254,7 +263,10 @@ where - Msm::constant(q_eval.clone()) } - fn f_eval(&self, coeff: &QuerySetCoeff, q_eval: &T, powers_of_x_1: &[T]) -> T { + fn f_eval(&self, coeff: &QuerySetCoeff, q_eval: &T, powers_of_x_1: &[T]) -> T + where + T: LoadedScalar, + { let loader = q_eval.loader(); let r_eval = { let r_evals = self @@ -291,7 +303,12 @@ where F: FieldExt, T: LoadedScalar, { - fn new(shifts: &[F], powers_of_x: &[T], x_3: &T, x_3_minus_x_shift_i: &BTreeMap) -> Self { + fn new( + shifts: &[(Rotation, T)], + powers_of_x: &[T], + x_3: &T, + x_3_minus_x_shift_i: &BTreeMap, + ) -> Self { let loader = x_3.loader(); let normalized_ell_primes = shifts .iter() @@ -301,9 +318,9 @@ where .iter() .enumerate() .filter(|&(i, _)| i != j) - .map(|(_, shift_i)| (*shift_j - shift_i)) + .map(|(_, shift_i)| (shift_j.1.clone() - &shift_i.1)) .reduce(|acc, value| acc * value) - .unwrap_or_else(|| F::one()) + .unwrap_or_else(|| loader.load_const(&F::one())) }) .collect_vec(); @@ -313,17 +330,15 @@ where let barycentric_weights = shifts .iter() .zip(normalized_ell_primes.iter()) - .map(|(shift, normalized_ell_prime)| { - loader.sum_products_with_coeff(&[ - (*normalized_ell_prime, x_pow_k_minus_one, x_3), - (-(*normalized_ell_prime * shift), x_pow_k_minus_one, x), - ]) + .map(|((_, loaded_shift), normalized_ell_prime)| { + let tmp = normalized_ell_prime.clone() * x_pow_k_minus_one; + loader.sum_products(&[(&tmp, x_3), (&-(tmp.clone() * loaded_shift), x)]) }) .map(Fraction::one_over) .collect_vec(); let f_eval_coeff = Fraction::one_over(loader.product( - &shifts.iter().map(|shift| x_3_minus_x_shift_i.get(shift).unwrap()).collect_vec(), + &shifts.iter().map(|(shift, _)| x_3_minus_x_shift_i.get(shift).unwrap()).collect_vec(), )); Self { diff --git a/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs b/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs index 3a448056..e51ccc2e 100644 --- a/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs +++ b/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs @@ -6,7 +6,7 @@ use crate::{ PolynomialCommitmentScheme, Query, }, util::{ - arithmetic::{CurveAffine, FieldExt, Fraction, MultiMillerLoop}, + arithmetic::{CurveAffine, FieldExt, Fraction, MultiMillerLoop, Rotation}, msm::Msm, transcript::TranscriptRead, Itertools, @@ -35,7 +35,7 @@ where fn read_proof( _: &KzgSuccinctVerifyingKey, - _: &[Query], + _: &[Query], transcript: &mut T, ) -> Result, Error> where @@ -48,16 +48,15 @@ where svk: &KzgSuccinctVerifyingKey, commitments: &[Msm], z: &L::LoadedScalar, - queries: &[Query], + queries: &[Query], proof: &Bdfg21Proof, ) -> Result { let sets = query_sets(queries); let f = { let coeffs = query_set_coeffs(&sets, z, &proof.z_prime); - let powers_of_mu = proof - .mu - .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); + let powers_of_mu = + proof.mu.powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); let msms = sets .iter() .zip(coeffs.iter()) @@ -72,10 +71,7 @@ where let rhs = Msm::base(&proof.w_prime); let lhs = f + rhs.clone() * &proof.z_prime; - Ok(KzgAccumulator::new( - lhs.evaluate(Some(svk.g)), - rhs.evaluate(Some(svk.g)), - )) + Ok(KzgAccumulator::new(lhs.evaluate(Some(svk.g)), rhs.evaluate(Some(svk.g)))) } } @@ -104,94 +100,71 @@ where let w = transcript.read_ec_point()?; let z_prime = transcript.squeeze_challenge(); let w_prime = transcript.read_ec_point()?; - Ok(Bdfg21Proof { - mu, - gamma, - w, - z_prime, - w_prime, - }) + Ok(Bdfg21Proof { mu, gamma, w, z_prime, w_prime }) } } -fn query_sets(queries: &[Query]) -> Vec> { - let poly_shifts = queries.iter().fold( - Vec::<(usize, Vec, Vec<&T>)>::new(), - |mut poly_shifts, query| { - if let Some(pos) = poly_shifts - .iter() - .position(|(poly, _, _)| *poly == query.poly) - { +fn query_sets(queries: &[Query]) -> Vec> { + let poly_shifts = + queries.iter().fold(Vec::<(usize, Vec<_>, Vec<&T>)>::new(), |mut poly_shifts, query| { + if let Some(pos) = poly_shifts.iter().position(|(poly, _, _)| *poly == query.poly) { let (_, shifts, evals) = &mut poly_shifts[pos]; - if !shifts.contains(&query.shift) { - shifts.push(query.shift); + if !shifts.iter().map(|(shift, _)| shift).contains(&query.shift) { + shifts.push((query.shift, query.loaded_shift.clone())); evals.push(&query.eval); } } else { - poly_shifts.push((query.poly, vec![query.shift], vec![&query.eval])); + poly_shifts.push(( + query.poly, + vec![(query.shift, query.loaded_shift.clone())], + vec![&query.eval], + )); } poly_shifts - }, - ); - - poly_shifts.into_iter().fold( - Vec::>::new(), - |mut sets, (poly, shifts, evals)| { - if let Some(pos) = sets.iter().position(|set| { - BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) - }) { - let set = &mut sets[pos]; - if !set.polys.contains(&poly) { - set.polys.push(poly); - set.evals.push( - set.shifts - .iter() - .map(|lhs| { - let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); - evals[idx] - }) - .collect(), - ); - } - } else { - let set = QuerySet { - shifts, - polys: vec![poly], - evals: vec![evals], - }; - sets.push(set); + }); + + poly_shifts.into_iter().fold(Vec::>::new(), |mut sets, (poly, shifts, evals)| { + if let Some(pos) = sets.iter().position(|set| { + BTreeSet::from_iter(set.shifts.iter().map(|(shift, _)| shift)) + == BTreeSet::from_iter(shifts.iter().map(|(shift, _)| shift)) + }) { + let set = &mut sets[pos]; + if !set.polys.contains(&poly) { + set.polys.push(poly); + set.evals.push( + set.shifts + .iter() + .map(|lhs| { + let idx = shifts.iter().position(|rhs| lhs.0 == rhs.0).unwrap(); + evals[idx] + }) + .collect(), + ); } - sets - }, - ) + } else { + let set = QuerySet { shifts, polys: vec![poly], evals: vec![evals] }; + sets.push(set); + } + sets + }) } fn query_set_coeffs<'a, F: FieldExt, T: LoadedScalar>( - sets: &[QuerySet<'a, F, T>], + sets: &[QuerySet<'a, Rotation, T>], z: &T, z_prime: &T, ) -> Vec> { - let loader = z.loader(); + // map of shift => loaded_shift, removing duplicate `shift` values + // shift is the rotation, not omega^rotation, to ensure BTreeMap does not depend on omega (otherwise ordering can change) + let superset = BTreeMap::from_iter(sets.iter().flat_map(|set| set.shifts.clone())); - let superset = sets - .iter() - .flat_map(|set| set.shifts.clone()) - .sorted() - .dedup(); - - let size = sets - .iter() - .map(|set| set.shifts.len()) - .chain(Some(2)) - .max() - .unwrap(); + let size = sets.iter().map(|set| set.shifts.len()).chain(Some(2)).max().unwrap(); let powers_of_z = z.powers(size); - let z_prime_minus_z_shift_i = BTreeMap::from_iter(superset.map(|shift| { - ( - shift, - z_prime.clone() - z.clone() * loader.load_const(&shift), - ) - })); + let z_prime_minus_z_shift_i = BTreeMap::from_iter( + superset + .into_iter() + .map(|(shift, loaded_shift)| (shift, z_prime.clone() - z.clone() * loaded_shift)), + ); let mut z_s_1 = None; let mut coeffs = sets @@ -219,19 +192,22 @@ fn query_set_coeffs<'a, F: FieldExt, T: LoadedScalar>( } #[derive(Clone, Debug)] -struct QuerySet<'a, F, T> { - shifts: Vec, +struct QuerySet<'a, S, T> { + shifts: Vec<(S, T)>, // vec of (shift, loaded_shift) polys: Vec, evals: Vec>, } -impl<'a, F: FieldExt, T: LoadedScalar> QuerySet<'a, F, T> { +impl<'a, S, T> QuerySet<'a, S, T> { fn msm>( &self, - coeff: &QuerySetCoeff, + coeff: &QuerySetCoeff, commitments: &[Msm<'a, C, L>], powers_of_mu: &[T], - ) -> Msm { + ) -> Msm + where + T: LoadedScalar, + { self.polys .iter() .zip(self.evals.iter()) @@ -274,10 +250,10 @@ where T: LoadedScalar, { fn new( - shifts: &[F], + shifts: &[(Rotation, T)], powers_of_z: &[T], z_prime: &T, - z_prime_minus_z_shift_i: &BTreeMap, + z_prime_minus_z_shift_i: &BTreeMap, z_s_1: &Option, ) -> Self { let loader = z_prime.loader(); @@ -290,9 +266,9 @@ where .iter() .enumerate() .filter(|&(i, _)| i != j) - .map(|(_, shift_i)| (*shift_j - shift_i)) + .map(|(_, shift_i)| (shift_j.1.clone() - &shift_i.1)) .reduce(|acc, value| acc * value) - .unwrap_or_else(|| F::one()) + .unwrap_or_else(|| loader.load_const(&F::one())) }) .collect_vec(); @@ -302,11 +278,9 @@ where let barycentric_weights = shifts .iter() .zip(normalized_ell_primes.iter()) - .map(|(shift, normalized_ell_prime)| { - loader.sum_products_with_coeff(&[ - (*normalized_ell_prime, z_pow_k_minus_one, z_prime), - (-(*normalized_ell_prime * shift), z_pow_k_minus_one, z), - ]) + .map(|((_, loaded_shift), normalized_ell_prime)| { + let tmp = normalized_ell_prime.clone() * z_pow_k_minus_one; + loader.sum_products(&[(&tmp, z_prime), (&-(tmp.clone() * loaded_shift), z)]) }) .map(Fraction::one_over) .collect_vec(); @@ -314,7 +288,7 @@ where let z_s = loader.product( &shifts .iter() - .map(|shift| z_prime_minus_z_shift_i.get(shift).unwrap()) + .map(|(shift, _)| z_prime_minus_z_shift_i.get(shift).unwrap()) .collect_vec(), ); let z_s_1_over_z_s = z_s_1.clone().map(|z_s_1| Fraction::new(z_s_1, z_s.clone())); @@ -344,13 +318,8 @@ where .iter_mut() .chain(self.commitment_coeff.as_mut()) .for_each(Fraction::evaluate); - let barycentric_weights_sum = loader.sum( - &self - .eval_coeffs - .iter() - .map(Fraction::evaluated) - .collect_vec(), - ); + let barycentric_weights_sum = + loader.sum(&self.eval_coeffs.iter().map(Fraction::evaluated).collect_vec()); self.r_eval_coeff = Some(match self.commitment_coeff.as_ref() { Some(coeff) => Fraction::new(coeff.evaluated().clone(), barycentric_weights_sum), None => Fraction::one_over(barycentric_weights_sum), @@ -370,13 +339,9 @@ impl CostEstimation for KzgAs where M: MultiMillerLoop, { - type Input = Vec>; + type Input = Vec>; - fn estimate_cost(_: &Vec>) -> Cost { - Cost { - num_commitment: 2, - num_msm: 2, - ..Default::default() - } + fn estimate_cost(_: &Vec>) -> Cost { + Cost { num_commitment: 2, num_msm: 2, ..Default::default() } } } diff --git a/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs b/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs index e5741163..492d769f 100644 --- a/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs +++ b/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs @@ -6,7 +6,7 @@ use crate::{ PolynomialCommitmentScheme, Query, }, util::{ - arithmetic::{CurveAffine, MultiMillerLoop, PrimeField}, + arithmetic::{CurveAffine, MultiMillerLoop, Rotation}, msm::Msm, transcript::TranscriptRead, Itertools, @@ -31,7 +31,7 @@ where fn read_proof( _: &Self::VerifyingKey, - queries: &[Query], + queries: &[Query], transcript: &mut T, ) -> Result where @@ -44,22 +44,23 @@ where svk: &Self::VerifyingKey, commitments: &[Msm], z: &L::LoadedScalar, - queries: &[Query], + queries: &[Query], proof: &Self::Proof, ) -> Result { let sets = query_sets(queries); let powers_of_u = &proof.u.powers(sets.len()); let f = { - let powers_of_v = proof - .v - .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); + let powers_of_v = proof.v.powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); sets.iter() .map(|set| set.msm(commitments, &powers_of_v)) .zip(powers_of_u.iter()) .map(|(msm, power_of_u)| msm * power_of_u) .sum::>() }; - let z_omegas = sets.iter().map(|set| z.loader().load_const(&set.shift) * z); + let z_omegas = sets.iter().map(|set| { + let loaded_shift = set.loaded_shift.clone(); + loaded_shift * z + }); let rhs = proof .ws @@ -67,11 +68,7 @@ where .zip(powers_of_u.iter()) .map(|(w, power_of_u)| Msm::base(w) * power_of_u) .collect_vec(); - let lhs = f + rhs - .iter() - .zip(z_omegas) - .map(|(uw, z_omega)| uw.clone() * &z_omega) - .sum(); + let lhs = f + rhs.iter().zip(z_omegas).map(|(uw, z_omega)| uw.clone() * &z_omega).sum(); Ok(KzgAccumulator::new( lhs.evaluate(Some(svk.g)), @@ -97,7 +94,7 @@ where C: CurveAffine, L: Loader, { - fn read(queries: &[Query], transcript: &mut T) -> Result + fn read(queries: &[Query], transcript: &mut T) -> Result where T: TranscriptRead, { @@ -108,22 +105,25 @@ where } } -struct QuerySet<'a, F, T> { - shift: F, +struct QuerySet<'a, S, T> { + shift: S, + loaded_shift: T, polys: Vec, evals: Vec<&'a T>, } -impl<'a, F, T> QuerySet<'a, F, T> +impl<'a, S, T> QuerySet<'a, S, T> where - F: PrimeField, T: Clone, { fn msm>( &self, commitments: &[Msm<'a, C, L>], powers_of_v: &[L::LoadedScalar], - ) -> Msm { + ) -> Msm + where + T: LoadedScalar, + { self.polys .iter() .zip(self.evals.iter().cloned()) @@ -137,9 +137,9 @@ where } } -fn query_sets(queries: &[Query]) -> Vec> +fn query_sets(queries: &[Query]) -> Vec> where - F: PrimeField, + S: PartialEq + Copy, T: Clone + PartialEq, { queries.iter().fold(Vec::new(), |mut sets, query| { @@ -149,6 +149,7 @@ where } else { sets.push(QuerySet { shift: query.shift, + loaded_shift: query.loaded_shift.clone(), polys: vec![query.poly], evals: vec![&query.eval], }); @@ -161,14 +162,10 @@ impl CostEstimation for KzgAs where M: MultiMillerLoop, { - type Input = Vec>; + type Input = Vec>; - fn estimate_cost(queries: &Vec>) -> Cost { + fn estimate_cost(queries: &Vec>) -> Cost { let num_w = query_sets(queries).len(); - Cost { - num_commitment: num_w, - num_msm: num_w, - ..Default::default() - } + Cost { num_commitment: num_w, num_msm: num_w, ..Default::default() } } } diff --git a/snark-verifier/src/system/halo2.rs b/snark-verifier/src/system/halo2.rs index 9715d996..be7a0cdc 100644 --- a/snark-verifier/src/system/halo2.rs +++ b/snark-verifier/src/system/halo2.rs @@ -141,7 +141,7 @@ pub fn compile<'a, C: CurveAffine, P: Params<'a, C>>( PlonkProtocol { domain, - n_as_witness: None, + domain_as_witness: None, preprocessed, num_instance: polynomials.num_instance(), num_witness: polynomials.num_witness(), diff --git a/snark-verifier/src/verifier/plonk.rs b/snark-verifier/src/verifier/plonk.rs index 54fbaf74..9c910400 100644 --- a/snark-verifier/src/verifier/plonk.rs +++ b/snark-verifier/src/verifier/plonk.rs @@ -15,7 +15,10 @@ use crate::{ AccumulationDecider, AccumulationScheme, AccumulatorEncoding, PolynomialCommitmentScheme, Query, }, - util::{arithmetic::CurveAffine, transcript::TranscriptRead}, + util::{ + arithmetic::{CurveAffine, Rotation}, + transcript::TranscriptRead, + }, verifier::{plonk::protocol::CommonPolynomialEvaluation, SnarkVerifier}, Error, }; @@ -66,7 +69,7 @@ where &protocol.domain, protocol.langranges(), &proof.z, - &protocol.n_as_witness, + &protocol.domain_as_witness, ); L::batch_invert(common_poly_eval.denoms()); @@ -144,7 +147,7 @@ where L: Loader, AS: AccumulationScheme + PolynomialCommitmentScheme - + CostEstimation>>, + + CostEstimation>>, { type Input = PlonkProtocol; @@ -172,7 +175,7 @@ where L: Loader, AS: AccumulationScheme + PolynomialCommitmentScheme - + CostEstimation>>, + + CostEstimation>>, { type Input = PlonkProtocol; diff --git a/snark-verifier/src/verifier/plonk/proof.rs b/snark-verifier/src/verifier/plonk/proof.rs index 7adba7ac..5536dd06 100644 --- a/snark-verifier/src/verifier/plonk/proof.rs +++ b/snark-verifier/src/verifier/plonk/proof.rs @@ -13,7 +13,10 @@ use crate::{ }, Error, }; -use std::{collections::HashMap, iter}; +use std::{ + collections::{BTreeMap, HashMap}, + iter, +}; /// Proof of PLONK with [`PolynomialCommitmentScheme`] that has /// [`AccumulationScheme`]. @@ -153,26 +156,42 @@ where } /// Empty queries - pub fn empty_queries(protocol: &PlonkProtocol) -> Vec> { - protocol - .queries - .iter() - .map(|query| { - let shift = protocol.domain.rotate_scalar(C::Scalar::one(), query.rotation); - pcs::Query::new(query.poly, shift) - }) - .collect() + pub fn empty_queries(protocol: &PlonkProtocol) -> Vec> { + // `preprocessed` should always be non-empty, unless the circuit has no constraints or constants + protocol.queries.iter().map(|query| pcs::Query::new(query.poly, query.rotation)).collect() } pub(super) fn queries( &self, protocol: &PlonkProtocol, mut evaluations: HashMap, - ) -> Vec> { + ) -> Vec> { + if protocol.queries.is_empty() { + return vec![]; + } + let loader = evaluations[&protocol.queries[0]].loader(); + let rotations = + protocol.queries.iter().map(|query| query.rotation).sorted().dedup().collect_vec(); + let loaded_shifts = if let Some(domain) = protocol.domain_as_witness.as_ref() { + // the `rotation`s are still constants, it is only generator `omega` that might be witness + BTreeMap::from_iter( + rotations.into_iter().map(|rotation| (rotation, domain.rotate_one(rotation))), + ) + } else { + BTreeMap::from_iter(rotations.into_iter().map(|rotation| { + ( + rotation, + loader.load_const(&protocol.domain.rotate_scalar(C::Scalar::one(), rotation)), + ) + })) + }; Self::empty_queries(protocol) .into_iter() .zip(protocol.queries.iter().map(|query| evaluations.remove(query).unwrap())) - .map(|(query, eval)| query.with_evaluation(eval)) + .map(|(query, eval)| { + let shift = loaded_shifts[&query.shift].clone(); + query.with_evaluation(shift, eval) + }) .collect() } diff --git a/snark-verifier/src/verifier/plonk/protocol.rs b/snark-verifier/src/verifier/plonk/protocol.rs index 6256d522..023b6ebb 100644 --- a/snark-verifier/src/verifier/plonk/protocol.rs +++ b/snark-verifier/src/verifier/plonk/protocol.rs @@ -9,13 +9,44 @@ use num_integer::Integer; use num_traits::One; use serde::{Deserialize, Serialize}; use std::{ - cmp::max, + cmp::{max, Ordering}, collections::{BTreeMap, BTreeSet}, fmt::Debug, iter::{self, Sum}, ops::{Add, Mul, Neg, Sub}, }; +/// Domain parameters to be optionally loaded as witnesses +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct DomainAsWitness +where + C: CurveAffine, + L: Loader, +{ + /// Number of rows in the domain + pub n: L::LoadedScalar, + /// Generator of the domain + pub gen: L::LoadedScalar, + /// Inverse generator of the domain + pub gen_inv: L::LoadedScalar, +} + +impl DomainAsWitness +where + C: CurveAffine, + L: Loader, +{ + /// Rotate `F::one()` to given `rotation`. + pub fn rotate_one(&self, rotation: Rotation) -> L::LoadedScalar { + let loader = self.gen.loader(); + match rotation.0.cmp(&0) { + Ordering::Equal => loader.load_one(), + Ordering::Greater => self.gen.pow_const(rotation.0 as u64), + Ordering::Less => self.gen_inv.pow_const(-rotation.0 as u64), + } + } +} + /// Protocol specifying configuration of a PLONK. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct PlonkProtocol @@ -34,8 +65,8 @@ where serialize = "L::LoadedScalar: Serialize", deserialize = "L::LoadedScalar: Deserialize<'de>" ))] - /// Optional: load `domain.n` as a witness - pub n_as_witness: Option, + /// Optional: load `domain.n` and `domain.gen` as a witness + pub domain_as_witness: Option>, #[serde(bound( serialize = "L::LoadedEcPoint: Serialize", @@ -123,7 +154,7 @@ where .map(|transcript_initial_state| loader.load_const(transcript_initial_state)); PlonkProtocol { domain: self.domain.clone(), - n_as_witness: None, + domain_as_witness: None, preprocessed, num_instance: self.num_instance.clone(), num_witness: self.num_witness.clone(), @@ -142,12 +173,17 @@ where #[cfg(feature = "loader_halo2")] mod halo2 { use crate::{ - loader::halo2::{EccInstructions, Halo2Loader}, + loader::{ + halo2::{EccInstructions, Halo2Loader}, + LoadedScalar, + }, util::arithmetic::CurveAffine, verifier::plonk::PlonkProtocol, }; use std::rc::Rc; + use super::DomainAsWitness; + impl PlonkProtocol where C: CurveAffine, @@ -160,8 +196,13 @@ mod halo2 { loader: &Rc>, load_n_as_witness: bool, ) -> PlonkProtocol>> { - let n_as_witness = load_n_as_witness - .then(|| loader.assign_scalar(C::Scalar::from(self.domain.n as u64))); + let domain_as_witness = load_n_as_witness.then(|| { + let n = loader.assign_scalar(C::Scalar::from(self.domain.n as u64)); + let gen = loader.assign_scalar(self.domain.gen); + let gen_inv = gen.invert().expect("subgroup generation is invertible"); + DomainAsWitness { n, gen, gen_inv } + }); + let preprocessed = self .preprocessed .iter() @@ -173,7 +214,7 @@ mod halo2 { .map(|transcript_initial_state| loader.assign_scalar(*transcript_initial_state)); PlonkProtocol { domain: self.domain.clone(), - n_as_witness, + domain_as_witness, preprocessed, num_instance: self.num_instance.clone(), num_witness: self.num_witness.clone(), @@ -216,35 +257,37 @@ where { // if `n_as_witness` is Some, then we assume `n_as_witness` has value equal to `domain.n` (i.e., number of rows in the circuit) // and is loaded as a witness instead of a constant. + // The generator of `domain` also depends on `n`. pub fn new( domain: &Domain, - langranges: impl IntoIterator, + lagranges: impl IntoIterator, z: &L::LoadedScalar, - n_as_witness: &Option, + domain_as_witness: &Option>, ) -> Self { let loader = z.loader(); - let zn = if let Some(n) = n_as_witness.as_ref() { - z.pow_var(n, C::Scalar::S as usize + 1) + let lagranges = lagranges.into_iter().sorted().dedup().collect_vec(); + let one = loader.load_one(); + + let (zn, n_inv, omegas) = if let Some(domain) = domain_as_witness.as_ref() { + let zn = z.pow_var(&domain.n, C::Scalar::S as usize + 1); + let n_inv = domain.n.invert().expect("n is not zero"); + let omegas = lagranges.iter().map(|&i| domain.rotate_one(Rotation(i))).collect_vec(); + (zn, n_inv, omegas) } else { - z.pow_const(domain.n as u64) + let zn = z.pow_const(domain.n as u64); + let n_inv = loader.load_const(&domain.n_inv); + let omegas = lagranges + .iter() + .map(|&i| loader.load_const(&domain.rotate_scalar(C::Scalar::one(), Rotation(i)))) + .collect_vec(); + (zn, n_inv, omegas) }; - let langranges = langranges.into_iter().sorted().dedup().collect_vec(); - let one = loader.load_one(); let zn_minus_one = zn.clone() - &one; let zn_minus_one_inv = Fraction::one_over(zn_minus_one.clone()); - let n_inv = if let Some(n) = n_as_witness.as_ref() { - n.invert().expect("n is not zero") - } else { - loader.load_const(&domain.n_inv) - }; let numer = zn_minus_one.clone() * &n_inv; - let omegas = langranges - .iter() - .map(|&i| loader.load_const(&domain.rotate_scalar(C::Scalar::one(), Rotation(i)))) - .collect_vec(); let lagrange_evals = omegas .iter() .map(|omega| Fraction::new(numer.clone() * omega, z.clone() - omega)) @@ -255,7 +298,7 @@ where zn_minus_one, zn_minus_one_inv, identity: z.clone(), - lagrange: langranges.into_iter().zip(lagrange_evals).collect(), + lagrange: lagranges.into_iter().zip(lagrange_evals).collect(), } }