diff --git a/src/ops/norm.rs b/src/ops/norm.rs index 8393b87a..dcef8761 100644 --- a/src/ops/norm.rs +++ b/src/ops/norm.rs @@ -3,7 +3,10 @@ use std::mem::MaybeUninit; use rayon::prelude::*; use rten_tensor::prelude::*; use rten_tensor::{NdTensorView, Tensor, TensorView}; -use rten_vecmath::{vec_shift_scale_in_place, vec_softmax_in_place, vec_sum, vec_sum_square}; +use rten_vecmath::{ + vec_shift_scale_bias, vec_shift_scale_in_place, vec_softmax_in_place, vec_sum, vec_sum_square, + vec_sum_square_sub, +}; use crate::ops::static_dims; use crate::ops::{resolve_axis, InputList, IntoOpResult, OpError, Operator, Output, OutputList}; @@ -164,23 +167,18 @@ pub fn instance_normalization_in_place( )); } - // Needed for `slice_sum` below. + // Needed for `vec_*` ops below. input.make_contiguous(); for n in 0..batch { for c in 0..chans { let mut slice = input.slice_mut([n, c]); + let chan_data = slice.data_mut().unwrap(); + let chan_scale = scale[[c]]; let chan_bias = bias[[c]]; - let chan_mean = vec_sum(slice.data().unwrap()) / slice.len() as f32; - let chan_variance = slice - .iter() - .map(|x| { - let diff = *x - chan_mean; - diff * diff - }) - .sum::() - / slice.len() as f32; + let chan_mean = vec_sum(chan_data) / chan_data.len() as f32; + let chan_variance = vec_sum_square_sub(chan_data, chan_mean) / chan_data.len() as f32; // The instance norm formula, from the ONNX spec, is: // @@ -188,8 +186,7 @@ pub fn instance_normalization_in_place( // // It has been rewritten here to optimize the inner loop. let scaled_std_dev_reciprocal = chan_scale / (chan_variance + epsilon).sqrt(); - - slice.apply(|x| (*x - chan_mean) * scaled_std_dev_reciprocal + chan_bias) + vec_shift_scale_bias(chan_data, chan_mean, scaled_std_dev_reciprocal, chan_bias); } }