Skip to content

Commit

Permalink
Merge pull request #86 from gibsramen/excise-builder
Browse files Browse the repository at this point in the history
Add InferenceData support to VI
  • Loading branch information
gibsramen authored Sep 21, 2023
2 parents 306c77c + d6e4fc8 commit 35cf431
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 109 deletions.
11 changes: 9 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:

strategy:
matrix:
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v2
Expand All @@ -39,7 +39,14 @@ jobs:

- name: Install conda packages
shell: bash -l {0}
run: mamba install -c conda-forge biom-format patsy pytest xarray scikit-bio flake8 arviz cmdstanpy
run: mamba install -c conda-forge biom-format patsy pytest xarray scikit-bio flake8 arviz

# Temp req before CmdStanPy cuts a new release
- name: Install develop branch of CmdStanPy
shell: bash -l {0}
run: >
pip install git+https://github.com/stan-dev/cmdstanpy.git@develop;
install_cmdstan
- name: Install BIRDMAn
shell: bash -l {0}
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ stylecheck:
flake8 birdman/*.py tests/*.py setup.py

pytest:
pytest
pytest tests -W ignore::FutureWarning

documentation:
cd docs && make html
3 changes: 2 additions & 1 deletion birdman/default_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def __init__(
dims={
"beta_var": ["covariate"],
"log_lhood": ["tbl_sample"],
"y_predict": ["tbl_sample"]
"y_predict": ["tbl_sample"],
"inv_disp": []
},
coords={
"covariate": self.colnames,
Expand Down
1 change: 1 addition & 0 deletions birdman/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def r2_score(inference_object: az.InferenceData) -> pd.Series:
:returns: Bayesian :math:`R^2` & standard deviation
:rtype: pd.Series
"""

if "observed_data" not in inference_object.groups():
raise ValueError("Inference data is missing observed data!")

Expand Down
149 changes: 60 additions & 89 deletions birdman/inference.py
Original file line number Diff line number Diff line change
@@ -1,118 +1,65 @@
from typing import List, Sequence
from typing import List, Sequence, Union

import arviz as az
from cmdstanpy import CmdStanMCMC
from cmdstanpy import CmdStanMCMC, CmdStanVB
import numpy as np
import xarray as xr

from .util import _drop_data


def full_fit_to_inference(
fit: CmdStanMCMC,
def fit_to_inference(
fit: Union[CmdStanMCMC, CmdStanVB],
chains: int,
draws: int,
params: Sequence[str],
coords: dict,
dims: dict,
alr_params: Sequence[str] = None,
posterior_predictive: str = None,
log_likelihood: str = None,
) -> az.InferenceData:
"""Convert fitted Stan model into inference object.
:param fit: Fitted model
:type params: CmdStanMCMC
:param params: Posterior fitted parameters to include
:type params: Sequence[str]
:param coords: Mapping of entries in dims to labels
:type coords: dict
:param dims: Dimensions of parameters in the model
:type dims: dict
:param alr_params: Parameters to convert from ALR to CLR
:type alr_params: Sequence[str], optional
:param posterior_predictive: Name of posterior predictive values from
Stan model to include in ``arviz`` InferenceData object
:type posterior_predictive: str, optional
:param log_likelihood: Name of log likelihood values from Stan model
to include in ``arviz`` InferenceData object
:type log_likelihood: str, optional
:returns: ``arviz`` InferenceData object with selected values
:rtype: az.InferenceData
"""
):
if log_likelihood is not None and log_likelihood not in dims:
raise KeyError("Must include dimensions for log-likelihood!")
if posterior_predictive is not None and posterior_predictive not in dims:
raise KeyError("Must include dimensions for posterior predictive!")

inference = az.from_cmdstanpy(
fit,
coords=coords,
log_likelihood=log_likelihood,
posterior_predictive=posterior_predictive,
dims=dims
)

vars_to_drop = set(inference.posterior.data_vars).difference(params)
inference.posterior = _drop_data(inference.posterior, vars_to_drop)
das = dict()

return inference
for param in params:
data = fit.stan_variable(param)

_dims = dims[param]
_coords = {k: coords[k] for k in _dims}

def single_feature_fit_to_inference(
fit: CmdStanMCMC,
params: Sequence[str],
coords: dict,
dims: dict,
posterior_predictive: str = None,
log_likelihood: str = None,
) -> az.InferenceData:
"""Convert single feature fit to InferenceData.
das[param] = stan_var_to_da(data, _coords, _dims, chains, draws)

:param fit: Single feature fit with CmdStanPy
:type fit: cmdstanpy.CmdStanMCMC
if log_likelihood:
data = fit.stan_variable(log_likelihood)

:param params: Posterior fitted parameters to include
:type params: Sequence[str]
_dims = dims[log_likelihood]
_coords = {k: coords[k] for k in _dims}

:param coords: Coordinates to use for annotating Inference dims
:type coords: dict
ll_da = stan_var_to_da(data, _coords, _dims, chains, draws)
ll_ds = xr.Dataset({log_likelihood: ll_da})
else:
ll_ds = None

:param dims: Dimensions of parameters in fitted model
:type dims: dict
if posterior_predictive:
data = fit.stan_variable(posterior_predictive)

:param posterior_predictive: Name of variable holding PP values
:type posterior_predictive: str
_dims = dims[posterior_predictive]
_coords = {k: coords[k] for k in _dims}

:param log_likelihood: Name of variable holding LL values
:type log_likelihood: str
pp_da = stan_var_to_da(data, _coords, _dims, chains, draws)
pp_ds = xr.Dataset({posterior_predictive: pp_da})
else:
pp_ds = None

:returns: InferenceData object of single feature
:rtype: az.InferenceData
"""
_coords = coords.copy()
if "feature" in coords:
_coords.pop("feature")

_dims = dims.copy()
for k, v in _dims.items():
if "feature" in v:
v.remove("feature")

feat_inf = az.from_cmdstanpy(
posterior=fit,
posterior_predictive=posterior_predictive,
log_likelihood=log_likelihood,
coords=_coords,
dims=_dims
inf = az.InferenceData(
posterior=xr.Dataset(das),
log_likelihood=ll_ds,
posterior_predictive=pp_ds
)
vars_to_drop = set(feat_inf.posterior.data_vars).difference(params)
feat_inf.posterior = _drop_data(feat_inf.posterior, vars_to_drop)
return feat_inf

return inf


def concatenate_inferences(
Expand Down Expand Up @@ -165,3 +112,27 @@ def concatenate_inferences(
all_group_inferences.append(group_inf)

return az.concat(*all_group_inferences)


def stan_var_to_da(
data: np.ndarray,
coords: dict,
dims: dict,
chains: int,
draws: int
):
"""Convert Stan variable draws to xr.DataArray.
"""
data = np.stack(np.split(data, chains))

coords["draw"] = np.arange(draws)
coords["chain"] = np.arange(chains)
dims = ["chain", "draw"] + dims

da = xr.DataArray(
data,
coords=coords,
dims=dims,
)
return da
19 changes: 15 additions & 4 deletions birdman/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
from patsy import dmatrix

from .inference import full_fit_to_inference, single_feature_fit_to_inference
from .inference import fit_to_inference


class BaseModel(ABC):
Expand Down Expand Up @@ -114,7 +114,7 @@ def add_parameters(self, param_dict: dict = None):

def fit_model(
self,
method: str = "mcmc",
method: str = "vi",
num_draws: int = 500,
mcmc_warmup: int = None,
mcmc_chains: int = 4,
Expand Down Expand Up @@ -167,6 +167,9 @@ def fit_model(
mcmc_kwargs = mcmc_kwargs or dict()
mcmc_warmup = mcmc_warmup or mcmc_warmup

self.num_chains = mcmc_chains
self.num_draws = num_draws

self.fit = self.sm.sample(
chains=mcmc_chains,
parallel_chains=mcmc_chains,
Expand All @@ -179,6 +182,9 @@ def fit_model(
elif method == "vi":
vi_kwargs = vi_kwargs or dict()

self.num_chains = 1
self.num_draws = num_draws

self.fit = self.sm.variational(
data=self.dat,
iter=vi_iter,
Expand Down Expand Up @@ -224,8 +230,10 @@ def to_inference(self) -> az.InferenceData:
"""
self._check_fit_for_inf()

inference = full_fit_to_inference(
inference = fit_to_inference(
fit=self.fit,
chains=self.num_chains,
draws=self.num_draws,
params=self.params,
coords=self.coords,
dims=self.dims,
Expand All @@ -246,6 +254,7 @@ def to_inference(self) -> az.InferenceData:
}
)
inference = az.concat(inference, obs)

return inference


Expand All @@ -269,8 +278,10 @@ def to_inference(self) -> az.InferenceData:
"""
self._check_fit_for_inf()

inference = single_feature_fit_to_inference(
inference = fit_to_inference(
fit=self.fit,
chains=self.num_chains,
draws=self.num_draws,
params=self.params,
coords=self.coords,
dims=self.dims,
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def model():
metadata=md,
)
nb.compile_model()
nb.fit_model(mcmc_chains=4, num_draws=100)
nb.fit_model(method="mcmc", mcmc_chains=4, num_draws=100)
return nb


Expand Down Expand Up @@ -76,7 +76,7 @@ def single_feat_model():
)

nb.compile_model()
nb.fit_model(mcmc_chains=4, num_draws=100)
nb.fit_model(method="mcmc", mcmc_chains=4, num_draws=100)

return nb

Expand Down
5 changes: 3 additions & 2 deletions tests/test_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ def test_custom_model(table_biom, metadata):
},
)
custom_model.compile_model()
custom_model.fit_model(num_draws=100, mcmc_chains=4, seed=42)
custom_model.fit_model(method="mcmc", num_draws=100, mcmc_chains=4,
seed=42)
inference = custom_model.to_inference()

assert set(inference.groups()) == {"posterior", "sample_stats"}
assert set(inference.groups()) == {"posterior"}
ds = inference.posterior

assert ds.coords._names == {"chain", "covariate", "draw", "feature_alr"}
Expand Down
10 changes: 6 additions & 4 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def dataset_comparison(self, model, ds):
assert (ds.coords["chain"] == [0, 1, 2, 3]).all()

def test_serial_to_inference(self, example_model):
inf = mu.full_fit_to_inference(
inf = mu.fit_to_inference(
fit=example_model.fit,
chains=4,
draws=100,
coords={
"feature": example_model.feature_names,
"feature_alr": example_model.feature_names[1:],
Expand All @@ -41,16 +43,17 @@ def test_serial_to_inference(self, example_model):
"y_predict": ["tbl_sample", "feature"]
},
params=["beta_var", "inv_disp"],
alr_params=["beta_var"]
)
self.dataset_comparison(example_model, inf.posterior)


# Posterior predictive & log likelihood
class TestPPLL:
def test_serial_ppll(self, example_model):
inf = mu.full_fit_to_inference(
inf = mu.fit_to_inference(
fit=example_model.fit,
chains=4,
draws=100,
coords={
"feature": example_model.feature_names,
"feature_alr": example_model.feature_names[1:],
Expand All @@ -64,7 +67,6 @@ def test_serial_ppll(self, example_model):
"y_predict": ["sample", "feature"]
},
params=["beta_var", "inv_disp"],
alr_params=["beta_var"],
posterior_predictive="y_predict",
log_likelihood="log_lhood",
)
Expand Down
Loading

0 comments on commit 35cf431

Please sign in to comment.