Skip to content

Commit

Permalink
Rollup merge of rust-lang#126879 - the8472:next-chunk-filter-drop, r=…
Browse files Browse the repository at this point in the history
…cuviper

fix Drop items getting leaked in Filter::next_chunk

The optimization only makes sense for non-drop elements anyway. Use the default implementation for items that are Drop instead.

It also simplifies the implementation.

fixes rust-lang#126872
tracking issue rust-lang#98326
  • Loading branch information
matthiaskrgr authored Jun 26, 2024
2 parents 5aedb8a + ff33a66 commit c14a130
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 45 deletions.
90 changes: 45 additions & 45 deletions core/src/iter/adapters/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::iter::{adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedF
use crate::num::NonZero;
use crate::ops::Try;
use core::array;
use core::mem::{ManuallyDrop, MaybeUninit};
use core::mem::MaybeUninit;
use core::ops::ControlFlow;

/// An iterator that filters the elements of `iter` with `predicate`.
Expand All @@ -27,6 +27,42 @@ impl<I, P> Filter<I, P> {
}
}

impl<I, P> Filter<I, P>
where
I: Iterator,
P: FnMut(&I::Item) -> bool,
{
#[inline]
fn next_chunk_dropless<const N: usize>(
&mut self,
) -> Result<[I::Item; N], array::IntoIter<I::Item, N>> {
let mut array: [MaybeUninit<I::Item>; N] = [const { MaybeUninit::uninit() }; N];
let mut initialized = 0;

let result = self.iter.try_for_each(|element| {
let idx = initialized;
// branchless index update combined with unconditionally copying the value even when
// it is filtered reduces branching and dependencies in the loop.
initialized = idx + (self.predicate)(&element) as usize;
// SAFETY: Loop conditions ensure the index is in bounds.
unsafe { array.get_unchecked_mut(idx) }.write(element);

if initialized < N { ControlFlow::Continue(()) } else { ControlFlow::Break(()) }
});

match result {
ControlFlow::Break(()) => {
// SAFETY: The loop above is only explicitly broken when the array has been fully initialized
Ok(unsafe { MaybeUninit::array_assume_init(array) })
}
ControlFlow::Continue(()) => {
// SAFETY: The range is in bounds since the loop breaks when reaching N elements.
Err(unsafe { array::IntoIter::new_unchecked(array, 0..initialized) })
}
}
}
}

#[stable(feature = "core_impl_debug", since = "1.9.0")]
impl<I: fmt::Debug, P> fmt::Debug for Filter<I, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down Expand Up @@ -64,52 +100,16 @@ where
fn next_chunk<const N: usize>(
&mut self,
) -> Result<[Self::Item; N], array::IntoIter<Self::Item, N>> {
let mut array: [MaybeUninit<Self::Item>; N] = [const { MaybeUninit::uninit() }; N];

struct Guard<'a, T> {
array: &'a mut [MaybeUninit<T>],
initialized: usize,
}

impl<T> Drop for Guard<'_, T> {
#[inline]
fn drop(&mut self) {
if const { crate::mem::needs_drop::<T>() } {
// SAFETY: self.initialized is always <= N, which also is the length of the array.
unsafe {
core::ptr::drop_in_place(MaybeUninit::slice_assume_init_mut(
self.array.get_unchecked_mut(..self.initialized),
));
}
}
// avoid codegen for the dead branch
let fun = const {
if crate::mem::needs_drop::<I::Item>() {
array::iter_next_chunk::<I::Item, N>
} else {
Self::next_chunk_dropless::<N>
}
}

let mut guard = Guard { array: &mut array, initialized: 0 };

let result = self.iter.try_for_each(|element| {
let idx = guard.initialized;
guard.initialized = idx + (self.predicate)(&element) as usize;

// SAFETY: Loop conditions ensure the index is in bounds.
unsafe { guard.array.get_unchecked_mut(idx) }.write(element);

if guard.initialized < N { ControlFlow::Continue(()) } else { ControlFlow::Break(()) }
});
};

let guard = ManuallyDrop::new(guard);

match result {
ControlFlow::Break(()) => {
// SAFETY: The loop above is only explicitly broken when the array has been fully initialized
Ok(unsafe { MaybeUninit::array_assume_init(array) })
}
ControlFlow::Continue(()) => {
let initialized = guard.initialized;
// SAFETY: The range is in bounds since the loop breaks when reaching N elements.
Err(unsafe { array::IntoIter::new_unchecked(array, 0..initialized) })
}
}
fun(self)
}

#[inline]
Expand Down
13 changes: 13 additions & 0 deletions core/tests/iter/adapters/filter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use core::iter::*;
use std::rc::Rc;

#[test]
fn test_iterator_filter_count() {
Expand Down Expand Up @@ -50,3 +51,15 @@ fn test_double_ended_filter() {
assert_eq!(it.next().unwrap(), &2);
assert_eq!(it.next_back(), None);
}

#[test]
fn test_next_chunk_does_not_leak() {
let drop_witness: [_; 5] = std::array::from_fn(|_| Rc::new(()));

let v = (0..5).map(|i| drop_witness[i].clone()).collect::<Vec<_>>();
let _ = v.into_iter().filter(|_| false).next_chunk::<1>();

for ref w in drop_witness {
assert_eq!(Rc::strong_count(w), 1);
}
}

0 comments on commit c14a130

Please sign in to comment.