From 28675ebeaff673c087d563365c1ba5e878a82261 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Thu, 19 Dec 2024 07:09:41 +0000 Subject: [PATCH] Vectorize variance calculation and input scaling in InstanceNormalization InstanceNormalization has three steps: 1. Compute channel mean 2. Compute channel variance 3. Shift and scale result to normalize the mean and variance, and then apply a per-channel scale and bias Previously only step 1 was vectorized. This vectorizes steps 2 and 3 as well. Tested using the wav2vec example on x64, this made InstanceNormalization ~3x faster (~26ms -> ~8ms per run). --- src/ops/norm.rs | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/src/ops/norm.rs b/src/ops/norm.rs index 8393b87a..dcef8761 100644 --- a/src/ops/norm.rs +++ b/src/ops/norm.rs @@ -3,7 +3,10 @@ use std::mem::MaybeUninit; use rayon::prelude::*; use rten_tensor::prelude::*; use rten_tensor::{NdTensorView, Tensor, TensorView}; -use rten_vecmath::{vec_shift_scale_in_place, vec_softmax_in_place, vec_sum, vec_sum_square}; +use rten_vecmath::{ + vec_shift_scale_bias, vec_shift_scale_in_place, vec_softmax_in_place, vec_sum, vec_sum_square, + vec_sum_square_sub, +}; use crate::ops::static_dims; use crate::ops::{resolve_axis, InputList, IntoOpResult, OpError, Operator, Output, OutputList}; @@ -164,23 +167,18 @@ pub fn instance_normalization_in_place( )); } - // Needed for `slice_sum` below. + // Needed for `vec_*` ops below. input.make_contiguous(); for n in 0..batch { for c in 0..chans { let mut slice = input.slice_mut([n, c]); + let chan_data = slice.data_mut().unwrap(); + let chan_scale = scale[[c]]; let chan_bias = bias[[c]]; - let chan_mean = vec_sum(slice.data().unwrap()) / slice.len() as f32; - let chan_variance = slice - .iter() - .map(|x| { - let diff = *x - chan_mean; - diff * diff - }) - .sum::() - / slice.len() as f32; + let chan_mean = vec_sum(chan_data) / chan_data.len() as f32; + let chan_variance = vec_sum_square_sub(chan_data, chan_mean) / chan_data.len() as f32; // The instance norm formula, from the ONNX spec, is: // @@ -188,8 +186,7 @@ pub fn instance_normalization_in_place( // // It has been rewritten here to optimize the inner loop. let scaled_std_dev_reciprocal = chan_scale / (chan_variance + epsilon).sqrt(); - - slice.apply(|x| (*x - chan_mean) * scaled_std_dev_reciprocal + chan_bias) + vec_shift_scale_bias(chan_data, chan_mean, scaled_std_dev_reciprocal, chan_bias); } }