From 606d5fda947023d80a9b0c00ec2f3d84ff6ae50e Mon Sep 17 00:00:00 2001 From: Gibraan Rahman Date: Tue, 12 Sep 2023 14:27:01 -0700 Subject: [PATCH 1/6] Update to_inference --- birdman/diagnostics.py | 1 + birdman/inference.py | 187 ++++++++++++++++++++++++++----------- birdman/model_base.py | 12 ++- tests/test_custom_model.py | 2 +- tests/test_inference.py | 6 +- tests/test_model.py | 2 +- 6 files changed, 150 insertions(+), 60 deletions(-) diff --git a/birdman/diagnostics.py b/birdman/diagnostics.py index 48fd572..ea4987f 100644 --- a/birdman/diagnostics.py +++ b/birdman/diagnostics.py @@ -102,6 +102,7 @@ def r2_score(inference_object: az.InferenceData) -> pd.Series: :returns: Bayesian :math:`R^2` & standard deviation :rtype: pd.Series """ + if "observed_data" not in inference_object.groups(): raise ValueError("Inference data is missing observed data!") diff --git a/birdman/inference.py b/birdman/inference.py index 995c14d..bb92783 100644 --- a/birdman/inference.py +++ b/birdman/inference.py @@ -1,66 +1,63 @@ from typing import List, Sequence import arviz as az -from cmdstanpy import CmdStanMCMC +from cmdstanpy import CmdStanMCMC, CmdStanVB +import numpy as np import xarray as xr from .util import _drop_data -def full_fit_to_inference( - fit: CmdStanMCMC, - params: Sequence[str], - coords: dict, - dims: dict, - alr_params: Sequence[str] = None, - posterior_predictive: str = None, - log_likelihood: str = None, -) -> az.InferenceData: - """Convert fitted Stan model into inference object. - - :param fit: Fitted model - :type params: CmdStanMCMC - - :param params: Posterior fitted parameters to include - :type params: Sequence[str] - - :param coords: Mapping of entries in dims to labels - :type coords: dict - - :param dims: Dimensions of parameters in the model - :type dims: dict - - :param alr_params: Parameters to convert from ALR to CLR - :type alr_params: Sequence[str], optional - - :param posterior_predictive: Name of posterior predictive values from - Stan model to include in ``arviz`` InferenceData object - :type posterior_predictive: str, optional - - :param log_likelihood: Name of log likelihood values from Stan model - to include in ``arviz`` InferenceData object - :type log_likelihood: str, optional - - :returns: ``arviz`` InferenceData object with selected values - :rtype: az.InferenceData - """ - if log_likelihood is not None and log_likelihood not in dims: - raise KeyError("Must include dimensions for log-likelihood!") - if posterior_predictive is not None and posterior_predictive not in dims: - raise KeyError("Must include dimensions for posterior predictive!") - - inference = az.from_cmdstanpy( - fit, - coords=coords, - log_likelihood=log_likelihood, - posterior_predictive=posterior_predictive, - dims=dims - ) - - vars_to_drop = set(inference.posterior.data_vars).difference(params) - inference.posterior = _drop_data(inference.posterior, vars_to_drop) - - return inference +# def full_fit_to_inference( +# fit: CmdStanMCMC, +# params: Sequence[str], +# coords: dict, +# dims: dict, +# posterior_predictive: str = None, +# log_likelihood: str = None, +# ) -> az.InferenceData: +# """Convert fitted Stan model into inference object. +# +# :param fit: Fitted model +# :type params: CmdStanMCMC +# +# :param params: Posterior fitted parameters to include +# :type params: Sequence[str] +# +# :param coords: Mapping of entries in dims to labels +# :type coords: dict +# +# :param dims: Dimensions of parameters in the model +# :type dims: dict +# +# :param posterior_predictive: Name of posterior predictive values from +# Stan model to include in ``arviz`` InferenceData object +# :type posterior_predictive: str, optional +# +# :param log_likelihood: Name of log likelihood values from Stan model +# to include in ``arviz`` InferenceData object +# :type log_likelihood: str, optional +# +# :returns: ``arviz`` InferenceData object with selected values +# :rtype: az.InferenceData +# """ +# if log_likelihood is not None and log_likelihood not in dims: +# raise KeyError("Must include dimensions for log-likelihood!") +# if posterior_predictive is not None and posterior_predictive not in dims: +# raise KeyError("Must include dimensions for posterior predictive!") +# +# inference = az.from_cmdstanpy( +# fit, +# coords=coords, +# log_likelihood=log_likelihood, +# posterior_predictive=posterior_predictive, +# dims=dims +# ) +# +# vars_to_drop = set(inference.posterior.data_vars).difference(params) +# inference.posterior = _drop_data(inference.posterior, vars_to_drop) +# +# return inference def single_feature_fit_to_inference( @@ -165,3 +162,83 @@ def concatenate_inferences( all_group_inferences.append(group_inf) return az.concat(*all_group_inferences) + + +def stan_var_to_da( + data: np.ndarray, + coords: dict, + dims: dict, + chains: int, + draws: int +): + """Convert Stan variable draws to xr.DataArray. + + """ + data = np.stack(np.split(data, chains)) + + coords["draw"] = np.arange(draws) + coords["chain"] = np.arange(chains) + dims = ["chain", "draw"] + dims + + da = xr.DataArray( + data, + coords=coords, + dims=dims, + ) + return da + + +def full_fit_to_inference( + fit: CmdStanVB, + chains: int, + draws: int, + params: Sequence[str], + coords: dict, + dims: dict, + posterior_predictive: str = None, + log_likelihood: str = None, +): + if log_likelihood is not None and log_likelihood not in dims: + raise KeyError("Must include dimensions for log-likelihood!") + if posterior_predictive is not None and posterior_predictive not in dims: + raise KeyError("Must include dimensions for posterior predictive!") + + das = dict() + + for param in params: + data = fit.stan_variable(param) + + _dims = dims[param] + _coords = {k: coords[k] for k in _dims} + + das[param] = stan_var_to_da(data, _coords, _dims, chains, draws) + + if log_likelihood: + data = fit.stan_variable(log_likelihood) + + _dims = dims[log_likelihood] + _coords = {k: coords[k] for k in _dims} + + ll_da = stan_var_to_da(data, _coords, _dims, chains, draws) + ll_ds = xr.Dataset({log_likelihood: ll_da}) + else: + ll_ds = None + + if posterior_predictive: + data = fit.stan_variable(posterior_predictive) + + _dims = dims[posterior_predictive] + _coords = {k: coords[k] for k in _dims} + + pp_da = stan_var_to_da(data, _coords, _dims, chains, draws) + pp_ds = xr.Dataset({posterior_predictive: pp_da}) + else: + pp_ds = None + + inf = az.InferenceData( + posterior=xr.Dataset(das), + log_likelihood=ll_ds, + posterior_predictive=pp_ds + ) + + return inf diff --git a/birdman/model_base.py b/birdman/model_base.py index 1911f8f..9c90b52 100644 --- a/birdman/model_base.py +++ b/birdman/model_base.py @@ -1,10 +1,11 @@ from abc import ABC, abstractmethod +from functools import partial from math import ceil from typing import Sequence import arviz as az import biom -from cmdstanpy import CmdStanModel +from cmdstanpy import CmdStanModel, CmdStanMCMC, CmdStanVB import pandas as pd from patsy import dmatrix @@ -167,6 +168,9 @@ def fit_model( mcmc_kwargs = mcmc_kwargs or dict() mcmc_warmup = mcmc_warmup or mcmc_warmup + self.num_chains = mcmc_chains + self.num_draws = num_draws + self.fit = self.sm.sample( chains=mcmc_chains, parallel_chains=mcmc_chains, @@ -179,6 +183,9 @@ def fit_model( elif method == "vi": vi_kwargs = vi_kwargs or dict() + self.num_chains = 1 + self.num_draws = num_draws + self.fit = self.sm.variational( data=self.dat, iter=vi_iter, @@ -226,6 +233,8 @@ def to_inference(self) -> az.InferenceData: inference = full_fit_to_inference( fit=self.fit, + chains=self.num_chains, + draws=self.num_draws, params=self.params, coords=self.coords, dims=self.dims, @@ -246,6 +255,7 @@ def to_inference(self) -> az.InferenceData: } ) inference = az.concat(inference, obs) + return inference diff --git a/tests/test_custom_model.py b/tests/test_custom_model.py index 2da859a..5aa3c02 100644 --- a/tests/test_custom_model.py +++ b/tests/test_custom_model.py @@ -43,7 +43,7 @@ def test_custom_model(table_biom, metadata): custom_model.fit_model(num_draws=100, mcmc_chains=4, seed=42) inference = custom_model.to_inference() - assert set(inference.groups()) == {"posterior", "sample_stats"} + assert set(inference.groups()) == {"posterior"} ds = inference.posterior assert ds.coords._names == {"chain", "covariate", "draw", "feature_alr"} diff --git a/tests/test_inference.py b/tests/test_inference.py index 06e89bc..8402802 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -28,6 +28,8 @@ def dataset_comparison(self, model, ds): def test_serial_to_inference(self, example_model): inf = mu.full_fit_to_inference( fit=example_model.fit, + chains=4, + draws=100, coords={ "feature": example_model.feature_names, "feature_alr": example_model.feature_names[1:], @@ -41,7 +43,6 @@ def test_serial_to_inference(self, example_model): "y_predict": ["tbl_sample", "feature"] }, params=["beta_var", "inv_disp"], - alr_params=["beta_var"] ) self.dataset_comparison(example_model, inf.posterior) @@ -51,6 +52,8 @@ class TestPPLL: def test_serial_ppll(self, example_model): inf = mu.full_fit_to_inference( fit=example_model.fit, + chains=4, + draws=100, coords={ "feature": example_model.feature_names, "feature_alr": example_model.feature_names[1:], @@ -64,7 +67,6 @@ def test_serial_ppll(self, example_model): "y_predict": ["sample", "feature"] }, params=["beta_var", "inv_disp"], - alr_params=["beta_var"], posterior_predictive="y_predict", log_likelihood="log_lhood", ) diff --git a/tests/test_model.py b/tests/test_model.py index d8d6131..9caa32a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -74,7 +74,7 @@ def test_single_feat(self, table_biom, metadata): class TestToInference: def test_serial_to_inference(self, example_model): inference_data = example_model.to_inference() - target_groups = {"posterior", "sample_stats", "log_likelihood", + target_groups = {"posterior", "log_likelihood", "posterior_predictive", "observed_data"} assert set(inference_data.groups()) == target_groups From 389e88a170d0d7c5999168cd69478d77b95b8819 Mon Sep 17 00:00:00 2001 From: Gibraan Rahman Date: Wed, 13 Sep 2023 12:48:36 -0700 Subject: [PATCH 2/6] Update single feature fit --- birdman/default_models.py | 3 +- birdman/inference.py | 75 +++++++-------------------------------- birdman/model_base.py | 2 ++ 3 files changed, 17 insertions(+), 63 deletions(-) diff --git a/birdman/default_models.py b/birdman/default_models.py index 6eeef9c..1ec485c 100644 --- a/birdman/default_models.py +++ b/birdman/default_models.py @@ -192,7 +192,8 @@ def __init__( dims={ "beta_var": ["covariate"], "log_lhood": ["tbl_sample"], - "y_predict": ["tbl_sample"] + "y_predict": ["tbl_sample"], + "inv_disp": [] }, coords={ "covariate": self.colnames, diff --git a/birdman/inference.py b/birdman/inference.py index bb92783..5b319bf 100644 --- a/birdman/inference.py +++ b/birdman/inference.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from typing import List, Sequence, Union import arviz as az from cmdstanpy import CmdStanMCMC, CmdStanVB @@ -8,60 +8,10 @@ from .util import _drop_data -# def full_fit_to_inference( -# fit: CmdStanMCMC, -# params: Sequence[str], -# coords: dict, -# dims: dict, -# posterior_predictive: str = None, -# log_likelihood: str = None, -# ) -> az.InferenceData: -# """Convert fitted Stan model into inference object. -# -# :param fit: Fitted model -# :type params: CmdStanMCMC -# -# :param params: Posterior fitted parameters to include -# :type params: Sequence[str] -# -# :param coords: Mapping of entries in dims to labels -# :type coords: dict -# -# :param dims: Dimensions of parameters in the model -# :type dims: dict -# -# :param posterior_predictive: Name of posterior predictive values from -# Stan model to include in ``arviz`` InferenceData object -# :type posterior_predictive: str, optional -# -# :param log_likelihood: Name of log likelihood values from Stan model -# to include in ``arviz`` InferenceData object -# :type log_likelihood: str, optional -# -# :returns: ``arviz`` InferenceData object with selected values -# :rtype: az.InferenceData -# """ -# if log_likelihood is not None and log_likelihood not in dims: -# raise KeyError("Must include dimensions for log-likelihood!") -# if posterior_predictive is not None and posterior_predictive not in dims: -# raise KeyError("Must include dimensions for posterior predictive!") -# -# inference = az.from_cmdstanpy( -# fit, -# coords=coords, -# log_likelihood=log_likelihood, -# posterior_predictive=posterior_predictive, -# dims=dims -# ) -# -# vars_to_drop = set(inference.posterior.data_vars).difference(params) -# inference.posterior = _drop_data(inference.posterior, vars_to_drop) -# -# return inference - - def single_feature_fit_to_inference( - fit: CmdStanMCMC, + fit: Union[CmdStanMCMC, CmdStanVB], + chains: int, + draws: int, params: Sequence[str], coords: dict, dims: dict, @@ -100,15 +50,16 @@ def single_feature_fit_to_inference( if "feature" in v: v.remove("feature") - feat_inf = az.from_cmdstanpy( - posterior=fit, - posterior_predictive=posterior_predictive, + feat_inf = full_fit_to_inference( + fit, + chains, + draws, + params, + _coords, + _dims, log_likelihood=log_likelihood, - coords=_coords, - dims=_dims + posterior_predictive=posterior_predictive ) - vars_to_drop = set(feat_inf.posterior.data_vars).difference(params) - feat_inf.posterior = _drop_data(feat_inf.posterior, vars_to_drop) return feat_inf @@ -189,7 +140,7 @@ def stan_var_to_da( def full_fit_to_inference( - fit: CmdStanVB, + fit: Union[CmdStanMCMC, CmdStanVB], chains: int, draws: int, params: Sequence[str], diff --git a/birdman/model_base.py b/birdman/model_base.py index 9c90b52..85f3363 100644 --- a/birdman/model_base.py +++ b/birdman/model_base.py @@ -281,6 +281,8 @@ def to_inference(self) -> az.InferenceData: inference = single_feature_fit_to_inference( fit=self.fit, + chains=self.num_chains, + draws=self.num_draws, params=self.params, coords=self.coords, dims=self.dims, From ed3d08c922d2f4e9048d5607add894d712ed9ba2 Mon Sep 17 00:00:00 2001 From: Gibraan Rahman Date: Wed, 13 Sep 2023 13:01:47 -0700 Subject: [PATCH 3/6] Cleanup fit_to_inf --- Makefile | 2 +- birdman/inference.py | 133 ++++++++++++---------------------------- birdman/model_base.py | 9 ++- tests/test_inference.py | 4 +- 4 files changed, 45 insertions(+), 103 deletions(-) diff --git a/Makefile b/Makefile index 57e99d7..7aca32a 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ stylecheck: flake8 birdman/*.py tests/*.py setup.py pytest: - pytest + pytest tests -W ignore::FutureWarning documentation: cd docs && make html diff --git a/birdman/inference.py b/birdman/inference.py index 5b319bf..15fe528 100644 --- a/birdman/inference.py +++ b/birdman/inference.py @@ -5,10 +5,8 @@ import numpy as np import xarray as xr -from .util import _drop_data - -def single_feature_fit_to_inference( +def fit_to_inference( fit: Union[CmdStanMCMC, CmdStanVB], chains: int, draws: int, @@ -17,50 +15,51 @@ def single_feature_fit_to_inference( dims: dict, posterior_predictive: str = None, log_likelihood: str = None, -) -> az.InferenceData: - """Convert single feature fit to InferenceData. +): + if log_likelihood is not None and log_likelihood not in dims: + raise KeyError("Must include dimensions for log-likelihood!") + if posterior_predictive is not None and posterior_predictive not in dims: + raise KeyError("Must include dimensions for posterior predictive!") - :param fit: Single feature fit with CmdStanPy - :type fit: cmdstanpy.CmdStanMCMC + das = dict() - :param params: Posterior fitted parameters to include - :type params: Sequence[str] + for param in params: + data = fit.stan_variable(param) - :param coords: Coordinates to use for annotating Inference dims - :type coords: dict + _dims = dims[param] + _coords = {k: coords[k] for k in _dims} - :param dims: Dimensions of parameters in fitted model - :type dims: dict + das[param] = stan_var_to_da(data, _coords, _dims, chains, draws) - :param posterior_predictive: Name of variable holding PP values - :type posterior_predictive: str + if log_likelihood: + data = fit.stan_variable(log_likelihood) - :param log_likelihood: Name of variable holding LL values - :type log_likelihood: str + _dims = dims[log_likelihood] + _coords = {k: coords[k] for k in _dims} - :returns: InferenceData object of single feature - :rtype: az.InferenceData - """ - _coords = coords.copy() - if "feature" in coords: - _coords.pop("feature") - - _dims = dims.copy() - for k, v in _dims.items(): - if "feature" in v: - v.remove("feature") - - feat_inf = full_fit_to_inference( - fit, - chains, - draws, - params, - _coords, - _dims, - log_likelihood=log_likelihood, - posterior_predictive=posterior_predictive + ll_da = stan_var_to_da(data, _coords, _dims, chains, draws) + ll_ds = xr.Dataset({log_likelihood: ll_da}) + else: + ll_ds = None + + if posterior_predictive: + data = fit.stan_variable(posterior_predictive) + + _dims = dims[posterior_predictive] + _coords = {k: coords[k] for k in _dims} + + pp_da = stan_var_to_da(data, _coords, _dims, chains, draws) + pp_ds = xr.Dataset({posterior_predictive: pp_da}) + else: + pp_ds = None + + inf = az.InferenceData( + posterior=xr.Dataset(das), + log_likelihood=ll_ds, + posterior_predictive=pp_ds ) - return feat_inf + + return inf def concatenate_inferences( @@ -137,59 +136,3 @@ def stan_var_to_da( dims=dims, ) return da - - -def full_fit_to_inference( - fit: Union[CmdStanMCMC, CmdStanVB], - chains: int, - draws: int, - params: Sequence[str], - coords: dict, - dims: dict, - posterior_predictive: str = None, - log_likelihood: str = None, -): - if log_likelihood is not None and log_likelihood not in dims: - raise KeyError("Must include dimensions for log-likelihood!") - if posterior_predictive is not None and posterior_predictive not in dims: - raise KeyError("Must include dimensions for posterior predictive!") - - das = dict() - - for param in params: - data = fit.stan_variable(param) - - _dims = dims[param] - _coords = {k: coords[k] for k in _dims} - - das[param] = stan_var_to_da(data, _coords, _dims, chains, draws) - - if log_likelihood: - data = fit.stan_variable(log_likelihood) - - _dims = dims[log_likelihood] - _coords = {k: coords[k] for k in _dims} - - ll_da = stan_var_to_da(data, _coords, _dims, chains, draws) - ll_ds = xr.Dataset({log_likelihood: ll_da}) - else: - ll_ds = None - - if posterior_predictive: - data = fit.stan_variable(posterior_predictive) - - _dims = dims[posterior_predictive] - _coords = {k: coords[k] for k in _dims} - - pp_da = stan_var_to_da(data, _coords, _dims, chains, draws) - pp_ds = xr.Dataset({posterior_predictive: pp_da}) - else: - pp_ds = None - - inf = az.InferenceData( - posterior=xr.Dataset(das), - log_likelihood=ll_ds, - posterior_predictive=pp_ds - ) - - return inf diff --git a/birdman/model_base.py b/birdman/model_base.py index 85f3363..1953be1 100644 --- a/birdman/model_base.py +++ b/birdman/model_base.py @@ -1,15 +1,14 @@ from abc import ABC, abstractmethod -from functools import partial from math import ceil from typing import Sequence import arviz as az import biom -from cmdstanpy import CmdStanModel, CmdStanMCMC, CmdStanVB +from cmdstanpy import CmdStanModel import pandas as pd from patsy import dmatrix -from .inference import full_fit_to_inference, single_feature_fit_to_inference +from .inference import fit_to_inference class BaseModel(ABC): @@ -231,7 +230,7 @@ def to_inference(self) -> az.InferenceData: """ self._check_fit_for_inf() - inference = full_fit_to_inference( + inference = fit_to_inference( fit=self.fit, chains=self.num_chains, draws=self.num_draws, @@ -279,7 +278,7 @@ def to_inference(self) -> az.InferenceData: """ self._check_fit_for_inf() - inference = single_feature_fit_to_inference( + inference = fit_to_inference( fit=self.fit, chains=self.num_chains, draws=self.num_draws, diff --git a/tests/test_inference.py b/tests/test_inference.py index 8402802..cb55ac0 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -26,7 +26,7 @@ def dataset_comparison(self, model, ds): assert (ds.coords["chain"] == [0, 1, 2, 3]).all() def test_serial_to_inference(self, example_model): - inf = mu.full_fit_to_inference( + inf = mu.fit_to_inference( fit=example_model.fit, chains=4, draws=100, @@ -50,7 +50,7 @@ def test_serial_to_inference(self, example_model): # Posterior predictive & log likelihood class TestPPLL: def test_serial_ppll(self, example_model): - inf = mu.full_fit_to_inference( + inf = mu.fit_to_inference( fit=example_model.fit, chains=4, draws=100, From d4fd94ac2b2010173e6f8e5b65d34dfd196182e6 Mon Sep 17 00:00:00 2001 From: Gibraan Rahman Date: Fri, 15 Sep 2023 13:12:55 -0700 Subject: [PATCH 4/6] Fix GH action --- .github/workflows/main.yml | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7aa835b..8d26380 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -21,7 +21,7 @@ jobs: strategy: matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v2 @@ -39,7 +39,14 @@ jobs: - name: Install conda packages shell: bash -l {0} - run: mamba install -c conda-forge biom-format patsy pytest xarray scikit-bio flake8 arviz cmdstanpy + run: mamba install -c conda-forge biom-format patsy pytest xarray scikit-bio flake8 arviz + + # Temp req before CmdStanPy cuts a new release + - name: Install develop branch of CmdStanPy + shell: bash -l {0} + run: > + pip install git+https://github.com/stan-dev/cmdstanpy.git@develop; + install_cmdstan - name: Install BIRDMAn shell: bash -l {0} From 752109e227ffff2662dd62c18033d11f51bedb90 Mon Sep 17 00:00:00 2001 From: Gibraan Rahman Date: Mon, 18 Sep 2023 11:46:18 -0700 Subject: [PATCH 5/6] Make VI default --- birdman/model_base.py | 2 +- tests/conftest.py | 4 ++-- tests/test_custom_model.py | 2 +- tests/test_model.py | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/birdman/model_base.py b/birdman/model_base.py index 1953be1..ceb7d25 100644 --- a/birdman/model_base.py +++ b/birdman/model_base.py @@ -114,7 +114,7 @@ def add_parameters(self, param_dict: dict = None): def fit_model( self, - method: str = "mcmc", + method: str = "vi", num_draws: int = 500, mcmc_warmup: int = None, mcmc_chains: int = 4, diff --git a/tests/conftest.py b/tests/conftest.py index 4ee42cb..b3d2020 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -47,7 +47,7 @@ def model(): metadata=md, ) nb.compile_model() - nb.fit_model(mcmc_chains=4, num_draws=100) + nb.fit_model(method="mcmc", mcmc_chains=4, num_draws=100) return nb @@ -76,7 +76,7 @@ def single_feat_model(): ) nb.compile_model() - nb.fit_model(mcmc_chains=4, num_draws=100) + nb.fit_model(method="mcmc", mcmc_chains=4, num_draws=100) return nb diff --git a/tests/test_custom_model.py b/tests/test_custom_model.py index 5aa3c02..d8691b4 100644 --- a/tests/test_custom_model.py +++ b/tests/test_custom_model.py @@ -40,7 +40,7 @@ def test_custom_model(table_biom, metadata): }, ) custom_model.compile_model() - custom_model.fit_model(num_draws=100, mcmc_chains=4, seed=42) + custom_model.fit_model(method="mcmc", num_draws=100, mcmc_chains=4, seed=42) inference = custom_model.to_inference() assert set(inference.groups()) == {"posterior"} diff --git a/tests/test_model.py b/tests/test_model.py index 9caa32a..d7839ce 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -49,7 +49,7 @@ def test_nb_lme(self, table_biom, metadata): metadata=md, ) nb_lme.compile_model() - nb_lme.fit_model(num_draws=100) + nb_lme.fit_model(method="mcmc", num_draws=100) inf = nb_lme.to_inference() post = inf.posterior @@ -68,7 +68,7 @@ def test_single_feat(self, table_biom, metadata): metadata=md, ) nb.compile_model() - nb.fit_model(num_draws=100) + nb.fit_model(method="mcmc", num_draws=100) class TestToInference: @@ -159,5 +159,5 @@ def test_iteration_fit(self, table_biom, metadata): for fit, model in model_iterator: model.compile_model() - model.fit_model(num_draws=100, mcmc_chains=4, seed=42) + model.fit_model(method="mcmc", num_draws=100, mcmc_chains=4, seed=42) _ = model.to_inference() From d6e4fc88c1fb47eacd431c2ec6ac858dc130a2a5 Mon Sep 17 00:00:00 2001 From: Gibraan Rahman Date: Mon, 18 Sep 2023 11:56:09 -0700 Subject: [PATCH 6/6] Fix style --- tests/test_custom_model.py | 3 ++- tests/test_model.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_custom_model.py b/tests/test_custom_model.py index d8691b4..988b0f5 100644 --- a/tests/test_custom_model.py +++ b/tests/test_custom_model.py @@ -40,7 +40,8 @@ def test_custom_model(table_biom, metadata): }, ) custom_model.compile_model() - custom_model.fit_model(method="mcmc", num_draws=100, mcmc_chains=4, seed=42) + custom_model.fit_model(method="mcmc", num_draws=100, mcmc_chains=4, + seed=42) inference = custom_model.to_inference() assert set(inference.groups()) == {"posterior"} diff --git a/tests/test_model.py b/tests/test_model.py index d7839ce..d1efdee 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -159,5 +159,6 @@ def test_iteration_fit(self, table_biom, metadata): for fit, model in model_iterator: model.compile_model() - model.fit_model(method="mcmc", num_draws=100, mcmc_chains=4, seed=42) + model.fit_model(method="mcmc", num_draws=100, mcmc_chains=4, + seed=42) _ = model.to_inference()