diff --git a/birdman/inference.py b/birdman/inference.py index 15fe528..b33b7d2 100644 --- a/birdman/inference.py +++ b/birdman/inference.py @@ -1,3 +1,4 @@ +from functools import partial from typing import List, Sequence, Union import arviz as az @@ -21,10 +22,17 @@ def fit_to_inference( if posterior_predictive is not None and posterior_predictive not in dims: raise KeyError("Must include dimensions for posterior predictive!") + # Required because as of writing, CmdStanVB.stan_variable defaults to + # returning the mean rather than the sample + if isinstance(fit, CmdStanVB): + stan_var_fn = partial(fit.stan_variable, mean=False) + else: + stan_var_fn = fit.stan_variable + das = dict() for param in params: - data = fit.stan_variable(param) + data = stan_var_fn(param) _dims = dims[param] _coords = {k: coords[k] for k in _dims} @@ -32,7 +40,7 @@ def fit_to_inference( das[param] = stan_var_to_da(data, _coords, _dims, chains, draws) if log_likelihood: - data = fit.stan_variable(log_likelihood) + data = stan_var_fn(log_likelihood) _dims = dims[log_likelihood] _coords = {k: coords[k] for k in _dims} @@ -43,7 +51,7 @@ def fit_to_inference( ll_ds = None if posterior_predictive: - data = fit.stan_variable(posterior_predictive) + data = stan_var_fn(posterior_predictive) _dims = dims[posterior_predictive] _coords = {k: coords[k] for k in _dims} @@ -84,21 +92,19 @@ def concatenate_inferences( """ group_list = [] group_list.append([x.posterior for x in inf_list]) - group_list.append([x.sample_stats for x in inf_list]) if "log_likelihood" in inf_list[0].groups(): group_list.append([x.log_likelihood for x in inf_list]) if "posterior_predictive" in inf_list[0].groups(): group_list.append([x.posterior_predictive for x in inf_list]) po_ds = xr.concat(group_list[0], concatenation_name) - ss_ds = xr.concat(group_list[1], concatenation_name) - group_dict = {"posterior": po_ds, "sample_stats": ss_ds} + group_dict = {"posterior": po_ds} if "log_likelihood" in inf_list[0].groups(): - ll_ds = xr.concat(group_list[2], concatenation_name) + ll_ds = xr.concat(group_list[1], concatenation_name) group_dict["log_likelihood"] = ll_ds if "posterior_predictive" in inf_list[0].groups(): - pp_ds = xr.concat(group_list[3], concatenation_name) + pp_ds = xr.concat(group_list[2], concatenation_name) group_dict["posterior_predictive"] = pp_ds all_group_inferences = [] @@ -114,6 +120,7 @@ def concatenate_inferences( return az.concat(*all_group_inferences) +# TODO: Fix docstring def stan_var_to_da( data: np.ndarray, coords: dict, @@ -121,9 +128,7 @@ def stan_var_to_da( chains: int, draws: int ): - """Convert Stan variable draws to xr.DataArray. - - """ + """Convert Stan variable draws to xr.DataArray.""" data = np.stack(np.split(data, chains)) coords["draw"] = np.arange(draws) diff --git a/tests/test_inference.py b/tests/test_inference.py index cb55ac0..adb3ca6 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,6 +1,9 @@ import numpy as np +import pytest from birdman import inference as mu +from birdman.default_models import NegativeBinomialSingle +from birdman import ModelIterator class TestToInference: @@ -78,3 +81,30 @@ def test_serial_ppll(self, example_model): nb_data = example_model.fit.stan_variable(v) nb_data = np.array(np.split(nb_data, 4, axis=0)) np.testing.assert_array_almost_equal(nb_data, inf_data) + + +@pytest.mark.parametrize("method", ["mcmc", "vi"]) +def test_concat(table_biom, metadata, method): + tbl = table_biom + md = metadata + + model_iterator = ModelIterator( + table=tbl, + model=NegativeBinomialSingle, + formula="host_common_name", + metadata=md, + ) + + infs = [] + for fname, model in model_iterator: + model.compile_model() + model.fit_model(method, num_draws=100) + infs.append(model.to_inference()) + + inf_concat = mu.concatenate_inferences( + infs, + coords={"feature": tbl.ids("observation")}, + ) + exp_feat_ids = tbl.ids("observation") + feat_ids = inf_concat.posterior.coords["feature"].to_numpy() + assert (exp_feat_ids == feat_ids).all()