diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 2a0ec4b0d..28d2e9b2c 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -544,7 +544,6 @@ fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> { } /// Returns `true` iff the slices intersect. -#[allow(dead_code)] pub fn slices_intersect( dim: &D, indices1: &D::SliceArg, diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 97193de36..027f5a8af 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -28,6 +28,7 @@ use crate::iter::{ AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut, IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows, }; +use crate::slice::MultiSlice; use crate::stacking::stack; use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex}; @@ -350,6 +351,39 @@ where self.view_mut().slice_move(info) } + /// Return multiple disjoint, sliced, mutable views of the array. + /// + /// See [*Slicing*](#slicing) for full documentation. + /// See also [`SliceInfo`] and [`D::SliceArg`]. + /// + /// [`SliceInfo`]: struct.SliceInfo.html + /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg + /// + /// **Panics** if any of the following occur: + /// + /// * if any of the views would intersect (i.e. if any element would appear in multiple slices) + /// * if an index is out of bounds or step size is zero + /// * if `D` is `IxDyn` and `info` does not match the number of array axes + /// + /// # Example + /// + /// ``` + /// use ndarray::{arr2, s}; + /// + /// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]); + /// let (mut edges, mut middle) = a.multi_slice_mut((s![.., ..;2], s![.., 1])); + /// edges.fill(1); + /// middle.fill(0); + /// assert_eq!(a, arr2(&[[1, 0, 1], [1, 0, 1]])); + /// ``` + pub fn multi_slice_mut<'a, M>(&'a mut self, info: M) -> M::Output + where + M: MultiSlice<'a, A, D>, + S: DataMut, + { + info.multi_slice_move(self.view_mut()) + } + /// Slice the array, possibly changing the number of dimensions. /// /// See [*Slicing*](#slicing) for full documentation. diff --git a/src/impl_views/splitting.rs b/src/impl_views/splitting.rs index e9fcfcb64..dcfb04b86 100644 --- a/src/impl_views/splitting.rs +++ b/src/impl_views/splitting.rs @@ -7,6 +7,7 @@ // except according to those terms. use crate::imp_prelude::*; +use crate::slice::MultiSlice; /// Methods for read-only array views. impl<'a, A, D> ArrayView<'a, A, D> @@ -109,4 +110,29 @@ where (left.deref_into_view_mut(), right.deref_into_view_mut()) } } + + /// Split the view into multiple disjoint slices. + /// + /// This is similar to [`.multi_slice_mut()`], but `.multi_slice_move()` + /// consumes `self` and produces views with lifetimes matching that of + /// `self`. + /// + /// See [*Slicing*](#slicing) for full documentation. + /// See also [`SliceInfo`] and [`D::SliceArg`]. + /// + /// [`.multi_slice_mut()`]: struct.ArrayBase.html#method.multi_slice_mut + /// [`SliceInfo`]: struct.SliceInfo.html + /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg + /// + /// **Panics** if any of the following occur: + /// + /// * if any of the views would intersect (i.e. if any element would appear in multiple slices) + /// * if an index is out of bounds or step size is zero + /// * if `D` is `IxDyn` and `info` does not match the number of array axes + pub fn multi_slice_move(self, info: M) -> M::Output + where + M: MultiSlice<'a, A, D>, + { + info.multi_slice_move(self) + } } diff --git a/src/lib.rs b/src/lib.rs index 35d1e1aab..680dc8557 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -475,6 +475,13 @@ pub type Ixs = isize; /// [`.slice_move()`]: #method.slice_move /// [`.slice_collapse()`]: #method.slice_collapse /// +/// It's possible to take multiple simultaneous *mutable* slices with +/// [`.multi_slice_mut()`] or (for [`ArrayViewMut`] only) +/// [`.multi_slice_move()`]. +/// +/// [`.multi_slice_mut()`]: #method.multi_slice_mut +/// [`.multi_slice_move()`]: type.ArrayViewMut.html#method.multi_slice_move +/// /// ``` /// extern crate ndarray; /// @@ -525,6 +532,20 @@ pub type Ixs = isize; /// [12, 11, 10]]); /// assert_eq!(f, g); /// assert_eq!(f.shape(), &[2, 3]); +/// +/// // Let's take two disjoint, mutable slices of a matrix with +/// // +/// // - One containing all the even-index columns in the matrix +/// // - One containing all the odd-index columns in the matrix +/// let mut h = arr2(&[[0, 1, 2, 3], +/// [4, 5, 6, 7]]); +/// let (s0, s1) = h.multi_slice_mut((s![.., ..;2], s![.., 1..;2])); +/// let i = arr2(&[[0, 2], +/// [4, 6]]); +/// let j = arr2(&[[1, 3], +/// [5, 7]]); +/// assert_eq!(s0, i); +/// assert_eq!(s1, j); /// } /// ``` /// diff --git a/src/slice.rs b/src/slice.rs index 1d0dfa2b0..58bff48b5 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -5,8 +5,9 @@ // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. +use crate::dimension::slices_intersect; use crate::error::{ErrorKind, ShapeError}; -use crate::Dimension; +use crate::{ArrayViewMut, Dimension}; use std::fmt; use std::marker::PhantomData; use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; @@ -629,3 +630,103 @@ macro_rules! s( &*&$crate::s![@parse ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*] }; ); + +/// Slicing information describing multiple mutable, disjoint slices. +/// +/// It's unfortunate that we need `'a` and `A` to be parameters of the trait, +/// but they're necessary until Rust supports generic associated types. +pub trait MultiSlice<'a, A, D> +where + A: 'a, + D: Dimension, +{ + /// The type of the slices created by `.multi_slice_move()`. + type Output; + + /// Split the view into multiple disjoint slices. + /// + /// **Panics** if performing any individual slice panics or if the slices + /// are not disjoint (i.e. if they intersect). + fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output; +} + +impl<'a, A, D> MultiSlice<'a, A, D> for () +where + A: 'a, + D: Dimension, +{ + type Output = (); + + fn multi_slice_move(&self, _view: ArrayViewMut<'a, A, D>) -> Self::Output {} +} + +impl<'a, A, D, Do0> MultiSlice<'a, A, D> for (&SliceInfo,) +where + A: 'a, + D: Dimension, + Do0: Dimension, +{ + type Output = (ArrayViewMut<'a, A, Do0>,); + + fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { + (view.slice_move(self.0),) + } +} + +macro_rules! impl_multislice_tuple { + ([$($but_last:ident)*] $last:ident) => { + impl_multislice_tuple!(@def_impl ($($but_last,)* $last,), [$($but_last)*] $last); + }; + (@def_impl ($($all:ident,)*), [$($but_last:ident)*] $last:ident) => { + impl<'a, A, D, $($all,)*> MultiSlice<'a, A, D> for ($(&SliceInfo,)*) + where + A: 'a, + D: Dimension, + $($all: Dimension,)* + { + type Output = ($(ArrayViewMut<'a, A, $all>,)*); + + fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { + #[allow(non_snake_case)] + let ($($all,)*) = self; + + let shape = view.raw_dim(); + assert!(!impl_multislice_tuple!(@intersects_self &shape, ($($all,)*))); + + let raw_view = view.into_raw_view_mut(); + unsafe { + ( + $(raw_view.clone().slice_move($but_last).deref_into_view_mut(),)* + raw_view.slice_move($last).deref_into_view_mut(), + ) + } + } + } + }; + (@intersects_self $shape:expr, ($head:expr,)) => { + false + }; + (@intersects_self $shape:expr, ($head:expr, $($tail:expr,)*)) => { + $(slices_intersect($shape, $head, $tail)) ||* + || impl_multislice_tuple!(@intersects_self $shape, ($($tail,)*)) + }; +} + +impl_multislice_tuple!([Do0] Do1); +impl_multislice_tuple!([Do0 Do1] Do2); +impl_multislice_tuple!([Do0 Do1 Do2] Do3); +impl_multislice_tuple!([Do0 Do1 Do2 Do3] Do4); +impl_multislice_tuple!([Do0 Do1 Do2 Do3 Do4] Do5); + +impl<'a, A, D, T> MultiSlice<'a, A, D> for &T +where + A: 'a, + D: Dimension, + T: MultiSlice<'a, A, D>, +{ + type Output = T::Output; + + fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { + T::multi_slice_move(self, view) + } +} diff --git a/tests/array.rs b/tests/array.rs index ea374bc7c..db54b7e5f 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -15,6 +15,20 @@ use ndarray::{arr3, rcarr2}; use ndarray::{Slice, SliceInfo, SliceOrIndex}; use std::iter::FromIterator; +macro_rules! assert_panics { + ($body:expr) => { + if let Ok(v) = ::std::panic::catch_unwind(|| $body) { + panic!("assertion failed: should_panic; \ + non-panicking result: {:?}", v); + } + }; + ($body:expr, $($arg:tt)*) => { + if let Ok(_) = ::std::panic::catch_unwind(|| $body) { + panic!($($arg)*); + } + }; +} + #[test] fn test_matmul_arcarray() { let mut A = ArcArray::::zeros((2, 3)); @@ -328,6 +342,82 @@ fn test_slice_collapse_with_indices() { assert_eq!(vi, Array3::from_elem((1, 1, 1), elem)); } +#[test] +fn test_multislice() { + macro_rules! do_test { + ($arr:expr, $($s:expr),*) => { + { + let arr = $arr; + let copy = arr.clone(); + assert_eq!( + arr.multi_slice_mut(($($s,)*)), + ($(copy.clone().slice_mut($s),)*) + ); + } + }; + } + + let mut arr = Array1::from_iter(0..48).into_shape((8, 6)).unwrap(); + + assert_eq!( + (arr.clone().view_mut(),), + arr.multi_slice_mut((s![.., ..],)), + ); + assert_eq!(arr.multi_slice_mut(()), ()); + do_test!(&mut arr, s![0, ..]); + do_test!(&mut arr, s![0, ..], s![1, ..]); + do_test!(&mut arr, s![0, ..], s![-1, ..]); + do_test!(&mut arr, s![0, ..], s![1.., ..]); + do_test!(&mut arr, s![1, ..], s![..;2, ..]); + do_test!(&mut arr, s![..2, ..], s![2.., ..]); + do_test!(&mut arr, s![1..;2, ..], s![..;2, ..]); + do_test!(&mut arr, s![..;-2, ..], s![..;2, ..]); + do_test!(&mut arr, s![..;12, ..], s![3..;3, ..]); + do_test!(&mut arr, s![3, ..], s![..-1;-2, ..]); + do_test!(&mut arr, s![0, ..], s![1, ..], s![2, ..]); + do_test!(&mut arr, s![0, ..], s![1, ..], s![2, ..], s![3, ..]); +} + +#[test] +fn test_multislice_intersecting() { + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![3, ..], s![3, ..])); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![3, ..], s![3.., ..])); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![3, ..], s![..;3, ..])); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![..;6, ..], s![3..;3, ..])); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![2, ..], s![..-1;-2, ..])); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![4, ..], s![3, ..], s![3, ..])); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![3, ..], s![4, ..], s![3, ..])); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![3, ..], s![3, ..], s![4, ..])); + }); + assert_panics!({ + let mut arr = Array2::::zeros((8, 6)); + arr.multi_slice_mut((s![3, ..], s![3, ..], s![4, ..], s![3, ..])); + }); +} + #[should_panic] #[test] fn index_out_of_bounds() {