Skip to content

Commit

Permalink
update multinomial model
Browse files Browse the repository at this point in the history
  • Loading branch information
gibsramen committed Mar 31, 2021
1 parent fcf2e61 commit c5ef11f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 7 deletions.
31 changes: 30 additions & 1 deletion birdman/default_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,36 @@ def __init__(
seed, parallelize_across="chains")

param_dict = {
"depth": table.sum(axis="sample").astype(int), # sampling depths
"B_p": beta_prior,
"depth": table.sum(axis="sample").astype(int)
}
self.add_parameters(param_dict)

def to_inference_object(self) -> az.InferenceData:
"""Convert fitted Stan model into ``arviz`` InferenceData object.
:returns: ``arviz`` InferenceData object with selected values
:rtype: az.InferenceData
"""
dims = {
"beta": ["covariate", "feature"],
"log_lhood": ["tbl_sample"],
"y_predict": ["tbl_sample", "feature"]
}
coords = {
"covariate": self.colnames,
"feature": self.feature_names,
"tbl_sample": self.sample_names
}

# TODO: May want to allow not passing PP/LL/OD in the future
inf = super().to_inference_object(
params=["beta", "phi"],
dims=dims,
coords=coords,
alr_params=["beta"],
posterior_predictive="y_predict",
log_likelihood="log_lhood",
include_observed_data=True,
)
return inf
15 changes: 11 additions & 4 deletions birdman/templates/multinomial.stan
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@ parameters {
}

transformed parameters {
matrix[N, D-1] lam;
matrix[N, D] lam_clr;
vector[N] z;
simplex[D] theta[N];

lam = x * beta;
lam_clr = append_col(to_vector(rep_array(0, N)), lam);
lam_clr = append_col(to_vector(rep_array(0, N)), x*beta);
for (n in 1:N){
theta[n] = softmax(to_vector(lam_clr[n,]));
}
Expand All @@ -38,3 +35,13 @@ model {
target += multinomial_lpmf(y[n,] | to_vector(theta[n,]));
}
}

generated quantities {
int y_predict[N, D];
vector[N] log_lhood;

for (n in 1:N){
y_predict[n,] = multinomial_rng(theta[n], depth[n]);
log_lhood[n] = multinomial_lpmf(y[n,] | to_vector(theta[n,]));
}
}
15 changes: 13 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from birdman import NegativeBinomial, NegativeBinomialLME
from birdman import Multinomial, NegativeBinomial, NegativeBinomialLME

TEMPLATES = resource_filename("birdman", "templates")

Expand Down Expand Up @@ -56,11 +56,22 @@ def test_nb_lme(self, table_biom, metadata):

inf = nb_lme.to_inference_object()
post = inf.posterior
print(post)
assert post["subj_int"].dims == ("chain", "draw", "group", "feature")
assert post["subj_int"].shape == (4, 100, 3, 28)
assert (post.coords["group"].values == ["G0", "G1", "G2"]).all()

def test_mult(self, table_biom, metadata):
md = metadata.copy()
np.random.seed(42)
mult = Multinomial(
table=table_biom,
formula="host_common_name",
metadata=md,
num_iter=100,
)
mult.compile_model()
mult.fit_model()


class TestToInference:
def test_serial_to_inference(self, example_model):
Expand Down

0 comments on commit c5ef11f

Please sign in to comment.