From c5ef11fb19fe723642ffb20353565cdd7ed0cc24 Mon Sep 17 00:00:00 2001 From: Gibraan Rahman Date: Wed, 31 Mar 2021 15:59:50 -0700 Subject: [PATCH] update multinomial model --- birdman/default_models.py | 31 +++++++++++++++++++++++++++++- birdman/templates/multinomial.stan | 15 +++++++++++---- tests/test_model.py | 15 +++++++++++++-- 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/birdman/default_models.py b/birdman/default_models.py index c52abe6..25ad98e 100644 --- a/birdman/default_models.py +++ b/birdman/default_models.py @@ -305,7 +305,36 @@ def __init__( seed, parallelize_across="chains") param_dict = { - "depth": table.sum(axis="sample").astype(int), # sampling depths "B_p": beta_prior, + "depth": table.sum(axis="sample").astype(int) } self.add_parameters(param_dict) + + def to_inference_object(self) -> az.InferenceData: + """Convert fitted Stan model into ``arviz`` InferenceData object. + + :returns: ``arviz`` InferenceData object with selected values + :rtype: az.InferenceData + """ + dims = { + "beta": ["covariate", "feature"], + "log_lhood": ["tbl_sample"], + "y_predict": ["tbl_sample", "feature"] + } + coords = { + "covariate": self.colnames, + "feature": self.feature_names, + "tbl_sample": self.sample_names + } + + # TODO: May want to allow not passing PP/LL/OD in the future + inf = super().to_inference_object( + params=["beta", "phi"], + dims=dims, + coords=coords, + alr_params=["beta"], + posterior_predictive="y_predict", + log_likelihood="log_lhood", + include_observed_data=True, + ) + return inf diff --git a/birdman/templates/multinomial.stan b/birdman/templates/multinomial.stan index c4e494d..69e9cbc 100644 --- a/birdman/templates/multinomial.stan +++ b/birdman/templates/multinomial.stan @@ -14,13 +14,10 @@ parameters { } transformed parameters { - matrix[N, D-1] lam; matrix[N, D] lam_clr; - vector[N] z; simplex[D] theta[N]; - lam = x * beta; - lam_clr = append_col(to_vector(rep_array(0, N)), lam); + lam_clr = append_col(to_vector(rep_array(0, N)), x*beta); for (n in 1:N){ theta[n] = softmax(to_vector(lam_clr[n,])); } @@ -38,3 +35,13 @@ model { target += multinomial_lpmf(y[n,] | to_vector(theta[n,])); } } + +generated quantities { + int y_predict[N, D]; + vector[N] log_lhood; + + for (n in 1:N){ + y_predict[n,] = multinomial_rng(theta[n], depth[n]); + log_lhood[n] = multinomial_lpmf(y[n,] | to_vector(theta[n,])); + } +} diff --git a/tests/test_model.py b/tests/test_model.py index 16ead0e..93a39d5 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,7 +3,7 @@ import numpy as np -from birdman import NegativeBinomial, NegativeBinomialLME +from birdman import Multinomial, NegativeBinomial, NegativeBinomialLME TEMPLATES = resource_filename("birdman", "templates") @@ -56,11 +56,22 @@ def test_nb_lme(self, table_biom, metadata): inf = nb_lme.to_inference_object() post = inf.posterior - print(post) assert post["subj_int"].dims == ("chain", "draw", "group", "feature") assert post["subj_int"].shape == (4, 100, 3, 28) assert (post.coords["group"].values == ["G0", "G1", "G2"]).all() + def test_mult(self, table_biom, metadata): + md = metadata.copy() + np.random.seed(42) + mult = Multinomial( + table=table_biom, + formula="host_common_name", + metadata=md, + num_iter=100, + ) + mult.compile_model() + mult.fit_model() + class TestToInference: def test_serial_to_inference(self, example_model):