Skip to content

Commit

Permalink
Make VI default
Browse files Browse the repository at this point in the history
  • Loading branch information
gibsramen committed Sep 18, 2023
1 parent d4fd94a commit 752109e
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion birdman/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def add_parameters(self, param_dict: dict = None):

def fit_model(
self,
method: str = "mcmc",
method: str = "vi",
num_draws: int = 500,
mcmc_warmup: int = None,
mcmc_chains: int = 4,
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def model():
metadata=md,
)
nb.compile_model()
nb.fit_model(mcmc_chains=4, num_draws=100)
nb.fit_model(method="mcmc", mcmc_chains=4, num_draws=100)
return nb


Expand Down Expand Up @@ -76,7 +76,7 @@ def single_feat_model():
)

nb.compile_model()
nb.fit_model(mcmc_chains=4, num_draws=100)
nb.fit_model(method="mcmc", mcmc_chains=4, num_draws=100)

return nb

Expand Down
2 changes: 1 addition & 1 deletion tests/test_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_custom_model(table_biom, metadata):
},
)
custom_model.compile_model()
custom_model.fit_model(num_draws=100, mcmc_chains=4, seed=42)
custom_model.fit_model(method="mcmc", num_draws=100, mcmc_chains=4, seed=42)
inference = custom_model.to_inference()

assert set(inference.groups()) == {"posterior"}
Expand Down
6 changes: 3 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_nb_lme(self, table_biom, metadata):
metadata=md,
)
nb_lme.compile_model()
nb_lme.fit_model(num_draws=100)
nb_lme.fit_model(method="mcmc", num_draws=100)

inf = nb_lme.to_inference()
post = inf.posterior
Expand All @@ -68,7 +68,7 @@ def test_single_feat(self, table_biom, metadata):
metadata=md,
)
nb.compile_model()
nb.fit_model(num_draws=100)
nb.fit_model(method="mcmc", num_draws=100)


class TestToInference:
Expand Down Expand Up @@ -159,5 +159,5 @@ def test_iteration_fit(self, table_biom, metadata):

for fit, model in model_iterator:
model.compile_model()
model.fit_model(num_draws=100, mcmc_chains=4, seed=42)
model.fit_model(method="mcmc", num_draws=100, mcmc_chains=4, seed=42)
_ = model.to_inference()

0 comments on commit 752109e

Please sign in to comment.