diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index bd6b0e40..47f17e3d 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -1,4 +1,4 @@ -name: Clippy and Tests +name: Clippy checks on: push: @@ -14,7 +14,7 @@ jobs: steps: - name: Checkout project - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install stable toolchain uses: dtolnay/rust-toolchain@stable @@ -29,34 +29,3 @@ jobs: run: | cargo install cargo-rdme cargo rdme --check --no-fail-on-warnings - - tests: - name: Tests - runs-on: ubuntu-latest - - steps: - - name: Checkout project - uses: actions/checkout@v3 - - - name: Install stable toolchain - uses: dtolnay/rust-toolchain@stable - - - name: Cache dependencies - uses: Swatinem/rust-cache@v2 - - - name: Install nextest - uses: taiki-e/install-action@nextest - - - name: Run tests on default features - run: cargo nextest run --no-default-features --no-fail-fast --failure-output=immediate-final - - - name: Run tests with sync feature - run: > - cargo nextest run - --features sync - --filter-expr 'test(util::sync::tests::share_gradual_taiko)' - --filter-expr 'test(taiko::difficulty::gradual::tests::next_and_nth)' - --no-fail-fast --failure-output=immediate-final - - - name: Run doctests - run: cargo test --no-default-features --doc diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..22f39428 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,101 @@ +name: Tests + +on: + push: + branches: + - main + - next + pull_request: + +jobs: + doc: + name: Doc tests + runs-on: ubuntu-latest + + steps: + - name: Checkout project + uses: actions/checkout@v4 + + - name: Install stable toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + + - name: Run doctests + run: > + cargo test + --doc + --no-default-features --features compact_strains + + default: + name: Default tests + runs-on: ubuntu-latest + + steps: + - name: Checkout project + uses: actions/checkout@v4 + + - name: Install stable toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + + - name: Install nextest + uses: taiki-e/install-action@nextest + + - name: Run all tests + run: > + cargo nextest run + --no-default-features --features compact_strains + --no-fail-fast --failure-output=immediate-final + + sync: + name: Test sync feature + runs-on: ubuntu-latest + + steps: + - name: Checkout project + uses: actions/checkout@v4 + + - name: Install stable toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + + - name: Install nextest + uses: taiki-e/install-action@nextest + + - name: Run specific tests + run: > + cargo nextest run + --features sync + --filter-expr 'test(util::sync::tests::share_gradual_taiko)' + --filter-expr 'test(taiko::difficulty::gradual::tests::next_and_nth)' + --no-fail-fast --failure-output=immediate-final + + non_compact: + name: Test without compact_strains feature + runs-on: ubuntu-latest + + steps: + - name: Checkout project + uses: actions/checkout@v4 + + - name: Install stable toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + + - name: Install nextest + uses: taiki-e/install-action@nextest + + - name: Run integration tests + run: > + cargo nextest run + --no-default-features + --test '*' + --no-fail-fast --failure-output=immediate-final diff --git a/Cargo.toml b/Cargo.toml index 683b4e3b..0ef3fbd6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,8 @@ description = "Difficulty and performance calculation for osu!" keywords = ["osu", "pp", "stars", "performance", "osu!"] [features] -default = [] +default = ["compact_strains"] +compact_strains = [] sync = [] tracing = ["rosu-map/tracing"] diff --git a/README.md b/README.md index fec84865..dde2e4a7 100644 --- a/README.md +++ b/README.md @@ -125,11 +125,12 @@ Calculating performances: Median: 44.13µs | Mean: 45.53µs ### Features -| Flag | Description | Dependencies -| --------- | ----------- | ------------ -| `default` | No features | -| `sync` | Some gradual calculation types can only be shared across threads if this feature is enabled. This adds a performance penalty so only enable this if really needed. | -| `tracing` | Any error encountered during beatmap decoding will be logged through `tracing::error`. If this feature is **not** enabled, errors will be ignored. | [`tracing`] +| Flag | Description | Dependencies +| ----------------- | ------------------------------------- | ------------ +| `default` | Enables the `compact_strains` feature | +| `compact_strains` | Storing internal strain values in a plain Vec introduces an out-of-memory risk on maliciously long maps (see [/b/3739922](https://osu.ppy.sh/b/3739922)). This feature stores strains more compactly, but comes with a ~5% loss in performance. | +| `sync` | Some gradual calculation types can only be shared across threads if this feature is enabled. This adds a performance penalty so only enable this if really needed. | +| `tracing` | Any error encountered during beatmap decoding will be logged through `tracing::error`. If this feature is **not** enabled, errors will be ignored. | [`tracing`] ### Bindings diff --git a/src/any/difficulty/skills.rs b/src/any/difficulty/skills.rs index 6f0b6e02..916d8232 100644 --- a/src/any/difficulty/skills.rs +++ b/src/any/difficulty/skills.rs @@ -1,3 +1,5 @@ +use crate::util::strains_vec::StrainsVec; + pub fn strain_decay(ms: f64, strain_decay_base: f64) -> f64 { strain_decay_base.powf(ms / 1000.0) } @@ -27,7 +29,7 @@ pub trait ISkill { pub struct StrainSkill { pub curr_section_peak: f64, pub curr_section_end: f64, - pub strain_peaks: Vec, + pub strain_peaks: StrainsVec, } impl Default for StrainSkill { @@ -36,7 +38,7 @@ impl Default for StrainSkill { curr_section_peak: 0.0, curr_section_end: 0.0, // mean=386.81 | median=279 - strain_peaks: Vec::with_capacity(256), + strain_peaks: StrainsVec::with_capacity(256), } } } @@ -53,7 +55,7 @@ impl StrainSkill { self.curr_section_peak = initial_strain; } - pub fn get_curr_strain_peaks(self) -> Vec { + pub fn get_curr_strain_peaks(self) -> StrainsVec { let mut strain_peaks = self.strain_peaks; strain_peaks.push(self.curr_section_peak); @@ -65,10 +67,8 @@ impl StrainSkill { let mut weight = 1.0; let mut peaks = self.get_curr_strain_peaks(); - peaks.retain(|&strain| strain > 0.0); - peaks.sort_by(|a, b| b.total_cmp(a)); - for strain in peaks { + for strain in peaks.sorted_non_zero_iter() { difficulty += strain * weight; weight *= decay_weight; } @@ -95,7 +95,7 @@ impl StrainDecaySkill { self.inner.start_new_section_from(initial_strain); } - pub fn get_curr_strain_peaks(self) -> Vec { + pub fn get_curr_strain_peaks(self) -> StrainsVec { self.inner.get_curr_strain_peaks() } diff --git a/src/catch/difficulty/skills/movement.rs b/src/catch/difficulty/skills/movement.rs index 13d2a5fe..4300abd2 100644 --- a/src/catch/difficulty/skills/movement.rs +++ b/src/catch/difficulty/skills/movement.rs @@ -4,6 +4,7 @@ use crate::{ skills::{strain_decay, ISkill, Skill, StrainDecaySkill}, }, catch::difficulty::object::CatchDifficultyObject, + util::strains_vec::StrainsVec, }; const ABSOLUTE_PLAYER_POSITIONING_ERROR: f32 = 16.0; @@ -107,7 +108,7 @@ impl Movement { dist_addition / weighted_strain_time } - pub fn get_curr_strain_peaks(self) -> Vec { + pub fn get_curr_strain_peaks(self) -> StrainsVec { self.inner.get_curr_strain_peaks() } diff --git a/src/catch/strains.rs b/src/catch/strains.rs index da3822d5..ff133417 100644 --- a/src/catch/strains.rs +++ b/src/catch/strains.rs @@ -20,6 +20,6 @@ pub fn strains(difficulty: &Difficulty, converted: &CatchBeatmap<'_>) -> CatchSt let DifficultyValues { movement, .. } = DifficultyValues::calculate(difficulty, converted); CatchStrains { - movement: movement.get_curr_strain_peaks(), + movement: movement.get_curr_strain_peaks().into_vec(), } } diff --git a/src/lib.rs b/src/lib.rs index 5401a349..3d72fdd0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -121,11 +121,12 @@ //! //! ## Features //! -//! | Flag | Description | Dependencies -//! | --------- | ----------- | ------------ -//! | `default` | No features | -//! | `sync` | Some gradual calculation types can only be shared across threads if this feature is enabled. This adds a performance penalty so only enable this if really needed. | -//! | `tracing` | Any error encountered during beatmap decoding will be logged through `tracing::error`. If this feature is **not** enabled, errors will be ignored. | [`tracing`] +//! | Flag | Description | Dependencies +//! | ----------------- | ------------------------------------- | ------------ +//! | `default` | Enables the `compact_strains` feature | +//! | `compact_strains` | Storing internal strain values in a plain Vec introduces an out-of-memory risk on maliciously long maps (see [/b/3739922](https://osu.ppy.sh/b/3739922)). This feature stores strains more compactly, but comes with a ~5% loss in performance. | +//! | `sync` | Some gradual calculation types can only be shared across threads if this feature is enabled. This adds a performance penalty so only enable this if really needed. | +//! | `tracing` | Any error encountered during beatmap decoding will be logged through `tracing::error`. If this feature is **not** enabled, errors will be ignored. | [`tracing`] //! //! ## Bindings //! diff --git a/src/mania/difficulty/skills/strain.rs b/src/mania/difficulty/skills/strain.rs index b698cbc8..ab4a139d 100644 --- a/src/mania/difficulty/skills/strain.rs +++ b/src/mania/difficulty/skills/strain.rs @@ -4,6 +4,7 @@ use crate::{ skills::{strain_decay, ISkill, Skill, StrainDecaySkill}, }, mania::difficulty::object::ManiaDifficultyObject, + util::strains_vec::StrainsVec, }; const INDIVIDUAL_DECAY_BASE: f64 = 0.125; @@ -37,7 +38,7 @@ impl Strain { } } - pub fn get_curr_strain_peaks(self) -> Vec { + pub fn get_curr_strain_peaks(self) -> StrainsVec { self.inner.get_curr_strain_peaks() } diff --git a/src/mania/strains.rs b/src/mania/strains.rs index 86d2c5b5..ad9b5135 100644 --- a/src/mania/strains.rs +++ b/src/mania/strains.rs @@ -20,6 +20,6 @@ pub fn strains(difficulty: &Difficulty, converted: &ManiaBeatmap<'_>) -> ManiaSt let values = DifficultyValues::calculate(difficulty, converted); ManiaStrains { - strains: values.strain.get_curr_strain_peaks(), + strains: values.strain.get_curr_strain_peaks().into_vec(), } } diff --git a/src/osu/difficulty/skills/aim.rs b/src/osu/difficulty/skills/aim.rs index 8ffd9cf8..2ceff6f6 100644 --- a/src/osu/difficulty/skills/aim.rs +++ b/src/osu/difficulty/skills/aim.rs @@ -6,7 +6,7 @@ use crate::{ skills::{strain_decay, ISkill, Skill}, }, osu::difficulty::object::OsuDifficultyObject, - util::float_ext::FloatExt, + util::{float_ext::FloatExt, strains_vec::StrainsVec}, }; use super::strain::OsuStrainSkill; @@ -30,7 +30,7 @@ impl Aim { } } - pub fn get_curr_strain_peaks(self) -> Vec { + pub fn get_curr_strain_peaks(self) -> StrainsVec { self.inner.get_curr_strain_peaks() } diff --git a/src/osu/difficulty/skills/flashlight.rs b/src/osu/difficulty/skills/flashlight.rs index 0ac9bf58..afdc964d 100644 --- a/src/osu/difficulty/skills/flashlight.rs +++ b/src/osu/difficulty/skills/flashlight.rs @@ -6,7 +6,7 @@ use crate::{ skills::{strain_decay, ISkill, Skill, StrainSkill}, }, osu::{difficulty::object::OsuDifficultyObject, object::OsuObjectKind}, - util::mods::Mods, + util::{mods::Mods, strains_vec::StrainsVec}, }; use super::strain::OsuStrainSkill; @@ -33,7 +33,7 @@ impl Flashlight { } } - pub fn get_curr_strain_peaks(self) -> Vec { + pub fn get_curr_strain_peaks(self) -> StrainsVec { self.inner.get_curr_strain_peaks() } @@ -48,8 +48,7 @@ impl Flashlight { } fn static_difficulty_value(skill: StrainSkill) -> f64 { - skill.get_curr_strain_peaks().into_iter().sum::() - * OsuStrainSkill::DIFFICULTY_MULTIPLER + skill.get_curr_strain_peaks().sum() * OsuStrainSkill::DIFFICULTY_MULTIPLER } } diff --git a/src/osu/difficulty/skills/speed.rs b/src/osu/difficulty/skills/speed.rs index 80e716f4..7b4260e0 100644 --- a/src/osu/difficulty/skills/speed.rs +++ b/src/osu/difficulty/skills/speed.rs @@ -6,6 +6,7 @@ use crate::{ skills::{strain_decay, ISkill, Skill}, }, osu::difficulty::object::OsuDifficultyObject, + util::strains_vec::StrainsVec, }; use super::strain::OsuStrainSkill; @@ -37,7 +38,7 @@ impl Speed { } } - pub fn get_curr_strain_peaks(self) -> Vec { + pub fn get_curr_strain_peaks(self) -> StrainsVec { self.inner.get_curr_strain_peaks() } diff --git a/src/osu/difficulty/skills/strain.rs b/src/osu/difficulty/skills/strain.rs index 9cf107f9..48045494 100644 --- a/src/osu/difficulty/skills/strain.rs +++ b/src/osu/difficulty/skills/strain.rs @@ -1,4 +1,4 @@ -use crate::any::difficulty::skills::StrainSkill; +use crate::{any::difficulty::skills::StrainSkill, util::strains_vec::StrainsVec}; #[derive(Clone, Default)] pub struct OsuStrainSkill { @@ -21,7 +21,7 @@ impl OsuStrainSkill { self.inner.start_new_section_from(initial_strain); } - pub fn get_curr_strain_peaks(self) -> Vec { + pub fn get_curr_strain_peaks(self) -> StrainsVec { self.inner.get_curr_strain_peaks() } @@ -36,10 +36,8 @@ impl OsuStrainSkill { let mut weight = 1.0; let mut peaks = self.get_curr_strain_peaks(); - peaks.retain(|&strain| strain > 0.0); - peaks.sort_by(|a, b| b.total_cmp(a)); - let peaks_iter = peaks.iter_mut().take(reduced_section_count); + let peaks_iter = peaks.sorted_non_zero_iter_mut().take(reduced_section_count); for (i, strain) in peaks_iter.enumerate() { let clamped = f64::from((i as f32 / reduced_section_count as f32).clamp(0.0, 1.0)); @@ -47,9 +45,9 @@ impl OsuStrainSkill { *strain *= lerp(reduced_strain_baseline, 1.0, scale); } - peaks.sort_by(|a, b| b.total_cmp(a)); + peaks.sort_desc(); - for strain in peaks { + for strain in peaks.iter() { difficulty += strain * weight; weight *= decay_weight; } diff --git a/src/osu/strains.rs b/src/osu/strains.rs index fe602e6c..a9b0ef38 100644 --- a/src/osu/strains.rs +++ b/src/osu/strains.rs @@ -38,9 +38,9 @@ pub fn strains(difficulty: &Difficulty, converted: &OsuBeatmap<'_>) -> OsuStrain } = DifficultyValues::calculate(difficulty, converted); OsuStrains { - aim: aim.get_curr_strain_peaks(), - aim_no_sliders: aim_no_sliders.get_curr_strain_peaks(), - speed: speed.get_curr_strain_peaks(), - flashlight: flashlight.get_curr_strain_peaks(), + aim: aim.get_curr_strain_peaks().into_vec(), + aim_no_sliders: aim_no_sliders.get_curr_strain_peaks().into_vec(), + speed: speed.get_curr_strain_peaks().into_vec(), + flashlight: flashlight.get_curr_strain_peaks().into_vec(), } } diff --git a/src/taiko/difficulty/skills/color.rs b/src/taiko/difficulty/skills/color.rs index bd3e70fb..4d4c52a0 100644 --- a/src/taiko/difficulty/skills/color.rs +++ b/src/taiko/difficulty/skills/color.rs @@ -12,7 +12,10 @@ use crate::{ }, object::{TaikoDifficultyObject, TaikoDifficultyObjects}, }, - util::sync::{RefCount, Weak}, + util::{ + strains_vec::StrainsVec, + sync::{RefCount, Weak}, + }, }; const SKILL_MULTIPLIER: f64 = 0.12; @@ -43,7 +46,7 @@ impl Color { ColorEvaluator::evaluate_diff_of(curr) } - pub fn get_curr_strain_peaks(self) -> Vec { + pub fn get_curr_strain_peaks(self) -> StrainsVec { self.inner.get_curr_strain_peaks() } diff --git a/src/taiko/difficulty/skills/peaks.rs b/src/taiko/difficulty/skills/peaks.rs index 6c5b9dcd..3d8d900d 100644 --- a/src/taiko/difficulty/skills/peaks.rs +++ b/src/taiko/difficulty/skills/peaks.rs @@ -61,9 +61,8 @@ impl Peaks { let zip = color_peaks .iter() - .copied() - .zip(rhythm_peaks.iter().copied()) - .zip(stamina_peaks.iter().copied()); + .zip(rhythm_peaks.iter()) + .zip(stamina_peaks.iter()); for ((mut color_peak, mut rhythm_peak), mut stamina_peak) in zip { color_peak *= COLOR_SKILL_MULTIPLIER; diff --git a/src/taiko/difficulty/skills/rhythm.rs b/src/taiko/difficulty/skills/rhythm.rs index d44e1b4f..b48388ec 100644 --- a/src/taiko/difficulty/skills/rhythm.rs +++ b/src/taiko/difficulty/skills/rhythm.rs @@ -12,7 +12,7 @@ use crate::{ }, object::HitType, }, - util::{float_ext::FloatExt, limited_queue::LimitedQueue}, + util::{float_ext::FloatExt, limited_queue::LimitedQueue, strains_vec::StrainsVec}, }; const SKILL_MULTIPLIER: f64 = 10.0; @@ -142,7 +142,7 @@ impl Rhythm { self.curr_strain() } - pub fn get_curr_strain_peaks(self) -> Vec { + pub fn get_curr_strain_peaks(self) -> StrainsVec { self.inner.get_curr_strain_peaks() } diff --git a/src/taiko/difficulty/skills/stamina.rs b/src/taiko/difficulty/skills/stamina.rs index 0931e44b..7b49c29d 100644 --- a/src/taiko/difficulty/skills/stamina.rs +++ b/src/taiko/difficulty/skills/stamina.rs @@ -7,6 +7,7 @@ use crate::{ difficulty::object::{TaikoDifficultyObject, TaikoDifficultyObjects}, object::HitType, }, + util::strains_vec::StrainsVec, }; const SKILL_MULTIPLIER: f64 = 1.1; @@ -18,7 +19,7 @@ pub struct Stamina { } impl Stamina { - pub fn get_curr_strain_peaks(self) -> Vec { + pub fn get_curr_strain_peaks(self) -> StrainsVec { self.inner.get_curr_strain_peaks() } diff --git a/src/taiko/strains.rs b/src/taiko/strains.rs index 519ea143..4807bbb2 100644 --- a/src/taiko/strains.rs +++ b/src/taiko/strains.rs @@ -24,8 +24,8 @@ pub fn strains(difficulty: &Difficulty, converted: &TaikoBeatmap<'_>) -> TaikoSt let values = DifficultyValues::calculate(difficulty, converted); TaikoStrains { - color: values.peaks.color.get_curr_strain_peaks(), - rhythm: values.peaks.rhythm.get_curr_strain_peaks(), - stamina: values.peaks.stamina.get_curr_strain_peaks(), + color: values.peaks.color.get_curr_strain_peaks().into_vec(), + rhythm: values.peaks.rhythm.get_curr_strain_peaks().into_vec(), + stamina: values.peaks.stamina.get_curr_strain_peaks().into_vec(), } } diff --git a/src/util/mod.rs b/src/util/mod.rs index b3d3af0e..4514eb98 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -5,4 +5,5 @@ pub mod map_or_attrs; pub mod mods; pub mod random; pub mod sort; +pub mod strains_vec; pub mod sync; diff --git a/src/util/strains_vec.rs b/src/util/strains_vec.rs new file mode 100644 index 00000000..5fd844a5 --- /dev/null +++ b/src/util/strains_vec.rs @@ -0,0 +1,391 @@ +pub use inner::*; + +#[cfg(feature = "compact_strains")] +mod inner { + use std::{iter::Copied, slice::Iter}; + + use self::entry::StrainsEntry; + + /// A specialized `Vec` where all entries must be non-negative. + /// + /// It is compact in the sense that zeros are not stored directly but instead + /// as amount of times they appear consecutively. + /// + /// For cases with few consecutive zeros, this type generally reduces + /// performance slightly. However, for edge cases like `/b/3739922` the length + /// of the list is massively reduced, preventing out-of-memory issues. + #[derive(Clone)] + pub struct StrainsVec { + inner: Vec, + len: usize, + #[cfg(debug_assertions)] + // Ensures that methods are used correctly + has_zero: bool, + } + + impl StrainsVec { + /// Constructs a new, empty [`StrainsVec`] with at least the specified + /// capacity. + pub fn with_capacity(capacity: usize) -> Self { + Self { + inner: Vec::with_capacity(capacity), + len: 0, + #[cfg(debug_assertions)] + has_zero: false, + } + } + + /// Returns the number of elements. + pub const fn len(&self) -> usize { + self.len + } + + /// Appends an element to the back. + pub fn push(&mut self, value: f64) { + if value.to_bits() > 0 { + self.inner.push(StrainsEntry::new_value(value)); + } else if let Some(last) = self.inner.last_mut().filter(|e| e.is_zero()) { + last.incr_zero_count(); + } else { + self.inner.push(StrainsEntry::new_zero()); + + #[cfg(debug_assertions)] + { + self.has_zero = true; + } + } + + self.len += 1; + } + + /// Sorts the entries in descending order. + pub fn sort_desc(&mut self) { + #[cfg(debug_assertions)] + debug_assert!(!self.has_zero); + + self.inner.sort_by(|a, b| b.value().total_cmp(&a.value())); + } + + /// Removes all zero entries + pub fn retain_non_zero(&mut self) { + self.inner.retain(StrainsEntry::is_value); + + #[cfg(debug_assertions)] + { + self.has_zero = false; + } + } + + /// Removes all zeros and sorts the remaining entries in descending order. + pub fn retain_non_zero_and_sort(&mut self) { + self.retain_non_zero(); + self.sort_desc(); + } + + /// Iterator over the raw entries, assuming that there are no zeros. + /// + /// Panics if there are zeros. + pub fn non_zero_iter(&self) -> impl ExactSizeIterator + '_ { + #[cfg(debug_assertions)] + debug_assert!(!self.has_zero); + + self.inner.iter().copied().map(StrainsEntry::value) + } + + /// Same as [`StrainsVec::retain_non_zero_and_sort`] followed by + /// [`StrainsVec::iter`] but the resulting iterator is faster + /// because it doesn't need to check whether entries are zero. + pub fn sorted_non_zero_iter(&mut self) -> impl ExactSizeIterator + '_ { + self.retain_non_zero_and_sort(); + + self.non_zero_iter() + } + + /// Removes all zeros, sorts the remaining entries in descending order, and + /// returns an iterator over mutable references to the values. + pub fn sorted_non_zero_iter_mut(&mut self) -> impl ExactSizeIterator { + self.retain_non_zero_and_sort(); + + self.inner.iter_mut().map(StrainsEntry::as_value_mut) + } + + /// Sum up all values. + pub fn sum(&self) -> f64 { + self.inner + .iter() + .copied() + .filter(StrainsEntry::is_value) + .fold(0.0, |sum, e| sum + e.value()) + } + + /// Returns an iterator over the [`StrainsVec`]. + pub fn iter(&self) -> StrainsIter<'_> { + StrainsIter::new(self) + } + + /// Allocates a new `Vec` to store all values, including zeros. + pub fn into_vec(self) -> Vec { + let mut vec = Vec::with_capacity(self.len); + vec.extend(self.iter()); + + vec + } + } + + pub struct StrainsIter<'a> { + inner: Copied>, + curr: Option, + len: usize, + } + + impl<'a> StrainsIter<'a> { + pub fn new(vec: &'a StrainsVec) -> Self { + let mut inner = vec.inner.iter().copied(); + let curr = inner.next(); + + Self { + inner, + curr, + len: vec.len, + } + } + } + + impl<'a> Iterator for StrainsIter<'a> { + type Item = f64; + + fn next(&mut self) -> Option { + loop { + let curr = self.curr.as_mut()?; + + if curr.is_value() { + let value = curr.value(); + self.curr = self.inner.next(); + self.len -= 1; + + return Some(value); + } else if curr.zero_count() > 0 { + curr.decr_zero_count(); + self.len -= 1; + + return Some(0.0); + } + + self.curr = self.inner.next(); + } + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + + (len, Some(len)) + } + } + + impl ExactSizeIterator for StrainsIter<'_> { + fn len(&self) -> usize { + self.len + } + } + + /// Private module to hide internal fields. + mod entry { + /// Either a positive `f64` or an amount of consecutive `0.0`. + /// + /// If the first bit is not set, i.e. the sign bit of a `f64` indicates + /// that it's positive, the union represents that `f64`. Otherwise, the + /// first bit is ignored and the union represents a `u64`. + #[derive(Copy, Clone)] + pub union StrainsEntry { + value: f64, + zero_count: u64, + } + + impl StrainsEntry { + const ZERO_COUNT_MASK: u64 = u64::MAX >> 1; + + pub fn new_value(value: f64) -> Self { + debug_assert!( + value.is_sign_positive(), + "attempted to create negative strain entry, please report as a bug" + ); + + Self { value } + } + + pub const fn new_zero() -> Self { + Self { + zero_count: !Self::ZERO_COUNT_MASK + 1, + } + } + + pub fn is_zero(self) -> bool { + unsafe { self.value.is_sign_negative() } + } + + // Requiring `self` as a reference improves ergonomics for passing this + // method as argument to higher-order functions. + #[allow(clippy::trivially_copy_pass_by_ref)] + pub fn is_value(&self) -> bool { + !self.is_zero() + } + + pub fn value(self) -> f64 { + debug_assert!(self.is_value()); + + unsafe { self.value } + } + + pub fn as_value_mut(&mut self) -> &mut f64 { + debug_assert!(self.is_value()); + + unsafe { &mut self.value } + } + + pub fn zero_count(self) -> u64 { + debug_assert!(self.is_zero()); + + unsafe { self.zero_count & Self::ZERO_COUNT_MASK } + } + + pub fn incr_zero_count(&mut self) { + debug_assert!(self.is_zero()); + + unsafe { + self.zero_count += 1; + } + } + + pub fn decr_zero_count(&mut self) { + debug_assert!(self.is_zero()); + + unsafe { + self.zero_count -= 1; + } + } + } + } + + #[cfg(test)] + mod tests { + use proptest::prelude::*; + + use crate::util::float_ext::FloatExt; + + use super::*; + + proptest! { + #[test] + fn expected(mut values in prop::collection::vec(prop::option::of(0.0..1_000.0), 0..1_000)) { + let mut vec = StrainsVec::with_capacity(values.len()); + + let mut additional_zeros = 0; + let mut prev_zero = false; + let mut sum = 0.0; + + for opt in values.iter().copied() { + if let Some(value) = opt { + vec.push(value); + prev_zero = false; + sum += value; + } else { + vec.push(0.0); + + if prev_zero { + additional_zeros += 1; + } + + prev_zero = true; + } + } + + assert_eq!(vec.len(), values.len()); + assert_eq!(vec.inner.len(), values.len() - additional_zeros); + assert!(vec.sum().eq(sum)); + assert!(vec.iter().eq(values.iter().copied().map(|opt| opt.unwrap_or(0.0)))); + + values.retain(Option::is_some); + + values.sort_by(|a, b| { + let (Some(a), Some(b)) = (a, b) else { unreachable!() }; + + b.total_cmp(a) + }); + + assert!(vec.sorted_non_zero_iter().eq(values.into_iter().flatten())); + } + } + } +} + +#[cfg(not(feature = "compact_strains"))] +mod inner { + use std::{ + iter::Copied, + slice::{Iter, IterMut}, + }; + + /// Plain wrapper around `Vec` because the `compact_strains` feature + /// is disabled. + #[derive(Clone)] + pub struct StrainsVec { + inner: Vec, + } + + impl StrainsVec { + pub fn with_capacity(capacity: usize) -> Self { + Self { + inner: Vec::with_capacity(capacity), + } + } + + pub fn len(&self) -> usize { + self.inner.len() + } + + pub fn push(&mut self, value: f64) { + self.inner.push(value); + } + + pub fn sort_desc(&mut self) { + self.inner.sort_by(|a, b| b.total_cmp(a)); + } + + pub fn retain_non_zero(&mut self) { + self.inner.retain(|&a| a > 0.0); + } + + pub fn retain_non_zero_and_sort(&mut self) { + self.retain_non_zero(); + self.sort_desc(); + } + + pub fn non_zero_iter(&self) -> Copied> { + self.inner.iter().copied() + } + + pub fn sorted_non_zero_iter(&mut self) -> Copied> { + self.retain_non_zero_and_sort(); + + self.non_zero_iter() + } + + pub fn sorted_non_zero_iter_mut(&mut self) -> IterMut<'_, f64> { + self.retain_non_zero_and_sort(); + + self.inner.iter_mut() + } + + pub fn sum(&self) -> f64 { + self.inner.iter().copied().sum() + } + + pub fn iter(&self) -> Copied> { + self.inner.iter().copied() + } + + pub fn into_vec(self) -> Vec { + self.inner + } + } +}