Skip to content

Commit

Permalink
Compute unweighted average and error by default for numerical integra…
Browse files Browse the repository at this point in the history
…tion

- Add methods to get statistics without updating the grid
  • Loading branch information
benruijl committed Aug 10, 2024
1 parent 86af1cb commit 8a3df03
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8971,7 +8971,7 @@ impl PythonNumericalIntegrator {
match &self.grid {
Grid::Continuous(cs) => {
let mut a = cs.accumulator.shallow_copy();
a.update_iter();
a.update_iter(false);
Ok((
a.avg,
a.err,
Expand All @@ -8983,7 +8983,7 @@ impl PythonNumericalIntegrator {
}
Grid::Discrete(ds) => {
let mut a = ds.accumulator.shallow_copy();
a.update_iter();
a.update_iter(false);
Ok((
a.avg,
a.err,
Expand Down
70 changes: 53 additions & 17 deletions src/numerical_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ use crate::domains::float::{ConstructibleFloat, NumericalFloatComparison, Real};
/// the error and the chi-squared of samples added over multiple
/// iterations.
///
/// Samples can be added using [`StatisticsAccumulator::add_samples()`]. When an iteration of
/// samples is finished, call [`StatisticsAccumulator::update_iter()`], which
/// Samples can be added using [`Self::add_sample()`]. When an iteration of
/// samples is finished, call [`Self::update_iter()`], which
/// updates the average, error and chi-squared over all iterations with the average
/// and error of the current iteration in a weighted fashion.
///
/// This accumulator can be merged with another accumulator using [`StatisticsAccumulator::merge()`] or
/// [`StatisticsAccumulator::merge_samples_no_reset()`]. This is useful when
/// This accumulator can be merged with another accumulator using [`Self::merge_samples()`] or
/// [`Self::merge_samples_no_reset()`]. This is useful when
/// samples are collected in multiple threads.
///
/// The accumulator also stores which samples yielded the highest weight thus far.
Expand All @@ -23,6 +23,8 @@ use crate::domains::float::{ConstructibleFloat, NumericalFloatComparison, Real};
pub struct StatisticsAccumulator<T: Real + ConstructibleFloat + Copy + NumericalFloatComparison> {
sum: T,
sum_sq: T,
total_sum: T,
total_sum_sq: T,
weight_sum: T,
avg_sum: T,
pub avg: T,
Expand All @@ -48,6 +50,8 @@ impl<T: Real + ConstructibleFloat + Copy + NumericalFloatComparison> StatisticsA
StatisticsAccumulator {
sum: T::new_zero(),
sum_sq: T::new_zero(),
total_sum: T::new_zero(),
total_sum_sq: T::new_zero(),
weight_sum: T::new_zero(),
avg_sum: T::new_zero(),
avg: T::new_zero(),
Expand Down Expand Up @@ -76,6 +80,8 @@ impl<T: Real + ConstructibleFloat + Copy + NumericalFloatComparison> StatisticsA
StatisticsAccumulator {
sum: self.sum,
sum_sq: self.sum_sq,
total_sum: self.total_sum,
total_sum_sq: self.total_sum_sq,
weight_sum: self.weight_sum,
avg_sum: self.avg_sum,
avg: self.avg,
Expand Down Expand Up @@ -151,9 +157,12 @@ impl<T: Real + ConstructibleFloat + Copy + NumericalFloatComparison> StatisticsA
}
}

/// Process the samples added with `[`Self::add_sample()`]` and
/// Process the samples added with [`Self::add_sample()`] and
/// compute a new average, error, and chi-squared.
pub fn update_iter(&mut self) -> bool {
///
/// When `weighted_average=True`, a weighted average and error is computed using
/// the iteration variances as a weight.
pub fn update_iter(&mut self, weighted_average: bool) -> bool {
// TODO: we could be throwing away events that are very rare
if self.new_samples < 2 {
self.cur_iter += 1;
Expand All @@ -163,9 +172,11 @@ impl<T: Real + ConstructibleFloat + Copy + NumericalFloatComparison> StatisticsA
self.processed_samples += self.new_samples;
self.num_zero_evaluations += self.new_zero_evaluations;
let n = T::new_from_usize(self.new_samples);
self.sum /= &n;
self.sum_sq /= n * n;
let mut w = (self.sum_sq * n).sqrt();
self.total_sum += self.sum;
self.total_sum_sq += self.sum_sq;
self.sum /= n;
self.sum_sq /= n;
let mut w = self.sum_sq.sqrt();

w = ((w + self.sum) * (w - self.sum)) / (n - T::new_one());
if w == T::new_zero() {
Expand All @@ -178,9 +189,17 @@ impl<T: Real + ConstructibleFloat + Copy + NumericalFloatComparison> StatisticsA

self.weight_sum += w;
self.avg_sum += w * self.sum;
let sigma_sq = self.weight_sum.inv();
self.avg = sigma_sq * self.avg_sum;
self.err = sigma_sq.sqrt();

if weighted_average {
let sigma_sq = self.weight_sum.inv();
self.avg = sigma_sq * self.avg_sum;
self.err = sigma_sq.sqrt();
} else {
let n = T::new_from_usize(self.processed_samples);
self.avg = self.total_sum / n;
self.err = ((self.total_sum_sq / n - self.avg * self.avg) / (n - T::new_one())).sqrt();
}

if self.cur_iter == 0 {
self.guess = self.sum;
}
Expand All @@ -199,6 +218,23 @@ impl<T: Real + ConstructibleFloat + Copy + NumericalFloatComparison> StatisticsA
true
}

/// Get an estimate for the average, error and chi-squared, as if the current iteration
/// has ended without adding more samples.
pub fn get_live_estimate(&self) -> (T, T, T) {
let mut a = self.shallow_copy();
a.update_iter(false);
(a.avg, a.err, a.chi_sq)
}

/// Format the live `mean ± sdev` as `mean(sdev)` in a human-readable way with the correct number of digits.
///
/// Based on the Python package [gvar](https://github.com/gplepage/gvar) by Peter Lepage.
pub fn format_live_uncertainty(&self) -> String {
let mut a = self.shallow_copy();
a.update_iter(false);
Self::format_uncertainty_impl(a.avg.to_f64(), a.err.to_f64())
}

/// Format `mean ± sdev` as `mean(sdev)` in a human-readable way with the correct number of digits.
///
/// Based on the Python package [gvar](https://github.com/gplepage/gvar) by Peter Lepage.
Expand Down Expand Up @@ -535,7 +571,7 @@ impl<T: Real + ConstructibleFloat + Copy + NumericalFloatComparison> DiscreteGri
}

let acc = &mut bin.accumulator;
acc.update_iter();
acc.update_iter(false);

if acc.processed_samples > 1 {
err_sum += acc.err * T::new_from_usize(acc.processed_samples - 1).sqrt();
Expand Down Expand Up @@ -583,7 +619,7 @@ impl<T: Real + ConstructibleFloat + Copy + NumericalFloatComparison> DiscreteGri
bin.pdf /= sum;
}

self.accumulator.update_iter();
self.accumulator.update_iter(false);
}

/// Sample a point form this grid, writing the result in `sample`.
Expand Down Expand Up @@ -768,7 +804,7 @@ impl<T: Real + ConstructibleFloat + Copy + NumericalFloatComparison> ContinuousG
d.update(learning_rate);
}

self.accumulator.update_iter();
self.accumulator.update_iter(false);
}

/// Returns `Ok` when this grid can be merged with another grid,
Expand Down Expand Up @@ -1129,7 +1165,7 @@ mod test {
grid.update(1.5, 1.5);
}

assert_eq!(grid.accumulator.avg, 0.9713543844460519);
assert_eq!(grid.accumulator.err, 0.0009026050146732183)
assert_eq!(grid.accumulator.avg, 0.9718412953459551);
assert_eq!(grid.accumulator.err, 0.0009349254838085983)
}
}

0 comments on commit 8a3df03

Please sign in to comment.