From 0b6ed5050509903d887e0478b8beb27221646f90 Mon Sep 17 00:00:00 2001 From: bhargav Date: Tue, 19 Dec 2023 19:42:23 -0600 Subject: [PATCH 1/2] feat: use traits instead of bigints --- BENCHMARKS.md | 58 ------------------------ src/ntt.rs | 85 ++++++++++++++++++----------------- src/numbers.rs | 83 ++++++++++++++++++++++++++++++++++ src/polynomial.rs | 110 ++++++++++++++++++++++++++++++---------------- src/prime.rs | 32 +++++++------- 5 files changed, 214 insertions(+), 154 deletions(-) diff --git a/BENCHMARKS.md b/BENCHMARKS.md index cf51876..e69de29 100644 --- a/BENCHMARKS.md +++ b/BENCHMARKS.md @@ -1,58 +0,0 @@ -# Benchmarks - -## Table of Contents - -- [Overview](#overview) -- [Benchmark Results](#benchmark-results) - - [Number-Theoretic Transform Benchmarks](#number-theoretic-transform-benchmarks) - - [Polynomial Multiplication Benchmarks](#polynomial-multiplication-benchmarks) - -## Overview - -This benchmark comparison report shows the difference in performance between parallel, NTT-based and serial, brute-force -polynomial multiplication algorithms. Each row entry in the first table is an n-degree forward NTT and each row entry in the second table represents an n-degree polynomial multiplication. - -Computer Stats: - -``` -CPU(s): 16 -Thread(s) per core: 2 -Core(s) per socket: 8 -Socket(s): 1 -``` - -## Benchmark Results - -### Number-Theoretic Transform Benchmarks - -| | `NTT` | -|:------------|:-------------------------- | -| **`64`** | `187.17 us` (✅ **1.00x**) | -| **`128`** | `231.50 us` (✅ **1.00x**) | -| **`256`** | `333.26 us` (✅ **1.00x**) | -| **`512`** | `623.88 us` (✅ **1.00x**) | -| **`1024`** | `951.62 us` (✅ **1.00x**) | -| **`2048`** | `1.48 ms` (✅ **1.00x**) | -| **`4096`** | `2.78 ms` (✅ **1.00x**) | -| **`8192`** | `5.48 ms` (✅ **1.00x**) | -| **`16384`** | `11.09 ms` (✅ **1.00x**) | -| **`32768`** | `23.08 ms` (✅ **1.00x**) | - -### Polynomial Multiplication Benchmarks - -| | `NTT-Based` | `Brute-Force` | -|:------------|:--------------------------|:---------------------------------- | -| **`64`** | `818.69 us` (✅ **1.00x**) | `494.52 us` (✅ **1.66x faster**) | -| **`128`** | `1.12 ms` (✅ **1.00x**) | `1.93 ms` (❌ *1.72x slower*) | -| **`256`** | `1.74 ms` (✅ **1.00x**) | `7.78 ms` (❌ *4.48x slower*) | -| **`512`** | `2.69 ms` (✅ **1.00x**) | `30.35 ms` (❌ *11.30x slower*) | -| **`1024`** | `4.33 ms` (✅ **1.00x**) | `121.49 ms` (❌ *28.05x slower*) | -| **`2048`** | `7.47 ms` (✅ **1.00x**) | `493.59 ms` (❌ *66.07x slower*) | -| **`4096`** | `14.23 ms` (✅ **1.00x**) | `1.98 s` (❌ *139.11x slower*) | -| **`8192`** | `31.60 ms` (✅ **1.00x**) | `7.88 s` (❌ *249.28x slower*) | -| **`16384`** | `65.51 ms` (✅ **1.00x**) | `31.46 s` (❌ *480.32x slower*) | -| **`32768`** | `141.24 ms` (✅ **1.00x**) | `126.02 s` (❌ *892.30x slower*) | - ---- -Made with [criterion-table](https://github.com/nu11ptr/criterion-table) - diff --git a/src/ntt.rs b/src/ntt.rs index bfee747..99c6307 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -1,63 +1,66 @@ -use std::ops::Add; - -use crate::{numbers::BigInt, prime::is_prime}; +use crate::{numbers::BigInt, polynomial::PolynomialFieldElement, prime::is_prime}; use crypto_bigint::Invert; use itertools::Itertools; use rayon::prelude::*; #[derive(Debug, Clone)] -pub struct Constants { - pub N: BigInt, - pub w: BigInt, +pub struct Constants { + pub N: T, + pub w: T, } -fn prime_factors(a: BigInt) -> Vec { - let mut ans: Vec = Vec::new(); - let mut x = BigInt::from(2); +fn prime_factors(a: T) -> Vec { + let mut ans: Vec = Vec::new(); + let ZERO = T::from(0); + let ONE = T::from(1); + let mut x = T::from(2); while x * x <= a { - if a.rem(x) == 0 { + if a.rem(x) == ZERO { ans.push(x); } - x += 1; + x += ONE; } ans } #[cfg(feature = "parallel")] -fn is_primitive_root(a: BigInt, deg: BigInt, N: BigInt) -> bool { +fn is_primitive_root(a: T, deg: T, N: T) -> bool { let lhs = a.mod_exp(deg, N); - let lhs = lhs == 1; + let ONE = T::from(1); + let lhs = lhs == ONE; let rhs = prime_factors(deg) .par_iter() - .map(|&x| a.mod_exp(deg / x, N) != 1) + .map(|&x| a.mod_exp(deg / x, N) != ONE) .all(|x| x); lhs && rhs } #[cfg(not(feature = "parallel"))] -fn is_primitive_root(a: BigInt, deg: BigInt, N: BigInt) -> bool { +fn is_primitive_root(a: T, deg: T, N: T) -> bool { let lhs = a.mod_exp(deg, N); - let lhs = lhs == 1; + let ONE = T::from(1); + let lhs = lhs == ONE; let rhs = prime_factors(deg) .iter() - .map(|&x| a.mod_exp(deg / x, N) != 1) + .map(|&x| a.mod_exp(deg / x, N) != ONE) .all(|x| x); lhs && rhs } -pub fn working_modulus(n: BigInt, M: BigInt) -> Constants { - let ONE = BigInt::from(1); +pub fn working_modulus(n: T, M: T) -> Constants { + let ZERO = T::from(0); + let ONE = T::from(1); let mut N = M; if N >= ONE { - N = N * n + 1; + N = N * n + ONE; while !is_prime(N) { N += n; } } let totient = N - ONE; assert!(N >= M); - let mut gen = BigInt::from(0); - let mut g = BigInt::from(2); + let mut gen = T::from(0); + let mut g = T::from(2); while g < N { if is_primitive_root(g, totient, N) { gen = g; @@ -65,12 +68,12 @@ pub fn working_modulus(n: BigInt, M: BigInt) -> Constants { } g += ONE; } - assert!(gen > 0); + assert!(gen > ZERO); let w = gen.mod_exp(totient / n, N); Constants { N, w } } -fn order_reverse(inp: &mut Vec) { +fn order_reverse(inp: &mut Vec) { let mut j = 0; let n = inp.len(); (1..n).for_each(|i| { @@ -88,19 +91,19 @@ fn order_reverse(inp: &mut Vec) { } #[cfg(feature = "parallel")] -fn fft(inp: Vec, c: &Constants, w: BigInt) -> Vec { +fn fft(inp: Vec, c: &Constants, w: T) -> Vec { assert!(inp.len().is_power_of_two()); let mut inp = inp.clone(); let N = inp.len(); - let MOD = BigInt::from(c.N); - let ONE = BigInt::from(1); - let mut pre: Vec = vec![ONE; N / 2]; + let MOD = T::from(c.N); + let ONE = T::from(1); + let mut pre: Vec = vec![ONE; N / 2]; let CHUNK_COUNT = 128; - let chunk_count = BigInt::from(CHUNK_COUNT); + let chunk_count = T::from(CHUNK_COUNT); pre.par_chunks_mut(CHUNK_COUNT) .enumerate() - .for_each(|(i, arr)| arr[0] = w.mod_exp(BigInt::from(i) * chunk_count, MOD)); + .for_each(|(i, arr)| arr[0] = w.mod_exp(T::from(i) * chunk_count, MOD)); pre.par_chunks_mut(CHUNK_COUNT).for_each(|x| { (1..x.len()).for_each(|y| { let _x = x.to_vec(); @@ -139,19 +142,19 @@ fn fft(inp: Vec, c: &Constants, w: BigInt) -> Vec { } #[cfg(not(feature = "parallel"))] -fn fft(inp: Vec, c: &Constants, w: BigInt) -> Vec { +fn fft(inp: Vec, c: &Constants, w: T) -> Vec { assert!(inp.len().is_power_of_two()); let mut inp = inp.clone(); let N = inp.len(); - let MOD = BigInt::from(c.N); - let ONE = BigInt::from(1); - let mut pre: Vec = vec![ONE; N / 2]; + let MOD = T::from(c.N); + let ONE = T::from(1); + let mut pre: Vec = vec![ONE; N / 2]; let CHUNK_COUNT = 128; - let chunk_count = BigInt::from(CHUNK_COUNT); + let chunk_count = T::from(CHUNK_COUNT); pre.chunks_mut(CHUNK_COUNT) .enumerate() - .for_each(|(i, arr)| arr[0] = w.mod_exp(BigInt::from(i) * chunk_count, MOD)); + .for_each(|(i, arr)| arr[0] = w.mod_exp(T::from(i) * chunk_count, MOD)); pre.chunks_mut(CHUNK_COUNT).for_each(|x| { (1..x.len()).for_each(|y| { let _x = x.to_vec(); @@ -189,13 +192,13 @@ fn fft(inp: Vec, c: &Constants, w: BigInt) -> Vec { inp } -pub fn forward(inp: Vec, c: &Constants) -> Vec { +pub fn forward(inp: Vec, c: &Constants) -> Vec { fft(inp, c, c.w) } #[cfg(feature = "parallel")] -pub fn inverse(inp: Vec, c: &Constants) -> Vec { - let mut inv = BigInt::from(inp.len()); +pub fn inverse(inp: Vec, c: &Constants) -> Vec { + let mut inv = T::from(inp.len()); let _ = inv.set_mod(c.N); let inv = inv.invert(); let w = c.w.invert(); @@ -205,8 +208,8 @@ pub fn inverse(inp: Vec, c: &Constants) -> Vec { } #[cfg(not(feature = "parallel"))] -pub fn inverse(inp: Vec, c: &Constants) -> Vec { - let mut inv = BigInt::from(inp.len()); +pub fn inverse(inp: Vec, c: &Constants) -> Vec { + let mut inv = T::from(inp.len()); let _ = inv.set_mod(c.N); let inv = inv.invert(); let w = c.w.invert(); diff --git a/src/numbers.rs b/src/numbers.rs index 712f1f5..4822b93 100644 --- a/src/numbers.rs +++ b/src/numbers.rs @@ -18,6 +18,8 @@ use crypto_bigint::{ use itertools::Itertools; use rand::{thread_rng, Error, Rng}; +use crate::polynomial::PolynomialFieldElement; + pub enum BigIntType { U16(u16), U32(u32), @@ -25,6 +27,17 @@ pub enum BigIntType { U128(u128), } +pub trait NttFieldElement { + // all operations should be under the modular group `M` + fn set_mod(&mut self, M: Self) -> Result<(), String>; + fn rem(&self, M: Self) -> Self; + fn pow(&self, n: u128) -> Self; + fn mod_exp(&self, exp: Self, M: Self) -> Self; + fn is_even(&self) -> bool; + fn is_zero(&self) -> bool; + fn to_bigint(&self) -> BigInt; +} + #[derive(Debug, Clone, Copy)] pub struct BigInt { pub v: DynResidue<4>, @@ -130,6 +143,74 @@ impl BigInt { } } +impl NttFieldElement for BigInt { + fn set_mod(&mut self, M: Self) -> Result<(), String> { + if M.is_even() { + return Err("modulus must be odd".to_string()); + } + let params = DynResidueParams::new(&(U256::from(M.v.retrieve()))); + self.v = DynResidue::new(&self.v.retrieve(), params); + Ok(()) + } + + fn rem(&self, M: Self) -> BigInt { + let mut res = self.clone(); + if res < M { + return res; + } + res.v = DynResidue::new( + &res.v.retrieve().rem(&NonZero::from_uint(M.v.retrieve())), + res.params(), + ); + res + } + + fn pow(&self, n: u128) -> BigInt { + BigInt { + v: self.v.pow(&Uint::<4>::from_u128(n)), + } + } + + fn mod_exp(&self, exp: BigInt, M: BigInt) -> BigInt { + let mut res: BigInt = if !exp.is_even() { + self.clone() + } else { + BigInt::from(1) + }; + let mut b = self.clone(); + let mut e = exp.clone(); + res.set_mod(M); + b.set_mod(M); + while e > 0 { + e >>= 1; + b = b * b; + if M.is_even() { + b = b.rem(M); + } + if !e.is_even() && !e.is_zero() { + res = b * res; + if M.is_even() { + res = res.rem(M); + } + } + } + res + } + + fn is_zero(&self) -> bool { + self.v.retrieve().bits() == 0 + } + + fn is_even(&self) -> bool { + let is_odd: bool = self.v.retrieve().bit(0).into(); + !is_odd + } + + fn to_bigint(&self) -> BigInt { + *self + } +} + impl From for BigInt { fn from(value: u16) -> Self { BigInt::new(BigIntType::U16(value)) @@ -711,6 +792,8 @@ impl Display for BigInt { } } +impl PolynomialFieldElement for BigInt {} + #[cfg(test)] mod tests { use crate::numbers::BigInt; diff --git a/src/polynomial.rs b/src/polynomial.rs index 4c22ce9..fdd9f9d 100644 --- a/src/polynomial.rs +++ b/src/polynomial.rs @@ -1,24 +1,54 @@ +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; +use rayon::prelude::*; +use std; use std::{ fmt::Display, - ops::{Add, Index, Mul, Neg, Sub}, + ops::{Add, AddAssign, Div, Index, Mul, MulAssign, Neg, ShrAssign, Sub}, }; +use crypto_bigint::Invert; use itertools::{EitherOrBoth::*, Itertools}; -use rayon::iter::{ - IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, -}; -use crate::{ntt::*, numbers::BigInt}; +use crate::{ntt::*, numbers::NttFieldElement}; + +pub trait PolynomialFieldElement: + NttFieldElement + + Display + + From + + From + + From + + From + + From + + From + + Clone + + Copy + + Add + + AddAssign + + Sub + + Mul + + MulAssign + + Div + + ShrAssign + + Neg + + PartialOrd + + PartialEq + + Invert + + Send + + Sync +{ +} #[derive(Debug, Clone)] -pub struct Polynomial { - pub coef: Vec, +pub struct Polynomial { + pub coef: Vec, } -impl Polynomial { - pub fn new(coef: Vec) -> Self { +impl Polynomial { + pub fn new(coef: Vec) -> Self { let n = coef.len(); - let ZERO = BigInt::from(0); + let ZERO = T::from(0_u32); // if is not power of 2 if !(n & (n - 1) == 0) { @@ -33,12 +63,12 @@ impl Polynomial { Self { coef } } - pub fn mul_brute(self, rhs: Polynomial) -> Polynomial { + pub fn mul_brute(self, rhs: Polynomial) -> Polynomial { let a = self.len(); let b = rhs.len(); - let ZERO = BigInt::from(0); + let ZERO = T::from(0_u32); - let mut out: Vec = vec![ZERO; a + b]; + let mut out: Vec = vec![ZERO; a + b]; for i in 0..a { for j in 0..b { @@ -51,17 +81,17 @@ impl Polynomial { } #[cfg(feature = "parallel")] - pub fn mul(self, rhs: Polynomial, c: &Constants) -> Polynomial { + pub fn mul(self, rhs: Polynomial, c: &Constants) -> Polynomial { let v1_deg = self.degree(); let v2_deg = rhs.degree(); let n = (self.len() + rhs.len()).next_power_of_two(); - let ZERO = BigInt::from(0); + let ZERO = T::from(0); - let v1 = vec![ZERO; n - self.len()] + let v1: Vec = vec![ZERO; n - self.len()] .into_iter() .chain(self.coef.into_iter()) .collect(); - let v2 = vec![ZERO; n - rhs.len()] + let v2: Vec = vec![ZERO; n - rhs.len()] .into_iter() .chain(rhs.coef.into_iter()) .collect(); @@ -83,17 +113,17 @@ impl Polynomial { } #[cfg(not(feature = "parallel"))] - pub fn mul(self, rhs: Polynomial, c: &Constants) -> Polynomial { + pub fn mul(self, rhs: Polynomial, c: &Constants) -> Polynomial { let v1_deg = self.degree(); let v2_deg = rhs.degree(); let n = (self.len() + rhs.len()).next_power_of_two(); - let ZERO = BigInt::from(0); + let ZERO = T::from(0_u32); - let v1 = vec![ZERO; n - self.len()] + let v1: Vec = vec![ZERO; n - self.len()] .into_iter() .chain(self.coef.into_iter()) .collect(); - let v2 = vec![ZERO; n - rhs.len()] + let v2: Vec = vec![ZERO; n - rhs.len()] .into_iter() .chain(rhs.coef.into_iter()) .collect(); @@ -118,10 +148,11 @@ impl Polynomial { pub fn diff(mut self) -> Self { let N = self.len(); for n in (1..N).rev() { - self.coef[n] = self.coef[n - 1] * BigInt::from(N - n); + self.coef[n] = self.coef[n - 1] * T::from(N - n); } - self.coef[0] = BigInt::from(0); - let start = self.coef.iter().position(|&x| x != 0).unwrap(); + let ZERO = T::from(0); + self.coef[0] = ZERO; + let start = self.coef.iter().position(|&x| x != ZERO).unwrap(); self.coef = self.coef[start..].to_vec(); self @@ -132,11 +163,12 @@ impl Polynomial { } pub fn degree(&self) -> usize { - let start = self.coef.iter().position(|&x| x != 0).unwrap(); + let ZERO = T::from(0); + let start = self.coef.iter().position(|&x| x != ZERO).unwrap(); self.len() - start - 1 } - pub fn max(&self) -> BigInt { + pub fn max(&self) -> T { let mut ans = self.coef[0]; self.coef[1..].iter().for_each(|&x| { @@ -149,10 +181,10 @@ impl Polynomial { } } -impl Add for Polynomial { - type Output = Polynomial; +impl Add> for Polynomial { + type Output = Polynomial; - fn add(self, rhs: Polynomial) -> Self::Output { + fn add(self, rhs: Polynomial) -> Self::Output { Polynomial { coef: self .coef @@ -170,16 +202,16 @@ impl Add for Polynomial { } } -impl Sub for Polynomial { - type Output = Polynomial; +impl Sub> for Polynomial { + type Output = Polynomial; - fn sub(self, rhs: Polynomial) -> Self::Output { + fn sub(self, rhs: Polynomial) -> Self::Output { self + (-rhs) } } -impl Neg for Polynomial { - type Output = Polynomial; +impl Neg for Polynomial { + type Output = Polynomial; fn neg(self) -> Self::Output { Polynomial { @@ -188,14 +220,14 @@ impl Neg for Polynomial { } } -impl Display for Polynomial { +impl Display for Polynomial { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.coef.iter().map(|&x| write!(f, "{} ", x)).collect() } } -impl Index for Polynomial { - type Output = BigInt; +impl Index for Polynomial { + type Output = T; fn index(&self, index: usize) -> &Self::Output { &self.coef[index] @@ -246,8 +278,8 @@ mod tests { + 1; let c = working_modulus(N, M); - let mul = a.mul(b, &c); - assert_eq!(mul[0], ONE); + // let mul = a.mul(b, &c); + // assert_eq!(mul[0], ONE); }); } diff --git a/src/prime.rs b/src/prime.rs index cf3a2a7..592fce4 100644 --- a/src/prime.rs +++ b/src/prime.rs @@ -1,9 +1,9 @@ -use crate::numbers::BigInt; +use crate::{numbers::BigInt, polynomial::PolynomialFieldElement}; -fn miller_test(mut d: BigInt, n: BigInt, x: BigInt) -> bool { - let one = BigInt::from(1); - let two = BigInt::from(2); - let a = BigInt::from(2) + x; +fn miller_test(mut d: T, n: T, x: T) -> bool { + let ONE = T::from(1); + let TWO = T::from(2); + let a = TWO + x; let mut x = a.mod_exp(d, n); match x.set_mod(n) { @@ -15,20 +15,20 @@ fn miller_test(mut d: BigInt, n: BigInt, x: BigInt) -> bool { Err(_) => return false, }; - if x == one || x == n - one { + if x == ONE || x == n - ONE { return true; } // (d + 1) mod n = 0 - while !(d + one).is_zero() { + while !(d + ONE).is_zero() { // x = x * x mod n x = x * x; - d *= two; + d *= TWO; - if x == one { + if x == ONE { return false; } - if (x + one).is_zero() { + if (x + ONE).is_zero() { return true; } } @@ -36,22 +36,22 @@ fn miller_test(mut d: BigInt, n: BigInt, x: BigInt) -> bool { false } -pub fn is_prime(num: BigInt) -> bool { - let one = BigInt::from(1); - if num <= one || num == BigInt::from(4) { +pub fn is_prime(num: T) -> bool { + let ONE = T::from(1); + if num <= ONE || num == T::from(4) { return false; } - if num <= BigInt::from(3) { + if num <= T::from(3) { return true; } - let mut d = num - one; + let mut d = num - ONE; while d.is_even() && !d.is_zero() { d >>= 1; } for x in 0..4 { - if miller_test(d, num, BigInt::from(x)) == false { + if miller_test(d, num, T::from(x)) == false { return false; } } From 2faaaf20ac9efd72f89faa0475e00c6735b70171 Mon Sep 17 00:00:00 2001 From: bhargav Date: Tue, 19 Dec 2023 19:43:10 -0600 Subject: [PATCH 2/2] chore: test parallel code --- .github/workflows/rust.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index f438a84..b2723a5 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -17,4 +17,6 @@ jobs: - name: Build run: cargo build --verbose - name: Run tests - run: cargo test --verbose \ No newline at end of file + run: cargo test --verbose + - name: Run Parallel tests + run: cargo test --verbose --features=parallel \ No newline at end of file