Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Rewrite sample stats to use apache arrow #4

Merged
merged 11 commits into from
Jun 29, 2023
25 changes: 15 additions & 10 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,30 @@ codegen-units = 1
[dependencies]
rand = { version = "0.8.5", features = ["small_rng"] }
rand_distr = "0.4.3"
multiversion = "0.7.0"
itertools = "0.10.3"
crossbeam = "0.8.1"
thiserror = "1.0.31"
rayon = "1.5.3"
ndarray = "0.15.4"
multiversion = "0.7.2"
itertools = "0.11.0"
crossbeam = "0.8.2"
thiserror = "1.0.40"
rayon = "1.7.0"
arrow2 = { version = "0.17.2", optional = true }
rand_chacha = "0.3.1"
anyhow = "1.0.71"

[dev-dependencies]
proptest = "1.0.0"
pretty_assertions = "1.2.1"
criterion = "0.4.0"
nix = "0.26.1"
proptest = "1.2.0"
pretty_assertions = "1.3.0"
criterion = "0.5.1"
nix = "0.26.2"
approx = "0.5.1"
ndarray = "0.15.6"

[[bench]]
name = "sample"
harness = false

[features]
nightly = ["simd_support"]
default = ["arrow"]

simd_support = []
arrow = ["dep:arrow2"]
232 changes: 182 additions & 50 deletions src/adapt_strategy.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
use std::{fmt::Debug, marker::PhantomData};

#[cfg(feature = "arrow")]
use arrow2::{
array::{MutableArray, MutableFixedSizeListArray, MutablePrimitiveArray, StructArray, TryPush},
datatypes::{DataType, Field},
};
use itertools::izip;

use crate::{
cpu_potential::{CpuLogpFunc, EuclideanPotential},
mass_matrix::{
DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance,
},
nuts::{
AdaptStrategy, AsSampleStatVec, Collector, Hamiltonian, NutsOptions, SampleStatItem,
SampleStatValue,
},
mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance},
nuts::{AdaptStrategy, Collector, Hamiltonian, NutsOptions},
stepsize::{AcceptanceRateCollector, DualAverage, DualAverageOptions},
DivergenceInfo,
};

#[cfg(feature = "arrow")]
use crate::nuts::{ArrowBuilder, ArrowRow};
#[cfg(feature = "arrow")]
use crate::SamplerArgs;

const LOWER_LIMIT: f64 = 1e-10f64;
const UPPER_LIMIT: f64 = 1e10f64;

Expand All @@ -36,22 +42,55 @@ impl<F, M> DualAverageStrategy<F, M> {
}
}


#[derive(Debug, Clone, Copy)]
pub struct DualAverageStats {
pub step_size_bar: f64,
pub mean_tree_accept: f64,
pub n_steps: u64,
}

impl AsSampleStatVec for DualAverageStats {
fn add_to_vec(&self, vec: &mut Vec<SampleStatItem>) {
vec.push(("step_size_bar", SampleStatValue::F64(self.step_size_bar)));
vec.push((
"mean_tree_accept",
SampleStatValue::F64(self.mean_tree_accept),
));
vec.push(("n_steps", SampleStatValue::U64(self.n_steps)));
#[cfg(feature = "arrow")]
pub struct DualAverageStatsBuilder {
step_size_bar: MutablePrimitiveArray<f64>,
mean_tree_accept: MutablePrimitiveArray<f64>,
n_steps: MutablePrimitiveArray<u64>,
}

#[cfg(feature = "arrow")]
impl ArrowBuilder<DualAverageStats> for DualAverageStatsBuilder {
fn append_value(&mut self, value: &DualAverageStats) {
self.step_size_bar.push(Some(value.step_size_bar));
self.mean_tree_accept.push(Some(value.mean_tree_accept));
self.n_steps.push(Some(value.n_steps));
}

fn finalize(mut self) -> Option<StructArray> {
let fields = vec![
Field::new("step_size_bar", DataType::Float64, false),
Field::new("mean_tree_accept", DataType::Float64, false),
Field::new("n_steps", DataType::UInt64, false),
];

let arrays = vec![
self.step_size_bar.as_box(),
self.mean_tree_accept.as_box(),
self.n_steps.as_box(),
];

Some(StructArray::new(DataType::Struct(fields), arrays, None))
}
}

#[cfg(feature = "arrow")]
impl ArrowRow for DualAverageStats {
type Builder = DualAverageStatsBuilder;

fn new_builder(_dim: usize, _settings: &SamplerArgs) -> Self::Builder {
Self::Builder {
step_size_bar: MutablePrimitiveArray::new(),
mean_tree_accept: MutablePrimitiveArray::new(),
n_steps: MutablePrimitiveArray::new(),
}
}
}

Expand Down Expand Up @@ -138,19 +177,11 @@ impl<F: CpuLogpFunc, M: MassMatrix> AdaptStrategy for DualAverageStrategy<F, M>
}

/// Settings for mass matrix adaptation
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, Default)]
pub struct DiagAdaptExpSettings {
pub store_mass_matrix: bool,
}

impl Default for DiagAdaptExpSettings {
fn default() -> Self {
Self {
store_mass_matrix: false,
}
}
}

pub(crate) struct ExpWindowDiagAdapt<F> {
dim: usize,
exp_variance_draw: RunningVariance,
Expand Down Expand Up @@ -220,18 +251,68 @@ impl<F: CpuLogpFunc> ExpWindowDiagAdapt<F> {
}
}


#[derive(Clone, Debug)]
pub struct ExpWindowDiagAdaptStats {
pub mass_matrix_inv: Option<Box<[f64]>>,
}

impl AsSampleStatVec for ExpWindowDiagAdaptStats {
fn add_to_vec(&self, vec: &mut Vec<SampleStatItem>) {
vec.push((
"mass_matrix_inv",
SampleStatValue::OptionArray(self.mass_matrix_inv.clone()),
));
#[cfg(feature = "arrow")]
pub struct ExpWindowDiagAdaptStatsBuilder {
mass_matrix_inv: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
}

#[cfg(feature = "arrow")]
impl ArrowBuilder<ExpWindowDiagAdaptStats> for ExpWindowDiagAdaptStatsBuilder {
fn append_value(&mut self, value: &ExpWindowDiagAdaptStats) {
if let Some(store) = self.mass_matrix_inv.as_mut() {
store
.try_push(
value
.mass_matrix_inv
.as_ref()
.map(|vals| vals.iter().map(|&x| Some(x))),
)
.unwrap();
}
}

fn finalize(self) -> Option<StructArray> {
if let Some(mut store) = self.mass_matrix_inv {
let fields = vec![Field::new(
"mass_matrix_inv",
store.data_type().clone(),
true,
)];

let arrays = vec![store.as_box()];

Some(StructArray::new(DataType::Struct(fields), arrays, None))
} else {
None
}
}
}

#[cfg(feature = "arrow")]
impl ArrowRow for ExpWindowDiagAdaptStats {
type Builder = ExpWindowDiagAdaptStatsBuilder;

fn new_builder(dim: usize, settings: &SamplerArgs) -> Self::Builder {
if settings
.mass_matrix_adapt
.mass_matrix_options
.store_mass_matrix
{
let items = MutablePrimitiveArray::new();
let values = MutableFixedSizeListArray::new_with_field(items, "item", false, dim);
Self::Builder {
mass_matrix_inv: Some(values),
}
} else {
Self::Builder {
mass_matrix_inv: None,
}
}
}
}

Expand Down Expand Up @@ -260,16 +341,19 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
state: &<Self::Potential as Hamiltonian>::State,
) {
self.exp_variance_draw.add_sample(state.q.iter().copied());
self.exp_variance_draw_bg.add_sample(state.q.iter().copied());
self.exp_variance_grad.add_sample(state.grad.iter().copied());
self.exp_variance_grad_bg.add_sample(state.grad.iter().copied());
self.exp_variance_draw_bg
.add_sample(state.q.iter().copied());
self.exp_variance_grad
.add_sample(state.grad.iter().copied());
self.exp_variance_grad_bg
.add_sample(state.grad.iter().copied());

potential.mass_matrix.update_diag(
state.grad.iter().map(|&grad| {
Some((grad).abs().recip().clamp(LOWER_LIMIT, UPPER_LIMIT))
})
state
.grad
.iter()
.map(|&grad| Some((grad).abs().recip().clamp(LOWER_LIMIT, UPPER_LIMIT))),
);

}

fn adapt(
Expand Down Expand Up @@ -303,7 +387,6 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
}
}


pub(crate) struct GradDiagStrategy<F: CpuLogpFunc> {
step_size: DualAverageStrategy<F, DiagMassMatrix>,
mass_matrix: ExpWindowDiagAdapt<F>,
Expand Down Expand Up @@ -332,8 +415,6 @@ impl Default for GradDiagOptions {
dual_average_options: DualAverageSettings::default(),
mass_matrix_options: DiagAdaptExpSettings::default(),
early_window: 0.3,
//step_size_window: 0.08,
//step_size_window: 0.15,
step_size_window: 0.2,
mass_matrix_switch_freq: 60,
early_mass_matrix_switch_freq: 10,
Expand All @@ -345,7 +426,7 @@ impl<F: CpuLogpFunc> AdaptStrategy for GradDiagStrategy<F> {
type Potential = EuclideanPotential<F, DiagMassMatrix>;
type Collector = CombinedCollector<
AcceptanceRateCollector<<EuclideanPotential<F, DiagMassMatrix> as Hamiltonian>::State>,
DrawGradCollector
DrawGradCollector,
>;
type Stats = CombinedStats<DualAverageStats, ExpWindowDiagAdaptStats>;
type Options = GradDiagOptions;
Expand Down Expand Up @@ -404,14 +485,16 @@ impl<F: CpuLogpFunc> AdaptStrategy for GradDiagStrategy<F> {
self.mass_matrix.update_estimators(&collector.collector2);
}
self.mass_matrix.update_potential(potential);
self.step_size.adapt(options, potential, draw, &collector.collector1);
self.step_size
.adapt(options, potential, draw, &collector.collector1);
return;
}

if draw == self.num_tune - 1 {
self.step_size.finalize();
}
self.step_size.adapt(options, potential, draw, &collector.collector1);
self.step_size
.adapt(options, potential, draw, &collector.collector1);
}

fn new_collector(&self) -> Self::Collector {
Expand All @@ -438,17 +521,65 @@ impl<F: CpuLogpFunc> AdaptStrategy for GradDiagStrategy<F> {
}
}

#[cfg(feature = "arrow")]
#[derive(Debug, Clone)]
pub struct CombinedStats<D1: Debug + ArrowRow, D2: Debug + ArrowRow> {
pub stats1: D1,
pub stats2: D2,
}

#[cfg(not(feature = "arrow"))]
#[derive(Debug, Clone)]
pub struct CombinedStats<D1: Debug, D2: Debug> {
pub stats1: D1,
pub stats2: D2,
}

impl<D1: AsSampleStatVec, D2: AsSampleStatVec> AsSampleStatVec for CombinedStats<D1, D2> {
fn add_to_vec(&self, vec: &mut Vec<SampleStatItem>) {
self.stats1.add_to_vec(vec);
self.stats2.add_to_vec(vec);
#[cfg(feature = "arrow")]
pub struct CombinedStatsBuilder<D1: ArrowRow, D2: ArrowRow> {
stats1: D1::Builder,
stats2: D2::Builder,
}

#[cfg(feature = "arrow")]
impl<D1: Debug + ArrowRow, D2: Debug + ArrowRow> ArrowRow for CombinedStats<D1, D2> {
type Builder = CombinedStatsBuilder<D1, D2>;

fn new_builder(dim: usize, settings: &SamplerArgs) -> Self::Builder {
Self::Builder {
stats1: D1::new_builder(dim, settings),
stats2: D2::new_builder(dim, settings),
}
}
}

#[cfg(feature = "arrow")]
impl<D1: Debug + ArrowRow, D2: Debug + ArrowRow> ArrowBuilder<CombinedStats<D1, D2>>
for CombinedStatsBuilder<D1, D2>
{
fn append_value(&mut self, value: &CombinedStats<D1, D2>) {
self.stats1.append_value(&value.stats1);
self.stats2.append_value(&value.stats2);
}

fn finalize(self) -> Option<StructArray> {
match (self.stats1.finalize(), self.stats2.finalize()) {
(None, None) => None,
(Some(stats1), None) => Some(stats1),
(None, Some(stats2)) => Some(stats2),
(Some(stats1), Some(stats2)) => {
let mut data1 = stats1.into_data();
let data2 = stats2.into_data();

assert!(data1.2.is_none());
assert!(data2.2.is_none());

data1.0.extend(data2.0);
data1.1.extend(data2.1);

Some(StructArray::new(DataType::Struct(data1.0), data1.1, None))
}
}
}
}

Expand All @@ -468,7 +599,7 @@ where
&mut self,
start: &Self::State,
end: &Self::State,
divergence_info: Option<&dyn crate::nuts::DivergenceInfo>,
divergence_info: Option<&DivergenceInfo>,
) {
self.collector1
.register_leapfrog(start, end, divergence_info);
Expand Down Expand Up @@ -555,6 +686,7 @@ mod test {
let options = NutsOptions {
maxdepth: 10u64,
store_gradient: true,
store_unconstrained: true,
};

let rng = {
Expand Down
Loading