From 516a5047c20102871b88b8ae7aa785fe857cb392 Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Fri, 2 Aug 2024 23:26:46 +0200 Subject: [PATCH] Allow aliasing in ArrayView::from_shape Changes the checks in the ArrayView::from_shape constructor so that it allows a few more cases: custom strides that lead to overlapping are allowed. Before, both ArrayViewMut and ArrayView applied the same check, that the dimensions and strides must be such that no elements can be reached by more than one index. However, this rule only applies for mutable data, for ArrayView we can allow this kind of aliasing. This is in fact how broadcasting works, where we use strides to repeat the same array data multiple times. --- src/dimension/mod.rs | 123 ++++++++++++++++++++------------- src/impl_constructors.rs | 6 +- src/impl_views/constructors.rs | 6 +- tests/array.rs | 17 +++++ 4 files changed, 99 insertions(+), 53 deletions(-) diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 6c5cd0e84..2ae851d1c 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -100,6 +100,21 @@ pub fn size_of_shape_checked(dim: &D) -> Result } } +/// Select how aliasing is checked +/// +/// For owned or mutable data: +/// +/// The strides must not allow any element to be referenced by two different indices. +/// +#[derive(Copy, Clone, PartialEq)] +pub(crate) enum CanIndexCheckMode +{ + /// Owned or mutable: No aliasing + OwnedMutable, + /// Aliasing + ReadOnly, +} + /// Checks whether the given data and dimension meet the invariants of the /// `ArrayBase` type, assuming the strides are created using /// `dim.default_strides()` or `dim.fortran_strides()`. @@ -125,12 +140,13 @@ pub fn size_of_shape_checked(dim: &D) -> Result /// `A` and in units of bytes between the least address and greatest address /// accessible by moving along all axes does not exceed `isize::MAX`. pub(crate) fn can_index_slice_with_strides( - data: &[A], dim: &D, strides: &Strides, + data: &[A], dim: &D, strides: &Strides, mode: CanIndexCheckMode, ) -> Result<(), ShapeError> { if let Strides::Custom(strides) = strides { - can_index_slice(data, dim, strides) + can_index_slice(data, dim, strides, mode) } else { + // contiguous shapes: never aliasing, mode does not matter can_index_slice_not_custom(data.len(), dim) } } @@ -239,15 +255,19 @@ where D: Dimension /// allocation. (In other words, the pointer to the first element of the array /// must be computed using `offset_from_low_addr_ptr_to_logical_ptr` so that /// negative strides are correctly handled.) -pub(crate) fn can_index_slice(data: &[A], dim: &D, strides: &D) -> Result<(), ShapeError> +/// +/// Note, condition (4) is guaranteed to be checked last +pub(crate) fn can_index_slice( + data: &[A], dim: &D, strides: &D, mode: CanIndexCheckMode, +) -> Result<(), ShapeError> { // Check conditions 1 and 2 and calculate `max_offset`. let max_offset = max_abs_offset_check_overflow::(dim, strides)?; - can_index_slice_impl(max_offset, data.len(), dim, strides) + can_index_slice_impl(max_offset, data.len(), dim, strides, mode) } fn can_index_slice_impl( - max_offset: usize, data_len: usize, dim: &D, strides: &D, + max_offset: usize, data_len: usize, dim: &D, strides: &D, mode: CanIndexCheckMode, ) -> Result<(), ShapeError> { // Check condition 3. @@ -260,7 +280,7 @@ fn can_index_slice_impl( } // Check condition 4. - if !is_empty && dim_stride_overlap(dim, strides) { + if !is_empty && mode != CanIndexCheckMode::ReadOnly && dim_stride_overlap(dim, strides) { return Err(from_kind(ErrorKind::Unsupported)); } @@ -782,6 +802,7 @@ mod test slice_min_max, slices_intersect, solve_linear_diophantine_eq, + CanIndexCheckMode, IntoDimension, }; use crate::error::{from_kind, ErrorKind}; @@ -796,11 +817,11 @@ mod test let v: alloc::vec::Vec<_> = (0..12).collect(); let dim = (2, 3, 2).into_dimension(); let strides = (1, 2, 6).into_dimension(); - assert!(super::can_index_slice(&v, &dim, &strides).is_ok()); + assert!(super::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable).is_ok()); let strides = (2, 4, 12).into_dimension(); assert_eq!( - super::can_index_slice(&v, &dim, &strides), + super::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable), Err(from_kind(ErrorKind::OutOfBounds)) ); } @@ -848,71 +869,79 @@ mod test #[test] fn can_index_slice_ix0() { - can_index_slice::(&[1], &Ix0(), &Ix0()).unwrap(); - can_index_slice::(&[], &Ix0(), &Ix0()).unwrap_err(); + can_index_slice::(&[1], &Ix0(), &Ix0(), CanIndexCheckMode::OwnedMutable).unwrap(); + can_index_slice::(&[], &Ix0(), &Ix0(), CanIndexCheckMode::OwnedMutable).unwrap_err(); } #[test] fn can_index_slice_ix1() { - can_index_slice::(&[], &Ix1(0), &Ix1(0)).unwrap(); - can_index_slice::(&[], &Ix1(0), &Ix1(1)).unwrap(); - can_index_slice::(&[], &Ix1(1), &Ix1(0)).unwrap_err(); - can_index_slice::(&[], &Ix1(1), &Ix1(1)).unwrap_err(); - can_index_slice::(&[1], &Ix1(1), &Ix1(0)).unwrap(); - can_index_slice::(&[1], &Ix1(1), &Ix1(2)).unwrap(); - can_index_slice::(&[1], &Ix1(1), &Ix1(-1isize as usize)).unwrap(); - can_index_slice::(&[1], &Ix1(2), &Ix1(1)).unwrap_err(); - can_index_slice::(&[1, 2], &Ix1(2), &Ix1(0)).unwrap_err(); - can_index_slice::(&[1, 2], &Ix1(2), &Ix1(1)).unwrap(); - can_index_slice::(&[1, 2], &Ix1(2), &Ix1(-1isize as usize)).unwrap(); + let mode = CanIndexCheckMode::OwnedMutable; + can_index_slice::(&[], &Ix1(0), &Ix1(0), mode).unwrap(); + can_index_slice::(&[], &Ix1(0), &Ix1(1), mode).unwrap(); + can_index_slice::(&[], &Ix1(1), &Ix1(0), mode).unwrap_err(); + can_index_slice::(&[], &Ix1(1), &Ix1(1), mode).unwrap_err(); + can_index_slice::(&[1], &Ix1(1), &Ix1(0), mode).unwrap(); + can_index_slice::(&[1], &Ix1(1), &Ix1(2), mode).unwrap(); + can_index_slice::(&[1], &Ix1(1), &Ix1(-1isize as usize), mode).unwrap(); + can_index_slice::(&[1], &Ix1(2), &Ix1(1), mode).unwrap_err(); + can_index_slice::(&[1, 2], &Ix1(2), &Ix1(0), mode).unwrap_err(); + can_index_slice::(&[1, 2], &Ix1(2), &Ix1(1), mode).unwrap(); + can_index_slice::(&[1, 2], &Ix1(2), &Ix1(-1isize as usize), mode).unwrap(); } #[test] fn can_index_slice_ix2() { - can_index_slice::(&[], &Ix2(0, 0), &Ix2(0, 0)).unwrap(); - can_index_slice::(&[], &Ix2(0, 0), &Ix2(2, 1)).unwrap(); - can_index_slice::(&[], &Ix2(0, 1), &Ix2(0, 0)).unwrap(); - can_index_slice::(&[], &Ix2(0, 1), &Ix2(2, 1)).unwrap(); - can_index_slice::(&[], &Ix2(0, 2), &Ix2(0, 0)).unwrap(); - can_index_slice::(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err(); - can_index_slice::(&[1], &Ix2(1, 2), &Ix2(5, 1)).unwrap_err(); - can_index_slice::(&[1, 2], &Ix2(1, 2), &Ix2(5, 1)).unwrap(); - can_index_slice::(&[1, 2], &Ix2(1, 2), &Ix2(5, 2)).unwrap_err(); - can_index_slice::(&[1, 2, 3, 4, 5], &Ix2(2, 2), &Ix2(3, 1)).unwrap(); - can_index_slice::(&[1, 2, 3, 4], &Ix2(2, 2), &Ix2(3, 1)).unwrap_err(); + let mode = CanIndexCheckMode::OwnedMutable; + can_index_slice::(&[], &Ix2(0, 0), &Ix2(0, 0), mode).unwrap(); + can_index_slice::(&[], &Ix2(0, 0), &Ix2(2, 1), mode).unwrap(); + can_index_slice::(&[], &Ix2(0, 1), &Ix2(0, 0), mode).unwrap(); + can_index_slice::(&[], &Ix2(0, 1), &Ix2(2, 1), mode).unwrap(); + can_index_slice::(&[], &Ix2(0, 2), &Ix2(0, 0), mode).unwrap(); + can_index_slice::(&[], &Ix2(0, 2), &Ix2(2, 1), mode).unwrap_err(); + can_index_slice::(&[1], &Ix2(1, 2), &Ix2(5, 1), mode).unwrap_err(); + can_index_slice::(&[1, 2], &Ix2(1, 2), &Ix2(5, 1), mode).unwrap(); + can_index_slice::(&[1, 2], &Ix2(1, 2), &Ix2(5, 2), mode).unwrap_err(); + can_index_slice::(&[1, 2, 3, 4, 5], &Ix2(2, 2), &Ix2(3, 1), mode).unwrap(); + can_index_slice::(&[1, 2, 3, 4], &Ix2(2, 2), &Ix2(3, 1), mode).unwrap_err(); + + // aliasing strides: ok when readonly + can_index_slice::(&[0; 4], &Ix2(2, 2), &Ix2(1, 1), CanIndexCheckMode::OwnedMutable).unwrap_err(); + can_index_slice::(&[0; 4], &Ix2(2, 2), &Ix2(1, 1), CanIndexCheckMode::ReadOnly).unwrap(); } #[test] fn can_index_slice_ix3() { - can_index_slice::(&[], &Ix3(0, 0, 1), &Ix3(2, 1, 3)).unwrap(); - can_index_slice::(&[], &Ix3(1, 1, 1), &Ix3(2, 1, 3)).unwrap_err(); - can_index_slice::(&[1], &Ix3(1, 1, 1), &Ix3(2, 1, 3)).unwrap(); - can_index_slice::(&[1; 11], &Ix3(2, 2, 3), &Ix3(6, 3, 1)).unwrap_err(); - can_index_slice::(&[1; 12], &Ix3(2, 2, 3), &Ix3(6, 3, 1)).unwrap(); + let mode = CanIndexCheckMode::OwnedMutable; + can_index_slice::(&[], &Ix3(0, 0, 1), &Ix3(2, 1, 3), mode).unwrap(); + can_index_slice::(&[], &Ix3(1, 1, 1), &Ix3(2, 1, 3), mode).unwrap_err(); + can_index_slice::(&[1], &Ix3(1, 1, 1), &Ix3(2, 1, 3), mode).unwrap(); + can_index_slice::(&[1; 11], &Ix3(2, 2, 3), &Ix3(6, 3, 1), mode).unwrap_err(); + can_index_slice::(&[1; 12], &Ix3(2, 2, 3), &Ix3(6, 3, 1), mode).unwrap(); } #[test] fn can_index_slice_zero_size_elem() { - can_index_slice::<(), _>(&[], &Ix1(0), &Ix1(1)).unwrap(); - can_index_slice::<(), _>(&[()], &Ix1(1), &Ix1(1)).unwrap(); - can_index_slice::<(), _>(&[(), ()], &Ix1(2), &Ix1(1)).unwrap(); + let mode = CanIndexCheckMode::OwnedMutable; + can_index_slice::<(), _>(&[], &Ix1(0), &Ix1(1), mode).unwrap(); + can_index_slice::<(), _>(&[()], &Ix1(1), &Ix1(1), mode).unwrap(); + can_index_slice::<(), _>(&[(), ()], &Ix1(2), &Ix1(1), mode).unwrap(); // These might seem okay because the element type is zero-sized, but // there could be a zero-sized type such that the number of instances // in existence are carefully controlled. - can_index_slice::<(), _>(&[], &Ix1(1), &Ix1(1)).unwrap_err(); - can_index_slice::<(), _>(&[()], &Ix1(2), &Ix1(1)).unwrap_err(); + can_index_slice::<(), _>(&[], &Ix1(1), &Ix1(1), mode).unwrap_err(); + can_index_slice::<(), _>(&[()], &Ix1(2), &Ix1(1), mode).unwrap_err(); - can_index_slice::<(), _>(&[(), ()], &Ix2(2, 1), &Ix2(1, 0)).unwrap(); - can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(0, 0)).unwrap(); + can_index_slice::<(), _>(&[(), ()], &Ix2(2, 1), &Ix2(1, 0), mode).unwrap(); + can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(0, 0), mode).unwrap(); // This case would be probably be sound, but that's not entirely clear // and it's not worth the special case code. - can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err(); + can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1), mode).unwrap_err(); } quickcheck! { @@ -923,8 +952,8 @@ mod test // Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`. result.is_err() } else { - result == can_index_slice(&data, &dim, &dim.default_strides()) && - result == can_index_slice(&data, &dim, &dim.fortran_strides()) + result == can_index_slice(&data, &dim, &dim.default_strides(), CanIndexCheckMode::OwnedMutable) && + result == can_index_slice(&data, &dim, &dim.fortran_strides(), CanIndexCheckMode::OwnedMutable) } } } diff --git a/src/impl_constructors.rs b/src/impl_constructors.rs index 3bdde09b5..260937a90 100644 --- a/src/impl_constructors.rs +++ b/src/impl_constructors.rs @@ -20,8 +20,8 @@ use num_traits::{One, Zero}; use std::mem; use std::mem::MaybeUninit; -use crate::dimension; use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr; +use crate::dimension::{self, CanIndexCheckMode}; use crate::error::{self, ShapeError}; use crate::extension::nonnull::nonnull_from_vec_data; use crate::imp_prelude::*; @@ -466,7 +466,7 @@ where { let dim = shape.dim; let is_custom = shape.strides.is_custom(); - dimension::can_index_slice_with_strides(&v, &dim, &shape.strides)?; + dimension::can_index_slice_with_strides(&v, &dim, &shape.strides, dimension::CanIndexCheckMode::OwnedMutable)?; if !is_custom && dim.size() != v.len() { return Err(error::incompatible_shapes(&Ix1(v.len()), &dim)); } @@ -510,7 +510,7 @@ where unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec) -> Self { // debug check for issues that indicates wrong use of this constructor - debug_assert!(dimension::can_index_slice(&v, &dim, &strides).is_ok()); + debug_assert!(dimension::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable).is_ok()); let ptr = nonnull_from_vec_data(&mut v).add(offset_from_low_addr_ptr_to_logical_ptr(&dim, &strides)); ArrayBase::from_data_ptr(DataOwned::new(v), ptr).with_strides_dim(strides, dim) diff --git a/src/impl_views/constructors.rs b/src/impl_views/constructors.rs index dcbec991b..15f2b9b6b 100644 --- a/src/impl_views/constructors.rs +++ b/src/impl_views/constructors.rs @@ -8,8 +8,8 @@ use std::ptr::NonNull; -use crate::dimension; use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr; +use crate::dimension::{self, CanIndexCheckMode}; use crate::error::ShapeError; use crate::extension::nonnull::nonnull_debug_checked_from_ptr; use crate::imp_prelude::*; @@ -54,7 +54,7 @@ where D: Dimension fn from_shape_impl(shape: StrideShape, xs: &'a [A]) -> Result { let dim = shape.dim; - dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?; + dimension::can_index_slice_with_strides(xs, &dim, &shape.strides, CanIndexCheckMode::ReadOnly)?; let strides = shape.strides.strides_for_dim(&dim); unsafe { Ok(Self::new_( @@ -157,7 +157,7 @@ where D: Dimension fn from_shape_impl(shape: StrideShape, xs: &'a mut [A]) -> Result { let dim = shape.dim; - dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?; + dimension::can_index_slice_with_strides(xs, &dim, &shape.strides, CanIndexCheckMode::OwnedMutable)?; let strides = shape.strides.strides_for_dim(&dim); unsafe { Ok(Self::new_( diff --git a/tests/array.rs b/tests/array.rs index 8f01d0636..4de22794c 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -10,6 +10,7 @@ use defmac::defmac; use itertools::{zip, Itertools}; use ndarray::indices; use ndarray::prelude::*; +use ndarray::ErrorKind; use ndarray::{arr3, rcarr2}; use ndarray::{Slice, SliceInfo, SliceInfoElem}; use num_complex::Complex; @@ -2060,6 +2061,22 @@ fn test_view_from_shape() assert_eq!(a, answer); } +#[test] +fn test_view_from_shape_allow_overlap() +{ + let data = [0, 1, 2]; + let view = ArrayView::from_shape((2, 3).strides((0, 1)), &data).unwrap(); + assert_eq!(view, aview2(&[data; 2])); +} + +#[test] +fn test_view_mut_from_shape_deny_overlap() +{ + let mut data = [0, 1, 2]; + let result = ArrayViewMut::from_shape((2, 3).strides((0, 1)), &mut data); + assert_matches!(result.map_err(|e| e.kind()), Err(ErrorKind::Unsupported)); +} + #[test] fn test_contiguous() {