diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index cf770c5fb..b2fa8caf6 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -375,18 +375,29 @@ pub fn do_slice(dim: &mut usize, stride: &mut usize, slice: Slice) -> isize { let m = end - start; let s = (*stride) as isize; - // Data pointer offset - let mut offset = stride_offset(start, *stride); - // Adjust for strides - // - // How to implement negative strides: - // - // Increase start pointer by - // old stride * (old dim - 1) - // to put the pointer completely in the other end - if step < 0 { - offset += stride_offset(m - 1, *stride); - } + // Compute data pointer offset. + let offset = if m == 0 { + // In this case, the resulting array is empty, so we *can* avoid performing a nonzero + // offset. + // + // In two special cases (which are the true reason for this `m == 0` check), we *must* avoid + // the nonzero offset corresponding to the general case. + // + // * When `end == 0 && step < 0`. (These conditions imply that `m == 0` since `to_abs_slice` + // ensures that `0 <= start <= end`.) We cannot execute `stride_offset(end - 1, *stride)` + // because the `end - 1` would underflow. + // + // * When `start == *dim && step > 0`. (These conditions imply that `m == 0` since + // `to_abs_slice` ensures that `start <= end <= *dim`.) We cannot use the offset returned + // by `stride_offset(start, *stride)` because that would be past the end of the axis. + 0 + } else if step < 0 { + // When the step is negative, the new first element is `end - 1`, not `start`, since the + // direction is reversed. + stride_offset(end - 1, *stride) + } else { + stride_offset(start, *stride) + }; // Update dimension. let abs_step = step.abs() as usize; diff --git a/tests/array.rs b/tests/array.rs index bff526b21..75cd4edd0 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -88,6 +88,16 @@ fn test_slice() { assert!(vi.iter().zip(A.iter()).all(|(a, b)| a == b)); } +#[test] +fn test_slice_edge_cases() { + let mut arr = Array3::::zeros((3, 4, 5)); + arr.slice_collapse(s![0..0;-1, .., ..]); + assert_eq!(arr.shape(), &[0, 4, 5]); + let mut arr = Array2::::from_shape_vec((1, 1).strides((10, 1)), vec![5]).unwrap(); + arr.slice_collapse(s![1..1, ..]); + assert_eq!(arr.shape(), &[0, 1]); +} + #[test] fn test_slice_inclusive_range() { let arr = array![[1, 2, 3], [4, 5, 6]];