diff --git a/src/ops/norm.rs b/src/ops/norm.rs index dcef8761..915a0a64 100644 --- a/src/ops/norm.rs +++ b/src/ops/norm.rs @@ -13,6 +13,54 @@ use crate::ops::{resolve_axis, InputList, IntoOpResult, OpError, Operator, Outpu use crate::slice_reductions::slice_max; use crate::tensor_pool::TensorPool; +struct NormalizeOptions { + /// Pre-computed mean of the input data. + mean: f32, + + /// Pre-computed variance of the input data. + variance: f32, + + /// Epsilon value used to avoid divide-by-zero in sqrt. + epsilon: f32, + + /// Constant scale to multiply normalized value by. + scale: f32, + + /// Constant bias to add to normalized value. + bias: f32, +} + +/// Normalize the mean and variance of elements in `data` and apply a constant +/// scale and bias to the result. +/// +/// ```text +/// Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + bias +/// ``` +fn normalize_slice(data: &mut [f32], opts: NormalizeOptions) { + let NormalizeOptions { + mean, + variance, + epsilon, + scale, + bias, + } = opts; + + // To avoid divisions in the vectorized loop, we re-arrange: + // + // ``` + // Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + bias + // ``` + // + // As: + // + // ``` + // scaled_std_dev_reciprocal = scale / (input_var + epsilon).sqrt() + // Y = (X - input_mean) * scaled_std_dev_reciprocal + bias + // ``` + let scaled_std_dev_reciprocal = scale / (variance + epsilon).sqrt(); + vec_shift_scale_bias(data, mean, scaled_std_dev_reciprocal, bias); +} + /// Perform in-place batch normalization on the `NC*` tensor `out`. /// /// See . @@ -31,23 +79,26 @@ pub fn batch_norm_in_place( let batch = input.size(0); let chans = input.size(1); + input.make_contiguous(); + for n in 0..batch { for c in 0..chans { let chan_mean = mean[[c]]; let chan_var = var[[c]]; let chan_scale = scale[[c]]; let chan_bias = bias[[c]]; - - // The batch norm formula, from the ONNX spec, is: - // - // Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + bias - // - // It has been rewritten here to simplify the inner loop below. - let scaled_std_dev_reciprocal = chan_scale / (chan_var + epsilon).sqrt(); - - input - .slice_mut([n, c]) - .apply(|el| (*el - chan_mean) * scaled_std_dev_reciprocal + chan_bias); + let mut chan = input.slice_mut([n, c]); + let chan_data = chan.data_mut().unwrap(); + normalize_slice( + chan_data, + NormalizeOptions { + mean: chan_mean, + variance: chan_var, + epsilon, + scale: chan_scale, + bias: chan_bias, + }, + ); } } @@ -180,13 +231,16 @@ pub fn instance_normalization_in_place( let chan_mean = vec_sum(chan_data) / chan_data.len() as f32; let chan_variance = vec_sum_square_sub(chan_data, chan_mean) / chan_data.len() as f32; - // The instance norm formula, from the ONNX spec, is: - // - // Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + bias - // - // It has been rewritten here to optimize the inner loop. - let scaled_std_dev_reciprocal = chan_scale / (chan_variance + epsilon).sqrt(); - vec_shift_scale_bias(chan_data, chan_mean, scaled_std_dev_reciprocal, chan_bias); + normalize_slice( + chan_data, + NormalizeOptions { + mean: chan_mean, + variance: chan_variance, + epsilon, + scale: chan_scale, + bias: chan_bias, + }, + ); } }