diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index 250d1efb0f6800..d17b4518b95532 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -18,15 +18,12 @@ use { /// non-zero weighted indices. #[derive(Clone)] pub struct WeightedShuffle { - arr: Vec, // Underlying array implementing binary indexed tree. + arr: Vec, // Underlying array implementing binary tree. sum: T, // Current sum of weights, excluding already selected indices. + msb: usize, // Most significant bit of indices. zeros: Vec, // Indices of zero weighted entries. } -// The implementation uses binary indexed tree: -// https://en.wikipedia.org/wiki/Fenwick_tree -// to maintain cumulative sum of weights excluding already selected indices -// over self.arr. impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + CheckedAdd, @@ -34,37 +31,43 @@ where /// If weights are negative or overflow the total sum /// they are treated as zero. pub fn new(name: &'static str, weights: &[T]) -> Self { - let size = weights.len() + 1; let zero = ::default(); - let mut arr = vec![zero; size]; + let mut arr = vec![zero; get_tree_size(weights.len())]; + let msb = (arr.len() + 1) >> 2; let mut sum = zero; let mut zeros = Vec::default(); let mut num_negative = 0; let mut num_overflow = 0; - for (mut k, &weight) in (1usize..).zip(weights) { + for (k, &weight) in weights.iter().enumerate() { #[allow(clippy::neg_cmp_op_on_partial_ord)] // weight < zero does not work for NaNs. if !(weight >= zero) { - zeros.push(k - 1); + zeros.push(k); num_negative += 1; continue; } if weight == zero { - zeros.push(k - 1); + zeros.push(k); continue; } sum = match sum.checked_add(&weight) { Some(val) => val, None => { - zeros.push(k - 1); + zeros.push(k); num_overflow += 1; continue; } }; - while k < size { - arr[k] += weight; - k += k & k.wrapping_neg(); - } + let index = get_mask_bits(msb).fold(0, |index, mask| { + (index << 1) + + if k & mask == 0 { + arr[index] += weight; + 1 + } else { + 2 + } + }); + arr[index] = weight } if num_negative > 0 { datapoint_error!("weighted-shuffle-negative", (name, num_negative, i64)); @@ -72,7 +75,12 @@ where if num_overflow > 0 { datapoint_error!("weighted-shuffle-overflow", (name, num_overflow, i64)); } - Self { arr, sum, zeros } + Self { + arr, + sum, + msb, + zeros, + } } } @@ -80,54 +88,47 @@ impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub, { - // Returns cumulative sum of current weights upto index k (inclusive). - fn cumsum(&self, mut k: usize) -> T { - let mut out = ::default(); - while k != 0 { - out += self.arr[k]; - k ^= k & k.wrapping_neg(); - } - out - } - // Removes given weight at index k. - fn remove(&mut self, mut k: usize, weight: T) { + fn remove(&mut self, k: usize, weight: T) { self.sum -= weight; - let size = self.arr.len(); - while k < size { - self.arr[k] -= weight; - k += k & k.wrapping_neg(); - } + let index = get_mask_bits(self.msb).fold(0, |index, mask| { + (index << 1) + + if k & mask == 0 { + self.arr[index] -= weight; + 1 + } else { + 2 + } + }); + self.arr[index] -= weight; } - // Returns smallest index such that self.cumsum(k) > val, + // Returns smallest index such that cumsum of weights[..=k] > val, // along with its respective weight. - fn search(&self, val: T) -> (/*index:*/ usize, /*weight:*/ T) { + fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) { let zero = ::default(); debug_assert!(val >= zero); debug_assert!(val < self.sum); - let mut lo = (/*index:*/ 0, /*cumsum:*/ zero); - let mut hi = (self.arr.len() - 1, self.sum); - while lo.0 + 1 < hi.0 { - let k = lo.0 + (hi.0 - lo.0) / 2; - let sum = self.cumsum(k); - if sum <= val { - lo = (k, sum); + let (index, k) = get_mask_bits(self.msb).fold((0, 0), |(index, k), mask| { + if val >= self.arr[k] { + val -= self.arr[k]; + (index | mask, (k << 1) + 2) } else { - hi = (k, sum); + (index, (k << 1) + 1) } - } - debug_assert!(lo.1 <= val); - debug_assert!(hi.1 > val); - (hi.0, hi.1 - lo.1) + }); + (index, self.arr[k]) } - pub fn remove_index(&mut self, index: usize) { + pub fn remove_index(&mut self, k: usize) { let zero = ::default(); - let weight = self.cumsum(index + 1) - self.cumsum(index); + let index = get_mask_bits(self.msb).fold(0, |index, mask| { + (index << 1) + if k & mask == 0 { 1 } else { 2 } + }); + let weight = self.arr[index]; if weight != zero { - self.remove(index + 1, weight); - } else if let Some(index) = self.zeros.iter().position(|ix| *ix == index) { + self.remove(k, weight); + } else if let Some(index) = self.zeros.iter().position(|ix| *ix == k) { self.zeros.remove(index); } } @@ -143,7 +144,7 @@ where if self.sum > zero { let sample = ::Sampler::sample_single(zero, self.sum, rng); let (index, _weight) = WeightedShuffle::search(self, sample); - return Some(index - 1); + return Some(index); } if self.zeros.is_empty() { return None; @@ -164,7 +165,7 @@ where let sample = ::Sampler::sample_single(zero, self.sum, rng); let (index, weight) = WeightedShuffle::search(&self, sample); self.remove(index, weight); - return Some(index - 1); + return Some(index); } if self.zeros.is_empty() { return None; @@ -176,6 +177,25 @@ where } } +// Maps number of items to the size of the binary tree +// holding those items on the leaves. +fn get_tree_size(count: usize) -> usize { + let shift = usize::BITS - count.leading_zeros() + + if count.is_power_of_two() || count == 0 { + 0 + } else { + 1 + }; + (1usize << shift) - 1 +} + +fn get_mask_bits(msb: usize) -> impl Iterator { + debug_assert!(msb.is_power_of_two() || msb == 0); + std::iter::successors((msb != 0).then_some(msb), |&bit| { + (bit != 1).then_some(bit >> 1) + }) +} + #[cfg(test)] mod tests { use { @@ -218,6 +238,39 @@ mod tests { shuffle } + #[test] + fn test_get_tree_size() { + assert_eq!(get_tree_size(0), 0); + assert_eq!(get_tree_size(1), 1); + assert_eq!(get_tree_size(2), 3); + assert_eq!(get_tree_size(3), 7); + assert_eq!(get_tree_size(4), 7); + for count in 5..9 { + assert_eq!(get_tree_size(count), 15); + } + for count in 9..17 { + assert_eq!(get_tree_size(count), 31); + } + for count in 17..33 { + assert_eq!(get_tree_size(count), 63); + } + assert_eq!(get_tree_size((1 << 16) - 1), (1 << 17) - 1); + assert_eq!(get_tree_size(1 << 16), (1 << 17) - 1); + assert_eq!(get_tree_size((1 << 16) + 1), (1 << 18) - 1); + assert_eq!(get_tree_size((1 << 17) - 1), (1 << 18) - 1); + assert_eq!(get_tree_size(1 << 17), (1 << 18) - 1); + assert_eq!(get_tree_size((1 << 17) + 1), (1 << 19) - 1); + } + + #[test] + fn test_get_mask_bits() { + assert_eq!(get_mask_bits(0).next(), None); + assert_eq!(get_mask_bits(1).collect::>(), [1]); + assert_eq!(get_mask_bits(2).collect::>(), [2, 1]); + assert_eq!(get_mask_bits(4).collect::>(), [4, 2, 1]); + assert_eq!(get_mask_bits(8).collect::>(), [8, 4, 2, 1]); + } + // Asserts that empty weights will return empty shuffle. #[test] fn test_weighted_shuffle_empty_weights() {