Skip to content

Commit

Permalink
fix: Fix bug where step size stats were not updated after tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Jul 8, 2024
1 parent 11bf46b commit dd7218d
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 194 deletions.
44 changes: 16 additions & 28 deletions src/adapt_strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{
use crate::nuts::{SamplerStats, StatTraceBuilder};

pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
step_size: StepSizeStrategy<M, A>,
step_size: StepSizeStrategy,
mass_matrix: A,
options: AdaptOptions<A::Options>,
num_tune: u64,
Expand Down Expand Up @@ -73,7 +73,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> SamplerStats<M> for GlobalStrategy<

fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder {
CombinedStatsBuilder {
stats1: self.step_size.new_builder(settings, dim),
stats1: SamplerStats::<M>::new_builder(&self.step_size, settings, dim),
stats2: self.mass_matrix.new_builder(settings, dim),
}
}
Expand All @@ -87,7 +87,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStats<M> for GlobalStrategy<M,

impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy<M, A> {
type Potential = A::Potential;
type Collector = CombinedCollector<M, AcceptanceRateCollector<M>, A::Collector>;
type Collector = CombinedCollector<M, AcceptanceRateCollector, A::Collector>;
type Options = AdaptOptions<A::Options>;

fn new(math: &mut M, options: Self::Options, num_tune: u64) -> Self {
Expand All @@ -99,7 +99,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
assert!(early_end < num_tune);

Self {
step_size: StepSizeStrategy::new(math, options.dual_average_options, num_tune),
step_size: StepSizeStrategy::new(options.dual_average_options),
mass_matrix: A::new(math, options.mass_matrix_options, num_tune),
options,
num_tune,
Expand All @@ -121,7 +121,6 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
) {
self.mass_matrix.init(math, options, potential, state, rng);
self.step_size.init(math, options, potential, state, rng);
self.step_size.enable();
}

fn adapt<R: Rng + ?Sized>(
Expand All @@ -134,6 +133,8 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
state: &State<M>,
rng: &mut R,
) {
self.step_size.update(&collector.collector1);

if draw >= self.num_tune {
self.tuning = false;
return;
Expand Down Expand Up @@ -172,44 +173,31 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
if did_change {
self.last_update = draw;
}

if is_late {
self.step_size.use_mean_sym();
self.step_size.update_estimator_late();
} else {
self.step_size.update_estimator_early();
}

// First time we change the mass matrix
if did_change & self.has_initial_mass_matrix {
self.has_initial_mass_matrix = false;
self.step_size.init(math, options, potential, state, rng);
} else {
self.step_size.adapt(
math,
options,
potential,
draw,
&collector.collector1,
state,
rng,
);
self.step_size.update_stepsize(potential, false)
}
return;
}

if draw == self.num_tune - 1 {
self.step_size.finalize();
}
self.step_size.adapt(
math,
options,
potential,
draw,
&collector.collector1,
state,
rng,
);
self.step_size.update_estimator_late();
let is_last = draw == self.num_tune - 1;
self.step_size.update_stepsize(potential, is_last);
}

fn new_collector(&self, math: &mut M) -> Self::Collector {
CombinedCollector {
collector1: self.step_size.new_collector(math),
collector1: self.step_size.new_collector(),
collector2: self.mass_matrix.new_collector(math),
_phantom: PhantomData,
}
Expand Down
12 changes: 4 additions & 8 deletions src/stepsize.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::marker::PhantomData;

use crate::{
math_base::Math,
nuts::{Collector, NutsOptions},
Expand Down Expand Up @@ -103,25 +101,23 @@ impl RunningMean {
}
}

pub struct AcceptanceRateCollector<M: Math> {
pub struct AcceptanceRateCollector {
initial_energy: f64,
pub(crate) mean: RunningMean,
pub(crate) mean_sym: RunningMean,
phantom: PhantomData<M>,
}

impl<M: Math> AcceptanceRateCollector<M> {
pub(crate) fn new() -> AcceptanceRateCollector<M> {
impl AcceptanceRateCollector {
pub(crate) fn new() -> AcceptanceRateCollector {
AcceptanceRateCollector {
initial_energy: 0.,
mean: RunningMean::new(),
mean_sym: RunningMean::new(),
phantom: PhantomData,
}
}
}

impl<M: Math> Collector<M> for AcceptanceRateCollector<M> {
impl<M: Math> Collector<M> for AcceptanceRateCollector {
fn register_leapfrog(
&mut self,
_math: &mut M,
Expand Down
Loading

0 comments on commit dd7218d

Please sign in to comment.