diff --git a/arkworks/groth16/src/lib.rs b/arkworks/groth16/src/lib.rs index def723a5..2d177920 100644 --- a/arkworks/groth16/src/lib.rs +++ b/arkworks/groth16/src/lib.rs @@ -34,6 +34,8 @@ pub mod prover; /// Verify proofs for the Groth16 zkSNARK construction. pub mod verifier; +pub mod reveal; + /// Constraints for the Groth16 verifier. #[cfg(feature = "r1cs")] pub mod constraints; diff --git a/examples/bin_test_groth16.rs b/examples/bin_test_groth16.rs index 51066f8c..d600b73d 100644 --- a/examples/bin_test_groth16.rs +++ b/examples/bin_test_groth16.rs @@ -16,5 +16,8 @@ struct Opt { fn main() { let opt = Opt::from_args(); Net::init_from_file(opt.input.to_str().unwrap(), opt.id); - // groth16::mpc_test_prove_and_verify(1); + zk_mpc::groth16::mpc_test_prove_and_verify::< + ark_bls12_377::Bls12_377, + mpc_algebra::AdditivePairingShare, + >(1); } diff --git a/mpc-algebra/src/share/additive.rs b/mpc-algebra/src/share/additive.rs index c477536c..55cab1e7 100644 --- a/mpc-algebra/src/share/additive.rs +++ b/mpc-algebra/src/share/additive.rs @@ -402,6 +402,11 @@ impl Reveal for AdditiveGroupShare { macro_rules! impl_group_basics { ($share:ident, $bound:ident) => { + impl Display for $share { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.val) + } + } impl Debug for $share { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{:?}", self.val) diff --git a/mpc-algebra/src/share/group.rs b/mpc-algebra/src/share/group.rs index 7607bbc0..f5f5e92d 100644 --- a/mpc-algebra/src/share/group.rs +++ b/mpc-algebra/src/share/group.rs @@ -5,17 +5,22 @@ use ark_serialize::{ CanonicalDeserialize, CanonicalDeserializeWithFlags, CanonicalSerialize, CanonicalSerializeWithFlags, }; +use ark_std::end_timer; +use ark_std::start_timer; use std::fmt::Debug; +use std::fmt::Display; use std::hash::Hash; use crate::Reveal; use super::field::FieldShare; +use super::BeaverSource; pub trait GroupShare: Clone + Copy + Debug + + Display + Send + Sync + Eq @@ -63,6 +68,47 @@ pub trait GroupShare: fn shift(&mut self, other: &G) -> &mut Self; + fn scale>( + self, + other: Self::FieldShare, + source: &mut S, + ) -> Self { + let timer = start_timer!(|| "SS scalar multiplication"); + let (mut x, y, z) = source.triple(); + let s = self; + let o = other; + // output: z - open(s + x)y - x*open(o + y) + open(s + x)open(o + y) + // xy - sy - xy - ox - yx + so + sy + xo + xy + // so + let mut sx = { + let mut t = s; + t.add(&x).open() + }; + let oy = { + let mut t = o; + t.add(&y).open() + }; + let mut out = z.clone(); + out.sub(&Self::scale_pub_group(sx.clone(), &y)); + out.sub(x.scale_pub_scalar(&oy)); + sx *= oy; + out.shift(&sx); + #[cfg(debug_assertions)] + { + let a = s.reveal(); + let b = o.reveal(); + let mut acp = a.clone(); + acp *= b; + let r = out.reveal(); + if acp != r { + println!("Bad multiplication!.\n{}\n*\n{}\n=\n{}", a, b, r); + panic!("Bad multiplication"); + } + } + end_timer!(timer); + out + } + /// Compute \sum_i (s_i * g_i) /// where the s_i are shared and the g_i are public. fn multi_scale_pub_group(bases: &[G], scalars: &[Self::FieldShare]) -> Self { diff --git a/mpc-algebra/src/wire/group.rs b/mpc-algebra/src/wire/group.rs index d0f90528..98c34d3b 100644 --- a/mpc-algebra/src/wire/group.rs +++ b/mpc-algebra/src/wire/group.rs @@ -1,5 +1,6 @@ use std::fmt::{self, Display}; use std::io::{self, Read, Write}; +use std::marker::PhantomData; use std::ops::*; use std::iter::Sum; @@ -12,10 +13,12 @@ use ark_serialize::{ CanonicalSerializeWithFlags, }; use ark_serialize::{Flags, SerializationError}; +use derivative::Derivative; +use mpc_net::{MpcMultiNet as Net, MpcNet}; use mpc_trait::MpcWire; use crate::share::group::GroupShare; -use crate::Reveal; +use crate::{BeaverSource, Reveal}; use super::field::MpcField; @@ -25,6 +28,45 @@ pub enum MpcGroup> { Shared(S), } +#[derive(Derivative)] +#[derivative(Default(bound = ""), Clone(bound = ""), Copy(bound = ""))] +pub struct DummyGroupTripleSource { + _scalar: PhantomData, + _share: PhantomData, +} + +impl> BeaverSource + for DummyGroupTripleSource +{ + #[inline] + fn triple(&mut self) -> (S, S::FieldShare, S) { + ( + S::from_add_shared(T::zero()), + ::from_add_shared(if Net::am_king() { + T::ScalarField::one() + } else { + T::ScalarField::zero() + }), + S::from_add_shared(T::zero()), + ) + } + #[inline] + fn inv_pair(&mut self) -> (S::FieldShare, S::FieldShare) { + ( + ::from_add_shared(if Net::am_king() { + T::ScalarField::one() + } else { + T::ScalarField::zero() + }), + ::from_add_shared(if Net::am_king() { + T::ScalarField::one() + } else { + T::ScalarField::zero() + }), + ) + } +} + impl> MpcGroup { pub fn map, FT: Fn(T) -> TT, FS: Fn(S) -> SS>( self, @@ -78,8 +120,11 @@ impl> Mul> fo } impl> Display for MpcGroup { - fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> fmt::Result { - todo!() + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result { + match self { + MpcGroup::Public(x) => write!(f, "{x} (public)"), + MpcGroup::Shared(x) => write!(f, "{x} (shared)"), + } } } @@ -308,16 +353,18 @@ impl<'a, T: Group, S: GroupShare> MulAssign<&'a MpcField { *x *= *y; } - MpcField::Shared(_y) => { - todo!() + MpcField::Shared(y) => { + let t = MpcGroup::Shared(S::scale_pub_group(*x, y)); + *self = t; } }, - MpcGroup::Shared(_x) => match other { - MpcField::Public(_y) => { - todo!() + MpcGroup::Shared(x) => match other { + MpcField::Public(y) => { + x.scale_pub_scalar(y); } - MpcField::Shared(_y) => { - todo!() + MpcField::Shared(y) => { + let t = x.scale(*y, &mut DummyGroupTripleSource::default()); + *x = t; } }, } diff --git a/mpc-algebra/src/wire/pairing.rs b/mpc-algebra/src/wire/pairing.rs index ce3e5314..9042b546 100644 --- a/mpc-algebra/src/wire/pairing.rs +++ b/mpc-algebra/src/wire/pairing.rs @@ -549,13 +549,14 @@ macro_rules! impl_pairing_curve_wrapper { } } #[inline] - fn from_add_shared(_t: Self::Base) -> Self { - todo!() + fn from_add_shared(t: Self::Base) -> Self { + Self { + val: $wrapped::from_add_shared(t), + } } #[inline] fn unwrap_as_public(self) -> Self::Base { - // self.val.unwrap_as_public() - todo!() + self.val.unwrap_as_public() } #[inline] fn king_share(_f: Self::Base, _rng: &mut R) -> Self { @@ -578,8 +579,8 @@ macro_rules! impl_pairing_curve_wrapper { } impl> MulAssign> for $wrap { - fn mul_assign(&mut self, _rhs: MpcField) { - todo!() + fn mul_assign(&mut self, other: MpcField) { + self.val.mul_assign(other); } } }; diff --git a/src/circuits/circuit.rs b/src/circuits/circuit.rs index d5ce7f0a..c9ad3a1a 100644 --- a/src/circuits/circuit.rs +++ b/src/circuits/circuit.rs @@ -76,14 +76,12 @@ impl> ConstraintSynthesizer> { +pub struct MySimpleCircuit { pub a: Option, pub b: Option, } -impl> ConstraintSynthesizer - for MySimpleCircuit -{ +impl ConstraintSynthesizer for MySimpleCircuit { fn generate_constraints( self, cs: ConstraintSystemRef, diff --git a/src/groth16.rs b/src/groth16.rs index 145530f7..de52bf3d 100644 --- a/src/groth16.rs +++ b/src/groth16.rs @@ -2,6 +2,342 @@ //! //! This module provides functions for setting up, proving, and verifying MPC (Multi-Party Computation) circuits using the Groth16 zkSNARK. +use std::ops::{AddAssign, Deref}; + +use ark_ec::ProjectiveCurve; +use ark_ec::{AffineCurve, PairingEngine}; +use ark_ff::{One, PrimeField, UniformRand, Zero}; +use ark_groth16::{ + generate_random_parameters, prepare_verifying_key, verify_proof, Proof, ProvingKey, +}; +use ark_poly::{EvaluationDomain, GeneralEvaluationDomain}; +use ark_relations::r1cs::{ + ConstraintSynthesizer, ConstraintSystem, ConstraintSystemRef, OptimizationGoal, + Result as R1CSResult, SynthesisError, +}; +use ark_std::{cfg_iter, cfg_iter_mut, end_timer, start_timer, test_rng}; +// use log::debug; +use mpc_algebra::{MpcField, MpcPairingEngine, PairingShare, Reveal}; +use rand::Rng; + +use crate::circuits::circuit::MySimpleCircuit; + +/// Create a Groth16 proof that is zero-knowledge. +/// This method samples randomness for zero knowledges via `rng`. +#[inline] +pub fn create_random_proof( + circuit: C, + pk: &ProvingKey, + rng: &mut R, +) -> R1CSResult> +where + E: PairingEngine, + //E::Fr: BatchProd, + C: ConstraintSynthesizer<::Fr>, + R: Rng, +{ + //use ark_ff::One; + //let r = ::Fr::one(); + //let s = ::Fr::one(); + let t = start_timer!(|| "zk sampling"); + let r = ::Fr::rand(rng); + let s = ::Fr::rand(rng); + end_timer!(t); + + create_proof::(circuit, pk, r, s) +} + +/// Create a Groth16 proof that is *not* zero-knowledge. +#[inline] +pub fn create_proof_no_zk(circuit: C, pk: &ProvingKey) -> R1CSResult> +where + E: PairingEngine, + //E::Fr: BatchProd, + C: ConstraintSynthesizer<::Fr>, +{ + create_proof::( + circuit, + pk, + ::Fr::zero(), + ::Fr::zero(), + ) +} + +/// Create a Groth16 proof using randomness `r` and `s`. +#[inline] +pub fn create_proof( + circuit: C, + pk: &ProvingKey, + r: ::Fr, + s: ::Fr, +) -> R1CSResult> +where + E: PairingEngine, + //E::Fr: BatchProd, + C: ConstraintSynthesizer<::Fr>, +{ + println!("r: {}", r); + println!("s: {}", s); + type D = GeneralEvaluationDomain; + + let prover_time = start_timer!(|| "Groth16::Prover"); + let cs = ConstraintSystem::new_ref(); + + // Set the optimization goal + cs.set_optimization_goal(OptimizationGoal::Constraints); + + // Synthesize the circuit. + let synthesis_time = start_timer!(|| "Constraint synthesis"); + circuit.generate_constraints(cs.clone())?; + //debug_assert!(cs.is_satisfied().unwrap()); + end_timer!(synthesis_time); + + let lc_time = start_timer!(|| "Inlining LCs"); + cs.finalize(); + end_timer!(lc_time); + + let witness_map_time = start_timer!(|| "R1CS to QAP witness map"); + let h = R1CStoQAP::witness_map::<::Fr, D<::Fr>>( + cs.clone(), + )?; + end_timer!(witness_map_time); + let prover_crypto_time = start_timer!(|| "crypto"); + let c_acc_time = start_timer!(|| "Compute C"); + let h_acc = <::G1Affine as AffineCurve>::multi_scalar_mul(&pk.h_query, &h); + println!("h_acc: {}", h_acc); + // Compute C + let prover = cs.borrow().unwrap(); + let l_aux_acc = <::G1Affine as AffineCurve>::multi_scalar_mul( + &pk.l_query, + &prover.witness_assignment, + ); + + let r_s_delta_g1 = pk.delta_g1.into_projective().scalar_mul(&r).scalar_mul(&s); + println!("r_s_delta_g1: {}", r_s_delta_g1); + + end_timer!(c_acc_time); + + let assignment: Vec<::Fr> = prover.instance_assignment[1..] + .iter() + .chain(prover.witness_assignment.iter()) + .cloned() + .collect(); + drop(prover); + drop(cs); + + // Compute A + let a_acc_time = start_timer!(|| "Compute A"); + let r_g1 = pk.delta_g1.scalar_mul(r); + println!("r_g1: {}", r_g1); + // debug!("Assignment:"); + // for (i, a) in assignment.iter().enumerate() { + // debug!(" a[{}]: {}", i, a); + // } + + let g_a = calculate_coeff(r_g1, &pk.a_query, pk.vk.alpha_g1, &assignment); + println!("g_a: {}", g_a); + + let s_g_a = g_a.scalar_mul(&s); + println!("s_g_a: {}", s_g_a); + end_timer!(a_acc_time); + + // Compute B in G1 if needed + // let g1_b = if !r.is_zero() { + let b_g1_acc_time = start_timer!(|| "Compute B in G1"); + let s_g1 = pk.delta_g1.scalar_mul(s); + let g1_b = calculate_coeff(s_g1, &pk.b_g1_query, pk.beta_g1, &assignment); + + end_timer!(b_g1_acc_time); + // + // g1_b + // } else { + // ::G1Projective::zero() + // }; + + // Compute B in G2 + let b_g2_acc_time = start_timer!(|| "Compute B in G2"); + let s_g2 = pk.vk.delta_g2.scalar_mul(s); + let g2_b = calculate_coeff(s_g2, &pk.b_g2_query, pk.vk.beta_g2, &assignment); + let r_g1_b = g1_b.scalar_mul(&r); + println!("r_g1_b: {}", r_g1_b); + drop(assignment); + + end_timer!(b_g2_acc_time); + + let c_time = start_timer!(|| "Finish C"); + let mut g_c = s_g_a; + g_c += &r_g1_b; + g_c -= &r_s_delta_g1; + g_c += &l_aux_acc; + g_c += &h_acc; + end_timer!(c_time); + end_timer!(prover_crypto_time); + + end_timer!(prover_time); + + Ok(Proof { + a: g_a.into_affine(), + b: g2_b.into_affine(), + c: g_c.into_affine(), + }) +} + +fn calculate_coeff( + initial: G::Projective, + query: &[G], + vk_param: G, + assignment: &[G::ScalarField], +) -> G::Projective where { + let el = query[0]; + let t = start_timer!(|| format!("MSM size {} {}", query.len() - 1, assignment.len())); + let acc = G::multi_scalar_mul(&query[1..], assignment); + end_timer!(t); + let mut res = initial; + res.add_assign_mixed(&el); + res += &acc; + res.add_assign_mixed(&vk_param); + + res +} + +/// r1cs to qap +#[inline] +fn evaluate_constraint<'a, LHS, RHS, R>(terms: &'a [(LHS, usize)], assignment: &'a [RHS]) -> R +where + LHS: One + Send + Sync + PartialEq, + RHS: Send + Sync + core::ops::Mul<&'a LHS, Output = RHS> + Copy, + R: Zero + Send + Sync + AddAssign + core::iter::Sum, +{ + // Need to wrap in a closure when using Rayon + #[cfg(feature = "parallel")] + let zero = || R::zero(); + #[cfg(not(feature = "parallel"))] + let zero = R::zero(); + + let res = cfg_iter!(terms).fold(zero, |mut sum, (coeff, index)| { + let val = &assignment[*index]; + + if coeff.is_one() { + sum += *val; + } else { + sum += val.mul(coeff); + } + + sum + }); + + // Need to explicitly call `.sum()` when using Rayon + #[cfg(feature = "parallel")] + return res.sum(); + #[cfg(not(feature = "parallel"))] + return res; +} + +pub struct R1CStoQAP; + +impl R1CStoQAP { + #[inline] + pub fn witness_map>( + prover: ConstraintSystemRef, + ) -> R1CSResult> { + let matrices = prover.to_matrices().unwrap(); + let zero = F::zero(); + let num_inputs = prover.num_instance_variables(); + let num_constraints = prover.num_constraints(); + let cs = prover.borrow().unwrap(); + let prover = cs.deref(); + + let full_assignment = [ + prover.instance_assignment.as_slice(), + prover.witness_assignment.as_slice(), + ] + .concat(); + + let domain = + D::new(num_constraints + num_inputs).ok_or(SynthesisError::PolynomialDegreeTooLarge)?; + let domain_size = domain.size(); + + let mut a = vec![zero; domain_size]; + let mut b = vec![zero; domain_size]; + + cfg_iter_mut!(a[..num_constraints]) + .zip(cfg_iter_mut!(b[..num_constraints])) + .zip(cfg_iter!(&matrices.a)) + .zip(cfg_iter!(&matrices.b)) + .for_each(|(((a, b), at_i), bt_i)| { + *a = evaluate_constraint(&at_i, &full_assignment); + *b = evaluate_constraint(&bt_i, &full_assignment); + }); + + { + let start = num_constraints; + let end = start + num_inputs; + a[start..end].clone_from_slice(&full_assignment[..num_inputs]); + } + + domain.ifft_in_place(&mut a); + domain.ifft_in_place(&mut b); + + domain.coset_fft_in_place(&mut a); + domain.coset_fft_in_place(&mut b); + let mut ab = a.clone(); + let batch_product_timer = start_timer!(|| "batch product"); + F::batch_product_in_place(&mut ab, &b); + end_timer!(batch_product_timer); + + let mut c = vec![zero; domain_size]; + cfg_iter_mut!(c[..prover.num_constraints]) + .enumerate() + .for_each(|(i, c)| { + *c = evaluate_constraint(&matrices.c[i], &full_assignment); + }); + + domain.ifft_in_place(&mut c); + domain.coset_fft_in_place(&mut c); + + cfg_iter_mut!(ab) + .zip(c) + .for_each(|(ab_i, c_i)| *ab_i -= &c_i); + + domain.divide_by_vanishing_poly_on_coset_in_place(&mut ab); + domain.coset_ifft_in_place(&mut ab); + + Ok(ab) + } +} + +pub fn mpc_test_prove_and_verify>(n_iters: usize) { + let rng = &mut test_rng(); + + let params = + generate_random_parameters::(MySimpleCircuit { a: None, b: None }, rng).unwrap(); + + let pvk = prepare_verifying_key::(¶ms.vk); + let mpc_params: ProvingKey> = ProvingKey::from_public(params); + + for _ in 0..n_iters { + let a = MpcField::::rand(rng); + let b = MpcField::::rand(rng); + + let mut c = a; + c *= &b; + + let mpc_circuit = MySimpleCircuit { + a: Some(a), + b: Some(b), + }; + + let mpc_proof = create_random_proof(mpc_circuit, &mpc_params, rng).unwrap(); + + let proof = mpc_proof.reveal(); + let pub_a = a.reveal(); + let pub_c = c.reveal(); + + assert!(verify_proof(&pvk, &proof, &[pub_c]).unwrap()); + assert!(!verify_proof(&pvk, &proof, &[pub_a]).unwrap()); + } +} + #[cfg(test)] mod tests { use ark_bls12_377::{Bls12_377, Fr}; @@ -34,41 +370,4 @@ mod tests { assert!(Groth16::::verify(&circuit_vk, &[c], &proof).unwrap()); assert!(!Groth16::::verify(&circuit_vk, &[a], &proof).unwrap()); } - - // #[test] - // fn test_mpc() { - // let mut rng = rand::thread_rng(); - - // // let a = Fr::rand(&mut rng); - // // let b = Fr::rand(&mut rng); - - // let a = AngleShare::rand(&mut rng); - // let b = AngleShare::rand(&mut rng); - - // let mut c = a; - // c = c * b; - - // let circuit = MyCircuit:: { - // a: Some(a), - // b: Some(b), - // }; - - // // let params = generate_random_parameters(circuit, &mut rng); - // let (circuit_pk, circuit_vk) = - // Groth16::::circuit_specific_setup(circuit.clone(), &mut rng).unwrap(); - - // // let pvk = prepare_verifying_key::(¶ms.vk); - - // // let mpc_proof = prover::create_random_proof(circuit, &circuit_pk, &mut rng); - - // // let proof = mpc_proof.reveal(); - - // // TODO: implement reveal - // // let pub_a = a.reveal(); - // // let pub_c = c.reveal(); - - // // assert!(verify_proof(&pvk, &proof, &[pub_c]).unwrap()); - // // assert!(Groth16::::verify(&circuit_vk, &[pub_c], &proof).unwrap()); - // // assert!(!Groth16::::verify(&circuit_vk, &[pub_a], &proof).unwrap()); - // } }