Skip to content

Commit

Permalink
Store gradient and unconstrained draw if requested
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Jun 19, 2023
1 parent 37161ba commit 6af13ea
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/adapt_strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ mod test {
let options = NutsOptions {
maxdepth: 10u64,
store_gradient: true,
store_unconstrained: true,
};

let rng = {
Expand Down
13 changes: 2 additions & 11 deletions src/cpu_sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::thread::JoinHandle;
use thiserror::Error;

use crate::{
adapt_strategy::{DualAverageSettings, GradDiagOptions, GradDiagStrategy},
adapt_strategy::{GradDiagOptions, GradDiagStrategy},
cpu_potential::EuclideanPotential,
mass_matrix::DiagMassMatrix,
nuts::{Chain, NutsChain, NutsError, NutsOptions, SampleStats},
Expand All @@ -30,8 +30,6 @@ pub struct SamplerArgs {
pub max_energy_error: f64,
/// Store detailed information about each divergence in the sampler stats
pub store_divergences: bool,
/// Settings for step size adaptation.
pub step_size_adapt: DualAverageSettings,
/// Settings for mass matrix adaptation.
pub mass_matrix_adapt: GradDiagOptions,
}
Expand All @@ -46,7 +44,6 @@ impl Default for SamplerArgs {
store_gradient: false,
store_unconstrained: false,
store_divergences: true,
step_size_adapt: DualAverageSettings::default(),
mass_matrix_adapt: GradDiagOptions::default(),
}
}
Expand Down Expand Up @@ -89,8 +86,6 @@ pub trait CpuLogpFuncMaker<Func>: Send + Sync
where
Func: CpuLogpFunc,
{
//type Func: CpuLogpFunc;

fn make_logp_func(&self, chain: usize) -> Result<Func, anyhow::Error>;
fn dim(&self) -> usize;
}
Expand Down Expand Up @@ -194,19 +189,15 @@ pub fn new_sampler<F: CpuLogpFunc, R: Rng + ?Sized>(
) -> impl Chain {
use crate::nuts::AdaptStrategy;
let num_tune = settings.num_tune;
//let step_size_adapt = DualAverageStrategy::new(settings.step_size_adapt, num_tune, logp.dim());
//let mass_matrix_adapt =
// ExpWindowDiagAdapt::new(settings.mass_matrix_adapt, num_tune, logp.dim());

let strategy = GradDiagStrategy::new(settings.mass_matrix_adapt, num_tune, logp.dim());

let mass_matrix = DiagMassMatrix::new(logp.dim());
let max_energy_error = settings.max_energy_error;
let potential = EuclideanPotential::new(logp, mass_matrix, max_energy_error, 1f64);

let options = NutsOptions {
maxdepth: settings.maxdepth,
store_gradient: settings.store_gradient,
store_unconstrained: settings.store_unconstrained,
};

let rng = rand::rngs::SmallRng::from_rng(rng).expect("Could not seed rng");
Expand Down
66 changes: 66 additions & 0 deletions src/nuts.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use arrow2::array::{MutableFixedSizeListArray, TryPush};
#[cfg(feature = "arrow")]
use arrow2::{
array::{MutableArray, MutableBooleanArray, MutablePrimitiveArray, StructArray},
Expand Down Expand Up @@ -368,6 +369,7 @@ impl<P: Hamiltonian, C: Collector<State = P::State>> NutsTree<P, C> {
pub struct NutsOptions {
pub maxdepth: u64,
pub store_gradient: bool,
pub store_unconstrained: bool,
}

pub(crate) fn draw<P, R, C>(
Expand Down Expand Up @@ -435,6 +437,7 @@ pub(crate) struct NutsSampleStats<HStats: Send + Debug, AdaptStats: Send + Debug
pub chain: u64,
pub draw: u64,
pub gradient: Option<Box<[f64]>>,
pub unconstrained: Option<Box<[f64]>>,
pub potential_stats: HStats,
pub strategy_stats: AdaptStats,
}
Expand All @@ -461,6 +464,8 @@ pub trait SampleStats: Send + Debug {
/// The logp gradient at the location of the draw. This is only stored
/// if NutsOptions.store_gradient is `true`.
fn gradient(&self) -> Option<&[f64]>;
/// The draw in the unconstrained space.
fn unconstrained(&self) -> Option<&[f64]>;
}

impl<H, A> SampleStats for NutsSampleStats<H, A>
Expand Down Expand Up @@ -495,6 +500,9 @@ where
fn gradient(&self) -> Option<&[f64]> {
self.gradient.as_ref().map(|x| &x[..])
}
fn unconstrained(&self) -> Option<&[f64]> {
self.unconstrained.as_ref().map(|x| &x[..])
}
}

#[cfg(feature = "arrow")]
Expand All @@ -506,6 +514,8 @@ pub struct StatsBuilder<H: Hamiltonian, A: AdaptStrategy> {
energy: MutablePrimitiveArray<f64>,
chain: MutablePrimitiveArray<u64>,
draw: MutablePrimitiveArray<u64>,
unconstrained: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
gradient: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
hamiltonian: <H::Stats as ArrowRow>::Builder,
adapt: <A::Stats as ArrowRow>::Builder,
}
Expand All @@ -514,6 +524,21 @@ pub struct StatsBuilder<H: Hamiltonian, A: AdaptStrategy> {
impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
fn new_with_capacity(dim: usize, settings: &SamplerArgs) -> Self {
let capacity = (settings.num_tune + settings.num_draws) as usize;

let gradient = if settings.store_gradient {
let items = MutablePrimitiveArray::new();
Some(MutableFixedSizeListArray::new_with_field(items, "item", false, dim))
} else {
None
};

let unconstrained = if settings.store_gradient {
let items = MutablePrimitiveArray::new();
Some(MutableFixedSizeListArray::new_with_field(items, "item", false, dim))
} else {
None
};

Self {
depth: MutablePrimitiveArray::with_capacity(capacity),
maxdepth_reached: MutableBooleanArray::with_capacity(capacity),
Expand All @@ -522,6 +547,8 @@ impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
energy: MutablePrimitiveArray::with_capacity(capacity),
chain: MutablePrimitiveArray::with_capacity(capacity),
draw: MutablePrimitiveArray::with_capacity(capacity),
gradient,
unconstrained,
hamiltonian: <H::Stats as ArrowRow>::new_builder(dim, settings),
adapt: <A::Stats as ArrowRow>::new_builder(dim, settings),
}
Expand All @@ -541,6 +568,28 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
self.chain.push(Some(value.chain));
self.draw.push(Some(value.draw));

if let Some(store) = self.gradient.as_mut() {
store
.try_push(
value
.gradient()
.as_ref()
.map(|vals| vals.iter().map(|&x| Some(x)))
)
.unwrap();
}

if let Some(store) = self.unconstrained.as_mut() {
store
.try_push(
value
.unconstrained()
.as_ref()
.map(|vals| vals.iter().map(|&x| Some(x)))
)
.unwrap();
}

self.hamiltonian.append_value(&value.potential_stats);
self.adapt.append_value(&value.strategy_stats);
}
Expand Down Expand Up @@ -579,6 +628,16 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
arrays.extend(adapt.1);
}

if let Some(mut gradient) = self.gradient.take() {
fields.push(Field::new("gradient", gradient.data_type().clone(), true));
arrays.push(gradient.as_box());
}

if let Some(mut unconstrained) = self.unconstrained.take() {
fields.push(Field::new("unconstrained", unconstrained.data_type().clone(), true));
arrays.push(unconstrained.as_box());
}

Some(StructArray::new(DataType::Struct(fields), arrays, None))
}
}
Expand Down Expand Up @@ -737,6 +796,13 @@ where
} else {
None
},
unconstrained: if self.options.store_unconstrained {
let mut unconstrained: Box<[f64]> = vec![0f64; self.potential.dim()].into();
state.write_position(&mut unconstrained);
Some(unconstrained)
} else {
None
},
};
self.strategy.adapt(
&mut self.options,
Expand Down

0 comments on commit 6af13ea

Please sign in to comment.