Skip to content

Commit

Permalink
Implement Polynomial for MultilinearExtension (#691)
Browse files Browse the repository at this point in the history
Co-authored-by: Antonio Mejías Gil <[email protected]>
  • Loading branch information
mmagician and Antonio95 authored Oct 25, 2023
1 parent c0666a8 commit d106993
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 76 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

### Features

- [\#691](https://github.com/arkworks-rs/algebra/pull/691) (`ark-poly`) Implement `Polynomial` for `SparseMultilinearExtension` and `DenseMultilinearExtension`.

### Improvements

### Bugfixes
Expand Down
4 changes: 2 additions & 2 deletions poly/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ let g: DenseMultilinearExtension<Fq> = DenseMultilinearExtension::from_evaluatio
);
// when evaluated at any point within the Boolean hypercube, f and g should be equal
let point_within_hypercube = &vec![Fq::from(0), Fq::from(1), Fq::from(1)];
assert_eq!(f.evaluate(&point_within_hypercube), g.evaluate(&point_within_hypercube).unwrap());
assert_eq!(f.evaluate(&point_within_hypercube), g.evaluate(&point_within_hypercube));

// We can also define a MLE g'(x_0, x_1, x_2) by providing the list of non-zero evaluations:
let g_prime: SparseMultilinearExtension<Fq> = SparseMultilinearExtension::from_evaluations(
Expand All @@ -135,7 +135,7 @@ let g_prime: SparseMultilinearExtension<Fq> = SparseMultilinearExtension::from_e
);
// at any random point (X0, X1, X2), g == g' with negligible probability, unless they are the same function
let random_point = &vec![Fq::from(123), Fq::from(456), Fq::from(789)];
assert_eq!(g_prime.evaluate(&random_point).unwrap(), g.evaluate(&random_point).unwrap());
assert_eq!(g_prime.evaluate(&random_point), g.evaluate(&random_point));

```

Expand Down
4 changes: 2 additions & 2 deletions poly/benches/dense_multilinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
extern crate criterion;

use ark_ff::Field;
use ark_poly::{DenseMultilinearExtension, MultilinearExtension};
use ark_poly::{DenseMultilinearExtension, MultilinearExtension, Polynomial};
use ark_std::{ops::Range, test_rng};
use ark_test_curves::bls12_381;
use criterion::{black_box, BenchmarkId, Criterion};
Expand Down Expand Up @@ -40,7 +40,7 @@ fn evaluation_op_bench<F: Field>(c: &mut Criterion) {
group.bench_with_input(BenchmarkId::from_parameter(nv), &nv, |b, &nv| {
let poly = DenseMultilinearExtension::<F>::rand(nv, &mut rng);
let point: Vec<_> = (0..nv).map(|_| F::rand(&mut rng)).collect();
b.iter(|| black_box(poly.evaluate(&point).unwrap()))
b.iter(|| black_box(poly.evaluate(&point)))
});
}
group.finish();
Expand Down
4 changes: 2 additions & 2 deletions poly/benches/sparse_multilinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
extern crate criterion;

use ark_ff::Field;
use ark_poly::{MultilinearExtension, SparseMultilinearExtension};
use ark_poly::{Polynomial, SparseMultilinearExtension};
use ark_std::{ops::Range, test_rng};
use ark_test_curves::bls12_381;
use criterion::{black_box, BenchmarkId, Criterion};
Expand Down Expand Up @@ -72,7 +72,7 @@ fn evaluation_op_bench<F: Field>(c: &mut Criterion) {
&mut rng,
);
let point: Vec<_> = (0..nv).map(|_| F::rand(&mut rng)).collect();
b.iter(|| black_box(poly.evaluate(&point).unwrap()))
b.iter(|| black_box(poly.evaluate(&point)))
},
);
}
Expand Down
84 changes: 46 additions & 38 deletions poly/src/evaluations/multivariate/multilinear/dense.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
//! Multilinear polynomial represented in dense evaluation form.
use crate::evaluations::multivariate::multilinear::{swap_bits, MultilinearExtension};
use crate::{
evaluations::multivariate::multilinear::{swap_bits, MultilinearExtension},
Polynomial,
};
use ark_ff::{Field, Zero};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_std::{
Expand Down Expand Up @@ -86,32 +89,6 @@ impl<F: Field> MultilinearExtension<F> for DenseMultilinearExtension<F> {
self.num_vars
}

/// Evaluate the dense MLE at the given point
/// # Example
/// ```
/// use ark_test_curves::bls12_381::Fr;
/// # use ark_poly::{MultilinearExtension, DenseMultilinearExtension};
/// # use ark_ff::One;
///
/// // The two-variate polynomial x_0 + 3 * x_0 * x_1 + 2 evaluates to [2, 3, 2, 6]
/// // in the two-dimensional hypercube with points [00, 10, 01, 11]
/// let mle = DenseMultilinearExtension::from_evaluations_vec(
/// 2, vec![2, 3, 2, 6].iter().map(|x| Fr::from(*x as u64)).collect()
/// );
///
/// // By the uniqueness of MLEs, `mle` is precisely the above polynomial, which
/// // takes the value 54 at the point (1, 17)
/// let eval = mle.evaluate(&[Fr::one(), Fr::from(17)]).unwrap();
/// assert_eq!(eval, Fr::from(54));
/// ```
fn evaluate(&self, point: &[F]) -> Option<F> {
if point.len() == self.num_vars {
Some(self.fix_variables(point)[0])
} else {
None
}
}

fn rand<R: Rng>(num_vars: usize, rng: &mut R) -> Self {
Self::from_evaluations_vec(
num_vars,
Expand Down Expand Up @@ -307,9 +284,40 @@ impl<F: Field> Zero for DenseMultilinearExtension<F> {
}
}

impl<F: Field> Polynomial<F> for DenseMultilinearExtension<F> {
type Point = Vec<F>;

fn degree(&self) -> usize {
self.num_vars
}

/// Evaluate the dense MLE at the given point
/// # Example
/// ```
/// use ark_test_curves::bls12_381::Fr;
/// # use ark_poly::{MultilinearExtension, DenseMultilinearExtension, Polynomial};
/// # use ark_ff::One;
///
/// // The two-variate polynomial x_0 + 3 * x_0 * x_1 + 2 evaluates to [2, 3, 2, 6]
/// // in the two-dimensional hypercube with points [00, 10, 01, 11]
/// let mle = DenseMultilinearExtension::from_evaluations_vec(
/// 2, vec![2, 3, 2, 6].iter().map(|x| Fr::from(*x as u64)).collect()
/// );
///
/// // By the uniqueness of MLEs, `mle` is precisely the above polynomial, which
/// // takes the value 54 at the point (1, 17)
/// let eval = mle.evaluate(&[Fr::one(), Fr::from(17)].into());
/// assert_eq!(eval, Fr::from(54));
/// ```
fn evaluate(&self, point: &Self::Point) -> F {
assert!(point.len() == self.num_vars);
self.fix_variables(&point)[0]
}
}

#[cfg(test)]
mod tests {
use crate::{DenseMultilinearExtension, MultilinearExtension};
use crate::{DenseMultilinearExtension, MultilinearExtension, Polynomial};
use ark_ff::{Field, Zero};
use ark_std::{ops::Neg, test_rng, vec::Vec, UniformRand};
use ark_test_curves::bls12_381::Fr;
Expand Down Expand Up @@ -340,7 +348,7 @@ mod tests {
let point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();
assert_eq!(
evaluate_data_array(&poly.evaluations, &point),
poly.evaluate(&point).unwrap()
poly.evaluate(&point)
)
}
}
Expand Down Expand Up @@ -390,32 +398,32 @@ mod tests {
let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect();
let poly1 = DenseMultilinearExtension::rand(NV, &mut rng);
let poly2 = DenseMultilinearExtension::rand(NV, &mut rng);
let v1 = poly1.evaluate(&point).unwrap();
let v2 = poly2.evaluate(&point).unwrap();
let v1 = poly1.evaluate(&point);
let v2 = poly2.evaluate(&point);
// test add
assert_eq!((&poly1 + &poly2).evaluate(&point).unwrap(), v1 + v2);
assert_eq!((&poly1 + &poly2).evaluate(&point), v1 + v2);
// test sub
assert_eq!((&poly1 - &poly2).evaluate(&point).unwrap(), v1 - v2);
assert_eq!((&poly1 - &poly2).evaluate(&point), v1 - v2);
// test negate
assert_eq!(poly1.clone().neg().evaluate(&point).unwrap(), -v1);
assert_eq!(poly1.clone().neg().evaluate(&point), -v1);
// test add assign
{
let mut poly1 = poly1.clone();
poly1 += &poly2;
assert_eq!(poly1.evaluate(&point).unwrap(), v1 + v2)
assert_eq!(poly1.evaluate(&point), v1 + v2)
}
// test sub assign
{
let mut poly1 = poly1.clone();
poly1 -= &poly2;
assert_eq!(poly1.evaluate(&point).unwrap(), v1 - v2)
assert_eq!(poly1.evaluate(&point), v1 - v2)
}
// test add assign with scalar
{
let mut poly1 = poly1.clone();
let scalar = Fr::rand(&mut rng);
poly1 += (scalar, &poly2);
assert_eq!(poly1.evaluate(&point).unwrap(), v1 + scalar * v2)
assert_eq!(poly1.evaluate(&point), v1 + scalar * v2)
}
// test additive identity
{
Expand All @@ -428,7 +436,7 @@ mod tests {
let mut zero = DenseMultilinearExtension::zero();
let scalar = Fr::rand(&mut rng);
zero += (scalar, &poly1);
assert_eq!(zero.evaluate(&point).unwrap(), scalar * v1);
assert_eq!(zero.evaluate(&point), scalar * v1);
}
}
}
Expand Down
7 changes: 3 additions & 4 deletions poly/src/evaluations/multivariate/multilinear/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use ark_ff::{Field, Zero};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_std::rand::Rng;

use crate::Polynomial;

/// This trait describes an interface for the multilinear extension
/// of an array.
/// The latter is a multilinear polynomial represented in terms of its
Expand All @@ -39,14 +41,11 @@ pub trait MultilinearExtension<F: Field>:
+ for<'a> AddAssign<(F, &'a Self)>
+ for<'a> SubAssign<&'a Self>
+ Index<usize>
+ Polynomial<F, Point = Vec<F>>
{
/// Returns the number of variables in `self`
fn num_vars(&self) -> usize;

/// Evaluates `self` at the given the vector `point` in slice.
/// If the number of variables does not match, return `None`.
fn evaluate(&self, point: &[F]) -> Option<F>;

/// Outputs an `l`-variate multilinear extension where value of evaluations
/// are sampled uniformly at random.
fn rand<R: Rng>(num_vars: usize, rng: &mut R) -> Self;
Expand Down
62 changes: 34 additions & 28 deletions poly/src/evaluations/multivariate/multilinear/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use crate::{
evaluations::multivariate::multilinear::swap_bits, DenseMultilinearExtension,
MultilinearExtension,
MultilinearExtension, Polynomial,
};
use ark_ff::{Field, Zero};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
Expand Down Expand Up @@ -120,14 +120,6 @@ impl<F: Field> MultilinearExtension<F> for SparseMultilinearExtension<F> {
self.num_vars
}

fn evaluate(&self, point: &[F]) -> Option<F> {
if point.len() == self.num_vars {
Some(self.fix_variables(point)[0])
} else {
None
}
}

/// Outputs an `l`-variate multilinear extension where value of evaluations
/// are sampled uniformly at random. The number of nonzero entries is
/// `sqrt(2^num_vars)` and indices of those nonzero entries are distributed
Expand Down Expand Up @@ -227,6 +219,19 @@ impl<F: Field> Index<usize> for SparseMultilinearExtension<F> {
}
}

impl<F: Field> Polynomial<F> for SparseMultilinearExtension<F> {
type Point = Vec<F>;

fn degree(&self) -> usize {
self.num_vars
}

fn evaluate(&self, point: &Self::Point) -> F {
assert!(point.len() == self.num_vars);
self.fix_variables(point)[0]
}
}

impl<F: Field> Add for SparseMultilinearExtension<F> {
type Output = SparseMultilinearExtension<F>;

Expand Down Expand Up @@ -399,7 +404,8 @@ fn hashmap_to_treemap<F: Field>(map: &HashMap<usize, F>) -> BTreeMap<usize, F> {
#[cfg(test)]
mod tests {
use crate::{
evaluations::multivariate::multilinear::MultilinearExtension, SparseMultilinearExtension,
evaluations::multivariate::multilinear::MultilinearExtension, Polynomial,
SparseMultilinearExtension,
};
use ark_ff::{One, Zero};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
Expand Down Expand Up @@ -453,7 +459,7 @@ mod tests {
let mut rng = test_rng();
let ev1 = Fr::rand(&mut rng);
let poly1 = SparseMultilinearExtension::from_evaluations(0, &vec![(0, ev1)]);
assert_eq!(poly1.evaluate(&[]).unwrap(), ev1);
assert_eq!(poly1.evaluate(&[].into()), ev1);

// test single-variate polynomial
let ev2 = vec![Fr::rand(&mut rng), Fr::rand(&mut rng)];
Expand All @@ -462,7 +468,7 @@ mod tests {

let x = Fr::rand(&mut rng);
assert_eq!(
poly2.evaluate(&[x]).unwrap(),
poly2.evaluate(&[x].into()),
x * ev2[1] + (Fr::one() - x) * ev2[0]
);

Expand All @@ -471,7 +477,7 @@ mod tests {
let poly2 = SparseMultilinearExtension::from_evaluations(1, &vec![(1, ev3)]);

let x = Fr::rand(&mut rng);
assert_eq!(poly2.evaluate(&[x]).unwrap(), x * ev3);
assert_eq!(poly2.evaluate(&[x].into()), x * ev3);
}

#[test]
Expand Down Expand Up @@ -500,32 +506,32 @@ mod tests {
let point: Vec<_> = (0..NV).map(|_| Fr::rand(&mut rng)).collect();
let poly1 = SparseMultilinearExtension::rand(NV, &mut rng);
let poly2 = SparseMultilinearExtension::rand(NV, &mut rng);
let v1 = poly1.evaluate(&point).unwrap();
let v2 = poly2.evaluate(&point).unwrap();
let v1 = poly1.evaluate(&point);
let v2 = poly2.evaluate(&point);
// test add
assert_eq!((&poly1 + &poly2).evaluate(&point).unwrap(), v1 + v2);
assert_eq!((&poly1 + &poly2).evaluate(&point), v1 + v2);
// test sub
assert_eq!((&poly1 - &poly2).evaluate(&point).unwrap(), v1 - v2);
assert_eq!((&poly1 - &poly2).evaluate(&point), v1 - v2);
// test negate
assert_eq!(poly1.clone().neg().evaluate(&point).unwrap(), -v1);
assert_eq!(poly1.clone().neg().evaluate(&point), -v1);
// test add assign
{
let mut poly1 = poly1.clone();
poly1 += &poly2;
assert_eq!(poly1.evaluate(&point).unwrap(), v1 + v2)
assert_eq!(poly1.evaluate(&point), v1 + v2)
}
// test sub assign
{
let mut poly1 = poly1.clone();
poly1 -= &poly2;
assert_eq!(poly1.evaluate(&point).unwrap(), v1 - v2)
assert_eq!(poly1.evaluate(&point), v1 - v2)
}
// test add assign with scalar
{
let mut poly1 = poly1.clone();
let scalar = Fr::rand(&mut rng);
poly1 += (scalar, &poly2);
assert_eq!(poly1.evaluate(&point).unwrap(), v1 + scalar * v2)
assert_eq!(poly1.evaluate(&point), v1 + scalar * v2)
}
// test additive identity
{
Expand All @@ -538,7 +544,7 @@ mod tests {
let mut zero = SparseMultilinearExtension::zero();
let scalar = Fr::rand(&mut rng);
zero += (scalar, &poly1);
assert_eq!(zero.evaluate(&point).unwrap(), scalar * v1);
assert_eq!(zero.evaluate(&point), scalar * v1);
}
}
}
Expand All @@ -551,29 +557,29 @@ mod tests {
let mut poly = SparseMultilinearExtension::rand(10, &mut rng);
let mut point: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect();

let expected = poly.evaluate(&point).unwrap();
let expected = poly.evaluate(&point);

poly = poly.relabel(2, 2, 1); // should have no effect
assert_eq!(expected, poly.evaluate(&point).unwrap());
assert_eq!(expected, poly.evaluate(&point));

poly = poly.relabel(3, 4, 1); // should switch 3 and 4
point.swap(3, 4);
assert_eq!(expected, poly.evaluate(&point).unwrap());
assert_eq!(expected, poly.evaluate(&point));

poly = poly.relabel(7, 5, 1);
point.swap(7, 5);
assert_eq!(expected, poly.evaluate(&point).unwrap());
assert_eq!(expected, poly.evaluate(&point));

poly = poly.relabel(2, 5, 3);
point.swap(2, 5);
point.swap(3, 6);
point.swap(4, 7);
assert_eq!(expected, poly.evaluate(&point).unwrap());
assert_eq!(expected, poly.evaluate(&point));

poly = poly.relabel(7, 0, 2);
point.swap(0, 7);
point.swap(1, 8);
assert_eq!(expected, poly.evaluate(&point).unwrap());
assert_eq!(expected, poly.evaluate(&point));
}
}

Expand Down

0 comments on commit d106993

Please sign in to comment.