Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize SumcheckSingle::new #8

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 146 additions & 14 deletions src/crypto/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
//! A global cache is used for twiddle factors.

use ark_ff::{FftField, Field};
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::{Arc, LazyLock, Mutex, RwLock, RwLockReadGuard};
use std::{
any::{Any, TypeId},
collections::HashMap,
mem::swap,
sync::{Arc, LazyLock, Mutex, RwLock, RwLockReadGuard},
};

#[cfg(feature = "parallel")]
use {
Expand Down Expand Up @@ -348,34 +351,163 @@ impl<F: Field> NttEngine<F> {
}
}

/// Fast Wavelet Transform.
///
/// The input slice must have a length that is a power of two.
/// Recursively applies the kernel
/// [1 0]
/// [1 1]
pub fn wavelet_transform<F: Field>(values: &mut [F]) {
debug_assert!(values.len().is_power_of_two());
wavelet_transform_batch(values, values.len())
}

pub fn wavelet_transform_batch<F: Field>(values: &mut [F], size: usize) {
debug_assert_eq!(values.len() % size, 0);
debug_assert!(size.is_power_of_two());
#[cfg(feature = "parallel")]
if values.len() > NttEngine::<F>::WORKLOAD_SIZE && values.len() != size {
// Multiple wavelet transforms, compute in parallel.
// Work size is largest multiple of `size` smaller than `WORKLOAD_SIZE`.
let workload_size = size * max(1, NttEngine::<F>::WORKLOAD_SIZE / size);
return values.par_chunks_mut(workload_size).for_each(|values| {
wavelet_transform_batch(values, size);
});
}
match size {
0 | 1 => {}
2 => {
for v in values.chunks_exact_mut(2) {
v[1] += v[0]
}
}
4 => {
for v in values.chunks_exact_mut(4) {
v[1] += v[0];
v[3] += v[2];
v[2] += v[0];
v[3] += v[1];
}
}
8 => {
for v in values.chunks_exact_mut(8) {
v[1] += v[0];
v[3] += v[2];
v[2] += v[0];
v[3] += v[1];
v[5] += v[4];
v[7] += v[6];
v[6] += v[4];
v[7] += v[5];
v[4] += v[0];
v[5] += v[1];
v[6] += v[2];
v[7] += v[3];
}
}
16 => {
for v in values.chunks_exact_mut(16) {
for v in v.chunks_exact_mut(4) {
v[1] += v[0];
v[3] += v[2];
v[2] += v[0];
v[3] += v[1];
}
let (a, v) = v.split_at_mut(4);
let (b, v) = v.split_at_mut(4);
let (c, d) = v.split_at_mut(4);
for i in 0..4 {
b[i] += a[i];
d[i] += c[i];
c[i] += a[i];
d[i] += b[i];
}
}
}
n => {
let n1 = 1 << (n.trailing_zeros() / 2);
let n2 = n / n1;
wavelet_transform_batch(values, n1);
transpose(values, n2, n1);
wavelet_transform_batch(values, n2);
transpose(values, n1, n2);
}
}
}

/// Transpose a matrix in-place.
/// Will batch transpose multiple matrices if the length of the slice is a multiple of rows * cols.
pub fn transpose<T: Copy>(matrix: &mut [T], rows: usize, cols: usize) {
pub fn transpose<F: Field>(matrix: &mut [F], rows: usize, cols: usize) {
debug_assert_eq!(matrix.len() % rows * cols, 0);
if rows == cols {
// TODO: Cache-oblivious recursive parallel algorithm.
for matrix in matrix.chunks_exact_mut(rows * cols) {
for i in 0..rows {
for j in (i + 1)..cols {
matrix.swap(i * cols + j, j * rows + i);
}
}
transpose_square(matrix, rows, cols);
}
} else {
// TODO: Re-use scratch space.
// TODO: Cache-oblivious recursive parallel algorithm.
// TODO: Special case for rows = 2 * cols and cols = 2 * rows.
let mut scratch = vec![F::ZERO; rows * cols];
for matrix in matrix.chunks_exact_mut(rows * cols) {
let copy = matrix.to_vec();
scratch.copy_from_slice(matrix);
for i in 0..rows {
for j in 0..cols {
matrix[j * rows + i] = copy[i * cols + j];
matrix[j * rows + i] = scratch[i * cols + j];
}
}
}
}
}

// Transpose a square power-of-two matrix in-place.
fn transpose_square<F: Field>(matrix: &mut [F], size: usize, stride: usize) {
debug_assert!(matrix.len() >= (size - 1) * stride + size);
debug_assert!(size.is_power_of_two());
if size * size > NttEngine::<F>::WORKLOAD_SIZE {
// Recurse into quadrants.
// This results in a cache-oblivious algorithm.
let n = size / 2;
let (upper, lower) = matrix.split_at_mut(n * stride);
// Ideally we'd parallelize this, but its not possible to
// express the strided matrices without unsafe code.
transpose_square(upper, n, stride);
transpose_square_swap(&mut upper[n..], lower, n, stride);
transpose_square(&mut lower[n..], n, stride);
} else {
for i in 0..size {
for j in (i + 1)..size {
matrix.swap(i * stride + j, j * stride + i);
}
}
}
}

/// Transpose and swap two square power-of-two size matrices.
fn transpose_square_swap<F: Field>(a: &mut [F], b: &mut [F], size: usize, stride: usize) {
debug_assert!(a.len() >= (size - 1) * stride + size);
debug_assert!(b.len() >= (size - 1) * stride + size);
debug_assert!(size.is_power_of_two());
if size * size > NttEngine::<F>::WORKLOAD_SIZE {
// Recurse into quadrants.
// This results in a cache-oblivious algorithm.
let n = size / 2;
let (a_upper, a_lower) = a.split_at_mut(n * stride);
let (b_upper, b_lower) = b.split_at_mut(n * stride);
// Ideally we'd parallelize this, but its not possible to
// express the strided matrices without unsafe code.
transpose_square_swap(a_upper, b_upper, n, stride);
transpose_square_swap(&mut a_upper[n..], b_lower, n, stride);
transpose_square_swap(a_lower, &mut b_upper[n..], n, stride);
transpose_square_swap(&mut a_lower[n..], &mut b_lower[n..], n, stride);
} else {
for i in 0..size {
for j in 0..size {
// The compiler does not eliminate the bounds checks here,
// but this doesn't matter as it is bottlenecked by memory bandwidth.
swap(&mut a[i * stride + j], &mut b[j * stride + i]);
}
}
}
}

/// Compute the largest factor of n that is <= sqrt(n).
/// Assumes n is of the form 2^k * {1,3,9}.
fn sqrt_factor(n: usize) -> usize {
Expand Down
17 changes: 2 additions & 15 deletions src/poly_utils/coeffs.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{evals::EvaluationsList, hypercube::BinaryHypercubePoint, MultilinearPoint};
use crate::crypto::ntt::wavelet_transform;
use ark_ff::Field;
use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, Polynomial};
#[cfg(feature = "parallel")]
Expand Down Expand Up @@ -226,21 +227,7 @@ where
{
fn from(value: CoefficientList<F>) -> Self {
let mut evals = value.coeffs;
let num_coeffs = evals.len();
let num_variables = value.num_variables;

for var in 0..num_variables {
let step = 1 << var;
for i in (0..num_coeffs).step_by(step * 2) {
for j in 0..step {
if i + j + step < num_coeffs {
let sum = evals[i + j] + evals[i + j + step];
evals[i + j + step] = sum;
}
}
}
}

wavelet_transform(&mut evals);
EvaluationsList::new(evals)
}
}
Expand Down