Skip to content

Commit

Permalink
Merge pull request #84 from gibsramen/vi
Browse files Browse the repository at this point in the history
Model fitting refactor
  • Loading branch information
gibsramen authored Sep 12, 2023
2 parents ebc6f03 + 3befd9e commit 306c77c
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 227 deletions.
63 changes: 0 additions & 63 deletions birdman/default_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,6 @@ class NegativeBinomial(TableModel):
:param metadata: Metadata for design matrix
:type metadata: pd.DataFrame
:param num_iter: Number of posterior sample draws, defaults to 500
:type num_iter: int
:param num_warmup: Number of posterior draws used for warmup, defaults to
num_iter
:type num_warmup: int
:param chains: Number of chains to use in MCMC, defaults to 4
:type chains: int
:param seed: Random seed to use for sampling, defaults to 42
:type seed: float
:param beta_prior: Standard deviation for normally distributed prior values
of beta, defaults to 5.0
:type beta_prior: float
Expand All @@ -82,10 +69,6 @@ def __init__(
table: biom.table.Table,
formula: str,
metadata: pd.DataFrame,
num_iter: int = 500,
num_warmup: int = None,
chains: int = 4,
seed: float = 42,
beta_prior: float = 5.0,
inv_disp_sd: float = 0.5,
):
Expand All @@ -94,10 +77,6 @@ def __init__(
super().__init__(
table=table,
model_path=filepath,
num_iter=num_iter,
num_warmup=num_warmup,
chains=chains,
seed=seed,
)
self.create_regression(formula=formula, metadata=metadata)

Expand Down Expand Up @@ -171,19 +150,6 @@ class NegativeBinomialSingle(SingleFeatureModel):
:param metadata: Metadata for design matrix
:type metadata: pd.DataFrame
:param num_iter: Number of posterior sample draws, defaults to 500
:type num_iter: int
:param num_warmup: Number of posterior draws used for warmup, defaults to
num_iter
:type num_warmup: int
:param chains: Number of chains to use in MCMC, defaults to 4
:type chains: int
:param seed: Random seed to use for sampling, defaults to 42
:type seed: float
:param beta_prior: Standard deviation for normally distributed prior values
of beta, defaults to 5.0
:type beta_prior: float
Expand All @@ -198,10 +164,6 @@ def __init__(
feature_id: str,
formula: str,
metadata: pd.DataFrame,
num_iter: int = 500,
num_warmup: int = None,
chains: int = 4,
seed: float = 42,
beta_prior: float = 5.0,
inv_disp_sd: float = 0.5,
):
Expand All @@ -211,10 +173,6 @@ def __init__(
table=table,
feature_id=feature_id,
model_path=filepath,
num_iter=num_iter,
num_warmup=num_warmup,
chains=chains,
seed=seed,
)
self.create_regression(formula=formula, metadata=metadata)

Expand Down Expand Up @@ -290,19 +248,6 @@ class NegativeBinomialLME(TableModel):
:param metadata: Metadata for design matrix
:type metadata: pd.DataFrame
:param num_iter: Number of posterior sample draws, defaults to 500
:type num_iter: int
:param num_warmup: Number of posterior draws used for warmup, defaults to
num_iter
:type num_warmup: int
:param chains: Number of chains to use in MCMC, defaults to 4
:type chains: int
:param seed: Random seed to use for sampling, defaults to 42
:type seed: float
:param beta_prior: Standard deviation for normally distributed prior values
of beta, defaults to 5.0
:type beta_prior: float
Expand All @@ -321,10 +266,6 @@ def __init__(
formula: str,
group_var: str,
metadata: pd.DataFrame,
num_iter: int = 500,
num_warmup: int = None,
chains: int = 4,
seed: float = 42,
beta_prior: float = 5.0,
inv_disp_sd: float = 0.5,
group_var_prior: float = 1.0
Expand All @@ -333,10 +274,6 @@ def __init__(
super().__init__(
table=table,
model_path=filepath,
num_iter=num_iter,
num_warmup=num_warmup,
chains=chains,
seed=seed,
)
self.create_regression(formula=formula, metadata=metadata)

Expand Down
19 changes: 2 additions & 17 deletions birdman/model_util.py → birdman/inference.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import re
from typing import List, Sequence

import arviz as az
from cmdstanpy import CmdStanMCMC
import xarray as xr

from .util import _drop_data


def full_fit_to_inference(
fit: CmdStanMCMC,
Expand Down Expand Up @@ -164,19 +165,3 @@ def concatenate_inferences(
all_group_inferences.append(group_inf)

return az.concat(*all_group_inferences)


def _drop_data(
dataset: xr.Dataset,
vars_to_drop: Sequence[str],
) -> xr.Dataset:
"""Drop data and associated dimensions from inference group."""
new_dataset = dataset.drop_vars(vars_to_drop)
# TODO: Figure out how to do this more cleanly
dims_to_drop = []
for var in vars_to_drop:
for dim in new_dataset.dims:
if re.match(f"{var}_dim_\\d", dim):
dims_to_drop.append(dim)
new_dataset = new_dataset.drop_dims(dims_to_drop)
return new_dataset
163 changes: 85 additions & 78 deletions birdman/model_base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from abc import ABC, abstractmethod
from math import ceil
from typing import Sequence
import warnings

import arviz as az
import biom
from cmdstanpy import CmdStanModel
import pandas as pd
from patsy import dmatrix

from .model_util import full_fit_to_inference, single_feature_fit_to_inference
from .inference import full_fit_to_inference, single_feature_fit_to_inference


class BaseModel(ABC):
Expand All @@ -20,36 +19,12 @@ class BaseModel(ABC):
:param model_path: Filepath to Stan model
:type model_path: str
:param num_iter: Number of posterior sample draws, defaults to 500
:type num_iter: int
:param num_warmup: Number of posterior draws used for warmup, defaults to
num_iter
:type num_warmup: int
:param chains: Number of chains to use in MCMC, defaults to 4
:type chains: int
:param seed: Random seed to use for sampling, defaults to 42
:type seed: float
"""
def __init__(
self,
table: biom.table.Table,
model_path: str,
num_iter: int = 500,
num_warmup: int = None,
chains: int = 4,
seed: float = 42,
):
self.num_iter = num_iter
if num_warmup is None:
self.num_warmup = num_iter
else:
self.num_warmup = num_warmup
self.chains = chains
self.seed = seed
self.sample_names = table.ids(axis="sample")
self.model_path = model_path
self.sm = None
Expand Down Expand Up @@ -139,50 +114,98 @@ def add_parameters(self, param_dict: dict = None):

def fit_model(
self,
sampler_args: dict = None,
convert_to_inference: bool = False
method: str = "mcmc",
num_draws: int = 500,
mcmc_warmup: int = None,
mcmc_chains: int = 4,
vi_iter: int = 1000,
vi_grad_samples: int = 40,
vi_require_converged: bool = False,
seed: float = 42,
mcmc_kwargs: dict = None,
vi_kwargs: dict = None
):
"""Fit Stan model.
"""Fit BIRDMAn model.
:param sampler_args: Additional parameters to pass to CmdStanPy
sampler (optional)
:type sampler_args: dict
:param method: Method by which to fit model, either 'mcmc' (default)
for Markov Chain Monte Carlo or 'vi' for Variational Inference
:type method: str
:param convert_to_inference: Whether to automatically convert to
inference given model specifications, defaults to False
:type convert_to_inference: bool
"""
if sampler_args is None:
sampler_args = dict()

_fit = self.sm.sample(
chains=self.chains,
parallel_chains=self.chains,
data=self.dat,
iter_warmup=self.num_warmup,
iter_sampling=self.num_iter,
seed=self.seed,
**sampler_args
)
:param num_draws: Number of output draws to sample from the posterior,
default is 500
:type num_draws: int
:param mcmc_warmup: Number of warmup iterations for MCMC sampling,
default is the same as num_draws
:type mcmc_warmup: int
:param mcmc_chains: Number of Markov chains to use for sampling,
default is 4
:type mcmc_chains: int
:param vi_iter: Number of ADVI iterations to use for VI, default is
1000
:type vi_iter: int
self.fit = _fit
:param vi_grad_samples: Number of MC draws for computing the gradient,
default is 40
:type vi_grad_samples: int
# If auto-conversion fails, fit will be of type CmdStanMCMC
if convert_to_inference:
try:
self.fit = self.to_inference()
except Exception as e:
warnings.warn(
"Auto conversion to InferenceData has failed! fit has "
"been saved as CmdStanMCMC instead. See error message"
f": \n{type(e).__name__}: {e}",
category=UserWarning
)
:param vi_require_converged: Whether or not to raise an error if Stan
reports that “The algorithm may not have converged”, default is
False
:type vi_require_converged: bool
:param seed: Random seed to use for sampling, default is 42
:type seed: int
:param mcmc_kwargs: kwargs to pass into CmdStanModel.sample
:param vi_kwargs: kwargs to pass into CmdStanModel.variational
"""
if method == "mcmc":
mcmc_kwargs = mcmc_kwargs or dict()
mcmc_warmup = mcmc_warmup or mcmc_warmup

self.fit = self.sm.sample(
chains=mcmc_chains,
parallel_chains=mcmc_chains,
data=self.dat,
iter_warmup=mcmc_warmup,
iter_sampling=num_draws,
seed=seed,
**mcmc_kwargs
)
elif method == "vi":
vi_kwargs = vi_kwargs or dict()

self.fit = self.sm.variational(
data=self.dat,
iter=vi_iter,
output_samples=num_draws,
grad_samples=vi_grad_samples,
require_converged=vi_require_converged,
seed=seed,
**vi_kwargs
)
else:
raise ValueError("method must be either 'mcmc' or 'vi'")

@abstractmethod
def to_inference(self):
"""Convert fitted model to az.InferenceData."""

def _check_fit_for_inf(self):
if self.fit is None:
raise ValueError("Model has not been fit!")

# if already Inference, just return
if isinstance(self.fit, az.InferenceData):
return self.fit

if not self.specified:
raise ValueError("Model has not been specified!")


class TableModel(BaseModel):
"""Fit a model on the entire table at once."""
Expand All @@ -199,15 +222,7 @@ def to_inference(self) -> az.InferenceData:
:returns: ``arviz`` InferenceData object with selected values
:rtype: az.InferenceData
"""
if self.fit is None:
raise ValueError("Model has not been fit!")

# if already Inference, just return
if isinstance(self.fit, az.InferenceData):
return self.fit

if not self.specified:
raise ValueError("Model has not been specified!")
self._check_fit_for_inf()

inference = full_fit_to_inference(
fit=self.fit,
Expand Down Expand Up @@ -252,15 +267,7 @@ def to_inference(self) -> az.InferenceData:
:returns: ``arviz`` InferenceData object with selected values
:rtype: az.InferenceData
"""
if self.fit is None:
raise ValueError("Model has not been fit!")

# if already Inference, just return
if isinstance(self.fit, az.InferenceData):
return self.fit

if not self.specified:
raise ValueError("Model has not been specified!")
self._check_fit_for_inf()

inference = single_feature_fit_to_inference(
fit=self.fit,
Expand Down
17 changes: 17 additions & 0 deletions birdman/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import re
from typing import Sequence

import xarray as xr


def _drop_data(dataset: xr.Dataset, vars_to_drop: Sequence[str]) -> xr.Dataset:
"""Drop data and associated dimensions from inference group."""
new_dataset = dataset.drop_vars(vars_to_drop)
# TODO: Figure out how to do this more cleanly
dims_to_drop = []
for var in vars_to_drop:
for dim in new_dataset.dims:
if re.match(f"{var}_dim_\\d", dim):
dims_to_drop.append(dim)
new_dataset = new_dataset.drop_dims(dims_to_drop)
return new_dataset
Loading

0 comments on commit 306c77c

Please sign in to comment.