From fc97a28f16676f9083a4deb7e2d3189828bbfd2e Mon Sep 17 00:00:00 2001 From: AuruTus <33182215+AuruTus@users.noreply.github.com> Date: Fri, 15 Dec 2023 07:09:24 +0800 Subject: [PATCH] fix/ndarray: remove reversed axes check (#1058) --- burn-ndarray/src/ops/base.rs | 43 +++++++++++++++++++++++++++++- burn-ndarray/src/tensor.rs | 9 +------ burn-tensor/src/tests/ops/stack.rs | 17 ++++++++++++ 3 files changed, 60 insertions(+), 9 deletions(-) diff --git a/burn-ndarray/src/ops/base.rs b/burn-ndarray/src/ops/base.rs index 6c391e6148..dae04bce0d 100644 --- a/burn-ndarray/src/ops/base.rs +++ b/burn-ndarray/src/ops/base.rs @@ -72,7 +72,9 @@ where .unwrap() .into_shared(); - NdArrayTensor { array } + // Transform column-major layout into row-major (standard) layout. (fix #1053) + let array = NdArrayTensor { array }; + Self::reshape(array.clone(), array.shape()) } fn to_slice_args( @@ -496,3 +498,42 @@ fn arg( array: output.into_shared(), } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn should_generate_row_major_layout_for_cat() { + let expected_shape: &[usize] = &[4, 6, 2]; + let expected_strides: &[isize] = &[12, 2, 1]; + let expected_array = NdArrayTensor::from_data(Data::::from([ + [[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0]], + [[7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0]], + [[13, 0], [14, 0], [15, 0], [16, 0], [17, 0], [18, 0]], + [[19, 0], [20, 0], [21, 0], [22, 0], [23, 0], [24, 0]], + ])); + + // unsqueeze dim on the outermost axis + let array = NdArrayOps::reshape( + NdArrayTensor::from_data(Data::::from([ + [1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18], + [19, 20, 21, 22, 23, 24], + ])), + Shape::from([4, 6, 1]), + ); + let zeros = NdArrayTensor::::from_data(Data::zeros([4, 6, 1])); + // make `ndarray` concatenates array on the outermost axis + let array = NdArrayOps::cat([array, zeros].to_vec(), 2); + + assert!(array.array.is_standard_layout()); + assert_eq!(array.array.shape(), expected_shape); + assert_eq!(array.array.strides(), expected_strides); + assert_eq!( + array.array.into_iter().collect::>(), + expected_array.array.into_iter().collect::>(), + ); + } +} diff --git a/burn-ndarray/src/tensor.rs b/burn-ndarray/src/tensor.rs index db99bc87e6..700f39201d 100644 --- a/burn-ndarray/src/tensor.rs +++ b/burn-ndarray/src/tensor.rs @@ -63,14 +63,7 @@ macro_rules! reshape { array $array:expr ) => {{ let dim = $crate::to_typed_dims!($n, $shape.dims, justdim); - let safe_into_shape = - $array.is_standard_layout() || - ( - $array.ndim() > 1 && - $array.raw_view().reversed_axes().is_standard_layout() - ); - - let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match safe_into_shape { + let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match $array.is_standard_layout() { true => $array .into_shape(dim) .expect("Safe to change shape without relayout") diff --git a/burn-tensor/src/tests/ops/stack.rs b/burn-tensor/src/tests/ops/stack.rs index 4dbe639bb7..302d0d34db 100644 --- a/burn-tensor/src/tests/ops/stack.rs +++ b/burn-tensor/src/tests/ops/stack.rs @@ -90,4 +90,21 @@ mod tests { let output: Tensor = TestTensor::stack(vec![tensor_1, tensor_2], 3); } + + #[test] + fn should_generate_row_major_layout() { + let data_expected = Data::from([ + [1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0], + [7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0], + [13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0], + [19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0], + ]); + + let tensor = Tensor::::arange(1..25).reshape([4, 6]); + let zeros: Tensor = Tensor::zeros([4, 6]); + let intersperse = + Tensor::stack::<3>([tensor.clone(), zeros.clone()].to_vec(), 2).reshape([4, 12]); + + assert_eq!(data_expected, intersperse.into_data()); + } }