From 6af13ea46f5d59aa3562760181c80eb7f91bc898 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 15 Jun 2023 17:30:08 -0500 Subject: [PATCH] Store gradient and unconstrained draw if requested --- src/adapt_strategy.rs | 1 + src/cpu_sampler.rs | 13 ++------- src/nuts.rs | 66 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 11 deletions(-) diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index c6e05e9..72fdcdd 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -686,6 +686,7 @@ mod test { let options = NutsOptions { maxdepth: 10u64, store_gradient: true, + store_unconstrained: true, }; let rng = { diff --git a/src/cpu_sampler.rs b/src/cpu_sampler.rs index 3c4e017..77477de 100644 --- a/src/cpu_sampler.rs +++ b/src/cpu_sampler.rs @@ -4,7 +4,7 @@ use std::thread::JoinHandle; use thiserror::Error; use crate::{ - adapt_strategy::{DualAverageSettings, GradDiagOptions, GradDiagStrategy}, + adapt_strategy::{GradDiagOptions, GradDiagStrategy}, cpu_potential::EuclideanPotential, mass_matrix::DiagMassMatrix, nuts::{Chain, NutsChain, NutsError, NutsOptions, SampleStats}, @@ -30,8 +30,6 @@ pub struct SamplerArgs { pub max_energy_error: f64, /// Store detailed information about each divergence in the sampler stats pub store_divergences: bool, - /// Settings for step size adaptation. - pub step_size_adapt: DualAverageSettings, /// Settings for mass matrix adaptation. pub mass_matrix_adapt: GradDiagOptions, } @@ -46,7 +44,6 @@ impl Default for SamplerArgs { store_gradient: false, store_unconstrained: false, store_divergences: true, - step_size_adapt: DualAverageSettings::default(), mass_matrix_adapt: GradDiagOptions::default(), } } @@ -89,8 +86,6 @@ pub trait CpuLogpFuncMaker: Send + Sync where Func: CpuLogpFunc, { - //type Func: CpuLogpFunc; - fn make_logp_func(&self, chain: usize) -> Result; fn dim(&self) -> usize; } @@ -194,12 +189,7 @@ pub fn new_sampler( ) -> impl Chain { use crate::nuts::AdaptStrategy; let num_tune = settings.num_tune; - //let step_size_adapt = DualAverageStrategy::new(settings.step_size_adapt, num_tune, logp.dim()); - //let mass_matrix_adapt = - // ExpWindowDiagAdapt::new(settings.mass_matrix_adapt, num_tune, logp.dim()); - let strategy = GradDiagStrategy::new(settings.mass_matrix_adapt, num_tune, logp.dim()); - let mass_matrix = DiagMassMatrix::new(logp.dim()); let max_energy_error = settings.max_energy_error; let potential = EuclideanPotential::new(logp, mass_matrix, max_energy_error, 1f64); @@ -207,6 +197,7 @@ pub fn new_sampler( let options = NutsOptions { maxdepth: settings.maxdepth, store_gradient: settings.store_gradient, + store_unconstrained: settings.store_unconstrained, }; let rng = rand::rngs::SmallRng::from_rng(rng).expect("Could not seed rng"); diff --git a/src/nuts.rs b/src/nuts.rs index d2f87af..696e7a8 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -1,3 +1,4 @@ +use arrow2::array::{MutableFixedSizeListArray, TryPush}; #[cfg(feature = "arrow")] use arrow2::{ array::{MutableArray, MutableBooleanArray, MutablePrimitiveArray, StructArray}, @@ -368,6 +369,7 @@ impl> NutsTree { pub struct NutsOptions { pub maxdepth: u64, pub store_gradient: bool, + pub store_unconstrained: bool, } pub(crate) fn draw( @@ -435,6 +437,7 @@ pub(crate) struct NutsSampleStats>, + pub unconstrained: Option>, pub potential_stats: HStats, pub strategy_stats: AdaptStats, } @@ -461,6 +464,8 @@ pub trait SampleStats: Send + Debug { /// The logp gradient at the location of the draw. This is only stored /// if NutsOptions.store_gradient is `true`. fn gradient(&self) -> Option<&[f64]>; + /// The draw in the unconstrained space. + fn unconstrained(&self) -> Option<&[f64]>; } impl SampleStats for NutsSampleStats @@ -495,6 +500,9 @@ where fn gradient(&self) -> Option<&[f64]> { self.gradient.as_ref().map(|x| &x[..]) } + fn unconstrained(&self) -> Option<&[f64]> { + self.unconstrained.as_ref().map(|x| &x[..]) + } } #[cfg(feature = "arrow")] @@ -506,6 +514,8 @@ pub struct StatsBuilder { energy: MutablePrimitiveArray, chain: MutablePrimitiveArray, draw: MutablePrimitiveArray, + unconstrained: Option>>, + gradient: Option>>, hamiltonian: ::Builder, adapt: ::Builder, } @@ -514,6 +524,21 @@ pub struct StatsBuilder { impl StatsBuilder { fn new_with_capacity(dim: usize, settings: &SamplerArgs) -> Self { let capacity = (settings.num_tune + settings.num_draws) as usize; + + let gradient = if settings.store_gradient { + let items = MutablePrimitiveArray::new(); + Some(MutableFixedSizeListArray::new_with_field(items, "item", false, dim)) + } else { + None + }; + + let unconstrained = if settings.store_gradient { + let items = MutablePrimitiveArray::new(); + Some(MutableFixedSizeListArray::new_with_field(items, "item", false, dim)) + } else { + None + }; + Self { depth: MutablePrimitiveArray::with_capacity(capacity), maxdepth_reached: MutableBooleanArray::with_capacity(capacity), @@ -522,6 +547,8 @@ impl StatsBuilder { energy: MutablePrimitiveArray::with_capacity(capacity), chain: MutablePrimitiveArray::with_capacity(capacity), draw: MutablePrimitiveArray::with_capacity(capacity), + gradient, + unconstrained, hamiltonian: ::new_builder(dim, settings), adapt: ::new_builder(dim, settings), } @@ -541,6 +568,28 @@ impl ArrowBuilder ArrowBuilder = vec![0f64; self.potential.dim()].into(); + state.write_position(&mut unconstrained); + Some(unconstrained) + } else { + None + }, }; self.strategy.adapt( &mut self.options,