diff --git a/src/ops/reduce.rs b/src/ops/reduce.rs index 039eb5fb..f163c19a 100644 --- a/src/ops/reduce.rs +++ b/src/ops/reduce.rs @@ -831,7 +831,7 @@ mod tests { use rten_tensor::prelude::*; use rten_tensor::test_util::{eq_with_nans, expect_equal}; - use rten_tensor::{NdTensor, Tensor}; + use rten_tensor::{NdTensor, SliceRange, Tensor}; use crate::ops::tests::{new_pool, run_op}; use crate::ops::{ @@ -1049,6 +1049,8 @@ mod tests { Ok(()) } + // Tests for ReduceMean specifically that also cover common functionality + // across the different reductions. #[test] fn test_reduce_mean() -> Result<(), Box> { let pool = new_pool(); @@ -1108,6 +1110,44 @@ mod tests { .unwrap(); assert_eq!(result.to_vec(), &[5.0]); + // Reduce non-contiguous lane + let tensor = Tensor::from([0., 1., 2., 3., 4., 5., 6.]); + let slice = tensor.slice(SliceRange::new(0, None, 2)); + let expected_mean = slice.iter().sum::() / slice.len() as f32; + let result = reduce_mean(&pool, slice.view(), Some(&[0]), false /* keep_dims */).unwrap(); + assert_eq!(result.to_vec(), &[expected_mean]); + + // Reduce contiguous lanes in non-contiguous tensor + let tensor = Tensor::from([[0., 1.], [2., 3.], [4., 5.]]); + let slice = tensor.slice(SliceRange::new(0, None, 2)); + let result = reduce_mean(&pool, slice.view(), Some(&[1]), false /* keep_dims */).unwrap(); + assert_eq!(result.to_vec(), &[0.5, 4.5]); + + // Reduce multiple non-contiguous dimensions + let tensor = Tensor::from([[0., 1.], [2., 3.], [4., 5.]]); + let slice = tensor.slice((SliceRange::new(0, None, 2), SliceRange::new(0, None, 2))); + let expected_mean = slice.iter().sum::() / slice.len() as f32; + let result = reduce_mean( + &pool, + slice.view(), + Some(&[0, 1]), + false, /* keep_dims */ + ) + .unwrap(); + assert_eq!(result.to_vec(), &[expected_mean]); + + // Reduce multiple contiguous dimensions in non-contiguous tensor + let tensor = Tensor::from([[[0.], [1.]], [[2.], [3.]], [[4.], [5.]]]); + let slice = tensor.slice(SliceRange::new(0, None, 2)); + let result = reduce_mean( + &pool, + slice.view(), + Some(&[1, 2]), + false, /* keep_dims */ + ) + .unwrap(); + assert_eq!(result.to_vec(), &[0.5, 4.5]); + Ok(()) }