Skip to content

Commit

Permalink
Add test cases for reducing non-contiguous views
Browse files Browse the repository at this point in the history
  • Loading branch information
robertknight committed Dec 16, 2024
1 parent 8d518fa commit 8a249ad
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion src/ops/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ mod tests {

use rten_tensor::prelude::*;
use rten_tensor::test_util::{eq_with_nans, expect_equal};
use rten_tensor::{NdTensor, Tensor};
use rten_tensor::{NdTensor, SliceRange, Tensor};

use crate::ops::tests::{new_pool, run_op};
use crate::ops::{
Expand Down Expand Up @@ -1049,6 +1049,8 @@ mod tests {
Ok(())
}

// Tests for ReduceMean specifically that also cover common functionality
// across the different reductions.
#[test]
fn test_reduce_mean() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
Expand Down Expand Up @@ -1108,6 +1110,44 @@ mod tests {
.unwrap();
assert_eq!(result.to_vec(), &[5.0]);

// Reduce non-contiguous lane
let tensor = Tensor::from([0., 1., 2., 3., 4., 5., 6.]);
let slice = tensor.slice(SliceRange::new(0, None, 2));
let expected_mean = slice.iter().sum::<f32>() / slice.len() as f32;
let result = reduce_mean(&pool, slice.view(), Some(&[0]), false /* keep_dims */).unwrap();
assert_eq!(result.to_vec(), &[expected_mean]);

// Reduce contiguous lanes in non-contiguous tensor
let tensor = Tensor::from([[0., 1.], [2., 3.], [4., 5.]]);
let slice = tensor.slice(SliceRange::new(0, None, 2));
let result = reduce_mean(&pool, slice.view(), Some(&[1]), false /* keep_dims */).unwrap();
assert_eq!(result.to_vec(), &[0.5, 4.5]);

// Reduce multiple non-contiguous dimensions
let tensor = Tensor::from([[0., 1.], [2., 3.], [4., 5.]]);
let slice = tensor.slice((SliceRange::new(0, None, 2), SliceRange::new(0, None, 2)));
let expected_mean = slice.iter().sum::<f32>() / slice.len() as f32;
let result = reduce_mean(
&pool,
slice.view(),
Some(&[0, 1]),
false, /* keep_dims */
)
.unwrap();
assert_eq!(result.to_vec(), &[expected_mean]);

// Reduce multiple contiguous dimensions in non-contiguous tensor
let tensor = Tensor::from([[[0.], [1.]], [[2.], [3.]], [[4.], [5.]]]);
let slice = tensor.slice(SliceRange::new(0, None, 2));
let result = reduce_mean(
&pool,
slice.view(),
Some(&[1, 2]),
false, /* keep_dims */
)
.unwrap();
assert_eq!(result.to_vec(), &[0.5, 4.5]);

Ok(())
}

Expand Down

0 comments on commit 8a249ad

Please sign in to comment.