Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implements weighted shuffle using N-ary tree #259

Merged
merged 1 commit into from
Mar 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 77 additions & 55 deletions gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ use {
std::ops::{AddAssign, Sub, SubAssign},
};

// Each internal tree node has FANOUT many child nodes with indices:
// (index << BIT_SHIFT) + 1 ..= (index << BIT_SHIFT) + FANOUT
// Conversely, for each node, the parent node is obtained by:
// (index - 1) >> BIT_SHIFT
const BIT_SHIFT: usize = 4;
const FANOUT: usize = 1 << BIT_SHIFT;
const BIT_MASK: usize = FANOUT - 1;

/// Implements an iterator where indices are shuffled according to their
/// weights:
/// - Returned indices are unique in the range [0, weights.len()).
Expand All @@ -18,12 +26,13 @@ use {
/// non-zero weighted indices.
#[derive(Clone)]
pub struct WeightedShuffle<T> {
// Underlying array implementing binary tree.
// tree[i] is the sum of weights in the left sub-tree of node i.
tree: Vec<T>,
// Underlying array implementing the tree.
// tree[i][j] is the sum of all weights in the j'th sub-tree of node i.
tree: Vec<[T; FANOUT - 1]>,
// Current sum of all weights, excluding already sampled ones.
weight: T,
zeros: Vec<usize>, // Indices of zero weighted entries.
// Indices of zero weighted entries.
zeros: Vec<usize>,
}

impl<T> WeightedShuffle<T>
Expand All @@ -34,7 +43,7 @@ where
/// they are treated as zero.
pub fn new(name: &'static str, weights: &[T]) -> Self {
let zero = <T as Default>::default();
let mut tree = vec![zero; get_tree_size(weights.len())];
let mut tree = vec![[zero; FANOUT - 1]; get_tree_size(weights.len())];
let mut sum = zero;
let mut zeros = Vec::default();
let mut num_negative = 0;
Expand All @@ -59,12 +68,14 @@ where
continue;
}
};
let mut index = tree.len() + k;
// Traverse the tree from the leaf node upwards to the root,
// updating the sub-tree sums along the way.
let mut index = tree.len() + k; // leaf node
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
let offset = index & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
if offset > 0 {
tree[index] += weight;
tree[index][offset - 1] += weight;
}
}
}
Expand All @@ -88,54 +99,73 @@ where
{
// Removes given weight at index k.
fn remove(&mut self, k: usize, weight: T) {
debug_assert!(self.weight >= weight);
self.weight -= weight;
let mut index = self.tree.len() + k;
// Traverse the tree from the leaf node upwards to the root,
// updating the sub-tree sums along the way.
let mut index = self.tree.len() + k; // leaf node
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
let offset = index & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
if offset > 0 {
self.tree[index] -= weight;
debug_assert!(self.tree[index][offset - 1] >= weight);
self.tree[index][offset - 1] -= weight;
}
}
}

// Returns smallest index such that cumsum of weights[..=k] > val,
// Returns smallest index such that sum of weights[..=k] > val,
// along with its respective weight.
fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) {
let zero = <T as Default>::default();
debug_assert!(val >= zero);
debug_assert!(val < self.weight);
let mut index = 0;
// Traverse the tree downwards from the root while maintaining the
// weight of the subtree which contains the target leaf node.
let mut index = 0; // root
let mut weight = self.weight;
while index < self.tree.len() {
if val < self.tree[index] {
weight = self.tree[index];
index = (index << 1) + 1;
} else {
weight -= self.tree[index];
val -= self.tree[index];
index = (index << 1) + 2;
'outer: while index < self.tree.len() {
for (j, &node) in self.tree[index].iter().enumerate() {
if val < node {
// Traverse to the j+1 subtree of self.tree[index].
weight = node;
index = (index << BIT_SHIFT) + j + 1;
continue 'outer;
} else {
debug_assert!(weight >= node);
weight -= node;
val -= node;
}
}
// Traverse to the right-most subtree of self.tree[index].
index = (index << BIT_SHIFT) + FANOUT;
}
(index - self.tree.len(), weight)
}

pub fn remove_index(&mut self, k: usize) {
let mut index = self.tree.len() + k;
// Traverse the tree from the leaf node upwards to the root, while
// maintaining the sum of weights of subtrees *not* containing the leaf
// node.
let mut index = self.tree.len() + k; // leaf node
let mut weight = <T as Default>::default(); // zero
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
let offset = index & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
if offset > 0 {
if self.tree[index] != weight {
self.remove(k, self.tree[index] - weight);
if self.tree[index][offset - 1] != weight {
self.remove(k, self.tree[index][offset - 1] - weight);
} else {
self.remove_zero(k);
}
return;
}
weight += self.tree[index];
// The leaf node is in the right-most subtree of self.tree[index].
for &node in &self.tree[index] {
weight += node;
}
}
// The leaf node is the right-most node of the whole tree.
if self.weight != weight {
self.remove(k, self.weight - weight);
} else {
Expand Down Expand Up @@ -193,17 +223,16 @@ where
}
}

// Maps number of items to the "internal" size of the binary tree "implicitly"
// holding those items on the leaves.
// Maps number of items to the "internal" size of the tree
// which "implicitly" holds those items on the leaves.
fn get_tree_size(count: usize) -> usize {
let shift = usize::BITS
- count.leading_zeros()
- if count.is_power_of_two() && count != 1 {
1
} else {
0
};
(1usize << shift) - 1
let mut size = if count == 1 { 1 } else { 0 };
let mut nodes = 1;
while nodes < count {
size += nodes;
nodes *= FANOUT;
}
size
}

#[cfg(test)]
Expand Down Expand Up @@ -251,25 +280,18 @@ mod tests {
#[test]
fn test_get_tree_size() {
assert_eq!(get_tree_size(0), 0);
assert_eq!(get_tree_size(1), 1);
assert_eq!(get_tree_size(2), 1);
assert_eq!(get_tree_size(3), 3);
assert_eq!(get_tree_size(4), 3);
for count in 5..9 {
assert_eq!(get_tree_size(count), 7);
for count in 1..=16 {
assert_eq!(get_tree_size(count), 1);
}
for count in 17..=256 {
assert_eq!(get_tree_size(count), 1 + 16);
}
for count in 9..17 {
assert_eq!(get_tree_size(count), 15);
for count in 257..=4096 {
assert_eq!(get_tree_size(count), 1 + 16 + 16 * 16);
}
for count in 17..33 {
assert_eq!(get_tree_size(count), 31);
for count in 4097..=65536 {
assert_eq!(get_tree_size(count), 1 + 16 + 16 * 16 + 16 * 16 * 16);
}
assert_eq!(get_tree_size((1 << 16) - 1), (1 << 16) - 1);
assert_eq!(get_tree_size(1 << 16), (1 << 16) - 1);
assert_eq!(get_tree_size((1 << 16) + 1), (1 << 17) - 1);
assert_eq!(get_tree_size((1 << 17) - 1), (1 << 17) - 1);
assert_eq!(get_tree_size(1 << 17), (1 << 17) - 1);
assert_eq!(get_tree_size((1 << 17) + 1), (1 << 18) - 1);
}

// Asserts that empty weights will return empty shuffle.
Expand Down
Loading