diff --git a/Cargo.lock b/Cargo.lock index 374bd05..048ab85 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,7 +118,7 @@ dependencies = [ "educe", "fnv", "hashbrown 0.15.1", - "itertools", + "itertools 0.13.0", "num-bigint", "num-integer", "num-traits", @@ -139,7 +139,7 @@ dependencies = [ "arrayvec", "digest", "educe", - "itertools", + "itertools 0.13.0", "num-bigint", "num-traits", "paste", @@ -275,6 +275,18 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + [[package]] name = "blake2" version = "0.10.6" @@ -284,6 +296,17 @@ dependencies = [ "digest", ] +[[package]] +name = "blake2b_simd" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23285ad32269793932e830392f2fe2f83e26488fd3ec778883a93c8323735780" +dependencies = [ + "arrayref", + "arrayvec", + "constant_time_eq", +] + [[package]] name = "blake3" version = "1.5.4" @@ -489,12 +512,29 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "ff" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ded41244b729663b1e574f1b4fb731469f69f79c17667b5d776b16cda0479449" +dependencies = [ + "bitvec", + "rand_core", + "subtle", +] + [[package]] name = "fnv" version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "generic-array" version = "0.14.7" @@ -516,6 +556,50 @@ dependencies = [ "wasi", ] +[[package]] +name = "goldilocks" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno-Goldilocks#29a15d186ce4375dab346a3cc9eca6e43540cb8d" +dependencies = [ + "ff", + "halo2curves", + "itertools 0.12.1", + "rand_core", + "serde", + "subtle", +] + +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff", + "rand_core", + "subtle", +] + +[[package]] +name = "halo2curves" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6b1142bd1059aacde1b477e0c80c142910f1ceae67fc619311d6a17428007ab" +dependencies = [ + "blake2b_simd", + "ff", + "group", + "lazy_static", + "num-bigint", + "num-traits", + "pasta_curves", + "paste", + "rand", + "rand_core", + "static_assertions", + "subtle", +] + [[package]] name = "hashbrown" version = "0.14.5" @@ -552,6 +636,15 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.13.0" @@ -581,6 +674,9 @@ name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin", +] [[package]] name = "libc" @@ -675,6 +771,21 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "pasta_curves" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e57598f73cc7e1b2ac63c79c517b31a0877cd7c402cdcaa311b5208de7a095" +dependencies = [ + "blake2b_simd", + "ff", + "group", + "lazy_static", + "rand", + "static_assertions", + "subtle", +] + [[package]] name = "paste" version = "1.0.15" @@ -714,6 +825,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "rand" version = "0.8.5" @@ -829,6 +946,18 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "strsim" version = "0.11.1" @@ -863,6 +992,32 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "tracing" version = "0.1.40" @@ -942,6 +1097,7 @@ dependencies = [ "blake3", "clap", "derivative", + "goldilocks", "lazy_static", "nimue", "nimue-pow", @@ -951,6 +1107,7 @@ dependencies = [ "serde", "serde_json", "sha3", + "thiserror", ] [[package]] @@ -1026,6 +1183,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + [[package]] name = "zerocopy" version = "0.7.35" diff --git a/Cargo.toml b/Cargo.toml index cf6e503..ab8c041 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,12 +27,14 @@ nimue-pow = { git = "https://github.com/arkworks-rs/nimue" } lazy_static = "1.4" rayon = { version = "1.10.0", optional = true } +goldilocks = { git = "https://github.com/scroll-tech/ceno-Goldilocks" } +thiserror = "1" + [profile.release] debug = true [features] -default = ["parallel"] -#default = [] +default = ["parallel", "ceno"] parallel = [ "dep:rayon", "ark-poly/parallel", @@ -40,5 +42,4 @@ parallel = [ "ark-crypto-primitives/parallel", ] rayon = ["dep:rayon"] - - +ceno = [] diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..5d1274a --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly-2024-10-03" diff --git a/src/ceno_binding/mod.rs b/src/ceno_binding/mod.rs new file mode 100644 index 0000000..6e64e05 --- /dev/null +++ b/src/ceno_binding/mod.rs @@ -0,0 +1,71 @@ +mod pcs; +pub use pcs::Whir; + +use ark_ff::FftField; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use serde::{de::DeserializeOwned, Serialize}; +use std::fmt::Debug; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error(transparent)] + ProofError(#[from] nimue::ProofError), +} + +pub trait PolynomialCommitmentScheme: Clone { + type Param: Clone + Debug + Serialize + DeserializeOwned; + type CommitmentWithData: Clone + Debug + Serialize + DeserializeOwned; + type Proof: Clone + CanonicalSerialize + CanonicalDeserialize + Serialize + DeserializeOwned; + type Poly: Clone + Debug + Serialize + DeserializeOwned; + type Transcript; + + fn setup(poly_size: usize) -> Self::Param; + + fn commit_and_write( + pp: &Self::Param, + poly: &Self::Poly, + transcript: &mut Self::Transcript, + ) -> Result; + + fn batch_commit( + pp: &Self::Param, + polys: &[Self::Poly], + ) -> Result; + + fn open( + pp: &Self::Param, + comm: Self::CommitmentWithData, + point: &[E], + eval: &E, + transcript: &mut Self::Transcript, + ) -> Result; + + /// This is a simple version of batch open: + /// 1. Open at one point + /// 2. All the polynomials share the same commitment. + /// 3. The point is already a random point generated by a sum-check. + fn batch_open( + pp: &Self::Param, + polys: &[Self::Poly], + comm: Self::CommitmentWithData, + point: &[E], + evals: &[E], + transcript: &mut Self::Transcript, + ) -> Result; + + fn verify( + vp: &Self::Param, + point: &[E], + eval: &E, + proof: &Self::Proof, + transcript: &Self::Transcript, + ) -> Result<(), Error>; + + fn batch_verify( + vp: &Self::Param, + point: &[E], + evals: &[E], + proof: &Self::Proof, + transcript: &mut Self::Transcript, + ) -> Result<(), Error>; +} diff --git a/src/ceno_binding/pcs.rs b/src/ceno_binding/pcs.rs new file mode 100644 index 0000000..fb63e2f --- /dev/null +++ b/src/ceno_binding/pcs.rs @@ -0,0 +1,603 @@ +use super::{Error, PolynomialCommitmentScheme}; +use crate::crypto::merkle_tree::blake3::{self as mt, MerkleTreeParams}; +use crate::parameters::{ + default_max_pow, FoldType, MultivariateParameters, SoundnessType, WhirParameters, +}; +use crate::poly_utils::{coeffs::CoefficientList, MultilinearPoint}; +use crate::whir::{ + committer::{Committer, Witness}, + iopattern::WhirIOPattern, + parameters::WhirConfig, + prover::Prover, + verifier::Verifier, + Statement, WhirProof, +}; + +use ark_crypto_primitives::merkle_tree::{Config, MerkleTree}; +use ark_ff::FftField; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use nimue::{DefaultHash, IOPattern, Merlin}; +use nimue_pow::blake3::Blake3PoW; +use rand::prelude::*; +use rand_chacha::ChaCha8Rng; +#[cfg(feature = "parallel")] +use rayon::slice::ParallelSlice; +use serde::ser::SerializeStruct; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::fmt::{self, Debug, Formatter}; +use std::marker::PhantomData; +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct Whir(PhantomData); + +type MerkleConfig = MerkleTreeParams; +type PowStrategy = Blake3PoW; +type WhirPCSConfig = WhirConfig, PowStrategy>; + +// Wrapper for WhirConfig +pub struct WhirConfigWrapper { + inner: WhirConfig, PowStrategy>, +} + +impl Serialize for WhirConfigWrapper { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut s = serializer.serialize_struct("WhirConfigWrapper", 17)?; + s.serialize_field( + "num_variables", + &(self.inner.mv_parameters.num_variables as u32), + )?; + s.serialize_field("initial_statement", &self.inner.initial_statement)?; + s.serialize_field("starting_log_inv_rate", &self.inner.starting_log_inv_rate)?; + s.serialize_field("folding_factor", &self.inner.folding_factor)?; + s.serialize_field("soundness_type", &self.inner.soundness_type)?; + s.serialize_field("security_level", &self.inner.security_level)?; + s.serialize_field("pow_bits", &self.inner.max_pow_bits)?; + s.serialize_field("fold_optimisation", &self.inner.fold_optimisation)?; + + s.end() + } +} + +impl<'de, E: FftField> Deserialize<'de> for WhirConfigWrapper { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct Visitor { + marker: PhantomData, + } + + impl<'de, E: FftField> serde::de::Visitor<'de> for Visitor { + type Value = WhirConfig, PowStrategy>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("struct WhirConfigWrapper") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut num_variables = None; + let mut soundness_type = None; + let mut security_level = None; + let mut pow_bits = None; + let mut initial_statement = None; + let mut starting_log_inv_rate = None; + let mut folding_factor = None; + let mut fold_optimisation = None; + + while let Some(key) = map.next_key()? { + match key { + "num_variables" => { + if num_variables.is_some() { + return Err(serde::de::Error::duplicate_field("num_variables")); + } + num_variables = Some(map.next_value()?); + } + "soundness_type" => { + if soundness_type.is_some() { + return Err(serde::de::Error::duplicate_field("soundness_type")); + } + soundness_type = Some(map.next_value()?); + } + "security_level" => { + if security_level.is_some() { + return Err(serde::de::Error::duplicate_field("security_level")); + } + security_level = Some(map.next_value()?); + } + "pow_bits" => { + if pow_bits.is_some() { + return Err(serde::de::Error::duplicate_field("pow_bits")); + } + pow_bits = Some(map.next_value()?); + } + "initial_statement" => { + if initial_statement.is_some() { + return Err(serde::de::Error::duplicate_field("initial_statement")); + } + initial_statement = Some(map.next_value()?); + } + "starting_log_inv_rate" => { + if starting_log_inv_rate.is_some() { + return Err(serde::de::Error::duplicate_field( + "starting_log_inv_rate", + )); + } + starting_log_inv_rate = Some(map.next_value()?); + } + "folding_factor" => { + if folding_factor.is_some() { + return Err(serde::de::Error::duplicate_field("folding_factor")); + } + folding_factor = Some(map.next_value()?); + } + "fold_optimisation" => { + if fold_optimisation.is_some() { + return Err(serde::de::Error::duplicate_field("fold_optimisation")); + } + fold_optimisation = Some(map.next_value()?); + } + _ => { + return Err(serde::de::Error::unknown_field( + key, + &[ + "num_variables", + "soundness_type", + "security_level", + "pow_bits", + "initial_statement", + "starting_log_inv_rate", + "folding_factor", + "fold_optimisation", + ], + )); + } + } + } + + let num_variables = num_variables + .ok_or_else(|| serde::de::Error::missing_field("num_variables"))?; + let soundness_type = soundness_type + .ok_or_else(|| serde::de::Error::missing_field("soundness_type"))?; + let security_level = security_level + .ok_or_else(|| serde::de::Error::missing_field("security_level"))?; + let pow_bits = + pow_bits.ok_or_else(|| serde::de::Error::missing_field("pow_bits"))?; + let initial_statement = initial_statement + .ok_or_else(|| serde::de::Error::missing_field("initial_statement"))?; + let starting_log_inv_rate = starting_log_inv_rate + .ok_or_else(|| serde::de::Error::missing_field("starting_log_inv_rate"))?; + let folding_factor = folding_factor + .ok_or_else(|| serde::de::Error::missing_field("folding_factor"))?; + let fold_optimisation = fold_optimisation + .ok_or_else(|| serde::de::Error::missing_field("fold_optimisation"))?; + + let mut rng = ChaCha8Rng::from_seed([0u8; 32]); + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + Ok(WhirConfig::new( + MultivariateParameters::new(num_variables), + WhirParameters { + initial_statement, + starting_log_inv_rate, + folding_factor, + soundness_type, + security_level, + pow_bits, + fold_optimisation, + _pow_parameters: PhantomData::, + // Merkle tree parameters + leaf_hash_params, + two_to_one_params, + }, + )) + } + } + + let config = deserializer.deserialize_struct( + "WhirConfigWrapper", + &[ + "num_variables", + "soundness_type", + "security_level", + "pow_bits", + "initial_statement", + "starting_log_inv_rate", + "folding_factor", + "fold_optimisation", + ], + Visitor { + marker: PhantomData, + }, + )?; + + Ok(WhirConfigWrapper { inner: config }) + } +} + +impl Debug for WhirConfigWrapper { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.write_str("WhirConfigWrapper") + } +} + +impl Clone for WhirConfigWrapper { + fn clone(&self) -> Self { + WhirConfigWrapper { + inner: self.inner.clone(), + } + } +} + +// Wrapper for Witness +pub struct WitnessWrapper { + inner: Witness>, +} + +impl Serialize for WitnessWrapper +where + F: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut s = serializer.serialize_struct("WitnessWrapper", 3)?; + s.serialize_field("polynomial", &self.inner.polynomial)?; + s.serialize_field("merkle_leaves", &self.inner.merkle_leaves)?; + s.serialize_field("ood_points", &self.inner.ood_points)?; + s.serialize_field("ood_answers", &self.inner.ood_answers)?; + s.serialize_field("tree_height", &self.inner.merkle_tree.height())?; + s.end() + } +} + +impl<'de, F> Deserialize<'de> for WitnessWrapper +where + F: Deserialize<'de> + FftField + CanonicalDeserialize + CanonicalSerialize, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct Visitor { + marker: PhantomData, + } + + impl<'de, F: FftField> serde::de::Visitor<'de> for Visitor + where + F: FftField + Deserialize<'de> + CanonicalDeserialize + CanonicalSerialize, + { + type Value = WitnessWrapper; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("struct WitnessWrapper") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut polynomial = None; + let mut merkle_leaves = None; + let mut ood_points = None; + let mut ood_answers = None; + let mut tree_height = None; + + while let Some(key) = map.next_key()? { + match key { + "polynomial" => { + if polynomial.is_some() { + return Err(serde::de::Error::duplicate_field("polynomial")); + } + polynomial = Some(map.next_value()?); + } + "merkle_leaves" => { + if merkle_leaves.is_some() { + return Err(serde::de::Error::duplicate_field("merkle_leaves")); + } + merkle_leaves = Some(map.next_value()?); + } + "ood_points" => { + if ood_points.is_some() { + return Err(serde::de::Error::duplicate_field("ood_points")); + } + ood_points = Some(map.next_value()?); + } + "ood_answers" => { + if ood_answers.is_some() { + return Err(serde::de::Error::duplicate_field("ood_answers")); + } + ood_answers = Some(map.next_value()?); + } + "tree_height" => { + if tree_height.is_some() { + return Err(serde::de::Error::duplicate_field("tree_height")); + } + tree_height = Some(map.next_value()?); + } + _ => { + return Err(serde::de::Error::unknown_field( + key, + &[ + "polynomial", + "merkle_leaves", + "ood_points", + "ood_answers", + "tree_height", + ], + )); + } + } + } + + let polynomial = + polynomial.ok_or_else(|| serde::de::Error::missing_field("polynomial"))?; + let merkle_leaves: Vec = merkle_leaves + .ok_or_else(|| serde::de::Error::missing_field("merkle_leaves"))?; + let ood_points = + ood_points.ok_or_else(|| serde::de::Error::missing_field("ood_points"))?; + let ood_answers = + ood_answers.ok_or_else(|| serde::de::Error::missing_field("ood_answers"))?; + + let mut rng = ChaCha8Rng::from_seed([0u8; 32]); + let (leaf_hash_param, two_to_one_hash_param) = mt::default_config::(&mut rng); + + let tree_height: usize = + tree_height.ok_or_else(|| serde::de::Error::missing_field("tree_height"))?; + let leaf_node_size = 1 << (tree_height - 1); + let fold_size = merkle_leaves.len() / leaf_node_size; + #[cfg(not(feature = "parallel"))] + let leafs_iter = merkle_leaves.chunks_exact(fold_size); + #[cfg(feature = "parallel")] + let leafs_iter = merkle_leaves.par_chunks_exact(fold_size); + let merkle_tree = + MerkleTree::new(&leaf_hash_param, &two_to_one_hash_param, leafs_iter).map_err( + |_| serde::de::Error::custom("Failed to construct the merkle tree"), + )?; + + Ok(WitnessWrapper { + inner: Witness { + polynomial, + merkle_tree, + merkle_leaves, + ood_points, + ood_answers, + }, + }) + } + } + + deserializer.deserialize_struct( + "WitnessWrapper", + &[ + "polynomial", + "merkle_leaves", + "ood_points", + "ood_answers", + "tree_height", + ], + Visitor { + marker: PhantomData, + }, + ) + } +} + +impl Debug for WitnessWrapper { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.write_str("WitnessWrapper") + } +} + +impl Clone for WitnessWrapper { + fn clone(&self) -> Self { + WitnessWrapper { + inner: self.inner.clone(), + } + } +} + +// Wrapper for WhirProof +#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] +pub struct WhirProofWrapper(WhirProof) +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize; + +impl Serialize for WhirProofWrapper +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let proof = &self.0 .0; + // Create a buffer that implements the `Write` trait + let mut buffer = Vec::new(); + proof.serialize_compressed(&mut buffer).unwrap(); + serializer.serialize_bytes(&buffer) + } +} + +impl<'de, MerkleConfig, F> Deserialize<'de> for WhirProofWrapper +where + MerkleConfig: Config, + F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + // Deserialize the bytes into a buffer + let buffer: Vec = Deserialize::deserialize(deserializer)?; + // Deserialize the buffer into a proof + let proof = WhirProof::deserialize_compressed(&buffer[..]).unwrap(); + Ok(WhirProofWrapper(proof)) + } +} + +impl PolynomialCommitmentScheme for Whir +where + E: FftField + CanonicalSerialize + CanonicalDeserialize + Serialize + DeserializeOwned + Debug, + E::BasePrimeField: Serialize + DeserializeOwned + Debug, +{ + type Param = WhirConfigWrapper; + type CommitmentWithData = WitnessWrapper; + type Proof = WhirProofWrapper, E>; + type Poly = CoefficientList; + type Transcript = Merlin; + + fn setup(poly_size: usize) -> Self::Param { + let mv_params = MultivariateParameters::::new(poly_size); + let starting_rate = 1; + let pow_bits = default_max_pow(poly_size, starting_rate); + let mut rng = ChaCha8Rng::from_seed([0u8; 32]); + + let (leaf_hash_params, two_to_one_params) = mt::default_config::(&mut rng); + + let whir_params = WhirParameters::, PowStrategy> { + initial_statement: true, + security_level: 100, + pow_bits, + folding_factor: 4, + leaf_hash_params, + two_to_one_params, + soundness_type: SoundnessType::ConjectureList, + fold_optimisation: FoldType::ProverHelps, + _pow_parameters: Default::default(), + starting_log_inv_rate: starting_rate, + }; + + WhirConfigWrapper { + inner: WhirConfig::, PowStrategy>::new(mv_params, whir_params), + } + } + + fn commit_and_write( + pp: &Self::Param, + poly: &Self::Poly, + transcript: &mut Self::Transcript, + ) -> Result { + let committer = Committer::new(pp.inner.clone()); + let witness = committer.commit(transcript, poly.clone())?; + + Ok(WitnessWrapper { inner: witness }) + } + + fn batch_commit( + _pp: &Self::Param, + _polys: &[Self::Poly], + ) -> Result { + todo!() + } + + fn open( + pp: &Self::Param, + witness: Self::CommitmentWithData, + point: &[E], + eval: &E, + transcript: &mut Self::Transcript, + ) -> Result { + let prover = Prover(pp.inner.clone()); + let statement = Statement { + points: vec![MultilinearPoint(point.to_vec())], + evaluations: vec![eval.clone()], + }; + + let proof = prover.prove(transcript, statement, witness.inner)?; + Ok(WhirProofWrapper(proof)) + } + + fn batch_open( + _pp: &Self::Param, + _polys: &[Self::Poly], + _comm: Self::CommitmentWithData, + _point: &[E], + _evals: &[E], + _transcript: &mut Self::Transcript, + ) -> Result { + todo!() + } + + fn verify( + vp: &Self::Param, + point: &[E], + eval: &E, + proof: &Self::Proof, + transcript: &Self::Transcript, + ) -> Result<(), Error> { + let reps = 1000; + let verifier = Verifier::new(vp.inner.clone()); + let io = IOPattern::::new("🌪️") + .commit_statement(&vp.inner) + .add_whir_proof(&vp.inner); + + let statement = Statement { + points: vec![MultilinearPoint(point.to_vec())], + evaluations: vec![eval.clone()], + }; + + for _ in 0..reps { + let mut arthur = io.to_arthur(transcript.transcript()); + verifier.verify(&mut arthur, &statement, &proof.0)?; + } + Ok(()) + } + + fn batch_verify( + _vp: &Self::Param, + _point: &[E], + _evals: &[E], + _proof: &Self::Proof, + _transcript: &mut Self::Transcript, + ) -> Result<(), Error> { + todo!() + } +} + +#[cfg(test)] +mod tests { + use ark_ff::{Field, Fp2, MontBackend, MontConfig}; + use rand::Rng; + + use crate::crypto::fields::F2Config64; + + use super::*; + + type Field64_2 = Fp2; + + type F = Field64_2; + + #[test] + fn single_point_verify() { + let poly_size = 10; + let num_coeffs = 1 << poly_size; + let pp = Whir::::setup(poly_size); + + let poly = CoefficientList::new( + (0..num_coeffs) + .map(::BasePrimeField::from) + .collect(), + ); + + let io = IOPattern::::new("🌪️") + .commit_statement(&pp) + .add_whir_proof(&pp); + let mut merlin = io.to_merlin(); + + let witness = Whir::::commit_and_write(&pp, &poly, &mut merlin).unwrap(); + + let mut rng = rand::thread_rng(); + let point: Vec = (0..poly_size).map(|_| F::from(rng.gen::())).collect(); + let eval = poly.evaluate_at_extension(&MultilinearPoint(point.clone())); + + let proof = Whir::::open(&pp, witness, &point, &eval, &mut merlin).unwrap(); + Whir::::verify(&pp, &point, &eval, &proof, &merlin).unwrap(); + } +} diff --git a/src/crypto/merkle_tree/blake3.rs b/src/crypto/merkle_tree/blake3.rs index ddcce8f..e47c7e2 100644 --- a/src/crypto/merkle_tree/blake3.rs +++ b/src/crypto/merkle_tree/blake3.rs @@ -10,8 +10,11 @@ use ark_crypto_primitives::{ }; use ark_ff::Field; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use nimue::{Arthur, ByteIOPattern, ByteReader, ByteWriter, IOPattern, Merlin, ProofError, ProofResult}; +use nimue::{ + Arthur, ByteIOPattern, ByteReader, ByteWriter, IOPattern, Merlin, ProofError, ProofResult, +}; use rand::RngCore; +use serde::{Deserialize, Serialize}; #[derive( Debug, Default, Clone, Copy, Eq, PartialEq, Hash, CanonicalSerialize, CanonicalDeserialize, @@ -106,7 +109,7 @@ impl TwoToOneCRHScheme for Blake3TwoToOneCRHScheme { pub type LeafH = Blake3LeafHash; pub type CompressH = Blake3TwoToOneCRHScheme; -#[derive(Debug, Default, Clone)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct MerkleTreeParams(PhantomData); impl Config for MerkleTreeParams { @@ -144,10 +147,10 @@ impl DigestWriter> for Merlin { } } -impl <'a, F: Field> DigestReader> for Arthur<'a> { +impl<'a, F: Field> DigestReader> for Arthur<'a> { fn read_digest(&mut self) -> ProofResult { let mut digest = [0; 32]; self.fill_next_bytes(&mut digest)?; Ok(Blake3Digest(digest)) } -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index 7bb62c1..9ba03f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "ceno")] +pub mod ceno_binding; // Connect whir with ceno pub mod cmdline_utils; pub mod crypto; // Crypto utils pub mod domain; // Domain that we are evaluating over diff --git a/src/parameters.rs b/src/parameters.rs index 633670d..2464362 100644 --- a/src/parameters.rs +++ b/src/parameters.rs @@ -1,13 +1,13 @@ use std::{fmt::Display, marker::PhantomData, str::FromStr}; use ark_crypto_primitives::merkle_tree::{Config, LeafParam, TwoToOneParam}; -use serde::Serialize; +use serde::{Deserialize, Serialize}; pub fn default_max_pow(num_variables: usize, log_inv_rate: usize) -> usize { num_variables + log_inv_rate - 3 } -#[derive(Debug, Clone, Copy, Serialize)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub enum SoundnessType { UniqueDecoding, ProvableList, @@ -64,7 +64,7 @@ impl Display for MultivariateParameters { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub enum FoldType { Naive, ProverHelps, diff --git a/src/poly_utils/coeffs.rs b/src/poly_utils/coeffs.rs index 5fbbea5..991a912 100644 --- a/src/poly_utils/coeffs.rs +++ b/src/poly_utils/coeffs.rs @@ -2,6 +2,7 @@ use super::{evals::EvaluationsList, hypercube::BinaryHypercubePoint, Multilinear use crate::ntt::wavelet_transform; use ark_ff::Field; use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, Polynomial}; +use serde::{Deserialize, Serialize}; #[cfg(feature = "parallel")] use { rayon::{join, prelude::*}, @@ -19,7 +20,7 @@ use { /// - coeffs[1] is the coefficient of X_2 /// - coeffs[2] is the coefficient of X_1 /// - coeffs[4] is the coefficient of X_0 -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct CoefficientList { coeffs: Vec, // list of coefficients. For multilinear polynomials, we have coeffs.len() == 1 << num_variables. num_variables: usize, // number of variables diff --git a/src/whir/committer.rs b/src/whir/committer.rs index 30d3484..9d298b2 100644 --- a/src/whir/committer.rs +++ b/src/whir/committer.rs @@ -16,6 +16,7 @@ use crate::whir::fs_utils::DigestWriter; #[cfg(feature = "parallel")] use rayon::prelude::*; +#[derive(Clone)] pub struct Witness where MerkleConfig: Config, @@ -35,7 +36,7 @@ where impl Committer where F: FftField, - MerkleConfig: Config + MerkleConfig: Config, { pub fn new(config: WhirConfig) -> Self { Self(config) diff --git a/src/whir/mod.rs b/src/whir/mod.rs index 7db951b..d228a99 100644 --- a/src/whir/mod.rs +++ b/src/whir/mod.rs @@ -4,11 +4,11 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use crate::poly_utils::MultilinearPoint; pub mod committer; +pub mod fs_utils; pub mod iopattern; pub mod parameters; pub mod prover; pub mod verifier; -pub mod fs_utils; #[derive(Debug, Clone, Default)] pub struct Statement { @@ -18,7 +18,7 @@ pub struct Statement { // Only includes the authentication paths #[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] -pub struct WhirProof(Vec<(MultiPath, Vec>)>) +pub struct WhirProof(pub(crate) Vec<(MultiPath, Vec>)>) where MerkleConfig: Config, F: Sized + Clone + CanonicalSerialize + CanonicalDeserialize; diff --git a/src/whir/parameters.rs b/src/whir/parameters.rs index 7a9d711..0cb7ca0 100644 --- a/src/whir/parameters.rs +++ b/src/whir/parameters.rs @@ -3,6 +3,7 @@ use std::{f64::consts::LOG2_10, fmt::Display, marker::PhantomData}; use ark_crypto_primitives::merkle_tree::{Config, LeafParam, TwoToOneParam}; use ark_ff::FftField; +use serde::{Deserialize, Serialize}; use crate::{ crypto::fields::FieldWithSize, @@ -45,7 +46,7 @@ where pub(crate) two_to_one_params: TwoToOneParam, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct RoundConfig { pub(crate) pow_bits: f64, pub(crate) folding_pow_bits: f64,