diff --git a/birdman/inference.py b/birdman/inference.py index b33b7d2..a376891 100644 --- a/birdman/inference.py +++ b/birdman/inference.py @@ -16,7 +16,36 @@ def fit_to_inference( dims: dict, posterior_predictive: str = None, log_likelihood: str = None, -): +) -> az.InferenceData: + """Convert a fitted model to an arviz InferenceData object. + + :param fit: Fitted CmdStan model + :type fit: Either CmdStanMCMC or CmdStanVB + + :param chains: Number of chains + :type chains: int + + :param draws: Number of draws + :type draws: int + + :param params: Parameters to include in inference + :type params: Sequence[str] + + :param coords: Coordinates for InferenceData + :type coords: dict + + :param dims: Dimensions for InferenceData + :type dims: dict + + :param posterior_predictive: Name of posterior predictive var in model + :type posterior_predictive: str + + :param log_likelihood: Name of log likelihood var in model + :type log_likelihood: str + + :returns: Model converted to InferenceData + :rtype: az.InferenceData + """ if log_likelihood is not None and log_likelihood not in dims: raise KeyError("Must include dimensions for log-likelihood!") if posterior_predictive is not None and posterior_predictive not in dims: @@ -120,15 +149,33 @@ def concatenate_inferences( return az.concat(*all_group_inferences) -# TODO: Fix docstring def stan_var_to_da( data: np.ndarray, coords: dict, dims: dict, chains: int, draws: int -): - """Convert Stan variable draws to xr.DataArray.""" +) -> xr.DataArray: + """Convert results of stan_var to DataArray. + + :params data: Result of stan_var + :type data: np.ndarray + + :params coords: Coordinates of variables + :type coords: dict + + :params dims: Dimensions of variables + :type dims: dict + + :params chains: Number of chains + :type chains: int + + :params draws: Number of draws + :type draws: int + + :returns: DataArray representation of stan variables + :rtype: xr.DataArray + """ data = np.stack(np.split(data, chains)) coords["draw"] = np.arange(draws) diff --git a/environment.yml b/environment.yml index 72d8bd7..c4b0e5d 100644 --- a/environment.yml +++ b/environment.yml @@ -6,7 +6,7 @@ channels: dependencies: - pandas - numpy - - python=3.7 + - python=3.8 - python-language-server - xarray - patsy @@ -16,4 +16,3 @@ dependencies: - docutils==0.16 - pip: - sphinx-rtd-theme==0.5.1 -prefix: /Users/gibs/miniconda3/envs/birdman