Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi_slice_* methods (supports flat tuples only) #717

Merged
1 change: 0 additions & 1 deletion src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,6 @@ fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> {
}

/// Returns `true` iff the slices intersect.
#[allow(dead_code)]
pub fn slices_intersect<D: Dimension>(
dim: &D,
indices1: &D::SliceArg,
Expand Down
34 changes: 34 additions & 0 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::iter::{
AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut,
IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows,
};
use crate::slice::MultiSlice;
use crate::stacking::stack;
use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex};

Expand Down Expand Up @@ -350,6 +351,39 @@ where
self.view_mut().slice_move(info)
}

/// Return multiple disjoint, sliced, mutable views of the array.
///
/// See [*Slicing*](#slicing) for full documentation.
/// See also [`SliceInfo`] and [`D::SliceArg`].
///
/// [`SliceInfo`]: struct.SliceInfo.html
/// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg
///
/// **Panics** if any of the following occur:
///
/// * if any of the views would intersect (i.e. if any element would appear in multiple slices)
/// * if an index is out of bounds or step size is zero
/// * if `D` is `IxDyn` and `info` does not match the number of array axes
///
/// # Example
///
/// ```
/// use ndarray::{arr2, s};
///
/// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]);
/// let (mut edges, mut middle) = a.multi_slice_mut((s![.., ..;2], s![.., 1]));
/// edges.fill(1);
/// middle.fill(0);
/// assert_eq!(a, arr2(&[[1, 0, 1], [1, 0, 1]]));
/// ```
pub fn multi_slice_mut<'a, M>(&'a mut self, info: M) -> M::Output
where
M: MultiSlice<'a, A, D>,
S: DataMut,
{
info.multi_slice_move(self.view_mut())
}

/// Slice the array, possibly changing the number of dimensions.
///
/// See [*Slicing*](#slicing) for full documentation.
Expand Down
26 changes: 26 additions & 0 deletions src/impl_views/splitting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// except according to those terms.

use crate::imp_prelude::*;
use crate::slice::MultiSlice;

/// Methods for read-only array views.
impl<'a, A, D> ArrayView<'a, A, D>
Expand Down Expand Up @@ -109,4 +110,29 @@ where
(left.deref_into_view_mut(), right.deref_into_view_mut())
}
}

/// Split the view into multiple disjoint slices.
///
/// This is similar to [`.multi_slice_mut()`], but `.multi_slice_move()`
/// consumes `self` and produces views with lifetimes matching that of
/// `self`.
///
/// See [*Slicing*](#slicing) for full documentation.
/// See also [`SliceInfo`] and [`D::SliceArg`].
///
/// [`.multi_slice_mut()`]: struct.ArrayBase.html#method.multi_slice_mut
/// [`SliceInfo`]: struct.SliceInfo.html
/// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg
///
/// **Panics** if any of the following occur:
///
/// * if any of the views would intersect (i.e. if any element would appear in multiple slices)
/// * if an index is out of bounds or step size is zero
/// * if `D` is `IxDyn` and `info` does not match the number of array axes
pub fn multi_slice_move<M>(self, info: M) -> M::Output
where
M: MultiSlice<'a, A, D>,
{
info.multi_slice_move(self)
}
}
21 changes: 21 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,13 @@ pub type Ixs = isize;
/// [`.slice_move()`]: #method.slice_move
/// [`.slice_collapse()`]: #method.slice_collapse
///
/// It's possible to take multiple simultaneous *mutable* slices with
/// [`.multi_slice_mut()`] or (for [`ArrayViewMut`] only)
/// [`.multi_slice_move()`].
///
/// [`.multi_slice_mut()`]: #method.multi_slice_mut
/// [`.multi_slice_move()`]: type.ArrayViewMut.html#method.multi_slice_move
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know about you, but I don't mind transitioning to nightly-only rustdoc links (because they render correctly in docs.rs).. Importance of rendering the links when the user generates docs on stable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be fine, although I'm having trouble getting nightly-style links to work correctly. The following don't work:

[`.multi_slice_mut()`]
[`multi_slice_mut()`]

If I add an explicit destination, multi_slice_mut works fine, but I'm not sure how to make multi_slice_move work:

This works:
[`.multi_slice_mut()`]

[`.multi_slice_mut()`]: ArrayBase::multi_slice_mut()


This doesn't:
[`.multi_slice_move()`]

[`.multi_slice_move()`]: ArrayViewMut::multi_slice_move()

///
/// ```
/// extern crate ndarray;
///
Expand Down Expand Up @@ -525,6 +532,20 @@ pub type Ixs = isize;
/// [12, 11, 10]]);
/// assert_eq!(f, g);
/// assert_eq!(f.shape(), &[2, 3]);
///
/// // Let's take two disjoint, mutable slices of a matrix with
/// //
/// // - One containing all the even-index columns in the matrix
/// // - One containing all the odd-index columns in the matrix
/// let mut h = arr2(&[[0, 1, 2, 3],
/// [4, 5, 6, 7]]);
/// let (s0, s1) = h.multi_slice_mut((s![.., ..;2], s![.., 1..;2]));
/// let i = arr2(&[[0, 2],
/// [4, 6]]);
/// let j = arr2(&[[1, 3],
/// [5, 7]]);
/// assert_eq!(s0, i);
/// assert_eq!(s1, j);
/// }
/// ```
///
Expand Down
103 changes: 102 additions & 1 deletion src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
use crate::dimension::slices_intersect;
use crate::error::{ErrorKind, ShapeError};
use crate::Dimension;
use crate::{ArrayViewMut, Dimension};
use std::fmt;
use std::marker::PhantomData;
use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
Expand Down Expand Up @@ -629,3 +630,103 @@ macro_rules! s(
&*&$crate::s![@parse ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*]
};
);

/// Slicing information describing multiple mutable, disjoint slices.
///
/// It's unfortunate that we need `'a` and `A` to be parameters of the trait,
/// but they're necessary until Rust supports generic associated types.
pub trait MultiSlice<'a, A, D>
where
A: 'a,
D: Dimension,
{
/// The type of the slices created by `.multi_slice_move()`.
type Output;

/// Split the view into multiple disjoint slices.
///
/// **Panics** if performing any individual slice panics or if the slices
/// are not disjoint (i.e. if they intersect).
fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output;
}

impl<'a, A, D> MultiSlice<'a, A, D> for ()
where
A: 'a,
D: Dimension,
{
type Output = ();

fn multi_slice_move(&self, _view: ArrayViewMut<'a, A, D>) -> Self::Output {}
}

impl<'a, A, D, Do0> MultiSlice<'a, A, D> for (&SliceInfo<D::SliceArg, Do0>,)
where
A: 'a,
D: Dimension,
Do0: Dimension,
{
type Output = (ArrayViewMut<'a, A, Do0>,);

fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
(view.slice_move(self.0),)
}
}

macro_rules! impl_multislice_tuple {
([$($but_last:ident)*] $last:ident) => {
impl_multislice_tuple!(@def_impl ($($but_last,)* $last,), [$($but_last)*] $last);
};
(@def_impl ($($all:ident,)*), [$($but_last:ident)*] $last:ident) => {
impl<'a, A, D, $($all,)*> MultiSlice<'a, A, D> for ($(&SliceInfo<D::SliceArg, $all>,)*)
where
A: 'a,
D: Dimension,
$($all: Dimension,)*
{
type Output = ($(ArrayViewMut<'a, A, $all>,)*);

fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
#[allow(non_snake_case)]
let ($($all,)*) = self;

let shape = view.raw_dim();
assert!(!impl_multislice_tuple!(@intersects_self &shape, ($($all,)*)));

let raw_view = view.into_raw_view_mut();
unsafe {
(
$(raw_view.clone().slice_move($but_last).deref_into_view_mut(),)*
raw_view.slice_move($last).deref_into_view_mut(),
)
}
}
}
};
(@intersects_self $shape:expr, ($head:expr,)) => {
false
};
(@intersects_self $shape:expr, ($head:expr, $($tail:expr,)*)) => {
$(slices_intersect($shape, $head, $tail)) ||*
|| impl_multislice_tuple!(@intersects_self $shape, ($($tail,)*))
};
}

impl_multislice_tuple!([Do0] Do1);
impl_multislice_tuple!([Do0 Do1] Do2);
impl_multislice_tuple!([Do0 Do1 Do2] Do3);
impl_multislice_tuple!([Do0 Do1 Do2 Do3] Do4);
impl_multislice_tuple!([Do0 Do1 Do2 Do3 Do4] Do5);

impl<'a, A, D, T> MultiSlice<'a, A, D> for &T
where
A: 'a,
D: Dimension,
T: MultiSlice<'a, A, D>,
{
type Output = T::Output;

fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
T::multi_slice_move(self, view)
}
}
90 changes: 90 additions & 0 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@ use ndarray::{arr3, rcarr2};
use ndarray::{Slice, SliceInfo, SliceOrIndex};
use std::iter::FromIterator;

macro_rules! assert_panics {
($body:expr) => {
if let Ok(v) = ::std::panic::catch_unwind(|| $body) {
panic!("assertion failed: should_panic; \
non-panicking result: {:?}", v);
}
};
($body:expr, $($arg:tt)*) => {
if let Ok(_) = ::std::panic::catch_unwind(|| $body) {
panic!($($arg)*);
}
};
}

#[test]
fn test_matmul_arcarray() {
let mut A = ArcArray::<usize, _>::zeros((2, 3));
Expand Down Expand Up @@ -328,6 +342,82 @@ fn test_slice_collapse_with_indices() {
assert_eq!(vi, Array3::from_elem((1, 1, 1), elem));
}

#[test]
fn test_multislice() {
macro_rules! do_test {
($arr:expr, $($s:expr),*) => {
{
let arr = $arr;
let copy = arr.clone();
assert_eq!(
arr.multi_slice_mut(($($s,)*)),
($(copy.clone().slice_mut($s),)*)
);
}
};
}

let mut arr = Array1::from_iter(0..48).into_shape((8, 6)).unwrap();

assert_eq!(
(arr.clone().view_mut(),),
arr.multi_slice_mut((s![.., ..],)),
);
assert_eq!(arr.multi_slice_mut(()), ());
do_test!(&mut arr, s![0, ..]);
do_test!(&mut arr, s![0, ..], s![1, ..]);
do_test!(&mut arr, s![0, ..], s![-1, ..]);
do_test!(&mut arr, s![0, ..], s![1.., ..]);
do_test!(&mut arr, s![1, ..], s![..;2, ..]);
do_test!(&mut arr, s![..2, ..], s![2.., ..]);
do_test!(&mut arr, s![1..;2, ..], s![..;2, ..]);
do_test!(&mut arr, s![..;-2, ..], s![..;2, ..]);
do_test!(&mut arr, s![..;12, ..], s![3..;3, ..]);
do_test!(&mut arr, s![3, ..], s![..-1;-2, ..]);
do_test!(&mut arr, s![0, ..], s![1, ..], s![2, ..]);
do_test!(&mut arr, s![0, ..], s![1, ..], s![2, ..], s![3, ..]);
}

#[test]
fn test_multislice_intersecting() {
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![3, ..], s![3, ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![3, ..], s![3.., ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![3, ..], s![..;3, ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![..;6, ..], s![3..;3, ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![2, ..], s![..-1;-2, ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![4, ..], s![3, ..], s![3, ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![3, ..], s![4, ..], s![3, ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![3, ..], s![3, ..], s![4, ..]));
});
assert_panics!({
let mut arr = Array2::<u8>::zeros((8, 6));
arr.multi_slice_mut((s![3, ..], s![3, ..], s![4, ..], s![3, ..]));
});
}

#[should_panic]
#[test]
fn index_out_of_bounds() {
Expand Down