From 40e45aa280735bbbd4e6986c1bd5364d2631c46c Mon Sep 17 00:00:00 2001 From: Gibraan Rahman Date: Wed, 4 Oct 2023 10:11:03 -0700 Subject: [PATCH 1/3] Fix #90 --- birdman/inference.py | 8 +++----- tests/test_inference.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/birdman/inference.py b/birdman/inference.py index 15fe528..ab3657c 100644 --- a/birdman/inference.py +++ b/birdman/inference.py @@ -84,21 +84,19 @@ def concatenate_inferences( """ group_list = [] group_list.append([x.posterior for x in inf_list]) - group_list.append([x.sample_stats for x in inf_list]) if "log_likelihood" in inf_list[0].groups(): group_list.append([x.log_likelihood for x in inf_list]) if "posterior_predictive" in inf_list[0].groups(): group_list.append([x.posterior_predictive for x in inf_list]) po_ds = xr.concat(group_list[0], concatenation_name) - ss_ds = xr.concat(group_list[1], concatenation_name) - group_dict = {"posterior": po_ds, "sample_stats": ss_ds} + group_dict = {"posterior": po_ds} if "log_likelihood" in inf_list[0].groups(): - ll_ds = xr.concat(group_list[2], concatenation_name) + ll_ds = xr.concat(group_list[1], concatenation_name) group_dict["log_likelihood"] = ll_ds if "posterior_predictive" in inf_list[0].groups(): - pp_ds = xr.concat(group_list[3], concatenation_name) + pp_ds = xr.concat(group_list[2], concatenation_name) group_dict["posterior_predictive"] = pp_ds all_group_inferences = [] diff --git a/tests/test_inference.py b/tests/test_inference.py index cb55ac0..e7bbe79 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,6 +1,9 @@ import numpy as np +import pytest from birdman import inference as mu +from birdman.default_models import NegativeBinomialSingle +from birdman import ModelIterator class TestToInference: @@ -78,3 +81,31 @@ def test_serial_ppll(self, example_model): nb_data = example_model.fit.stan_variable(v) nb_data = np.array(np.split(nb_data, 4, axis=0)) np.testing.assert_array_almost_equal(nb_data, inf_data) + + +@pytest.mark.parametrize("method", ["mcmc", "vi"]) +def test_concat(table_biom, metadata, method): + tbl = table_biom + md = metadata + + model_iterator = ModelIterator( + table=tbl, + model=NegativeBinomialSingle, + formula="host_common_name", + metadata=md, + ) + + infs = [] + for fname, model in model_iterator: + model.compile_model() + model.fit_model(method, num_draws=100) + infs.append(model.to_inference()) + + inf_concat = mu.concatenate_inferences( + infs, + coords={"feature": tbl.ids("observation")}, + ) + print(inf_concat.posterior) + exp_feat_ids = tbl.ids("observation") + feat_ids = inf_concat.posterior.coords["feature"].to_numpy() + assert (exp_feat_ids == feat_ids).all() From 395e33d6c0097e27850e874825375fcb5d5f7928 Mon Sep 17 00:00:00 2001 From: Gibraan Rahman Date: Wed, 4 Oct 2023 10:38:00 -0700 Subject: [PATCH 2/3] Remove vestigial print --- tests/test_inference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index e7bbe79..adb3ca6 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -105,7 +105,6 @@ def test_concat(table_biom, metadata, method): infs, coords={"feature": tbl.ids("observation")}, ) - print(inf_concat.posterior) exp_feat_ids = tbl.ids("observation") feat_ids = inf_concat.posterior.coords["feature"].to_numpy() assert (exp_feat_ids == feat_ids).all() From 09f8580b05c4b016d5270d0f4054b53d042abe0c Mon Sep 17 00:00:00 2001 From: Gibraan Rahman Date: Wed, 4 Oct 2023 11:00:33 -0700 Subject: [PATCH 3/3] Fix inference dims in VI --- birdman/inference.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/birdman/inference.py b/birdman/inference.py index ab3657c..b33b7d2 100644 --- a/birdman/inference.py +++ b/birdman/inference.py @@ -1,3 +1,4 @@ +from functools import partial from typing import List, Sequence, Union import arviz as az @@ -21,10 +22,17 @@ def fit_to_inference( if posterior_predictive is not None and posterior_predictive not in dims: raise KeyError("Must include dimensions for posterior predictive!") + # Required because as of writing, CmdStanVB.stan_variable defaults to + # returning the mean rather than the sample + if isinstance(fit, CmdStanVB): + stan_var_fn = partial(fit.stan_variable, mean=False) + else: + stan_var_fn = fit.stan_variable + das = dict() for param in params: - data = fit.stan_variable(param) + data = stan_var_fn(param) _dims = dims[param] _coords = {k: coords[k] for k in _dims} @@ -32,7 +40,7 @@ def fit_to_inference( das[param] = stan_var_to_da(data, _coords, _dims, chains, draws) if log_likelihood: - data = fit.stan_variable(log_likelihood) + data = stan_var_fn(log_likelihood) _dims = dims[log_likelihood] _coords = {k: coords[k] for k in _dims} @@ -43,7 +51,7 @@ def fit_to_inference( ll_ds = None if posterior_predictive: - data = fit.stan_variable(posterior_predictive) + data = stan_var_fn(posterior_predictive) _dims = dims[posterior_predictive] _coords = {k: coords[k] for k in _dims} @@ -112,6 +120,7 @@ def concatenate_inferences( return az.concat(*all_group_inferences) +# TODO: Fix docstring def stan_var_to_da( data: np.ndarray, coords: dict, @@ -119,9 +128,7 @@ def stan_var_to_da( chains: int, draws: int ): - """Convert Stan variable draws to xr.DataArray. - - """ + """Convert Stan variable draws to xr.DataArray.""" data = np.stack(np.split(data, chains)) coords["draw"] = np.arange(draws)