-
Notifications
You must be signed in to change notification settings - Fork 311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add .split_at() methods for AxisChunksIter/Mut #691
Conversation
IMO, it's easier to understand and work with the implementation of these iterators using `partial_chunk_index` and `partial_chunk_dim` than `n_whole_chunks` and `last_dim`.
/// size due to the axis length not being evenly divisible). If the axis | ||
/// length is evenly divisible by the chunk size, this index is larger than | ||
/// the maximum valid index. | ||
partial_chunk_index: usize, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be beneficial to rephrase this as an Option
, to make it clearer that we might (or might not) have a partial chunk? Something along the lines of:
pub struct AxisChunksIter<'a, A, D> {
iter: AxisIterCore<A, D>,
partial_chunk: Option<PartialChunk>,
life: PhantomData<&'a A>
}
struct PartialChunk {
partial_chunk_index: usize,
partial_chunk_dim: D
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it makes sense to use both the Option
variant and the value of partial_chunk_index
to represent whether or not there's a partial chunk. (The biggest reason is that I prefer data structures where there's a single source of truth, rather than having to keep multiple things in sync. There might also be a small performance cost to accessing partial_chunk_index
through the Option
(since accessing it requires checking whether the Option
is the Some
variant), but we'd need to test to determine if that would really be noticeable.) IMO, putting the fields in an Option
would be additional complication over the current approach without much benefit.
It would be reasonable to eliminate partial_chunk_index
and just use the Option
variant to represent the presence of a partial chunk, like this:
pub struct AxisChunksIter<'a, A, D> {
iter: AxisIterCore<A, D>,
partial_chunk: Option<D>,
life: PhantomData<&'a A>
}
or to always store the shape of the last chunk (regardless of whether or not it's a partial chunk):
pub struct AxisChunksIter<'a, A, D> {
iter: AxisIterCore<A, D>,
last_chunk_dim: D,
life: PhantomData<&'a A>
}
These approaches have two disadvantages since they rely on checking whether the iterator is at its end to handle the partial chunk instead of checking whether the current index is equal to partial_chunk_index
:
-
.split_at()
needs to check whether or not the partial chunk is in the left piece and determinepartial_chunk
orlast_chunk_dim
of the left piece accordingly. (The partial chunk is in the left piece whenindex == self.iter.len()
.) -
.next_back()
needs to setpartial_chunk
toNone
orlast_chunk_dim
toself.iter.inner_dim
each time it's called.
So, I'd rather keep the current approach and add more comments if necessary to make it clear.
}, | ||
Self { | ||
iter: right, | ||
partial_chunk_index: self.partial_chunk_index, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't read the whole code unfortunately (what's not visible in the diff) - why doesn't this partial_chunk_index require adjusting - the right part of the iter now starts at index
, so I'd expect this to be offset by - index
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's an example:
use ndarray::prelude::*;
fn main() {
let a: Array1<i32> = (0..13).collect();
let mut iter = a.axis_chunks_iter(Axis(0), 3);
iter.next(); // skip the first element so that we consider a partially-consumed iterator
println!("before_split = {:#?}", iter);
let (left, right) = iter.split_at(2);
println!("left = {:#?}", left);
println!("right = {:#?}", right);
}
which gives the output
before_split = AxisChunksIter {
iter: AxisIterCore {
index: 1,
end: 5,
stride: 3,
inner_dim: [3],
inner_strides: [1],
ptr: 0x00005634af728b40,
},
partial_chunk_index: 4,
partial_chunk_dim: [1],
life: PhantomData,
}
left = AxisChunksIter {
iter: AxisIterCore {
index: 1,
end: 3,
stride: 3,
inner_dim: [3],
inner_strides: [1],
ptr: 0x00005634af728b40,
},
partial_chunk_index: 4,
partial_chunk_dim: [1],
life: PhantomData,
}
right = AxisChunksIter {
iter: AxisIterCore {
index: 3,
end: 5,
stride: 3,
inner_dim: [3],
inner_strides: [1],
ptr: 0x00005634af728b40,
},
partial_chunk_index: 4,
partial_chunk_dim: [1],
life: PhantomData,
}
We can visualize the situation like this:
0 1 2 3 4
before split: ^ |
after split: ^ |^ |
The ^
s represent the index
es and the |
s represent the end
s of the iterators. (The |
s appear just before the corresponding end
indices.) There are 4 full chunks (indices 0..=3
) and 1 partial chunk (index 4
). Note that all indices are relative to the start of the axis, so any given index value represents the same location before and after the split. This is why partial_chunk_index
is the same before and after splitting. Before splitting, the index of the partial chunk is 4
, and it stays 4
in the split pieces. (The left piece will never actually reach index 4
since its end
is 3
; that's the desired behavior.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. If the only use of index is counting up to the partial_chunk_index, it makes total sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index
is also used in AxisIterCore
(which AxisChunksIter
wraps) to compute the pointer of each element/chunk and to check for the end of the iterator; see AxisIterCore
's implementation of .next()
. (.split_at()
on AxisIterCore
doesn't change the ptr
value; ptr
always corresponds to the start of the axis. This was part of #669.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Remember that I trust your judgment @jturner314. I have read the PR - it's not that 🙂 - I mean that I trust you to review and merge your own PRs, so you can do that when you think it is appropriate (which is probably almost all the time).
Thanks for reviewing this @bluss! I generally like to get a review from someone before merging, but thanks for the vote of confidence. I'm comfortable merging my own PRs without a review when necessary. |
This adds
.split_at()
methods forAxisChunksIter
andAxisChunksIterMut
. Once this is merged, it will be straightforward to implement #639 in terms of.split_at()
.