Skip to content

Commit

Permalink
Vectorize variance calculation and input scaling in InstanceNormaliza…
Browse files Browse the repository at this point in the history
…tion

InstanceNormalization has three steps:

 1. Compute channel mean
 2. Compute channel variance
 3. Shift and scale result to normalize the mean and variance, and then
    apply a per-channel scale and bias

Previously only step 1 was vectorized. This vectorizes steps 2 and 3 as well.

Tested using the wav2vec example on x64, this made InstanceNormalization ~3x
faster (~26ms -> ~8ms per run).
  • Loading branch information
robertknight committed Dec 19, 2024
1 parent aa81094 commit 28675eb
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions src/ops/norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use std::mem::MaybeUninit;
use rayon::prelude::*;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, Tensor, TensorView};
use rten_vecmath::{vec_shift_scale_in_place, vec_softmax_in_place, vec_sum, vec_sum_square};
use rten_vecmath::{
vec_shift_scale_bias, vec_shift_scale_in_place, vec_softmax_in_place, vec_sum, vec_sum_square,
vec_sum_square_sub,
};

use crate::ops::static_dims;
use crate::ops::{resolve_axis, InputList, IntoOpResult, OpError, Operator, Output, OutputList};
Expand Down Expand Up @@ -164,32 +167,26 @@ pub fn instance_normalization_in_place(
));
}

// Needed for `slice_sum` below.
// Needed for `vec_*` ops below.
input.make_contiguous();

for n in 0..batch {
for c in 0..chans {
let mut slice = input.slice_mut([n, c]);
let chan_data = slice.data_mut().unwrap();

let chan_scale = scale[[c]];
let chan_bias = bias[[c]];
let chan_mean = vec_sum(slice.data().unwrap()) / slice.len() as f32;
let chan_variance = slice
.iter()
.map(|x| {
let diff = *x - chan_mean;
diff * diff
})
.sum::<f32>()
/ slice.len() as f32;
let chan_mean = vec_sum(chan_data) / chan_data.len() as f32;
let chan_variance = vec_sum_square_sub(chan_data, chan_mean) / chan_data.len() as f32;

// The instance norm formula, from the ONNX spec, is:
//
// Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + bias
//
// It has been rewritten here to optimize the inner loop.
let scaled_std_dev_reciprocal = chan_scale / (chan_variance + epsilon).sqrt();

slice.apply(|x| (*x - chan_mean) * scaled_std_dev_reciprocal + chan_bias)
vec_shift_scale_bias(chan_data, chan_mean, scaled_std_dev_reciprocal, chan_bias);
}
}

Expand Down

0 comments on commit 28675eb

Please sign in to comment.