From b5da8e073931991a2c67ec1119d941893f437058 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Wed, 18 Sep 2019 20:16:20 -0400 Subject: [PATCH] Limit MultiSlice impls to flat tuples --- src/impl_methods.rs | 2 +- src/impl_views/splitting.rs | 4 +- src/slice.rs | 208 +++++++++++++----------------------- tests/array.rs | 2 +- 4 files changed, 81 insertions(+), 135 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 4a4451c83..55084532e 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -381,7 +381,7 @@ where M: MultiSlice<'a, A, D>, S: DataMut, { - unsafe { info.slice_and_deref(self.raw_view_mut()) } + info.multi_slice_move(self.view_mut()) } /// Slice the array, possibly changing the number of dimensions. diff --git a/src/impl_views/splitting.rs b/src/impl_views/splitting.rs index 529fac152..49bf893c3 100644 --- a/src/impl_views/splitting.rs +++ b/src/impl_views/splitting.rs @@ -129,10 +129,10 @@ where /// * if any of the views would intersect (i.e. if any element would appear in multiple slices) /// * if an index is out of bounds or step size is zero /// * if `D` is `IxDyn` and `info` does not match the number of array axes - pub fn multi_slice_move(mut self, info: M) -> M::Output + pub fn multi_slice_move(self, info: M) -> M::Output where M: MultiSlice<'a, A, D>, { - unsafe { info.slice_and_deref(self.raw_view_mut()) } + info.multi_slice_move(self) } } diff --git a/src/slice.rs b/src/slice.rs index 575b60783..faaf16ce0 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -7,7 +7,7 @@ // except according to those terms. use crate::dimension::slices_intersect; use crate::error::{ErrorKind, ShapeError}; -use crate::{ArrayViewMut, Dimension, RawArrayViewMut}; +use crate::{ArrayViewMut, Dimension}; use std::fmt; use std::marker::PhantomData; use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; @@ -633,188 +633,134 @@ macro_rules! s( /// Slicing information describing multiple mutable, disjoint slices. /// -/// It's unfortunate that we need `'out` and `A` to be parameters of the trait, +/// It's unfortunate that we need `'a` and `A` to be parameters of the trait, /// but they're necessary until Rust supports generic associated types. -/// -/// # Safety -/// -/// Implementers of this trait must ensure that: -/// -/// * `.slice_and_deref()` panics or aborts if the slices would intersect, and -/// -/// * the `.intersects_self()`, `.intersects_indices()`, and -/// `.intersects_other()` implementations are correct. -pub unsafe trait MultiSlice<'out, A, D> +pub trait MultiSlice<'a, A, D> where - A: 'out, + A: 'a, D: Dimension, { - /// The type of the slices created by `.slice_and_deref()`. + /// The type of the slices created by `.multi_slice_move()`. type Output; /// Slice the raw view into multiple raw views, and dereference them. /// /// **Panics** if performing any individual slice panics or if the slices /// are not disjoint (i.e. if they intersect). - /// - /// # Safety - /// - /// The caller must ensure that it is safe to mutably dereference the view - /// using the lifetime `'out`. - unsafe fn slice_and_deref(&self, view: RawArrayViewMut) -> Self::Output; - - /// Returns `true` if slicing an array of the specified `shape` with `self` - /// would result in intersecting slices. - /// - /// If `self.intersects_self(&view.raw_dim())` is `true`, then - /// `self.slice_and_deref(view)` must panic. - fn intersects_self(&self, shape: &D) -> bool; - - /// Returns `true` if any slices created by slicing an array of the - /// specified `shape` with `self` would intersect with the specified - /// indices. - /// - /// Note that even if this returns `false`, `self.intersects_self(shape)` - /// may still return `true`. (`.intersects_indices()` doesn't check for - /// intersections within `self`; it only checks for intersections between - /// `self` and `indices`.) - fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool; - - /// Returns `true` if any slices created by slicing an array of the - /// specified `shape` with `self` would intersect any slices created by - /// slicing the array with `other`. - /// - /// Note that even if this returns `false`, `self.intersects_self(shape)` - /// or `other.intersects_self(shape)` may still return `true`. - /// (`.intersects_other()` doesn't check for intersections within `self` or - /// within `other`; it only checks for intersections between `self` and - /// `other`.) - fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool; + fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output; } -unsafe impl<'out, A, D, Do> MultiSlice<'out, A, D> for SliceInfo +impl<'a, A, D> MultiSlice<'a, A, D> for () where - A: 'out, + A: 'a, D: Dimension, - Do: Dimension, { - type Output = ArrayViewMut<'out, A, Do>; - - unsafe fn slice_and_deref(&self, view: RawArrayViewMut) -> Self::Output { - view.slice_move(self).deref_into_view_mut() - } - - fn intersects_self(&self, _shape: &D) -> bool { - false - } - - fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool { - slices_intersect(shape, &*self, indices) - } + type Output = (); - fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool { - other.intersects_indices(shape, &*self) + fn multi_slice_move(&self, _view: ArrayViewMut<'a, A, D>) -> Self::Output { + () } } -unsafe impl<'out, A, D> MultiSlice<'out, A, D> for () +impl<'a, A, D, Do0> MultiSlice<'a, A, D> for (SliceInfo,) where - A: 'out, + A: 'a, D: Dimension, + D::SliceArg: Sized, + Do0: Dimension, { - type Output = (); - - unsafe fn slice_and_deref(&self, _view: RawArrayViewMut) -> Self::Output { - () - } + type Output = (ArrayViewMut<'a, A, Do0>,); - fn intersects_self(&self, _shape: &D) -> bool { - false + fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { + (view.slice_move(&self.0),) } +} - fn intersects_indices(&self, _shape: &D, _indices: &D::SliceArg) -> bool { - false - } +impl<'a, A, D, Do0> MultiSlice<'a, A, D> for (&SliceInfo,) +where + A: 'a, + D: Dimension, + Do0: Dimension, +{ + type Output = (ArrayViewMut<'a, A, Do0>,); - fn intersects_other(&self, _shape: &D, _other: impl MultiSlice<'out, A, D>) -> bool { - false + fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { + (view.slice_move(self.0),) } } macro_rules! impl_multislice_tuple { - ($($T:ident,)*) => { - unsafe impl<'out, A, D, $($T,)*> MultiSlice<'out, A, D> for ($($T,)*) + ($($Do:ident,)*) => { + impl<'a, A, D, $($Do,)*> MultiSlice<'a, A, D> for ($(SliceInfo,)*) where - A: 'out, + A: 'a, D: Dimension, - $($T: MultiSlice<'out, A, D>,)* + D::SliceArg: Sized, + $($Do: Dimension,)* { - type Output = ($($T::Output,)*); - - unsafe fn slice_and_deref(&self, view: RawArrayViewMut) -> Self::Output { - assert!(!self.intersects_self(&view.raw_dim())); + type Output = ($(ArrayViewMut<'a, A, $Do>,)*); + fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { #[allow(non_snake_case)] - let ($($T,)*) = self; - ($($T.slice_and_deref(view.clone()),)*) - } + let ($($Do,)*) = self; - fn intersects_self(&self, shape: &D) -> bool { - #[allow(non_snake_case)] - let ($($T,)*) = self; - impl_multislice_tuple!(@intersects_self shape, ($($T,)*)) - } + let shape = view.raw_dim(); + assert!(!impl_multislice_tuple!(@intersects_self &shape, ($(&$Do,)*))); - fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool { - #[allow(non_snake_case)] - let ($($T,)*) = self; - $($T.intersects_indices(shape, indices)) ||* + let raw_view = view.into_raw_view_mut(); + unsafe { + ($(raw_view.clone().slice_move(&$Do).deref_into_view_mut(),)*) + } } + } + + impl<'a, A, D, $($Do,)*> MultiSlice<'a, A, D> for ($(&SliceInfo,)*) + where + A: 'a, + D: Dimension, + $($Do: Dimension,)* + { + type Output = ($(ArrayViewMut<'a, A, $Do>,)*); - fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool { + fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { #[allow(non_snake_case)] - let ($($T,)*) = self; - $($T.intersects_other(shape, &other)) ||* + let ($($Do,)*) = self; + + let shape = view.raw_dim(); + assert!(!impl_multislice_tuple!(@intersects_self &shape, ($($Do,)*))); + + let raw_view = view.into_raw_view_mut(); + unsafe { + ($(raw_view.clone().slice_move($Do).deref_into_view_mut(),)*) + } } } }; + (@intersects_self $shape:expr, ($head:expr,)) => { - $head.intersects_self($shape) + false }; (@intersects_self $shape:expr, ($head:expr, $($tail:expr,)*)) => { - $head.intersects_self($shape) || - $($head.intersects_other($shape, &$tail)) ||* || - impl_multislice_tuple!(@intersects_self $shape, ($($tail,)*)) + $(slices_intersect($shape, $head, $tail)) ||* + || impl_multislice_tuple!(@intersects_self $shape, ($($tail,)*)) }; } -impl_multislice_tuple!(T0,); -impl_multislice_tuple!(T0, T1,); -impl_multislice_tuple!(T0, T1, T2,); -impl_multislice_tuple!(T0, T1, T2, T3,); -impl_multislice_tuple!(T0, T1, T2, T3, T4,); -impl_multislice_tuple!(T0, T1, T2, T3, T4, T5,); - -unsafe impl<'out, A, D, T> MultiSlice<'out, A, D> for &'_ T + +impl_multislice_tuple!(Do0, Do1,); +impl_multislice_tuple!(Do0, Do1, Do2,); +impl_multislice_tuple!(Do0, Do1, Do2, Do3,); +impl_multislice_tuple!(Do0, Do1, Do2, Do3, Do4,); +impl_multislice_tuple!(Do0, Do1, Do2, Do3, Do4, Do5,); + +impl<'a, A, D, T> MultiSlice<'a, A, D> for &T where - A: 'out, + A: 'a, D: Dimension, - T: MultiSlice<'out, A, D>, + T: MultiSlice<'a, A, D>, { type Output = T::Output; - unsafe fn slice_and_deref(&self, view: RawArrayViewMut) -> Self::Output { - T::slice_and_deref(self, view) - } - - fn intersects_self(&self, shape: &D) -> bool { - T::intersects_self(self, shape) - } - - fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool { - T::intersects_indices(self, shape, indices) - } - - fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool { - T::intersects_other(self, shape, other) + fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { + T::multi_slice_move(self, view) } } diff --git a/tests/array.rs b/tests/array.rs index ef70e5c0e..60df74ca7 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -354,7 +354,7 @@ fn test_multislice() { }); let mut arr = Array1::from_iter(0..48).into_shape((8, 6)).unwrap(); - assert_eq!(arr.clone().view(), arr.multi_slice_mut(s![.., ..])); + assert_eq!((arr.clone().view_mut(),), arr.multi_slice_mut((s![.., ..],))); test_multislice!(&mut arr, s![0, ..], s![1, ..]); test_multislice!(&mut arr, s![0, ..], s![-1, ..]); test_multislice!(&mut arr, s![0, ..], s![1.., ..]);