From 53197a06dc2ddf7a73913a91c19be03262fa50ea Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Sun, 15 Dec 2024 08:21:13 +0100 Subject: [PATCH] Make sure mean of the LLPR ensembles is consistent --- src/metatrain/utils/llpr.py | 19 +++++++++++++++++++ tests/utils/test_llpr.py | 9 +++++++++ 2 files changed, 28 insertions(+) diff --git a/src/metatrain/utils/llpr.py b/src/metatrain/utils/llpr.py index f950d32f6..b0fe0b117 100644 --- a/src/metatrain/utils/llpr.py +++ b/src/metatrain/utils/llpr.py @@ -226,6 +226,25 @@ def forward( ll_features.block().values, ensemble_weights, ) + + # since we know the exact mean of the ensemble from the model's prediction, + # it should be mathematically correct to use it to re-center the ensemble. + # Besides making sure that the average is always correct (so that results + # will always be consistent between LLPR ensembles and the original model), + # this also takes care of additive contributions that are not present in the + # last layer, which can be composition, short-range models, a bias in the + # last layer, etc. + original_name = ( + name.replace("_ensemble", "").replace("aux::", "") + if name.replace("_ensemble", "").replace("aux::", "") in outputs + else name.replace("_ensemble", "").replace("mtt::aux::", "") + ) + ensemble_values = ( + ensemble_values + - ensemble_values.mean(dim=1, keepdim=True) + + return_dict[original_name].block().values + ) + property_name = "energy" if name == "energy_ensemble" else "ensemble_member" ensemble = TensorMap( keys=Labels( diff --git a/tests/utils/test_llpr.py b/tests/utils/test_llpr.py index 46e98f21e..f6ca81731 100644 --- a/tests/utils/test_llpr.py +++ b/tests/utils/test_llpr.py @@ -278,11 +278,20 @@ def test_llpr_covariance_as_pseudo_hessian(tmpdir): assert "mtt::aux::energy_uncertainty" in outputs assert "energy_ensemble" in outputs + predictions = outputs["energy"].block().values analytical_uncertainty = outputs["mtt::aux::energy_uncertainty"].block().values + ensemble_mean = torch.mean( + outputs["energy_ensemble"].block().values, dim=1, keepdim=True + ) ensemble_uncertainty = torch.var( outputs["energy_ensemble"].block().values, dim=1, keepdim=True ) + print(predictions) + print(ensemble_mean) + print(predictions - ensemble_mean) + + torch.testing.assert_close(predictions, ensemble_mean, rtol=5e-3, atol=0.0) torch.testing.assert_close( analytical_uncertainty, ensemble_uncertainty, rtol=5e-3, atol=0.0 )