-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallel Iterator for AxisChunksIter #639
Conversation
I happen to need the same feature! When are we going to merge this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this. Parallel tools are always nice to have.
.into_par_iter() | ||
.map(|x| x.sum()) | ||
.sum(); | ||
println!("{:?}", a.slice(s![..10, ..5])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why you always print a slice of the array? You have 4 tests and you do it 4 times. I don't see the gain in the context of a test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied the tests for AxisIter, which include this println. cargo test
captures stdout for passing tests, so this will only print if the test fails.
src/iterators/mod.rs
Outdated
@@ -1312,6 +1352,26 @@ impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> { | |||
life: PhantomData, | |||
} | |||
} | |||
pub fn split_at(self, index: usize) -> (Self, Self) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that this 19-lines block is exactly identical to the other split_at
above. This may be "normal" (I'm not a Rust expert at all), but it looks wrong! Is there any tool to avoid this? Templating split_chunk_iter
to return the right types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe your idea would work, but this is my first contribution to ndarray, so I chose to mirror the existing code style for AxisIter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nilgoyette I believe it can be done using trait inheritance. AxisIterMut
can be a sub trait of AxisIter
and AxisChunksIterMut
can be a sub trait of AxisChunksIter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can get there with a more lightweight solution: we can have a generic private function and just call it in each of the method. What do you think @nitsky?
src/iterators/mod.rs
Outdated
(AxisIterCore<A, D>, usize, D), | ||
(AxisIterCore<A, D>, usize, D), | ||
) { | ||
let left_n_whole_chunks = index; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation isn't quite right in the case where the axis is not evenly divisible by the chunk size and the specified index
is equal to the number of chunks. Here's a test that fails with this implementation:
#[test]
fn axis_chunks_split_at() {
let mut a = Array2::<usize>::zeros((11, 3));
a.iter_mut().enumerate().for_each(|(i, elt)| *elt = i);
for source in &[
a.slice(s![..1, ..]),
a.slice(s![..5, ..]),
a.slice(s![..10, ..]),
a.slice(s![..11, ..]),
] {
let chunks_iter = source.axis_chunks_iter(Axis(0), 5);
let all_chunks: Vec<_> = chunks_iter.clone().collect();
let n_chunks = chunks_iter.len();
assert_eq!(n_chunks, all_chunks.len());
for index in 0..=n_chunks {
let (left, right) = chunks_iter.clone().split_at(index);
assert_eq!(&all_chunks[..index], &left.collect::<Vec<_>>()[..]);
assert_eq!(&all_chunks[index..], &right.collect::<Vec<_>>()[..]);
}
assert_panics!({
chunks_iter.split_at(n_chunks + 1);
});
}
}
One way to fix the implementation is this:
fn split_chunk_iter<A, D: Dimension>(
iter: AxisIterCore<A, D>,
n_whole_chunks: usize,
last_dim: D,
index: usize,
) -> (
(AxisIterCore<A, D>, usize, D),
(AxisIterCore<A, D>, usize, D),
) {
// Note: `index` is checked to be `<= iter.len` in `iter.split_at(index)`.
if index > n_whole_chunks {
// In this case, the entire iterator stays in the left piece; the right
// piece has length zero.
let (left, right) = iter.split_at(index);
debug_assert_eq!(right.len, 0);
(
(left, n_whole_chunks, last_dim.clone()),
(right, 0, last_dim),
)
} else {
// In this case, the right iterator contains the last chunk (and
// possibly more chunks before it).
let left_n_whole_chunks = index;
let right_n_whole_chunks = n_whole_chunks - left_n_whole_chunks;
let left_last_dim = iter.inner_dim.clone();
let right_last_dim = last_dim;
let (left, right) = iter.split_at(index);
(
(left, left_n_whole_chunks, left_last_dim),
(right, right_n_whole_chunks, right_last_dim),
)
}
}
IMO, a cleaner way to implement this is to keep track not of n_whole_chunks
but of the index corresponding to the partial chunk, and then implement .split_at()
like this:
pub fn split_at(self, index: usize) -> (Self, Self) {
let (left, right) = self.iter.split_at(index);
(
AxisChunksIter {
iter: left,
partial_chunk_index: self.partial_chunk_index,
partial_chunk_dim: self.partial_chunk_dim.clone(),
life: self.life,
},
AxisChunksIter {
iter: right,
partial_chunk_index: self.partial_chunk_index,
partial_chunk_dim: self.partial_chunk_dim,
life: self.life,
},
)
}
Once #669 is merged, I can provide a more complete suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've created #691 to add (hopefully correct) .split_at()
implementations. Once that is merged, it should be straightforward to update this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jturner314 great thanks! I'll update after that is merged.
@jturner314 this is ready for review now. Thanks for taking care of the implementation in #691. |
Hey - sorry for the late notice. From ndarray 0.13, the crate inside ./parallel is deprecated - it will be removed, as soon as we have made the deprecated-marking release of it when ndarray 0.13 itself goes live. The place where this change needs to be made is inside src/parallel and inside tests/par_*.rs |
Maybe there's a point to making the change both places? It makes for a nicer deprecation, but the two modules have already diverged, so I don't think we need to pursue that. |
Hi @bluss, I believe this PR includes changes in both places, with the parallel iterator declared in parallel/src/par.rs and in src/parallel/par.rs. I believe the tests and documentation appear in both places as well. Can you double check and let me know if I’m mistaken? |
@nitsky It looks fine, it changes in both places even if one would be enough. We can take it from here. |
Okay, thanks!! |
Thank you for this 🙂 |
This PR implements rayon parallelization for AxisChunksIter. See #89.