Skip to content

Commit

Permalink
Fix inference dims in VI
Browse files Browse the repository at this point in the history
  • Loading branch information
gibsramen committed Oct 4, 2023
1 parent 395e33d commit 09f8580
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions birdman/inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import List, Sequence, Union

import arviz as az
Expand All @@ -21,18 +22,25 @@ def fit_to_inference(
if posterior_predictive is not None and posterior_predictive not in dims:
raise KeyError("Must include dimensions for posterior predictive!")

# Required because as of writing, CmdStanVB.stan_variable defaults to
# returning the mean rather than the sample
if isinstance(fit, CmdStanVB):
stan_var_fn = partial(fit.stan_variable, mean=False)
else:
stan_var_fn = fit.stan_variable

das = dict()

for param in params:
data = fit.stan_variable(param)
data = stan_var_fn(param)

_dims = dims[param]
_coords = {k: coords[k] for k in _dims}

das[param] = stan_var_to_da(data, _coords, _dims, chains, draws)

if log_likelihood:
data = fit.stan_variable(log_likelihood)
data = stan_var_fn(log_likelihood)

_dims = dims[log_likelihood]
_coords = {k: coords[k] for k in _dims}
Expand All @@ -43,7 +51,7 @@ def fit_to_inference(
ll_ds = None

if posterior_predictive:
data = fit.stan_variable(posterior_predictive)
data = stan_var_fn(posterior_predictive)

_dims = dims[posterior_predictive]
_coords = {k: coords[k] for k in _dims}
Expand Down Expand Up @@ -112,16 +120,15 @@ def concatenate_inferences(
return az.concat(*all_group_inferences)


# TODO: Fix docstring
def stan_var_to_da(
data: np.ndarray,
coords: dict,
dims: dict,
chains: int,
draws: int
):
"""Convert Stan variable draws to xr.DataArray.
"""
"""Convert Stan variable draws to xr.DataArray."""
data = np.stack(np.split(data, chains))

coords["draw"] = np.arange(draws)
Expand Down

0 comments on commit 09f8580

Please sign in to comment.