From 3168392e5053d10a5b60ddd0afac3b7299d551e7 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Mon, 16 Dec 2024 20:57:42 +0100 Subject: [PATCH] Rename Reducer -> ReduceKernel For consistency with other code, use the term "Kernel" to describe the code which handles the inner loop of reduction ops. --- src/ops/reduce.rs | 61 +++++++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/src/ops/reduce.rs b/src/ops/reduce.rs index f163c19a..9206aabf 100644 --- a/src/ops/reduce.rs +++ b/src/ops/reduce.rs @@ -262,18 +262,23 @@ impl Operator for NonZero { } } -/// Trait for reducing a subset of elements from a tensor to a single value. -trait Reducer { +/// Kernel that handles reducing a single slice of the input. +trait ReduceKernel { /// Reduce a contiguous slice of values to a single value. fn reduce_slice(&self, slice: &[T]) -> T; } +/// Outer loop of reduction operations. +/// +/// This iterates over slices of the input that are reduced independently and +/// invokes the kernel on that slice. If the input is not contiguous, the slice +/// is packed before calling the kernel. fn reduce( pool: &TensorPool, input: TensorView, axes: Option<&[i32]>, keep_dims: bool, - reducer: &dyn Reducer, + kernel: &dyn ReduceKernel, ) -> Result, OpError> { let mut resolved_axes = match axes { Some(axes) if !axes.is_empty() => resolve_axes(input.ndim(), axes.iter())?, @@ -293,7 +298,7 @@ fn reduce( if input.ndim() == 0 { let item = input.item().unwrap(); - return Ok(Tensor::from_scalar(reducer.reduce_slice(&[*item]))); + return Ok(Tensor::from_scalar(kernel.reduce_slice(&[*item]))); } // nb. Some reduce operations cannot produce a meaningful result with @@ -336,7 +341,7 @@ fn reduce( reduced_data.extend( input_data .chunks(slice_len) - .map(|chunk| reducer.reduce_slice(chunk)), + .map(|chunk| kernel.reduce_slice(chunk)), ); } _ => { @@ -345,11 +350,11 @@ fn reduce( let resolved_axis = resolved_axes[0]; reduced_data.extend(input.lanes(resolved_axis).map(|lane| { if let Some(lane_slice) = lane.as_slice() { - reducer.reduce_slice(lane_slice) + kernel.reduce_slice(lane_slice) } else { tmp_buf.clear(); tmp_buf.extend(lane.copied()); - reducer.reduce_slice(&tmp_buf) + kernel.reduce_slice(&tmp_buf) } })); } else { @@ -363,12 +368,12 @@ fn reduce( // The reduced dimensions may be contiguous even if the // tensor is not. let reduced = if let Some(data) = slice.data() { - reducer.reduce_slice(data) + kernel.reduce_slice(data) } else { tmp_buf.clear(); let tmp_uninit = &mut tmp_buf.spare_capacity_mut()[..slice.len()]; let tmp = slice.copy_into_slice(tmp_uninit); - reducer.reduce_slice(tmp) + kernel.reduce_slice(tmp) }; reduced_data.push(reduced); } @@ -393,14 +398,14 @@ pub fn reduce_mean( axes: Option<&[i32]>, keep_dims: bool, ) -> Result { - struct MeanReducer {} - impl Reducer for MeanReducer { + struct MeanKernel {} + impl ReduceKernel for MeanKernel { fn reduce_slice(&self, slice: &[f32]) -> f32 { vec_sum(slice) / slice.len() as f32 } } - reduce(pool, input, axes, keep_dims, &MeanReducer {}) + reduce(pool, input, axes, keep_dims, &MeanKernel {}) } /// Reduces axes of a tensor using an inverse Root Mean Squared (RMS) @@ -418,18 +423,18 @@ pub fn reduce_inverse_rms( keep_dims: bool, epsilon: f32, ) -> Result { - struct InverseRmsReducer { + struct InverseRmsKernel { epsilon: f32, } - impl Reducer for InverseRmsReducer { + impl ReduceKernel for InverseRmsKernel { fn reduce_slice(&self, slice: &[f32]) -> f32 { let mean_square = vec_sum_square(slice) / slice.len() as f32; 1. / (mean_square + self.epsilon).sqrt() } } - reduce(pool, input, axes, keep_dims, &InverseRmsReducer { epsilon }) + reduce(pool, input, axes, keep_dims, &InverseRmsKernel { epsilon }) } #[derive(Debug)] @@ -462,14 +467,14 @@ pub fn reduce_l2( axes: Option<&[i32]>, keep_dims: bool, ) -> Result { - struct L2Reducer {} - impl Reducer for L2Reducer { + struct L2ReduceKernel {} + impl ReduceKernel for L2ReduceKernel { fn reduce_slice(&self, slice: &[f32]) -> f32 { vec_sum_square(slice).sqrt() } } - reduce(pool, input, axes, keep_dims, &L2Reducer {}) + reduce(pool, input, axes, keep_dims, &L2ReduceKernel {}) } #[derive(Debug)] @@ -534,7 +539,7 @@ fn reduce_min_max( struct MinMaxReducer { max: bool, } - impl Reducer for MinMaxReducer { + impl ReduceKernel for MinMaxReducer { fn reduce_slice(&self, slice: &[T]) -> T { let reduced = if self.max { slice.iter().copied().max_by(|a, b| cmp_nan_greater(*a, *b)) @@ -622,13 +627,13 @@ pub fn reduce_prod( axes: Option<&[i32]>, keep_dims: bool, ) -> Result, OpError> { - struct ProdReducer {} - impl Reducer for ProdReducer { + struct ProdKernel {} + impl ReduceKernel for ProdKernel { fn reduce_slice(&self, slice: &[T]) -> T { slice.iter().copied().product() } } - reduce(pool, input, axes, keep_dims, &ProdReducer {}) + reduce(pool, input, axes, keep_dims, &ProdKernel {}) } #[derive(Debug)] @@ -655,13 +660,13 @@ pub fn reduce_sum>( axes: Option<&[i32]>, keep_dims: bool, ) -> Result, OpError> { - struct SumReducer {} - impl> Reducer for SumReducer { + struct SumKernel {} + impl> ReduceKernel for SumKernel { fn reduce_slice(&self, slice: &[T]) -> T { slice_sum(slice) } } - reduce(pool, input, axes, keep_dims, &SumReducer {}) + reduce(pool, input, axes, keep_dims, &SumKernel {}) } #[derive(Debug)] @@ -688,13 +693,13 @@ pub fn reduce_sum_square + std::iter::Sum axes: Option<&[i32]>, keep_dims: bool, ) -> Result, OpError> { - struct SumSquareReducer {} - impl> Reducer for SumSquareReducer { + struct SumSquareKernel {} + impl> ReduceKernel for SumSquareKernel { fn reduce_slice(&self, slice: &[T]) -> T { slice.iter().copied().map(|x| x * x).sum() } } - reduce(pool, input, axes, keep_dims, &SumSquareReducer {}) + reduce(pool, input, axes, keep_dims, &SumSquareKernel {}) } #[derive(Debug)]