Skip to content

Commit

Permalink
fix/ndarray: remove reversed axes check (#1058)
Browse files Browse the repository at this point in the history
  • Loading branch information
AuruTus authored Dec 14, 2023
1 parent 4608cd9 commit fc97a28
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 9 deletions.
43 changes: 42 additions & 1 deletion burn-ndarray/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ where
.unwrap()
.into_shared();

NdArrayTensor { array }
// Transform column-major layout into row-major (standard) layout. (fix #1053)
let array = NdArrayTensor { array };
Self::reshape(array.clone(), array.shape())
}

fn to_slice_args<const D1: usize, const D2: usize>(
Expand Down Expand Up @@ -496,3 +498,42 @@ fn arg<E: NdArrayElement, const D: usize>(
array: output.into_shared(),
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn should_generate_row_major_layout_for_cat() {
let expected_shape: &[usize] = &[4, 6, 2];
let expected_strides: &[isize] = &[12, 2, 1];
let expected_array = NdArrayTensor::from_data(Data::<i32, 3>::from([
[[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0]],
[[7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0]],
[[13, 0], [14, 0], [15, 0], [16, 0], [17, 0], [18, 0]],
[[19, 0], [20, 0], [21, 0], [22, 0], [23, 0], [24, 0]],
]));

// unsqueeze dim on the outermost axis
let array = NdArrayOps::reshape(
NdArrayTensor::from_data(Data::<i32, 2>::from([
[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18],
[19, 20, 21, 22, 23, 24],
])),
Shape::from([4, 6, 1]),
);
let zeros = NdArrayTensor::<i32, 3>::from_data(Data::zeros([4, 6, 1]));
// make `ndarray` concatenates array on the outermost axis
let array = NdArrayOps::cat([array, zeros].to_vec(), 2);

assert!(array.array.is_standard_layout());
assert_eq!(array.array.shape(), expected_shape);
assert_eq!(array.array.strides(), expected_strides);
assert_eq!(
array.array.into_iter().collect::<Vec<_>>(),
expected_array.array.into_iter().collect::<Vec<_>>(),
);
}
}
9 changes: 1 addition & 8 deletions burn-ndarray/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,7 @@ macro_rules! reshape {
array $array:expr
) => {{
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
let safe_into_shape =
$array.is_standard_layout() ||
(
$array.ndim() > 1 &&
$array.raw_view().reversed_axes().is_standard_layout()
);

let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match safe_into_shape {
let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match $array.is_standard_layout() {
true => $array
.into_shape(dim)
.expect("Safe to change shape without relayout")
Expand Down
17 changes: 17 additions & 0 deletions burn-tensor/src/tests/ops/stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,21 @@ mod tests {

let output: Tensor<TestBackend, 4> = TestTensor::stack(vec![tensor_1, tensor_2], 3);
}

#[test]
fn should_generate_row_major_layout() {
let data_expected = Data::from([
[1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0],
[7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0],
[13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0],
[19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0],
]);

let tensor = Tensor::<TestBackend, 1, Int>::arange(1..25).reshape([4, 6]);
let zeros: Tensor<TestBackend, 2, Int> = Tensor::zeros([4, 6]);
let intersperse =
Tensor::stack::<3>([tensor.clone(), zeros.clone()].to_vec(), 2).reshape([4, 12]);

assert_eq!(data_expected, intersperse.into_data());
}
}

0 comments on commit fc97a28

Please sign in to comment.