From f9ad7af180be62f7dc0d2848c675102361afe568 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Mon, 16 Dec 2024 22:16:29 +0100 Subject: [PATCH 1/3] Add `vec_shift_scale_in_place` SIMD kernel --- rten-vecmath/src/lib.rs | 2 + rten-vecmath/src/shift_scale.rs | 114 ++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 rten-vecmath/src/shift_scale.rs diff --git a/rten-vecmath/src/lib.rs b/rten-vecmath/src/lib.rs index aa3b148b..c522ba2b 100644 --- a/rten-vecmath/src/lib.rs +++ b/rten-vecmath/src/lib.rs @@ -19,6 +19,7 @@ mod erf; mod exp; +mod shift_scale; mod softmax; mod sum; mod tanh; @@ -34,6 +35,7 @@ pub use exp::{ exp, sigmoid, silu, vec_exp, vec_exp_in_place, vec_sigmoid, vec_sigmoid_in_place, vec_silu, vec_silu_in_place, }; +pub use shift_scale::vec_shift_scale_in_place; pub use softmax::{vec_softmax, vec_softmax_in_place}; pub use sum::{vec_sum, vec_sum_square}; pub use tanh::{tanh, vec_tanh, vec_tanh_in_place}; diff --git a/rten-vecmath/src/shift_scale.rs b/rten-vecmath/src/shift_scale.rs new file mode 100644 index 00000000..f3f9be04 --- /dev/null +++ b/rten-vecmath/src/shift_scale.rs @@ -0,0 +1,114 @@ +use rten_simd::dispatch::{dispatch, SimdOp}; +use rten_simd::SimdFloat; + +struct SimdShiftScale<'a> { + data: &'a mut [f32], + bias: Option<&'a [f32]>, + scale: &'a [f32], + const_scale: f32, +} + +impl<'a> SimdOp for SimdShiftScale<'a> { + type Output = &'a mut [f32]; + + #[inline(always)] + unsafe fn eval(self) -> Self::Output { + let Self { + data, + bias, + scale, + const_scale, + } = self; + + let mut out_ptr = data.as_mut_ptr(); + let mut scale_ptr = scale.as_ptr(); + let mut bias_ptr = bias.map(|b| b.as_ptr()); + let mut n = data.len(); + + let zero = S::zero(); + let const_scale_vec = S::splat(const_scale); + + while n >= S::LEN { + let scale_vec = S::load(scale_ptr).mul(const_scale_vec); + let bias_vec = bias_ptr.map(|b| S::load(b)).unwrap_or(zero); + let y = S::load(out_ptr).mul_add(scale_vec, bias_vec); + y.store(out_ptr); + + out_ptr = out_ptr.add(S::LEN); + scale_ptr = scale_ptr.add(S::LEN); + bias_ptr = bias_ptr.map(|b| b.add(S::LEN)); + + n -= S::LEN; + } + + if n > 0 { + let scale_vec = S::load_partial(scale_ptr, n, 0.).mul(const_scale_vec); + let bias_vec = bias_ptr.map(|b| S::load_partial(b, n, 0.)).unwrap_or(zero); + let y = S::load_partial(out_ptr, n, 0.).mul_add(scale_vec, bias_vec); + y.store_partial(out_ptr, n); + } + + data + } +} + +/// Shift and scale each element in the input. +/// +/// This scales and shifts each element using `y[i] = y[i] * const_scale * +/// scale[i] + bias[i]`. +pub fn vec_shift_scale_in_place( + data: &mut [f32], + const_scale: f32, + scale: &[f32], + bias: Option<&[f32]>, +) { + let simd_op = SimdShiftScale { + data, + bias, + scale, + const_scale, + }; + dispatch(simd_op); +} + +#[cfg(test)] +mod tests { + use super::vec_shift_scale_in_place; + + fn reference_shift_scale( + data: &mut [f32], + const_scale: f32, + scale: &[f32], + bias: Option<&[f32]>, + ) { + for i in 0..data.len() { + data[i] = data[i].mul_add(const_scale * scale[i], bias.map(|b| b[i]).unwrap_or(0.)); + } + } + + #[test] + fn test_vec_shift_scale() { + let data: Vec<_> = (0..10).map(|i| i as f32 * 0.1).collect(); + let const_scale = 0.123; + let scale: Vec<_> = (0..data.len()).map(|i| 1.0 + i as f32 * 0.1).collect(); + let bias: Vec<_> = (0..data.len()).map(|i| -0.5 + i as f32 * 0.2).collect(); + + // With bias + let mut expected = data.clone(); + reference_shift_scale(&mut expected[..], const_scale, &scale, Some(&bias)); + + let mut actual = data.clone(); + vec_shift_scale_in_place(&mut actual[..], const_scale, &scale, Some(&bias)); + + assert_eq!(actual, expected); + + // Without bias + let mut expected = data.clone(); + reference_shift_scale(&mut expected[..], const_scale, &scale, None); + + let mut actual = data.clone(); + vec_shift_scale_in_place(&mut actual[..], const_scale, &scale, None); + + assert_eq!(actual, expected); + } +} From 2b61acf8bacb13728d6b8f508dc6707d270a8a1f Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Mon, 16 Dec 2024 22:00:25 +0100 Subject: [PATCH 2/3] Optimize LayerNormalization for better cache efficiency + SIMD usage Instead of performing each step of normalization on the whole input before moving onto the next, perform the full normalization over each input slice before moving on to the next. This is more cache efficient. Also fuse and vectorize the steps that scale the input to normalize the variance and apply elementwise scales. With these changes the operator is ~2.5-3x faster on x64 assuming the input is already contiguous. The `LayerNormalization` operator specification allows for the `bias` and `scale` values to have any shape that can be broadcast to the input shape. However actual models seen so far always set these shapes to match the normalized axes of the input. Hence this change drops support for other bias/scale input shapes for the time being. --- src/model.rs | 4 +- src/ops/norm.rs | 207 +++++++++++++++++++++++++++++++--------------- src/ops/reduce.rs | 29 ------- 3 files changed, 143 insertions(+), 97 deletions(-) diff --git a/src/model.rs b/src/model.rs index 212650f8..b03a1918 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1387,9 +1387,9 @@ mod tests { input_node, instance_norm_scale, instance_norm_bias ], { epsilon: Some(1e-5) }); - let layer_norm_scale_val = Tensor::from([1.0]); + let layer_norm_scale_val = Tensor::full(&[input_shape[input_shape.len() - 1]], 1.); let layer_norm_scale = graph_builder.add_constant(layer_norm_scale_val.view()); - let layer_norm_bias_val = Tensor::from([1.0]); + let layer_norm_bias_val = layer_norm_scale_val.clone(); let layer_norm_bias = graph_builder.add_constant(layer_norm_bias_val.view()); add_operator!(LayerNormalization, [ input_node, layer_norm_scale, layer_norm_bias diff --git a/src/ops/norm.rs b/src/ops/norm.rs index 26edd849..1bf204b9 100644 --- a/src/ops/norm.rs +++ b/src/ops/norm.rs @@ -1,15 +1,14 @@ -use rayon::prelude::*; +use std::mem::MaybeUninit; +use rayon::prelude::*; use rten_tensor::prelude::*; use rten_tensor::{NdTensorView, Tensor, TensorView}; -use rten_vecmath::{vec_softmax_in_place, vec_sum}; -use smallvec::SmallVec; +use rten_vecmath::{vec_shift_scale_in_place, vec_softmax_in_place, vec_sum, vec_sum_square}; -use crate::ops::reduce::reduce_inverse_rms; -use crate::ops::{add_in_place, mul_in_place, reduce_mean, static_dims, sub}; +use crate::ops::static_dims; use crate::ops::{resolve_axis, InputList, IntoOpResult, OpError, Operator, Output, OutputList}; use crate::slice_reductions::slice_max; -use crate::tensor_pool::{AutoReturn, TensorPool}; +use crate::tensor_pool::TensorPool; /// Perform in-place batch normalization on the `NC*` tensor `out`. /// @@ -253,52 +252,83 @@ pub fn layer_normalization( axis: isize, epsilon: Option, ) -> Result { + let epsilon = epsilon.unwrap_or(1e-5); + let resolved_axis = resolve_axis(input.ndim(), axis)?; + let normalized_slice_shape = &input.shape()[resolved_axis..]; + if !scale.can_broadcast_to(input.shape()) { return Err(OpError::IncompatibleInputShapes( "`scale` cannot be broadcast to input shape", )); } + if scale.shape() != normalized_slice_shape { + return Err(OpError::UnsupportedValue( + "`scale` shape does not match normalized axes of input", + )); + } + if let Some(bias) = bias.as_ref() { if !bias.can_broadcast_to(input.shape()) { return Err(OpError::IncompatibleInputShapes( "`bias` cannot be broadcast to input shape", )); } + if bias.shape() != normalized_slice_shape { + return Err(OpError::UnsupportedValue( + "`bias` shape does not match normalized axes of input", + )); + } } - let epsilon = epsilon.unwrap_or(1e-5); - let resolved_axis = resolve_axis(input.ndim(), axis)?; - let normalized_axes: SmallVec<[i32; 5]> = (resolved_axis..input.ndim()) - .map(|axis| axis as i32) - .collect(); - - // First step: standardize input elements to have zero mean and unit variance. - let mean = reduce_mean( - pool, - input.view(), - Some(normalized_axes.as_slice()), - true, /* keep_dims */ - )? - .auto_return(pool); - let mut normalized = sub(pool, input, mean.view())?.auto_return(pool); - - let inverse_std_dev = reduce_inverse_rms( - pool, - normalized.view(), - Some(normalized_axes.as_slice()), - true, /* keep_dims */ - epsilon, - )? - .auto_return(pool); - mul_in_place(normalized.view_mut(), inverse_std_dev.view()); - - // Second step: Shift and scale input. - mul_in_place(normalized.view_mut(), scale); - if let Some(bias) = bias { - add_in_place(normalized.view_mut(), bias); - } - - Ok(normalized.take()) + let input = input.to_contiguous_in(pool); + + let mut output = pool.alloc(input.len()); + let chunk_size = input.shape()[resolved_axis..].iter().product(); + + let bias = bias.map(|b| b.to_contiguous_in(pool)); + let bias_data = bias.as_ref().map(|b| b.data().unwrap()); + + let scale = scale.to_contiguous_in(pool); + let scale_data = scale.data().unwrap(); + + let mut n_init = 0; + for (in_chunk, out_chunk) in input + .data() + .unwrap() + .chunks(chunk_size) + .zip(output.spare_capacity_mut().chunks_mut(chunk_size)) + { + // Subtract mean + let sum = vec_sum(in_chunk); + let mean = sum / in_chunk.len() as f32; + for (x, y) in in_chunk.iter().zip(out_chunk.iter_mut()) { + y.write(x - mean); + } + + // Safety: We have initialized all elements of `out_chunk`. + let out_chunk = + unsafe { std::mem::transmute::<&mut [MaybeUninit], &mut [f32]>(out_chunk) }; + + // Compute inverse Root Mean Square + let sum_square = vec_sum_square(out_chunk); + let mean_squared = sum_square / out_chunk.len() as f32; + let inverse_rms = 1. / (mean_squared + epsilon).sqrt(); + + // Scale output by inverse RMS so that it has zero mean and unit + // variance, then apply per-element scale and bias. + // + // For efficiency these steps are fused into one pass over the data. + vec_shift_scale_in_place(out_chunk, inverse_rms, scale_data, bias_data); + + n_init += out_chunk.len(); + } + + // Safety: We initialized `n_init` elements. + unsafe { + output.set_len(n_init); + } + + Ok(Tensor::from_data(input.shape(), output)) } #[derive(Debug)] @@ -645,35 +675,80 @@ mod tests { fn test_layer_normalization() -> Result<(), Box> { let pool = new_pool(); - // Sample values generated using `torch.rand`. - let input = Tensor::from([[ - [0.9562, 0.0572], - [0.4366, 0.5655], - [0.2017, 0.0230], - [0.7941, 0.1554], - [0.3226, 0.120], - ]]); - let scale = Tensor::from([0.0751, 0.6952]); - let bias = Tensor::from([0.9993, 0.7632]); + struct Case { + input: Tensor, + scale: Tensor, + bias: Option, + axis: isize, + expected: Result, + } - let result = layer_normalization( - &pool, - input.view(), - scale.view(), - Some(bias.view()), - -1, /* axis */ - None, /* epsilon */ - ) - .unwrap(); + let cases = [ + // Normalize last axis + Case { + // Sample values generated using `torch.rand`. + input: Tensor::from([[ + [0.9562, 0.0572], + [0.4366, 0.5655], + [0.2017, 0.0230], + [0.7941, 0.1554], + [0.3226, 0.120], + ]]), + scale: Tensor::from([0.0751, 0.6952]), + bias: Some(Tensor::from([0.9993, 0.7632])), + axis: -1, + expected: Ok(Tensor::from([[ + [1.0744, 0.0680], + [0.9243, 1.4576], + [1.0744, 0.0684], + [1.0744, 0.0680], + [1.0744, 0.0683], + ]])), + }, + // Unsupported scale shape + Case { + input: Tensor::from([[1., 2., 3.], [4., 5., 6.]]), + scale: Tensor::full(&[2, 3], 1.0), + bias: None, + axis: -1, + expected: Err(OpError::UnsupportedValue( + "`scale` shape does not match normalized axes of input", + )), + }, + // Unsupported bias shape + Case { + input: Tensor::from([[1., 2., 3.], [4., 5., 6.]]), + scale: Tensor::from([1., 1., 1.]), + bias: Some(Tensor::full(&[2, 3], 1.0)), + axis: -1, + expected: Err(OpError::UnsupportedValue( + "`bias` shape does not match normalized axes of input", + )), + }, + ]; - let expected = Tensor::from([[ - [1.0744, 0.0680], - [0.9243, 1.4576], - [1.0744, 0.0684], - [1.0744, 0.0680], - [1.0744, 0.0683], - ]]); - expect_eq_1e4(&result, &expected)?; + for Case { + input, + scale, + bias, + axis, + expected, + } in cases + { + let result = layer_normalization( + &pool, + input.view(), + scale.view(), + bias.as_ref().map(|b| b.view()), + axis, + None, /* epsilon */ + ); + + match (result, expected) { + (Ok(result), Ok(expected)) => expect_eq_1e4(&result, &expected)?, + (result, expected) => assert_eq!(result, expected), + } + } Ok(()) } diff --git a/src/ops/reduce.rs b/src/ops/reduce.rs index 9206aabf..da89fdfe 100644 --- a/src/ops/reduce.rs +++ b/src/ops/reduce.rs @@ -408,35 +408,6 @@ pub fn reduce_mean( reduce(pool, input, axes, keep_dims, &MeanKernel {}) } -/// Reduces axes of a tensor using an inverse Root Mean Squared (RMS) -/// operation. -/// -/// This reduces axes according to the formula: -/// -/// ```text -/// 1. / (mean(x^2) + epsilon).sqrt() -/// ``` -pub fn reduce_inverse_rms( - pool: &TensorPool, - input: TensorView, - axes: Option<&[i32]>, - keep_dims: bool, - epsilon: f32, -) -> Result { - struct InverseRmsKernel { - epsilon: f32, - } - - impl ReduceKernel for InverseRmsKernel { - fn reduce_slice(&self, slice: &[f32]) -> f32 { - let mean_square = vec_sum_square(slice) / slice.len() as f32; - 1. / (mean_square + self.epsilon).sqrt() - } - } - - reduce(pool, input, axes, keep_dims, &InverseRmsKernel { epsilon }) -} - #[derive(Debug)] pub struct ReduceMean { pub axes: Option>, From 7685b8b31033d8d15c8b3f54a86a27fe1d2b6ca5 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Tue, 17 Dec 2024 11:16:49 +0100 Subject: [PATCH 3/3] Add test for normalizing multiple axes in LayerNormalization --- src/ops/norm.rs | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/ops/norm.rs b/src/ops/norm.rs index 1bf204b9..8393b87a 100644 --- a/src/ops/norm.rs +++ b/src/ops/norm.rs @@ -705,6 +705,27 @@ mod tests { [1.0744, 0.0683], ]])), }, + // Normalize multiple axes + Case { + // Sample values generated using `torch.rand`. + input: Tensor::from([[ + [0.9562, 0.0572], + [0.4366, 0.5655], + [0.2017, 0.0230], + [0.7941, 0.1554], + [0.3226, 0.120], + ]]), + scale: Tensor::full(&[5, 2], 1.1), + bias: Some(Tensor::full(&[5, 2], 0.1)), + axis: -2, + expected: Ok(Tensor::from([[ + [2.2467697, -1.0079411], + [0.36562642, 0.83229196], + [-0.48479798, -1.1317577], + [1.6599079, -0.65242106], + [-0.04709549, -0.7805821], + ]])), + }, // Unsupported scale shape Case { input: Tensor::from([[1., 2., 3.], [4., 5., 6.]]), @@ -745,7 +766,9 @@ mod tests { ); match (result, expected) { - (Ok(result), Ok(expected)) => expect_eq_1e4(&result, &expected)?, + (Ok(result), Ok(expected)) => { + expect_eq_1e4(&result, &expected)?; + } (result, expected) => assert_eq!(result, expected), } }