Skip to content

Commit

Permalink
Merge pull request #691 from jturner314/split-chunks
Browse files Browse the repository at this point in the history
Add .split_at() methods for AxisChunksIter/Mut
  • Loading branch information
jturner314 authored Sep 4, 2019
2 parents f2fd1dc + 4bee214 commit c916203
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 49 deletions.
132 changes: 83 additions & 49 deletions src/iterators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,19 @@ impl<A, D: Dimension> AxisIterCore<A, D> {
};
(left, right)
}

/// Does the same thing as `.next()` but also returns the index of the item
/// relative to the start of the axis.
fn next_with_index(&mut self) -> Option<(usize, *mut A)> {
let index = self.index;
self.next().map(|ptr| (index, ptr))
}

/// Does the same thing as `.next_back()` but also returns the index of the
/// item relative to the start of the axis.
fn next_back_with_index(&mut self) -> Option<(usize, *mut A)> {
self.next_back().map(|ptr| (self.end, ptr))
}
}

impl<A, D> Iterator for AxisIterCore<A, D>
Expand Down Expand Up @@ -1182,9 +1195,13 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
/// See [`.axis_chunks_iter()`](../struct.ArrayBase.html#method.axis_chunks_iter) for more information.
pub struct AxisChunksIter<'a, A, D> {
iter: AxisIterCore<A, D>,
n_whole_chunks: usize,
/// Dimension of the last (and possibly uneven) chunk
last_dim: D,
/// Index of the partial chunk (the chunk smaller than the specified chunk
/// size due to the axis length not being evenly divisible). If the axis
/// length is evenly divisible by the chunk size, this index is larger than
/// the maximum valid index.
partial_chunk_index: usize,
/// Dimension of the partial chunk.
partial_chunk_dim: D,
life: PhantomData<&'a A>,
}

Expand All @@ -1193,10 +1210,10 @@ clone_bounds!(
AxisChunksIter['a, A, D] {
@copy {
life,
n_whole_chunks,
partial_chunk_index,
}
iter,
last_dim,
partial_chunk_dim,
}
);

Expand Down Expand Up @@ -1233,12 +1250,9 @@ fn chunk_iter_parts<A, D: Dimension>(
let mut inner_dim = v.dim.clone();
inner_dim[axis] = size;

let mut last_dim = v.dim;
last_dim[axis] = if chunk_remainder == 0 {
size
} else {
chunk_remainder
};
let mut partial_chunk_dim = v.dim;
partial_chunk_dim[axis] = chunk_remainder;
let partial_chunk_index = n_whole_chunks;

let iter = AxisIterCore {
index: 0,
Expand All @@ -1249,16 +1263,16 @@ fn chunk_iter_parts<A, D: Dimension>(
ptr: v.ptr,
};

(iter, n_whole_chunks, last_dim)
(iter, partial_chunk_index, partial_chunk_dim)
}

impl<'a, A, D: Dimension> AxisChunksIter<'a, A, D> {
pub(crate) fn new(v: ArrayView<'a, A, D>, axis: Axis, size: usize) -> Self {
let (iter, n_whole_chunks, last_dim) = chunk_iter_parts(v, axis, size);
let (iter, partial_chunk_index, partial_chunk_dim) = chunk_iter_parts(v, axis, size);
AxisChunksIter {
iter,
n_whole_chunks,
last_dim,
partial_chunk_index,
partial_chunk_dim,
life: PhantomData,
}
}
Expand All @@ -1270,30 +1284,49 @@ macro_rules! chunk_iter_impl {
where
D: Dimension,
{
fn get_subview(
&self,
iter_item: Option<*mut A>,
is_uneven: bool,
) -> Option<$array<'a, A, D>> {
iter_item.map(|ptr| {
if !is_uneven {
unsafe {
$array::new_(
ptr,
self.iter.inner_dim.clone(),
self.iter.inner_strides.clone(),
)
}
} else {
unsafe {
$array::new_(
ptr,
self.last_dim.clone(),
self.iter.inner_strides.clone(),
)
}
fn get_subview(&self, index: usize, ptr: *mut A) -> $array<'a, A, D> {
if index != self.partial_chunk_index {
unsafe {
$array::new_(
ptr,
self.iter.inner_dim.clone(),
self.iter.inner_strides.clone(),
)
}
} else {
unsafe {
$array::new_(
ptr,
self.partial_chunk_dim.clone(),
self.iter.inner_strides.clone(),
)
}
})
}
}

/// Splits the iterator at index, yielding two disjoint iterators.
///
/// `index` is relative to the current state of the iterator (which is not
/// necessarily the start of the axis).
///
/// **Panics** if `index` is strictly greater than the iterator's remaining
/// length.
pub fn split_at(self, index: usize) -> (Self, Self) {
let (left, right) = self.iter.split_at(index);
(
Self {
iter: left,
partial_chunk_index: self.partial_chunk_index,
partial_chunk_dim: self.partial_chunk_dim.clone(),
life: self.life,
},
Self {
iter: right,
partial_chunk_index: self.partial_chunk_index,
partial_chunk_dim: self.partial_chunk_dim,
life: self.life,
},
)
}
}

Expand All @@ -1304,9 +1337,9 @@ macro_rules! chunk_iter_impl {
type Item = $array<'a, A, D>;

fn next(&mut self) -> Option<Self::Item> {
let res = self.iter.next();
let is_uneven = self.iter.index > self.n_whole_chunks;
self.get_subview(res, is_uneven)
self.iter
.next_with_index()
.map(|(index, ptr)| self.get_subview(index, ptr))
}

fn size_hint(&self) -> (usize, Option<usize>) {
Expand All @@ -1319,9 +1352,9 @@ macro_rules! chunk_iter_impl {
D: Dimension,
{
fn next_back(&mut self) -> Option<Self::Item> {
let is_uneven = self.iter.end > self.n_whole_chunks;
let res = self.iter.next_back();
self.get_subview(res, is_uneven)
self.iter
.next_back_with_index()
.map(|(index, ptr)| self.get_subview(index, ptr))
}
}

Expand All @@ -1342,18 +1375,19 @@ macro_rules! chunk_iter_impl {
/// for more information.
pub struct AxisChunksIterMut<'a, A, D> {
iter: AxisIterCore<A, D>,
n_whole_chunks: usize,
last_dim: D,
partial_chunk_index: usize,
partial_chunk_dim: D,
life: PhantomData<&'a mut A>,
}

impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> {
pub(crate) fn new(v: ArrayViewMut<'a, A, D>, axis: Axis, size: usize) -> Self {
let (iter, len, last_dim) = chunk_iter_parts(v.into_view(), axis, size);
let (iter, partial_chunk_index, partial_chunk_dim) =
chunk_iter_parts(v.into_view(), axis, size);
AxisChunksIterMut {
iter,
n_whole_chunks: len,
last_dim,
partial_chunk_index,
partial_chunk_dim,
life: PhantomData,
}
}
Expand Down
41 changes: 41 additions & 0 deletions tests/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ use itertools::assert_equal;
use itertools::{enumerate, rev};
use std::iter::FromIterator;

macro_rules! assert_panics {
($body:expr) => {
if let Ok(v) = ::std::panic::catch_unwind(|| $body) {
panic!("assertion failed: should_panic; \
non-panicking result: {:?}", v);
}
};
($body:expr, $($arg:tt)*) => {
if let Ok(_) = ::std::panic::catch_unwind(|| $body) {
panic!($($arg)*);
}
};
}

#[test]
fn double_ended() {
let a = ArcArray::linspace(0., 7., 8);
Expand Down Expand Up @@ -585,6 +599,33 @@ fn axis_chunks_iter_zero_axis_len() {
assert!(a.axis_chunks_iter(Axis(0), 5).next().is_none());
}

#[test]
fn axis_chunks_iter_split_at() {
let mut a = Array2::<usize>::zeros((11, 3));
a.iter_mut().enumerate().for_each(|(i, elt)| *elt = i);
for source in &[
a.slice(s![..0, ..]),
a.slice(s![..1, ..]),
a.slice(s![..5, ..]),
a.slice(s![..10, ..]),
a.slice(s![..11, ..]),
a.slice(s![.., ..0]),
] {
let chunks_iter = source.axis_chunks_iter(Axis(0), 5);
let all_chunks: Vec<_> = chunks_iter.clone().collect();
let n_chunks = chunks_iter.len();
assert_eq!(n_chunks, all_chunks.len());
for index in 0..=n_chunks {
let (left, right) = chunks_iter.clone().split_at(index);
assert_eq!(&all_chunks[..index], &left.collect::<Vec<_>>()[..]);
assert_eq!(&all_chunks[index..], &right.collect::<Vec<_>>()[..]);
}
assert_panics!({
chunks_iter.split_at(n_chunks + 1);
});
}
}

#[test]
fn axis_chunks_iter_mut() {
let a = ArcArray::from_iter(0..24);
Expand Down

0 comments on commit c916203

Please sign in to comment.