diff --git a/rten-vecmath/src/lib.rs b/rten-vecmath/src/lib.rs index c522ba2b..a4e8f31c 100644 --- a/rten-vecmath/src/lib.rs +++ b/rten-vecmath/src/lib.rs @@ -35,7 +35,7 @@ pub use exp::{ exp, sigmoid, silu, vec_exp, vec_exp_in_place, vec_sigmoid, vec_sigmoid_in_place, vec_silu, vec_silu_in_place, }; -pub use shift_scale::vec_shift_scale_in_place; +pub use shift_scale::{vec_shift_scale_bias, vec_shift_scale_in_place}; pub use softmax::{vec_softmax, vec_softmax_in_place}; -pub use sum::{vec_sum, vec_sum_square}; +pub use sum::{vec_sum, vec_sum_square, vec_sum_square_sub}; pub use tanh::{tanh, vec_tanh, vec_tanh_in_place}; diff --git a/rten-vecmath/src/shift_scale.rs b/rten-vecmath/src/shift_scale.rs index cd358f60..7934adde 100644 --- a/rten-vecmath/src/shift_scale.rs +++ b/rten-vecmath/src/shift_scale.rs @@ -80,9 +80,69 @@ pub fn vec_shift_scale_in_place( dispatch(simd_op); } +struct SimdShiftScaleBias<'a> { + data: &'a mut [f32], + x_bias: f32, + scale: f32, + bias: f32, +} + +impl<'a> SimdOp for SimdShiftScaleBias<'a> { + type Output = &'a mut [f32]; + + #[inline(always)] + unsafe fn eval(self) -> Self::Output { + let Self { + data, + x_bias, + scale, + bias, + } = self; + + let mut out_ptr = data.as_mut_ptr(); + let mut n = data.len(); + + let x_bias_vec = S::splat(x_bias); + let scale_vec = S::splat(scale); + let bias_vec = S::splat(bias); + + while n >= S::LEN { + let y = S::load(out_ptr) + .sub(x_bias_vec) + .mul_add(scale_vec, bias_vec); + y.store(out_ptr); + + out_ptr = out_ptr.add(S::LEN); + n -= S::LEN; + } + + if n > 0 { + let y = S::load_partial(out_ptr, n, 0.) + .sub(x_bias_vec) + .mul_add(scale_vec, bias_vec); + y.store_partial(out_ptr, n); + } + + data + } +} + +/// Shift and scale each element in the input. +/// +/// This updates `xs` as `xs[i] = (xs[i] - x_bias) * scale + bias`. +pub fn vec_shift_scale_bias(xs: &mut [f32], x_bias: f32, scale: f32, bias: f32) { + let op = SimdShiftScaleBias { + data: xs, + x_bias, + scale, + bias, + }; + dispatch(op); +} + #[cfg(test)] mod tests { - use super::vec_shift_scale_in_place; + use super::{vec_shift_scale_bias, vec_shift_scale_in_place}; fn reference_shift_scale( data: &mut [f32], @@ -95,6 +155,12 @@ mod tests { } } + fn reference_shift_scale_bias(data: &mut [f32], x_bias: f32, scale: f32, bias: f32) { + for i in 0..data.len() { + data[i] = (data[i] - x_bias).mul_add(scale, bias); + } + } + #[test] fn test_vec_shift_scale() { let data: Vec<_> = (0..10).map(|i| i as f32 * 0.1).collect(); @@ -120,4 +186,20 @@ mod tests { assert_eq!(actual, expected); } + + #[test] + fn test_vec_shift_scale_bias() { + let data: Vec<_> = (0..10).map(|i| i as f32 * 0.1).collect(); + let x_bias = 0.123; + let scale = 0.456; + let bias = 0.89; + + let mut expected = data.clone(); + reference_shift_scale_bias(&mut expected, x_bias, scale, bias); + + let mut actual = data.clone(); + vec_shift_scale_bias(&mut actual, x_bias, scale, bias); + + assert_eq!(actual, expected); + } } diff --git a/rten-vecmath/src/sum.rs b/rten-vecmath/src/sum.rs index f9ce44c4..81a01f61 100644 --- a/rten-vecmath/src/sum.rs +++ b/rten-vecmath/src/sum.rs @@ -64,13 +64,55 @@ pub fn vec_sum_square(xs: &[f32]) -> f32 { dispatch(op) } +struct SimdSumSquareSub<'a> { + input: &'a [f32], + offset: f32, +} + +impl SimdOp for SimdSumSquareSub<'_> { + type Output = f32; + + #[inline(always)] + unsafe fn eval(self) -> Self::Output { + let offset_vec = S::splat(self.offset); + let vec_sum = simd_fold( + self.input.into(), + S::zero(), + #[inline(always)] + |sum, x| { + let x_offset = x.sub(offset_vec); + x_offset.mul_add(x_offset, sum) + }, + // Padding value chosen so that `x - offset` is zero for unused + // positions in the final update, and thus the accumulator is not + // modified in those positions. + self.offset, + ); + vec_sum.sum() + } +} + +/// Compute the sum of squares of `xs - offset`. +/// +/// This is a variant of [`vec_sum_square`] which subtracts a constant value +/// from each element before squaring it. A typical use case is to compute the +/// variance of a sequence, which is defined as `mean((X - x_mean)^2)`. +pub fn vec_sum_square_sub(xs: &[f32], offset: f32) -> f32 { + let op = SimdSumSquareSub { input: xs, offset }; + dispatch(op) +} + #[cfg(test)] mod tests { - use super::{vec_sum, vec_sum_square}; + use super::{vec_sum, vec_sum_square, vec_sum_square_sub}; + + // Chosen to not be a multiple of vector size, so that tail handling is + // exercised. + const LEN: usize = 100; #[test] fn test_vec_sum() { - let xs: Vec = (0..100).map(|i| i as f32 * 0.1).collect(); + let xs: Vec = (0..LEN).map(|i| i as f32 * 0.1).collect(); let expected_sum: f32 = xs.iter().sum(); let sum = vec_sum(&xs); assert_eq!(sum, expected_sum); @@ -78,9 +120,18 @@ mod tests { #[test] fn test_vec_sum_square() { - let xs: Vec = (0..100).map(|i| i as f32 * 0.1).collect(); + let xs: Vec = (0..LEN).map(|i| i as f32 * 0.1).collect(); let expected_sum: f32 = xs.iter().copied().map(|x| x * x).sum(); let sum = vec_sum_square(&xs); assert_eq!(sum, expected_sum); } + + #[test] + fn test_vec_sum_square_sub() { + let xs: Vec = (0..LEN).map(|i| i as f32 * 0.1).collect(); + let mean = xs.iter().sum::() / xs.len() as f32; + let expected_sum: f32 = xs.iter().copied().map(|x| (x - mean) * (x - mean)).sum(); + let sum = vec_sum_square_sub(&xs, mean); + assert_eq!(sum, expected_sum); + } }