Skip to content

Commit

Permalink
Merge pull request #44 from gibsramen/restructure-inf
Browse files Browse the repository at this point in the history
Refactor conversion to InferenceData
  • Loading branch information
gibsramen authored May 28, 2021
2 parents bdd1753 + 53bd058 commit 2b26ead
Show file tree
Hide file tree
Showing 12 changed files with 419 additions and 307 deletions.
150 changes: 51 additions & 99 deletions birdman/default_models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import os
from pkg_resources import resource_filename

import arviz as az
import biom
import dask_jobqueue
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -108,51 +106,26 @@ def __init__(
}
self.add_parameters(param_dict)

def to_inference_object(
self,
dask_cluster: dask_jobqueue.JobQueueCluster = None,
jobs: int = 4
) -> az.InferenceData:
"""Convert fitted Stan model into ``arviz`` InferenceData object.
:param dask_cluster: Dask jobqueue to run parallel jobs (optional)
:type dask_cluster: dask_jobqueue.JobQueueCluster, optional
:param jobs: Number of jobs to run in parallel, defaults to 4
:type jobs: int
:returns: ``arviz`` InferenceData object with selected values
:rtype: az.InferenceData
"""
dims = {
"beta": ["covariate", "feature"],
"phi": ["feature"],
"log_lhood": ["tbl_sample", "feature"],
"y_predict": ["tbl_sample", "feature"]
}
coords = {
"covariate": self.colnames,
"feature": self.feature_names,
"tbl_sample": self.sample_names
}

# TODO: May want to allow not passing PP/LL/OD in the future
args = dict()
if self.parallelize_across == "chains":
args["alr_params"] = ["beta"]
else:
args["dask_cluster"] = dask_cluster
args["jobs"] = jobs
inf = super().to_inference_object(
self.specify_model(
params=["beta", "phi"],
dims=dims,
coords=coords,
posterior_predictive="y_predict",
log_likelihood="log_lhood",
dims={
"beta": ["covariate", "feature"],
"phi": ["feature"],
"log_lhood": ["tbl_sample", "feature"],
"y_predict": ["tbl_sample", "feature"]
},
coords={
"covariate": self.colnames,
"feature": self.feature_names,
"tbl_sample": self.sample_names
},
include_observed_data=True,
**args
posterior_predictive="y_predict",
log_likelihood="log_lhood"
)
return inf

if self.parallelize_across == "chains":
self.specifications["alr_params"] = ["beta"]


class NegativeBinomialLME(RegressionModel):
Expand Down Expand Up @@ -259,37 +232,28 @@ def __init__(
}
self.add_parameters(param_dict)

def to_inference_object(self) -> az.InferenceData:
"""Convert fitted Stan model into ``arviz`` InferenceData object.
:returns: ``arviz`` InferenceData object with selected values
:rtype: az.InferenceData
"""
dims = {
"beta": ["covariate", "feature"],
"phi": ["feature"],
"subj_int": ["group", "feature"],
"log_lhood": ["tbl_sample", "feature"],
"y_predict": ["tbl_sample", "feature"]
}
coords = {
"covariate": self.colnames,
"feature": self.feature_names,
"tbl_sample": self.sample_names,
"group": self.groups
}

# TODO: May want to allow not passing PP/LL/OD in the future
inf = super().to_inference_object(
params=["beta", "phi"],
dims=dims,
coords=coords,
posterior_predictive="y_predict",
log_likelihood="log_lhood",
alr_params=["beta", "subj_int"],
self.specify_model(
params=["beta", "phi", "subj_int"],
dims={
"beta": ["covariate", "feature"],
"phi": ["feature"],
"subj_int": ["group", "feature"],
"log_lhood": ["tbl_sample", "feature"],
"y_predict": ["tbl_sample", "feature"]
},
coords={
"covariate": self.colnames,
"feature": self.feature_names,
"tbl_sample": self.sample_names,
"group": self.groups
},
include_observed_data=True,
posterior_predictive="y_predict",
log_likelihood="log_lhood"
)
return inf

if self.parallelize_across == "chains":
self.specifications["alr_params"] = ["beta", "subj_int"]


class Multinomial(RegressionModel):
Expand Down Expand Up @@ -363,31 +327,19 @@ def __init__(
}
self.add_parameters(param_dict)

def to_inference_object(self) -> az.InferenceData:
"""Convert fitted Stan model into ``arviz`` InferenceData object.
:returns: ``arviz`` InferenceData object with selected values
:rtype: az.InferenceData
"""
dims = {
"beta": ["covariate", "feature"],
"log_lhood": ["tbl_sample"],
"y_predict": ["tbl_sample", "feature"]
}
coords = {
"covariate": self.colnames,
"feature": self.feature_names,
"tbl_sample": self.sample_names
}

# TODO: May want to allow not passing PP/LL/OD in the future
inf = super().to_inference_object(
params=["beta", "phi"],
dims=dims,
coords=coords,
alr_params=["beta"],
posterior_predictive="y_predict",
log_likelihood="log_lhood",
self.specify_model(
params=["beta"],
dims={
"beta": ["covariate", "feature"],
"log_lhood": ["tbl_sample"],
"y_predict": ["tbl_sample", "feature"]
},
coords={
"covariate": self.colnames,
"feature": self.feature_names,
"tbl_sample": self.sample_names,
},
include_observed_data=True,
posterior_predictive="y_predict",
log_likelihood="log_lhood"
)
return inf
Loading

0 comments on commit 2b26ead

Please sign in to comment.