Skip to content

Commit

Permalink
Vectorize inner loop of BatchNormalization
Browse files Browse the repository at this point in the history
BatchNormalization and InstanceNormalization are very similar, except
BatchNormalization uses pre-computed mean and variance statistics while
InstanceNormalization computes them dynamically.

Extract the vectorized normalization step from `instance_normalization` and
re-use it for `batch_normalization`.
  • Loading branch information
robertknight committed Dec 19, 2024
1 parent 28675eb commit 2c21c45
Showing 1 changed file with 72 additions and 18 deletions.
90 changes: 72 additions & 18 deletions src/ops/norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,54 @@ use crate::ops::{resolve_axis, InputList, IntoOpResult, OpError, Operator, Outpu
use crate::slice_reductions::slice_max;
use crate::tensor_pool::TensorPool;

struct NormalizeOptions {
/// Pre-computed mean of the input data.
mean: f32,

/// Pre-computed variance of the input data.
variance: f32,

/// Epsilon value used to avoid divide-by-zero in sqrt.
epsilon: f32,

/// Constant scale to multiply normalized value by.
scale: f32,

/// Constant bias to add to normalized value.
bias: f32,
}

/// Normalize the mean and variance of elements in `data` and apply a constant
/// scale and bias to the result.
///
/// ```text
/// Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + bias
/// ```
fn normalize_slice(data: &mut [f32], opts: NormalizeOptions) {
let NormalizeOptions {
mean,
variance,
epsilon,
scale,
bias,
} = opts;

// To avoid divisions in the vectorized loop, we re-arrange:
//
// ```
// Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + bias
// ```
//
// As:
//
// ```
// scaled_std_dev_reciprocal = scale / (input_var + epsilon).sqrt()
// Y = (X - input_mean) * scaled_std_dev_reciprocal + bias
// ```
let scaled_std_dev_reciprocal = scale / (variance + epsilon).sqrt();
vec_shift_scale_bias(data, mean, scaled_std_dev_reciprocal, bias);
}

/// Perform in-place batch normalization on the `NC*` tensor `out`.
///
/// See <https://github.com/onnx/onnx/blob/main/docs/Operators.md#batchnormalization>.
Expand All @@ -31,23 +79,26 @@ pub fn batch_norm_in_place(
let batch = input.size(0);
let chans = input.size(1);

input.make_contiguous();

for n in 0..batch {
for c in 0..chans {
let chan_mean = mean[[c]];
let chan_var = var[[c]];
let chan_scale = scale[[c]];
let chan_bias = bias[[c]];

// The batch norm formula, from the ONNX spec, is:
//
// Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + bias
//
// It has been rewritten here to simplify the inner loop below.
let scaled_std_dev_reciprocal = chan_scale / (chan_var + epsilon).sqrt();

input
.slice_mut([n, c])
.apply(|el| (*el - chan_mean) * scaled_std_dev_reciprocal + chan_bias);
let mut chan = input.slice_mut([n, c]);
let chan_data = chan.data_mut().unwrap();
normalize_slice(
chan_data,
NormalizeOptions {
mean: chan_mean,
variance: chan_var,
epsilon,
scale: chan_scale,
bias: chan_bias,
},
);
}
}

Expand Down Expand Up @@ -180,13 +231,16 @@ pub fn instance_normalization_in_place(
let chan_mean = vec_sum(chan_data) / chan_data.len() as f32;
let chan_variance = vec_sum_square_sub(chan_data, chan_mean) / chan_data.len() as f32;

// The instance norm formula, from the ONNX spec, is:
//
// Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + bias
//
// It has been rewritten here to optimize the inner loop.
let scaled_std_dev_reciprocal = chan_scale / (chan_variance + epsilon).sqrt();
vec_shift_scale_bias(chan_data, chan_mean, scaled_std_dev_reciprocal, chan_bias);
normalize_slice(
chan_data,
NormalizeOptions {
mean: chan_mean,
variance: chan_variance,
epsilon,
scale: chan_scale,
bias: chan_bias,
},
);
}
}

Expand Down

0 comments on commit 2c21c45

Please sign in to comment.