Skip to content

Commit

Permalink
Optimize LayerNormalization for better cache efficiency + SIMD usage
Browse files Browse the repository at this point in the history
Instead of performing each step of normalization on the whole input before
moving onto the next, perform the full normalization over each input slice
before moving on to the next. This is more cache efficient. Also fuse and
vectorize the steps that scale the input to normalize the variance and apply
elementwise scales.

With these changes the operator is ~2.5-3x faster on x64 assuming the input is
already contiguous.

The `LayerNormalization` operator specification allows for the `bias` and
`scale` values to have any shape that can be broadcast to the input shape.
However actual models seen so far always set these shapes to match the
normalized axes of the input. Hence this change drops support for other
bias/scale input shapes for the time being.
  • Loading branch information
robertknight committed Dec 17, 2024
1 parent f9ad7af commit 1e03a4f
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 70 deletions.
4 changes: 2 additions & 2 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1387,9 +1387,9 @@ mod tests {
input_node, instance_norm_scale, instance_norm_bias
], { epsilon: Some(1e-5) });

let layer_norm_scale_val = Tensor::from([1.0]);
let layer_norm_scale_val = Tensor::full(&[input_shape[input_shape.len() - 1]], 1.);
let layer_norm_scale = graph_builder.add_constant(layer_norm_scale_val.view());
let layer_norm_bias_val = Tensor::from([1.0]);
let layer_norm_bias_val = layer_norm_scale_val.clone();
let layer_norm_bias = graph_builder.add_constant(layer_norm_bias_val.view());
add_operator!(LayerNormalization, [
input_node, layer_norm_scale, layer_norm_bias
Expand Down
106 changes: 67 additions & 39 deletions src/ops/norm.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use rayon::prelude::*;
use std::mem::MaybeUninit;

use rayon::prelude::*;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, Tensor, TensorView};
use rten_vecmath::{vec_softmax_in_place, vec_sum};
use smallvec::SmallVec;
use rten_vecmath::{vec_shift_scale_in_place, vec_softmax_in_place, vec_sum, vec_sum_square};

use crate::ops::reduce::reduce_inverse_rms;
use crate::ops::{add_in_place, mul_in_place, reduce_mean, static_dims, sub};
use crate::ops::static_dims;
use crate::ops::{resolve_axis, InputList, IntoOpResult, OpError, Operator, Output, OutputList};
use crate::slice_reductions::slice_max;
use crate::tensor_pool::{AutoReturn, TensorPool};
use crate::tensor_pool::TensorPool;

/// Perform in-place batch normalization on the `NC*` tensor `out`.
///
Expand Down Expand Up @@ -253,52 +252,81 @@ pub fn layer_normalization(
axis: isize,
epsilon: Option<f32>,
) -> Result<Tensor, OpError> {
let epsilon = epsilon.unwrap_or(1e-5);
let resolved_axis = resolve_axis(input.ndim(), axis)?;
let normalized_slice_shape = &input.shape()[resolved_axis..];

if !scale.can_broadcast_to(input.shape()) {
return Err(OpError::IncompatibleInputShapes(
"`scale` cannot be broadcast to input shape",
));
}
if scale.shape() != normalized_slice_shape {
return Err(OpError::UnsupportedValue(
"`scale` shape does not match normalized axes of input",
));
}

if let Some(bias) = bias.as_ref() {
if !bias.can_broadcast_to(input.shape()) {
return Err(OpError::IncompatibleInputShapes(
"`bias` cannot be broadcast to input shape",
));
}
if bias.shape() != normalized_slice_shape {
return Err(OpError::UnsupportedValue(
"`bias` shape does not match normalized axes of input",
));
}
}

let epsilon = epsilon.unwrap_or(1e-5);
let resolved_axis = resolve_axis(input.ndim(), axis)?;
let normalized_axes: SmallVec<[i32; 5]> = (resolved_axis..input.ndim())
.map(|axis| axis as i32)
.collect();

// First step: standardize input elements to have zero mean and unit variance.
let mean = reduce_mean(
pool,
input.view(),
Some(normalized_axes.as_slice()),
true, /* keep_dims */
)?
.auto_return(pool);
let mut normalized = sub(pool, input, mean.view())?.auto_return(pool);

let inverse_std_dev = reduce_inverse_rms(
pool,
normalized.view(),
Some(normalized_axes.as_slice()),
true, /* keep_dims */
epsilon,
)?
.auto_return(pool);
mul_in_place(normalized.view_mut(), inverse_std_dev.view());

// Second step: Shift and scale input.
mul_in_place(normalized.view_mut(), scale);
if let Some(bias) = bias {
add_in_place(normalized.view_mut(), bias);
}

Ok(normalized.take())
let input = input.to_contiguous_in(pool);

let mut output = pool.alloc(input.len());
let chunk_size = input.shape()[resolved_axis..].iter().product();

let bias = bias.map(|b| b.to_contiguous_in(pool));
let bias_data = bias.as_ref().map(|b| b.data().unwrap());

let scale = scale.to_contiguous_in(pool);
let scale_data = scale.data().unwrap();

let mut n_init = 0;
for (in_chunk, out_chunk) in input
.data()
.unwrap()
.chunks(chunk_size)
.zip(output.spare_capacity_mut().chunks_mut(chunk_size))
{
// Zero mean
let sum = vec_sum(in_chunk);
let mean = sum / in_chunk.len() as f32;
for (x, y) in in_chunk.iter().zip(out_chunk.iter_mut()) {
y.write(x - mean);
}

// Compute standard deviation of input
let out_chunk =
unsafe { std::mem::transmute::<&mut [MaybeUninit<f32>], &mut [f32]>(out_chunk) };
let sum_square = vec_sum_square(out_chunk);
let mean_squared = sum_square / out_chunk.len() as f32;
let inverse_rms = 1. / (mean_squared + epsilon).sqrt();

// Shift and scale output. If there was no scale or bias, this
// would result in a mean of zero and unit variance.
//
// `y = y * inverse_rms * scale + bias`
vec_shift_scale_in_place(out_chunk, inverse_rms, scale_data, bias_data);

n_init += out_chunk.len();
}

// Safety: We initialized `n_init` elements.
unsafe {
output.set_len(n_init);
}

Ok(Tensor::from_data(input.shape(), output))
}

#[derive(Debug)]
Expand Down
29 changes: 0 additions & 29 deletions src/ops/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,35 +408,6 @@ pub fn reduce_mean(
reduce(pool, input, axes, keep_dims, &MeanKernel {})
}

/// Reduces axes of a tensor using an inverse Root Mean Squared (RMS)
/// operation.
///
/// This reduces axes according to the formula:
///
/// ```text
/// 1. / (mean(x^2) + epsilon).sqrt()
/// ```
pub fn reduce_inverse_rms(
pool: &TensorPool,
input: TensorView,
axes: Option<&[i32]>,
keep_dims: bool,
epsilon: f32,
) -> Result<Tensor, OpError> {
struct InverseRmsKernel {
epsilon: f32,
}

impl ReduceKernel<f32> for InverseRmsKernel {
fn reduce_slice(&self, slice: &[f32]) -> f32 {
let mean_square = vec_sum_square(slice) / slice.len() as f32;
1. / (mean_square + self.epsilon).sqrt()
}
}

reduce(pool, input, axes, keep_dims, &InverseRmsKernel { epsilon })
}

#[derive(Debug)]
pub struct ReduceMean {
pub axes: Option<Vec<i32>>,
Expand Down

0 comments on commit 1e03a4f

Please sign in to comment.