diff --git a/CHANGELOG.md b/CHANGELOG.md index 92698252c..cffa2839a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ ### Features +- [\#691](https://github.com/arkworks-rs/algebra/pull/691) (`ark-poly`) Implement `Polynomial` for `SparseMultilinearExtension` and `DenseMultilinearExtension`. + ### Improvements ### Bugfixes diff --git a/poly/README.md b/poly/README.md index 4f53d91c1..26439687c 100644 --- a/poly/README.md +++ b/poly/README.md @@ -120,7 +120,7 @@ let g: DenseMultilinearExtension = DenseMultilinearExtension::from_evaluatio ); // when evaluated at any point within the Boolean hypercube, f and g should be equal let point_within_hypercube = &vec![Fq::from(0), Fq::from(1), Fq::from(1)]; -assert_eq!(f.evaluate(&point_within_hypercube), g.evaluate(&point_within_hypercube).unwrap()); +assert_eq!(f.evaluate(&point_within_hypercube), g.evaluate(&point_within_hypercube)); // We can also define a MLE g'(x_0, x_1, x_2) by providing the list of non-zero evaluations: let g_prime: SparseMultilinearExtension = SparseMultilinearExtension::from_evaluations( @@ -135,7 +135,7 @@ let g_prime: SparseMultilinearExtension = SparseMultilinearExtension::from_e ); // at any random point (X0, X1, X2), g == g' with negligible probability, unless they are the same function let random_point = &vec![Fq::from(123), Fq::from(456), Fq::from(789)]; -assert_eq!(g_prime.evaluate(&random_point).unwrap(), g.evaluate(&random_point).unwrap()); +assert_eq!(g_prime.evaluate(&random_point), g.evaluate(&random_point)); ``` diff --git a/poly/benches/dense_multilinear.rs b/poly/benches/dense_multilinear.rs index 32b258c62..fbc0f6c4b 100644 --- a/poly/benches/dense_multilinear.rs +++ b/poly/benches/dense_multilinear.rs @@ -2,7 +2,7 @@ extern crate criterion; use ark_ff::Field; -use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; +use ark_poly::{DenseMultilinearExtension, MultilinearExtension, Polynomial}; use ark_std::{ops::Range, test_rng}; use ark_test_curves::bls12_381; use criterion::{black_box, BenchmarkId, Criterion}; @@ -40,7 +40,7 @@ fn evaluation_op_bench(c: &mut Criterion) { group.bench_with_input(BenchmarkId::from_parameter(nv), &nv, |b, &nv| { let poly = DenseMultilinearExtension::::rand(nv, &mut rng); let point: Vec<_> = (0..nv).map(|_| F::rand(&mut rng)).collect(); - b.iter(|| black_box(poly.evaluate(&point).unwrap())) + b.iter(|| black_box(poly.evaluate(&point))) }); } group.finish(); diff --git a/poly/benches/sparse_multilinear.rs b/poly/benches/sparse_multilinear.rs index 58523ce2f..68c2d60a2 100644 --- a/poly/benches/sparse_multilinear.rs +++ b/poly/benches/sparse_multilinear.rs @@ -2,7 +2,7 @@ extern crate criterion; use ark_ff::Field; -use ark_poly::{MultilinearExtension, SparseMultilinearExtension}; +use ark_poly::{Polynomial, SparseMultilinearExtension}; use ark_std::{ops::Range, test_rng}; use ark_test_curves::bls12_381; use criterion::{black_box, BenchmarkId, Criterion}; @@ -72,7 +72,7 @@ fn evaluation_op_bench(c: &mut Criterion) { &mut rng, ); let point: Vec<_> = (0..nv).map(|_| F::rand(&mut rng)).collect(); - b.iter(|| black_box(poly.evaluate(&point).unwrap())) + b.iter(|| black_box(poly.evaluate(&point))) }, ); } diff --git a/poly/src/evaluations/multivariate/multilinear/dense.rs b/poly/src/evaluations/multivariate/multilinear/dense.rs index 285c44968..6b69ff2c5 100644 --- a/poly/src/evaluations/multivariate/multilinear/dense.rs +++ b/poly/src/evaluations/multivariate/multilinear/dense.rs @@ -1,6 +1,9 @@ //! Multilinear polynomial represented in dense evaluation form. -use crate::evaluations::multivariate::multilinear::{swap_bits, MultilinearExtension}; +use crate::{ + evaluations::multivariate::multilinear::{swap_bits, MultilinearExtension}, + Polynomial, +}; use ark_ff::{Field, Zero}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::{ @@ -86,32 +89,6 @@ impl MultilinearExtension for DenseMultilinearExtension { self.num_vars } - /// Evaluate the dense MLE at the given point - /// # Example - /// ``` - /// use ark_test_curves::bls12_381::Fr; - /// # use ark_poly::{MultilinearExtension, DenseMultilinearExtension}; - /// # use ark_ff::One; - /// - /// // The two-variate polynomial x_0 + 3 * x_0 * x_1 + 2 evaluates to [2, 3, 2, 6] - /// // in the two-dimensional hypercube with points [00, 10, 01, 11] - /// let mle = DenseMultilinearExtension::from_evaluations_vec( - /// 2, vec![2, 3, 2, 6].iter().map(|x| Fr::from(*x as u64)).collect() - /// ); - /// - /// // By the uniqueness of MLEs, `mle` is precisely the above polynomial, which - /// // takes the value 54 at the point (1, 17) - /// let eval = mle.evaluate(&[Fr::one(), Fr::from(17)]).unwrap(); - /// assert_eq!(eval, Fr::from(54)); - /// ``` - fn evaluate(&self, point: &[F]) -> Option { - if point.len() == self.num_vars { - Some(self.fix_variables(point)[0]) - } else { - None - } - } - fn rand(num_vars: usize, rng: &mut R) -> Self { Self::from_evaluations_vec( num_vars, @@ -307,9 +284,40 @@ impl Zero for DenseMultilinearExtension { } } +impl Polynomial for DenseMultilinearExtension { + type Point = Vec; + + fn degree(&self) -> usize { + self.num_vars + } + + /// Evaluate the dense MLE at the given point + /// # Example + /// ``` + /// use ark_test_curves::bls12_381::Fr; + /// # use ark_poly::{MultilinearExtension, DenseMultilinearExtension, Polynomial}; + /// # use ark_ff::One; + /// + /// // The two-variate polynomial x_0 + 3 * x_0 * x_1 + 2 evaluates to [2, 3, 2, 6] + /// // in the two-dimensional hypercube with points [00, 10, 01, 11] + /// let mle = DenseMultilinearExtension::from_evaluations_vec( + /// 2, vec![2, 3, 2, 6].iter().map(|x| Fr::from(*x as u64)).collect() + /// ); + /// + /// // By the uniqueness of MLEs, `mle` is precisely the above polynomial, which + /// // takes the value 54 at the point (1, 17) + /// let eval = mle.evaluate(&[Fr::one(), Fr::from(17)].into()); + /// assert_eq!(eval, Fr::from(54)); + /// ``` + fn evaluate(&self, point: &Self::Point) -> F { + assert!(point.len() == self.num_vars); + self.fix_variables(&point)[0] + } +} + #[cfg(test)] mod tests { - use crate::{DenseMultilinearExtension, MultilinearExtension}; + use crate::{DenseMultilinearExtension, MultilinearExtension, Polynomial}; use ark_ff::{Field, Zero}; use ark_std::{ops::Neg, test_rng, vec::Vec, UniformRand}; use ark_test_curves::bls12_381::Fr; @@ -340,7 +348,7 @@ mod tests { let point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect(); assert_eq!( evaluate_data_array(&poly.evaluations, &point), - poly.evaluate(&point).unwrap() + poly.evaluate(&point) ) } } @@ -390,32 +398,32 @@ mod tests { let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect(); let poly1 = DenseMultilinearExtension::rand(NV, &mut rng); let poly2 = DenseMultilinearExtension::rand(NV, &mut rng); - let v1 = poly1.evaluate(&point).unwrap(); - let v2 = poly2.evaluate(&point).unwrap(); + let v1 = poly1.evaluate(&point); + let v2 = poly2.evaluate(&point); // test add - assert_eq!((&poly1 + &poly2).evaluate(&point).unwrap(), v1 + v2); + assert_eq!((&poly1 + &poly2).evaluate(&point), v1 + v2); // test sub - assert_eq!((&poly1 - &poly2).evaluate(&point).unwrap(), v1 - v2); + assert_eq!((&poly1 - &poly2).evaluate(&point), v1 - v2); // test negate - assert_eq!(poly1.clone().neg().evaluate(&point).unwrap(), -v1); + assert_eq!(poly1.clone().neg().evaluate(&point), -v1); // test add assign { let mut poly1 = poly1.clone(); poly1 += &poly2; - assert_eq!(poly1.evaluate(&point).unwrap(), v1 + v2) + assert_eq!(poly1.evaluate(&point), v1 + v2) } // test sub assign { let mut poly1 = poly1.clone(); poly1 -= &poly2; - assert_eq!(poly1.evaluate(&point).unwrap(), v1 - v2) + assert_eq!(poly1.evaluate(&point), v1 - v2) } // test add assign with scalar { let mut poly1 = poly1.clone(); let scalar = Fr::rand(&mut rng); poly1 += (scalar, &poly2); - assert_eq!(poly1.evaluate(&point).unwrap(), v1 + scalar * v2) + assert_eq!(poly1.evaluate(&point), v1 + scalar * v2) } // test additive identity { @@ -428,7 +436,7 @@ mod tests { let mut zero = DenseMultilinearExtension::zero(); let scalar = Fr::rand(&mut rng); zero += (scalar, &poly1); - assert_eq!(zero.evaluate(&point).unwrap(), scalar * v1); + assert_eq!(zero.evaluate(&point), scalar * v1); } } } diff --git a/poly/src/evaluations/multivariate/multilinear/mod.rs b/poly/src/evaluations/multivariate/multilinear/mod.rs index 0f5ec1a5d..fbabd0a26 100644 --- a/poly/src/evaluations/multivariate/multilinear/mod.rs +++ b/poly/src/evaluations/multivariate/multilinear/mod.rs @@ -16,6 +16,8 @@ use ark_ff::{Field, Zero}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::rand::Rng; +use crate::Polynomial; + /// This trait describes an interface for the multilinear extension /// of an array. /// The latter is a multilinear polynomial represented in terms of its @@ -39,14 +41,11 @@ pub trait MultilinearExtension: + for<'a> AddAssign<(F, &'a Self)> + for<'a> SubAssign<&'a Self> + Index + + Polynomial> { /// Returns the number of variables in `self` fn num_vars(&self) -> usize; - /// Evaluates `self` at the given the vector `point` in slice. - /// If the number of variables does not match, return `None`. - fn evaluate(&self, point: &[F]) -> Option; - /// Outputs an `l`-variate multilinear extension where value of evaluations /// are sampled uniformly at random. fn rand(num_vars: usize, rng: &mut R) -> Self; diff --git a/poly/src/evaluations/multivariate/multilinear/sparse.rs b/poly/src/evaluations/multivariate/multilinear/sparse.rs index f4f8040ca..ba8227c7f 100644 --- a/poly/src/evaluations/multivariate/multilinear/sparse.rs +++ b/poly/src/evaluations/multivariate/multilinear/sparse.rs @@ -2,7 +2,7 @@ use crate::{ evaluations::multivariate::multilinear::swap_bits, DenseMultilinearExtension, - MultilinearExtension, + MultilinearExtension, Polynomial, }; use ark_ff::{Field, Zero}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; @@ -120,14 +120,6 @@ impl MultilinearExtension for SparseMultilinearExtension { self.num_vars } - fn evaluate(&self, point: &[F]) -> Option { - if point.len() == self.num_vars { - Some(self.fix_variables(point)[0]) - } else { - None - } - } - /// Outputs an `l`-variate multilinear extension where value of evaluations /// are sampled uniformly at random. The number of nonzero entries is /// `sqrt(2^num_vars)` and indices of those nonzero entries are distributed @@ -227,6 +219,19 @@ impl Index for SparseMultilinearExtension { } } +impl Polynomial for SparseMultilinearExtension { + type Point = Vec; + + fn degree(&self) -> usize { + self.num_vars + } + + fn evaluate(&self, point: &Self::Point) -> F { + assert!(point.len() == self.num_vars); + self.fix_variables(point)[0] + } +} + impl Add for SparseMultilinearExtension { type Output = SparseMultilinearExtension; @@ -399,7 +404,8 @@ fn hashmap_to_treemap(map: &HashMap) -> BTreeMap { #[cfg(test)] mod tests { use crate::{ - evaluations::multivariate::multilinear::MultilinearExtension, SparseMultilinearExtension, + evaluations::multivariate::multilinear::MultilinearExtension, Polynomial, + SparseMultilinearExtension, }; use ark_ff::{One, Zero}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; @@ -453,7 +459,7 @@ mod tests { let mut rng = test_rng(); let ev1 = Fr::rand(&mut rng); let poly1 = SparseMultilinearExtension::from_evaluations(0, &vec![(0, ev1)]); - assert_eq!(poly1.evaluate(&[]).unwrap(), ev1); + assert_eq!(poly1.evaluate(&[].into()), ev1); // test single-variate polynomial let ev2 = vec![Fr::rand(&mut rng), Fr::rand(&mut rng)]; @@ -462,7 +468,7 @@ mod tests { let x = Fr::rand(&mut rng); assert_eq!( - poly2.evaluate(&[x]).unwrap(), + poly2.evaluate(&[x].into()), x * ev2[1] + (Fr::one() - x) * ev2[0] ); @@ -471,7 +477,7 @@ mod tests { let poly2 = SparseMultilinearExtension::from_evaluations(1, &vec![(1, ev3)]); let x = Fr::rand(&mut rng); - assert_eq!(poly2.evaluate(&[x]).unwrap(), x * ev3); + assert_eq!(poly2.evaluate(&[x].into()), x * ev3); } #[test] @@ -500,32 +506,32 @@ mod tests { let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect(); let poly1 = SparseMultilinearExtension::rand(NV, &mut rng); let poly2 = SparseMultilinearExtension::rand(NV, &mut rng); - let v1 = poly1.evaluate(&point).unwrap(); - let v2 = poly2.evaluate(&point).unwrap(); + let v1 = poly1.evaluate(&point); + let v2 = poly2.evaluate(&point); // test add - assert_eq!((&poly1 + &poly2).evaluate(&point).unwrap(), v1 + v2); + assert_eq!((&poly1 + &poly2).evaluate(&point), v1 + v2); // test sub - assert_eq!((&poly1 - &poly2).evaluate(&point).unwrap(), v1 - v2); + assert_eq!((&poly1 - &poly2).evaluate(&point), v1 - v2); // test negate - assert_eq!(poly1.clone().neg().evaluate(&point).unwrap(), -v1); + assert_eq!(poly1.clone().neg().evaluate(&point), -v1); // test add assign { let mut poly1 = poly1.clone(); poly1 += &poly2; - assert_eq!(poly1.evaluate(&point).unwrap(), v1 + v2) + assert_eq!(poly1.evaluate(&point), v1 + v2) } // test sub assign { let mut poly1 = poly1.clone(); poly1 -= &poly2; - assert_eq!(poly1.evaluate(&point).unwrap(), v1 - v2) + assert_eq!(poly1.evaluate(&point), v1 - v2) } // test add assign with scalar { let mut poly1 = poly1.clone(); let scalar = Fr::rand(&mut rng); poly1 += (scalar, &poly2); - assert_eq!(poly1.evaluate(&point).unwrap(), v1 + scalar * v2) + assert_eq!(poly1.evaluate(&point), v1 + scalar * v2) } // test additive identity { @@ -538,7 +544,7 @@ mod tests { let mut zero = SparseMultilinearExtension::zero(); let scalar = Fr::rand(&mut rng); zero += (scalar, &poly1); - assert_eq!(zero.evaluate(&point).unwrap(), scalar * v1); + assert_eq!(zero.evaluate(&point), scalar * v1); } } } @@ -551,29 +557,29 @@ mod tests { let mut poly = SparseMultilinearExtension::rand(10, &mut rng); let mut point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect(); - let expected = poly.evaluate(&point).unwrap(); + let expected = poly.evaluate(&point); poly = poly.relabel(2, 2, 1); // should have no effect - assert_eq!(expected, poly.evaluate(&point).unwrap()); + assert_eq!(expected, poly.evaluate(&point)); poly = poly.relabel(3, 4, 1); // should switch 3 and 4 point.swap(3, 4); - assert_eq!(expected, poly.evaluate(&point).unwrap()); + assert_eq!(expected, poly.evaluate(&point)); poly = poly.relabel(7, 5, 1); point.swap(7, 5); - assert_eq!(expected, poly.evaluate(&point).unwrap()); + assert_eq!(expected, poly.evaluate(&point)); poly = poly.relabel(2, 5, 3); point.swap(2, 5); point.swap(3, 6); point.swap(4, 7); - assert_eq!(expected, poly.evaluate(&point).unwrap()); + assert_eq!(expected, poly.evaluate(&point)); poly = poly.relabel(7, 0, 2); point.swap(0, 7); point.swap(1, 8); - assert_eq!(expected, poly.evaluate(&point).unwrap()); + assert_eq!(expected, poly.evaluate(&point)); } }