Skip to content

Commit

Permalink
Merge pull request #34 from contain-rs/fast-count
Browse files Browse the repository at this point in the history
implement faster counting for set operations
  • Loading branch information
pczarn authored Jul 3, 2024
2 parents d538d36 + eecddb3 commit 43660ac
Showing 1 changed file with 55 additions and 8 deletions.
63 changes: 55 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,14 +496,14 @@ impl<B: BitBlock> BitSet<B> {
}
let min = cmp::min(self.bit_vec.len(), other.bit_vec.len());

Intersection(
BlockIter::from_blocks(TwoBitPositions {
Intersection {
iter: BlockIter::from_blocks(TwoBitPositions {
set: self.bit_vec.blocks(),
other: other.bit_vec.blocks(),
merge: bitand,
})
.take(min),
)
}),
n: min,
}
}

/// Iterator over each usize stored in the `self` setminus `other`.
Expand Down Expand Up @@ -884,7 +884,14 @@ pub struct Iter<'a, B: 'a>(BlockIter<Blocks<'a, B>, B>);
#[derive(Clone)]
pub struct Union<'a, B: 'a>(BlockIter<TwoBitPositions<'a, B>, B>);
#[derive(Clone)]
pub struct Intersection<'a, B: 'a>(Take<BlockIter<TwoBitPositions<'a, B>, B>>);
pub struct Intersection<'a, B: 'a> {
iter: BlockIter<TwoBitPositions<'a, B>, B>,
// as an optimization, we compute the maximum possible
// number of elements in the intersection, and count it
// down as we return elements. If we reach zero, we can
// stop.
n: usize,
}
#[derive(Clone)]
pub struct Difference<'a, B: 'a>(BlockIter<TwoBitPositions<'a, B>, B>);
#[derive(Clone)]
Expand Down Expand Up @@ -916,6 +923,10 @@ where
Some(self.head_offset + (B::count_ones(k)))
}

fn count(self) -> usize {
self.head.count_ones() + self.tail.map(|block| block.count_ones()).sum::<usize>()
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
match self.tail.size_hint() {
Expand Down Expand Up @@ -962,6 +973,10 @@ impl<'a, B: BitBlock> Iterator for Iter<'a, B> {
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
#[inline]
fn count(self) -> usize {
self.0.count()
}
}

impl<'a, B: BitBlock> Iterator for Union<'a, B> {
Expand All @@ -975,18 +990,36 @@ impl<'a, B: BitBlock> Iterator for Union<'a, B> {
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
#[inline]
fn count(self) -> usize {
self.0.count()
}
}

impl<'a, B: BitBlock> Iterator for Intersection<'a, B> {
type Item = usize;

#[inline]
fn next(&mut self) -> Option<usize> {
self.0.next()
if self.n != 0 {
self.n -= 1;
self.iter.next()
} else {
None
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
// We could invoke self.iter.size_hint() and incorporate that into the hint.
// In practice, that does not seem worthwhile because the lower bound will
// always be zero and the upper bound could only possibly less then n in a
// partially iterated iterator. However, it makes little sense ask for size_hint
// in a partially iterated iterator, so it did not seem worthwhile.
(0, Some(self.n))
}
#[inline]
fn count(self) -> usize {
self.iter.count()
}
}

Expand All @@ -1001,6 +1034,10 @@ impl<'a, B: BitBlock> Iterator for Difference<'a, B> {
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
#[inline]
fn count(self) -> usize {
self.0.count()
}
}

impl<'a, B: BitBlock> Iterator for SymmetricDifference<'a, B> {
Expand All @@ -1014,6 +1051,10 @@ impl<'a, B: BitBlock> Iterator for SymmetricDifference<'a, B> {
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
#[inline]
fn count(self) -> usize {
self.0.count()
}
}

impl<'a, B: BitBlock> IntoIterator for &'a BitSet<B> {
Expand Down Expand Up @@ -1061,12 +1102,14 @@ mod tests {

let idxs: Vec<_> = bit_vec.iter().collect();
assert_eq!(idxs, [0, 2, 3]);
assert_eq!(bit_vec.iter().count(), 3);

let long: BitSet = (0..10000).filter(|&n| n % 2 == 0).collect();
let real: Vec<_> = (0..10000 / 2).map(|x| x * 2).collect();

let idxs: Vec<_> = long.iter().collect();
assert_eq!(idxs, real);
assert_eq!(long.iter().count(), real.len());
}

#[test]
Expand Down Expand Up @@ -1132,6 +1175,7 @@ mod tests {
let expected = [3, 5, 11, 77];
let actual: Vec<_> = a.intersection(&b).collect();
assert_eq!(actual, expected);
assert_eq!(a.intersection(&b).count(), expected.len());
}

#[test]
Expand All @@ -1151,6 +1195,7 @@ mod tests {
let expected = [1, 5, 500];
let actual: Vec<_> = a.difference(&b).collect();
assert_eq!(actual, expected);
assert_eq!(a.difference(&b).count(), expected.len());
}

#[test]
Expand All @@ -1172,6 +1217,7 @@ mod tests {
let expected = [1, 5, 11, 14, 220];
let actual: Vec<_> = a.symmetric_difference(&b).collect();
assert_eq!(actual, expected);
assert_eq!(a.symmetric_difference(&b).count(), expected.len());
}

#[test]
Expand All @@ -1197,6 +1243,7 @@ mod tests {
let expected = [1, 3, 5, 9, 11, 13, 19, 24, 160, 200];
let actual: Vec<_> = a.union(&b).collect();
assert_eq!(actual, expected);
assert_eq!(a.union(&b).count(), expected.len());
}

#[test]
Expand Down

0 comments on commit 43660ac

Please sign in to comment.