Skip to content

Commit

Permalink
Updated docs
Browse files Browse the repository at this point in the history
  • Loading branch information
gibsramen committed Oct 29, 2023
1 parent e1edd3d commit 1af6ba5
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 6 deletions.
55 changes: 51 additions & 4 deletions birdman/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,36 @@ def fit_to_inference(
dims: dict,
posterior_predictive: str = None,
log_likelihood: str = None,
):
) -> az.InferenceData:
"""Convert a fitted model to an arviz InferenceData object.
:param fit: Fitted CmdStan model
:type fit: Either CmdStanMCMC or CmdStanVB
:param chains: Number of chains
:type chains: int
:param draws: Number of draws
:type draws: int
:param params: Parameters to include in inference
:type params: Sequence[str]
:param coords: Coordinates for InferenceData
:type coords: dict
:param dims: Dimensions for InferenceData
:type dims: dict
:param posterior_predictive: Name of posterior predictive var in model
:type posterior_predictive: str
:param log_likelihood: Name of log likelihood var in model
:type log_likelihood: str
:returns: Model converted to InferenceData
:rtype: az.InferenceData
"""
if log_likelihood is not None and log_likelihood not in dims:
raise KeyError("Must include dimensions for log-likelihood!")
if posterior_predictive is not None and posterior_predictive not in dims:
Expand Down Expand Up @@ -120,15 +149,33 @@ def concatenate_inferences(
return az.concat(*all_group_inferences)


# TODO: Fix docstring
def stan_var_to_da(
data: np.ndarray,
coords: dict,
dims: dict,
chains: int,
draws: int
):
"""Convert Stan variable draws to xr.DataArray."""
) -> xr.DataArray:
"""Convert results of stan_var to DataArray.
:params data: Result of stan_var
:type data: np.ndarray
:params coords: Coordinates of variables
:type coords: dict
:params dims: Dimensions of variables
:type dims: dict
:params chains: Number of chains
:type chains: int
:params draws: Number of draws
:type draws: int
:returns: DataArray representation of stan variables
:rtype: xr.DataArray
"""
data = np.stack(np.split(data, chains))

coords["draw"] = np.arange(draws)
Expand Down
3 changes: 1 addition & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ channels:
dependencies:
- pandas
- numpy
- python=3.7
- python=3.8
- python-language-server
- xarray
- patsy
Expand All @@ -16,4 +16,3 @@ dependencies:
- docutils==0.16
- pip:
- sphinx-rtd-theme==0.5.1
prefix: /Users/gibs/miniconda3/envs/birdman

0 comments on commit 1af6ba5

Please sign in to comment.