From 269865c6bf395671ff170c6e0cd9b1644de9c1fa Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 22 Jul 2019 16:39:38 -0400 Subject: [PATCH 1/4] Clarify behavior of AxisChunksIter/Mut IMO, it's easier to understand and work with the implementation of these iterators using `partial_chunk_index` and `partial_chunk_dim` than `n_whole_chunks` and `last_dim`. --- src/iterators/mod.rs | 93 +++++++++++++++++++++----------------------- 1 file changed, 44 insertions(+), 49 deletions(-) diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 0f15c4b50..4d11b5577 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -1182,9 +1182,13 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> { /// See [`.axis_chunks_iter()`](../struct.ArrayBase.html#method.axis_chunks_iter) for more information. pub struct AxisChunksIter<'a, A, D> { iter: AxisIterCore, - n_whole_chunks: usize, - /// Dimension of the last (and possibly uneven) chunk - last_dim: D, + /// Index of the partial chunk (the chunk smaller than the specified chunk + /// size due to the axis length not being evenly divisible). If the axis + /// length is evenly divisible by the chunk size, this index is larger than + /// the maximum valid index. + partial_chunk_index: usize, + /// Dimension of the partial chunk. + partial_chunk_dim: D, life: PhantomData<&'a A>, } @@ -1193,10 +1197,10 @@ clone_bounds!( AxisChunksIter['a, A, D] { @copy { life, - n_whole_chunks, + partial_chunk_index, } iter, - last_dim, + partial_chunk_dim, } ); @@ -1233,12 +1237,9 @@ fn chunk_iter_parts( let mut inner_dim = v.dim.clone(); inner_dim[axis] = size; - let mut last_dim = v.dim; - last_dim[axis] = if chunk_remainder == 0 { - size - } else { - chunk_remainder - }; + let mut partial_chunk_dim = v.dim; + partial_chunk_dim[axis] = chunk_remainder; + let partial_chunk_index = n_whole_chunks; let iter = AxisIterCore { index: 0, @@ -1249,16 +1250,16 @@ fn chunk_iter_parts( ptr: v.ptr, }; - (iter, n_whole_chunks, last_dim) + (iter, partial_chunk_index, partial_chunk_dim) } impl<'a, A, D: Dimension> AxisChunksIter<'a, A, D> { pub(crate) fn new(v: ArrayView<'a, A, D>, axis: Axis, size: usize) -> Self { - let (iter, n_whole_chunks, last_dim) = chunk_iter_parts(v, axis, size); + let (iter, partial_chunk_index, partial_chunk_dim) = chunk_iter_parts(v, axis, size); AxisChunksIter { iter, - n_whole_chunks, - last_dim, + partial_chunk_index, + partial_chunk_dim, life: PhantomData, } } @@ -1270,30 +1271,24 @@ macro_rules! chunk_iter_impl { where D: Dimension, { - fn get_subview( - &self, - iter_item: Option<*mut A>, - is_uneven: bool, - ) -> Option<$array<'a, A, D>> { - iter_item.map(|ptr| { - if !is_uneven { - unsafe { - $array::new_( - ptr, - self.iter.inner_dim.clone(), - self.iter.inner_strides.clone(), - ) - } - } else { - unsafe { - $array::new_( - ptr, - self.last_dim.clone(), - self.iter.inner_strides.clone(), - ) - } + fn get_subview(&self, index: usize, ptr: *mut A) -> $array<'a, A, D> { + if index != self.partial_chunk_index { + unsafe { + $array::new_( + ptr, + self.iter.inner_dim.clone(), + self.iter.inner_strides.clone(), + ) } - }) + } else { + unsafe { + $array::new_( + ptr, + self.partial_chunk_dim.clone(), + self.iter.inner_strides.clone(), + ) + } + } } } @@ -1304,9 +1299,8 @@ macro_rules! chunk_iter_impl { type Item = $array<'a, A, D>; fn next(&mut self) -> Option { - let res = self.iter.next(); - let is_uneven = self.iter.index > self.n_whole_chunks; - self.get_subview(res, is_uneven) + let index = self.iter.index; + self.iter.next().map(|ptr| self.get_subview(index, ptr)) } fn size_hint(&self) -> (usize, Option) { @@ -1319,9 +1313,9 @@ macro_rules! chunk_iter_impl { D: Dimension, { fn next_back(&mut self) -> Option { - let is_uneven = self.iter.end > self.n_whole_chunks; - let res = self.iter.next_back(); - self.get_subview(res, is_uneven) + self.iter + .next_back() + .map(|ptr| self.get_subview(self.iter.end, ptr)) } } @@ -1342,18 +1336,19 @@ macro_rules! chunk_iter_impl { /// for more information. pub struct AxisChunksIterMut<'a, A, D> { iter: AxisIterCore, - n_whole_chunks: usize, - last_dim: D, + partial_chunk_index: usize, + partial_chunk_dim: D, life: PhantomData<&'a mut A>, } impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> { pub(crate) fn new(v: ArrayViewMut<'a, A, D>, axis: Axis, size: usize) -> Self { - let (iter, len, last_dim) = chunk_iter_parts(v.into_view(), axis, size); + let (iter, partial_chunk_index, partial_chunk_dim) = + chunk_iter_parts(v.into_view(), axis, size); AxisChunksIterMut { iter, - n_whole_chunks: len, - last_dim, + partial_chunk_index, + partial_chunk_dim, life: PhantomData, } } From af9d94926619c967d0fffd2dafcdb10f48e4d854 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 22 Jul 2019 17:02:06 -0400 Subject: [PATCH 2/4] Move some logic into AxisIterCore --- src/iterators/mod.rs | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 4d11b5577..9c9dbed11 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -825,6 +825,19 @@ impl AxisIterCore { }; (left, right) } + + /// Does the same thing as `.next()` but also returns the index of the item + /// relative to the start of the axis. + fn next_with_index(&mut self) -> Option<(usize, *mut A)> { + let index = self.index; + self.next().map(|ptr| (index, ptr)) + } + + /// Does the same thing as `.next_back()` but also returns the index of the + /// item relative to the start of the axis. + fn next_back_with_index(&mut self) -> Option<(usize, *mut A)> { + self.next_back().map(|ptr| (self.end, ptr)) + } } impl Iterator for AxisIterCore @@ -1299,8 +1312,9 @@ macro_rules! chunk_iter_impl { type Item = $array<'a, A, D>; fn next(&mut self) -> Option { - let index = self.iter.index; - self.iter.next().map(|ptr| self.get_subview(index, ptr)) + self.iter + .next_with_index() + .map(|(index, ptr)| self.get_subview(index, ptr)) } fn size_hint(&self) -> (usize, Option) { @@ -1314,8 +1328,8 @@ macro_rules! chunk_iter_impl { { fn next_back(&mut self) -> Option { self.iter - .next_back() - .map(|ptr| self.get_subview(self.iter.end, ptr)) + .next_back_with_index() + .map(|(index, ptr)| self.get_subview(index, ptr)) } } From b26ec6c49b43170ae00eae0b2932dcaf62319fd7 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 22 Jul 2019 16:40:40 -0400 Subject: [PATCH 3/4] Add split_at methods for AxisChunksIter/Mut --- src/iterators/mod.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 9c9dbed11..882eaaa77 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -1303,6 +1303,31 @@ macro_rules! chunk_iter_impl { } } } + + /// Splits the iterator at index, yielding two disjoint iterators. + /// + /// `index` is relative to the current state of the iterator (which is not + /// necessarily the start of the axis). + /// + /// **Panics** if `index` is strictly greater than the iterator's remaining + /// length. + pub fn split_at(self, index: usize) -> (Self, Self) { + let (left, right) = self.iter.split_at(index); + ( + Self { + iter: left, + partial_chunk_index: self.partial_chunk_index, + partial_chunk_dim: self.partial_chunk_dim.clone(), + life: self.life, + }, + Self { + iter: right, + partial_chunk_index: self.partial_chunk_index, + partial_chunk_dim: self.partial_chunk_dim, + life: self.life, + }, + ) + } } impl<'a, A, D> Iterator for $iter<'a, A, D> From 4bee214a2e229465170642fb65bc81ff15f70d60 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 22 Jul 2019 16:45:31 -0400 Subject: [PATCH 4/4] Add more tests for AxisChunksIter --- tests/iterators.rs | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/iterators.rs b/tests/iterators.rs index 6408b2f8d..325aa9797 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -13,6 +13,20 @@ use itertools::assert_equal; use itertools::{enumerate, rev}; use std::iter::FromIterator; +macro_rules! assert_panics { + ($body:expr) => { + if let Ok(v) = ::std::panic::catch_unwind(|| $body) { + panic!("assertion failed: should_panic; \ + non-panicking result: {:?}", v); + } + }; + ($body:expr, $($arg:tt)*) => { + if let Ok(_) = ::std::panic::catch_unwind(|| $body) { + panic!($($arg)*); + } + }; +} + #[test] fn double_ended() { let a = ArcArray::linspace(0., 7., 8); @@ -585,6 +599,33 @@ fn axis_chunks_iter_zero_axis_len() { assert!(a.axis_chunks_iter(Axis(0), 5).next().is_none()); } +#[test] +fn axis_chunks_iter_split_at() { + let mut a = Array2::::zeros((11, 3)); + a.iter_mut().enumerate().for_each(|(i, elt)| *elt = i); + for source in &[ + a.slice(s![..0, ..]), + a.slice(s![..1, ..]), + a.slice(s![..5, ..]), + a.slice(s![..10, ..]), + a.slice(s![..11, ..]), + a.slice(s![.., ..0]), + ] { + let chunks_iter = source.axis_chunks_iter(Axis(0), 5); + let all_chunks: Vec<_> = chunks_iter.clone().collect(); + let n_chunks = chunks_iter.len(); + assert_eq!(n_chunks, all_chunks.len()); + for index in 0..=n_chunks { + let (left, right) = chunks_iter.clone().split_at(index); + assert_eq!(&all_chunks[..index], &left.collect::>()[..]); + assert_eq!(&all_chunks[index..], &right.collect::>()[..]); + } + assert_panics!({ + chunks_iter.split_at(n_chunks + 1); + }); + } +} + #[test] fn axis_chunks_iter_mut() { let a = ArcArray::from_iter(0..24);