Skip to content

Commit

Permalink
Rename Reducer -> ReduceKernel
Browse files Browse the repository at this point in the history
For consistency with other code, use the term "Kernel" to describe the code
which handles the inner loop of reduction ops.
  • Loading branch information
robertknight committed Dec 16, 2024
1 parent 8a249ad commit 3168392
Showing 1 changed file with 33 additions and 28 deletions.
61 changes: 33 additions & 28 deletions src/ops/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,18 +262,23 @@ impl Operator for NonZero {
}
}

/// Trait for reducing a subset of elements from a tensor to a single value.
trait Reducer<T> {
/// Kernel that handles reducing a single slice of the input.
trait ReduceKernel<T> {
/// Reduce a contiguous slice of values to a single value.
fn reduce_slice(&self, slice: &[T]) -> T;
}

/// Outer loop of reduction operations.
///
/// This iterates over slices of the input that are reduced independently and
/// invokes the kernel on that slice. If the input is not contiguous, the slice
/// is packed before calling the kernel.
fn reduce<T: Copy>(
pool: &TensorPool,
input: TensorView<T>,
axes: Option<&[i32]>,
keep_dims: bool,
reducer: &dyn Reducer<T>,
kernel: &dyn ReduceKernel<T>,
) -> Result<Tensor<T>, OpError> {
let mut resolved_axes = match axes {
Some(axes) if !axes.is_empty() => resolve_axes(input.ndim(), axes.iter())?,
Expand All @@ -293,7 +298,7 @@ fn reduce<T: Copy>(

if input.ndim() == 0 {
let item = input.item().unwrap();
return Ok(Tensor::from_scalar(reducer.reduce_slice(&[*item])));
return Ok(Tensor::from_scalar(kernel.reduce_slice(&[*item])));
}

// nb. Some reduce operations cannot produce a meaningful result with
Expand Down Expand Up @@ -336,7 +341,7 @@ fn reduce<T: Copy>(
reduced_data.extend(
input_data
.chunks(slice_len)
.map(|chunk| reducer.reduce_slice(chunk)),
.map(|chunk| kernel.reduce_slice(chunk)),
);
}
_ => {
Expand All @@ -345,11 +350,11 @@ fn reduce<T: Copy>(
let resolved_axis = resolved_axes[0];
reduced_data.extend(input.lanes(resolved_axis).map(|lane| {
if let Some(lane_slice) = lane.as_slice() {
reducer.reduce_slice(lane_slice)
kernel.reduce_slice(lane_slice)
} else {
tmp_buf.clear();
tmp_buf.extend(lane.copied());
reducer.reduce_slice(&tmp_buf)
kernel.reduce_slice(&tmp_buf)
}
}));
} else {
Expand All @@ -363,12 +368,12 @@ fn reduce<T: Copy>(
// The reduced dimensions may be contiguous even if the
// tensor is not.
let reduced = if let Some(data) = slice.data() {
reducer.reduce_slice(data)
kernel.reduce_slice(data)
} else {
tmp_buf.clear();
let tmp_uninit = &mut tmp_buf.spare_capacity_mut()[..slice.len()];
let tmp = slice.copy_into_slice(tmp_uninit);
reducer.reduce_slice(tmp)
kernel.reduce_slice(tmp)
};
reduced_data.push(reduced);
}
Expand All @@ -393,14 +398,14 @@ pub fn reduce_mean(
axes: Option<&[i32]>,
keep_dims: bool,
) -> Result<Tensor, OpError> {
struct MeanReducer {}
impl Reducer<f32> for MeanReducer {
struct MeanKernel {}
impl ReduceKernel<f32> for MeanKernel {
fn reduce_slice(&self, slice: &[f32]) -> f32 {
vec_sum(slice) / slice.len() as f32
}
}

reduce(pool, input, axes, keep_dims, &MeanReducer {})
reduce(pool, input, axes, keep_dims, &MeanKernel {})
}

/// Reduces axes of a tensor using an inverse Root Mean Squared (RMS)
Expand All @@ -418,18 +423,18 @@ pub fn reduce_inverse_rms(
keep_dims: bool,
epsilon: f32,
) -> Result<Tensor, OpError> {
struct InverseRmsReducer {
struct InverseRmsKernel {
epsilon: f32,
}

impl Reducer<f32> for InverseRmsReducer {
impl ReduceKernel<f32> for InverseRmsKernel {
fn reduce_slice(&self, slice: &[f32]) -> f32 {
let mean_square = vec_sum_square(slice) / slice.len() as f32;
1. / (mean_square + self.epsilon).sqrt()
}
}

reduce(pool, input, axes, keep_dims, &InverseRmsReducer { epsilon })
reduce(pool, input, axes, keep_dims, &InverseRmsKernel { epsilon })
}

#[derive(Debug)]
Expand Down Expand Up @@ -462,14 +467,14 @@ pub fn reduce_l2(
axes: Option<&[i32]>,
keep_dims: bool,
) -> Result<Tensor, OpError> {
struct L2Reducer {}
impl Reducer<f32> for L2Reducer {
struct L2ReduceKernel {}
impl ReduceKernel<f32> for L2ReduceKernel {
fn reduce_slice(&self, slice: &[f32]) -> f32 {
vec_sum_square(slice).sqrt()
}
}

reduce(pool, input, axes, keep_dims, &L2Reducer {})
reduce(pool, input, axes, keep_dims, &L2ReduceKernel {})
}

#[derive(Debug)]
Expand Down Expand Up @@ -534,7 +539,7 @@ fn reduce_min_max<T: Copy + PartialOrd + IsNaN>(
struct MinMaxReducer {
max: bool,
}
impl<T: Copy + PartialOrd + IsNaN> Reducer<T> for MinMaxReducer {
impl<T: Copy + PartialOrd + IsNaN> ReduceKernel<T> for MinMaxReducer {
fn reduce_slice(&self, slice: &[T]) -> T {
let reduced = if self.max {
slice.iter().copied().max_by(|a, b| cmp_nan_greater(*a, *b))
Expand Down Expand Up @@ -622,13 +627,13 @@ pub fn reduce_prod<T: Copy + std::iter::Product>(
axes: Option<&[i32]>,
keep_dims: bool,
) -> Result<Tensor<T>, OpError> {
struct ProdReducer {}
impl<T: Copy + std::iter::Product> Reducer<T> for ProdReducer {
struct ProdKernel {}
impl<T: Copy + std::iter::Product> ReduceKernel<T> for ProdKernel {
fn reduce_slice(&self, slice: &[T]) -> T {
slice.iter().copied().product()
}
}
reduce(pool, input, axes, keep_dims, &ProdReducer {})
reduce(pool, input, axes, keep_dims, &ProdKernel {})
}

#[derive(Debug)]
Expand All @@ -655,13 +660,13 @@ pub fn reduce_sum<T: Copy + Default + std::ops::Add<T, Output = T>>(
axes: Option<&[i32]>,
keep_dims: bool,
) -> Result<Tensor<T>, OpError> {
struct SumReducer {}
impl<T: Copy + Default + std::ops::Add<T, Output = T>> Reducer<T> for SumReducer {
struct SumKernel {}
impl<T: Copy + Default + std::ops::Add<T, Output = T>> ReduceKernel<T> for SumKernel {
fn reduce_slice(&self, slice: &[T]) -> T {
slice_sum(slice)
}
}
reduce(pool, input, axes, keep_dims, &SumReducer {})
reduce(pool, input, axes, keep_dims, &SumKernel {})
}

#[derive(Debug)]
Expand All @@ -688,13 +693,13 @@ pub fn reduce_sum_square<T: Copy + std::ops::Mul<T, Output = T> + std::iter::Sum
axes: Option<&[i32]>,
keep_dims: bool,
) -> Result<Tensor<T>, OpError> {
struct SumSquareReducer {}
impl<T: Copy + std::iter::Sum + std::ops::Mul<Output = T>> Reducer<T> for SumSquareReducer {
struct SumSquareKernel {}
impl<T: Copy + std::iter::Sum + std::ops::Mul<Output = T>> ReduceKernel<T> for SumSquareKernel {
fn reduce_slice(&self, slice: &[T]) -> T {
slice.iter().copied().map(|x| x * x).sum()
}
}
reduce(pool, input, axes, keep_dims, &SumSquareReducer {})
reduce(pool, input, axes, keep_dims, &SumSquareKernel {})
}

#[derive(Debug)]
Expand Down

0 comments on commit 3168392

Please sign in to comment.