Skip to content

Commit

Permalink
Fix #90
Browse files Browse the repository at this point in the history
  • Loading branch information
gibsramen committed Oct 4, 2023
1 parent 35cf431 commit 40e45aa
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
8 changes: 3 additions & 5 deletions birdman/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,19 @@ def concatenate_inferences(
"""
group_list = []
group_list.append([x.posterior for x in inf_list])
group_list.append([x.sample_stats for x in inf_list])
if "log_likelihood" in inf_list[0].groups():
group_list.append([x.log_likelihood for x in inf_list])
if "posterior_predictive" in inf_list[0].groups():
group_list.append([x.posterior_predictive for x in inf_list])

po_ds = xr.concat(group_list[0], concatenation_name)
ss_ds = xr.concat(group_list[1], concatenation_name)
group_dict = {"posterior": po_ds, "sample_stats": ss_ds}
group_dict = {"posterior": po_ds}

if "log_likelihood" in inf_list[0].groups():
ll_ds = xr.concat(group_list[2], concatenation_name)
ll_ds = xr.concat(group_list[1], concatenation_name)
group_dict["log_likelihood"] = ll_ds
if "posterior_predictive" in inf_list[0].groups():
pp_ds = xr.concat(group_list[3], concatenation_name)
pp_ds = xr.concat(group_list[2], concatenation_name)
group_dict["posterior_predictive"] = pp_ds

all_group_inferences = []
Expand Down
31 changes: 31 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import pytest

from birdman import inference as mu
from birdman.default_models import NegativeBinomialSingle
from birdman import ModelIterator


class TestToInference:
Expand Down Expand Up @@ -78,3 +81,31 @@ def test_serial_ppll(self, example_model):
nb_data = example_model.fit.stan_variable(v)
nb_data = np.array(np.split(nb_data, 4, axis=0))
np.testing.assert_array_almost_equal(nb_data, inf_data)


@pytest.mark.parametrize("method", ["mcmc", "vi"])
def test_concat(table_biom, metadata, method):
tbl = table_biom
md = metadata

model_iterator = ModelIterator(
table=tbl,
model=NegativeBinomialSingle,
formula="host_common_name",
metadata=md,
)

infs = []
for fname, model in model_iterator:
model.compile_model()
model.fit_model(method, num_draws=100)
infs.append(model.to_inference())

inf_concat = mu.concatenate_inferences(
infs,
coords={"feature": tbl.ids("observation")},
)
print(inf_concat.posterior)
exp_feat_ids = tbl.ids("observation")
feat_ids = inf_concat.posterior.coords["feature"].to_numpy()
assert (exp_feat_ids == feat_ids).all()

0 comments on commit 40e45aa

Please sign in to comment.