From b9a599b522c99a8d08453034f7aa2f561cb4db0c Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Sat, 29 Oct 2022 09:55:53 -0300 Subject: [PATCH 1/7] add weight_predictions --- arviz/stats/stats.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index d22e50be62..fed13690aa 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -23,7 +23,7 @@ from .. import _log from ..data import InferenceData, convert_to_dataset, convert_to_inference_data from ..rcparams import rcParams, ScaleKeyword, ICKeyword -from ..utils import Numba, _numba_var, _var_names, get_coords +from ..utils import Numba, _numba_var, _var_names, get_coords, extract from .density_utils import get_bins as _get_bins from .density_utils import histogram as _histogram from .density_utils import kde as _kde @@ -49,6 +49,7 @@ "r2_score", "summary", "waic", + "weight_predictions", "_calculate_ics", ] @@ -2043,3 +2044,24 @@ def apply_test_function( setattr(out, grp, out_group) return out + + +def weight_predictions(idatas, weights): + len_idatas = [ + len(idata.posterior_predictive.chain) * len(idata.posterior_predictive.draw) + for idata in idatas + ] + + new_samples = (np.min(len_idatas) * weights).astype(int) + + new_idatas = [ + extract(idata, group="posterior_predictive", num_samples=samples).reset_coords() + for samples, idata in zip(new_samples, idatas) + ] + + weight_samples = InferenceData( + posterior_predictive=xr.concat(new_idatas, dim="sample"), + observed_data=idatas[0].observed_data, + ) + + return weight_samples From ffe7592c2ae60112e6dcc64ae720765214e5866c Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Sat, 29 Oct 2022 17:47:18 -0300 Subject: [PATCH 2/7] clean, add checks and docstring --- arviz/stats/__init__.py | 1 + arviz/stats/stats.py | 47 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/arviz/stats/__init__.py b/arviz/stats/__init__.py index 681b63b1ec..c39f57faae 100644 --- a/arviz/stats/__init__.py +++ b/arviz/stats/__init__.py @@ -20,6 +20,7 @@ "r2_score", "summary", "waic", + "weight_predictions", "ELPDData", "ess", "rhat", diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index fed13690aa..3700da3efc 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -21,9 +21,9 @@ NO_GET_ARGS = True from .. import _log -from ..data import InferenceData, convert_to_dataset, convert_to_inference_data +from ..data import InferenceData, convert_to_dataset, convert_to_inference_data, extract from ..rcparams import rcParams, ScaleKeyword, ICKeyword -from ..utils import Numba, _numba_var, _var_names, get_coords, extract +from ..utils import Numba, _numba_var, _var_names, get_coords from .density_utils import get_bins as _get_bins from .density_utils import histogram as _histogram from .density_utils import kde as _kde @@ -2046,7 +2046,44 @@ def apply_test_function( return out -def weight_predictions(idatas, weights): +def weight_predictions(idatas, weights=None): + """ + Generate weighted posterior predictive samples from a list of InferenceData + and a set of weights. + + Parameters + --------- + datasets: list[InfereneData] + List of :class:`arviz.InferenceData` objects containing the groups `posterior_predictive` + and `observed_data`. Observations should be the same for all InferenceData objects. + weights : array-like, optional + Individual weights for each model. Weights should be positive. If they do not sum up to 1, + they will be normalized. Default, same weight for each model. + Weights can be computed using many different methods including those in + :func:`arviz.compare`. + + Returns + ------- + idata: InferenceData + Output InferenceData object with the groups `posterior_predictive` and `observed_data`. + + See Also + -------- + compare : Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation + """ + if weights is None: + weights = np.ones(len(idatas)) / len(idatas) + weights = np.array(weights, dtype=float) + weights /= weights.sum() + + if not np.all(["posterior_predictive" in idata.groups() for idata in idatas]): + raise ValueError( + "All the InferenceData objects must contain the `posterior_predictive` group" + ) + + if not np.all([idatas[0].observed_data.equals(idata.observed_data) for idata in idatas[1:]]): + raise ValueError("The observed data should be the same for all InferenceData objects") + len_idatas = [ len(idata.posterior_predictive.chain) * len(idata.posterior_predictive.draw) for idata in idatas @@ -2059,9 +2096,9 @@ def weight_predictions(idatas, weights): for samples, idata in zip(new_samples, idatas) ] - weight_samples = InferenceData( + weighted_samples = InferenceData( posterior_predictive=xr.concat(new_idatas, dim="sample"), observed_data=idatas[0].observed_data, ) - return weight_samples + return weighted_samples From b5fd4f9bd7b9c834b8f3d0dfa06bf9268e2640c4 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Sun, 30 Oct 2022 06:36:54 -0300 Subject: [PATCH 3/7] checks --- arviz/stats/stats.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 3700da3efc..b775c0ee4c 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -2053,7 +2053,7 @@ def weight_predictions(idatas, weights=None): Parameters --------- - datasets: list[InfereneData] + idatas : list[InfereneData] List of :class:`arviz.InferenceData` objects containing the groups `posterior_predictive` and `observed_data`. Observations should be the same for all InferenceData objects. weights : array-like, optional @@ -2071,19 +2071,27 @@ def weight_predictions(idatas, weights=None): -------- compare : Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation """ - if weights is None: - weights = np.ones(len(idatas)) / len(idatas) - weights = np.array(weights, dtype=float) - weights /= weights.sum() + if len(idatas) < 2: + raise ValueError("You should provide a list with at least two InferenceData objects") - if not np.all(["posterior_predictive" in idata.groups() for idata in idatas]): + if not all("posterior_predictive" in idata.groups() for idata in idatas): raise ValueError( "All the InferenceData objects must contain the `posterior_predictive` group" ) - if not np.all([idatas[0].observed_data.equals(idata.observed_data) for idata in idatas[1:]]): + if not all(idatas[0].observed_data.equals(idata.observed_data) for idata in idatas[1:]): raise ValueError("The observed data should be the same for all InferenceData objects") + if weights is None: + weights = np.ones(len(idatas)) / len(idatas) + elif len(idatas) != len(weights): + raise ValueError( + "The number of weights should be the same as the number of InferenceData objects" + ) + + weights = np.asarray(weights, dtype=float) + weights /= weights.sum() + len_idatas = [ len(idata.posterior_predictive.chain) * len(idata.posterior_predictive.draw) for idata in idatas From d5ce86b2670977584cdad0ed8298f09cc98604e4 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Wed, 2 Nov 2022 09:10:42 -0300 Subject: [PATCH 4/7] add test --- arviz/tests/base_tests/test_stats.py | 31 +++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index b8c2422346..4a5b4e1739 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -3,7 +3,12 @@ import numpy as np import pytest -from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal +from numpy.testing import ( + assert_allclose, + assert_array_almost_equal, + assert_almost_equal, + assert_array_equal, +) from scipy.special import logsumexp from scipy.stats import linregress from xarray import DataArray, Dataset @@ -21,6 +26,7 @@ r2_score, summary, waic, + weight_predictions, _calculate_ics, ) from ...stats.stats import _gpinv @@ -800,3 +806,26 @@ def test_apply_test_function_should_overwrite_error(centered_eight): """Test error when overwrite=False but out_name is already a present variable.""" with pytest.raises(ValueError, match="Should overwrite"): apply_test_function(centered_eight, lambda y, theta: y, out_name_data="obs") + + +def test_weight_predictions(): + idata0 = from_dict( + posterior_predictive={"a": np.random.normal(-1, 1, 1000)}, observed_data={"a": [1]} + ) + idata1 = from_dict( + posterior_predictive={"a": np.random.normal(1, 1, 1000)}, observed_data={"a": [1]} + ) + + new = weight_predictions([idata0, idata1]) + assert ( + idata1.posterior_predictive.mean() + > new.posterior_predictive.mean() + > idata0.posterior_predictive.mean() + ) + assert "posterior_predictive" in new + assert "observed_data" in new + + new = weight_predictions([idata0, idata1], weights=[0.5, 0.5]) + assert_almost_equal(new.posterior_predictive["a"].mean(), 0, decimal=1) + new = weight_predictions([idata0, idata1], weights=[0.9, 0.1]) + assert_almost_equal(new.posterior_predictive["a"].mean(), -0.8, decimal=1) From c660453e6c4c84e439a9f1be8f83cd3d59df204d Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Wed, 2 Nov 2022 09:25:14 -0300 Subject: [PATCH 5/7] update changelog --- CHANGELOG.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a2a8f05f0d..b900670ff5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Change Log +## v0.xx.x + +### New features +* Add `weight_predictions` function to allow generation of weighted predictions from two or more InfereceData with `posterior_predictive` groups and a set of weights ([2147](https://github.com/arviz-devs/arviz/pull/2147)) + +### Maintenance and fixes + +### Deprecation + +### Documentation + + ## v0.13.0 (2022 Oct 22) ### New features From b056cb1c7e64e4c124cba3f5cb08fa8b91ec00e7 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Sat, 12 Nov 2022 09:26:42 -0300 Subject: [PATCH 6/7] update per comments --- arviz/stats/stats.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index b775c0ee4c..70f22c210d 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -2053,7 +2053,7 @@ def weight_predictions(idatas, weights=None): Parameters --------- - idatas : list[InfereneData] + idatas : list[InferenceData] List of :class:`arviz.InferenceData` objects containing the groups `posterior_predictive` and `observed_data`. Observations should be the same for all InferenceData objects. weights : array-like, optional @@ -2089,14 +2089,17 @@ def weight_predictions(idatas, weights=None): "The number of weights should be the same as the number of InferenceData objects" ) - weights = np.asarray(weights, dtype=float) + weights = np.array(weights, dtype=float) weights /= weights.sum() len_idatas = [ - len(idata.posterior_predictive.chain) * len(idata.posterior_predictive.draw) + idata.posterior_predictive.dims["chain"] * idata.posterior_predictive.dims["draw"] for idata in idatas ] + if not all(len_idatas): + raise ValueError("at least one of your idatas has 0 samples") + new_samples = (np.min(len_idatas) * weights).astype(int) new_idatas = [ From c16de55a86290a31bc6a96dfa55952ca491257d3 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Sat, 12 Nov 2022 09:31:39 -0300 Subject: [PATCH 7/7] update per comments --- arviz/stats/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 70f22c210d..f51a10bf93 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -2098,7 +2098,7 @@ def weight_predictions(idatas, weights=None): ] if not all(len_idatas): - raise ValueError("at least one of your idatas has 0 samples") + raise ValueError("At least one of your idatas has 0 samples") new_samples = (np.min(len_idatas) * weights).astype(int)