From aded37d37962fa682a30f88cd7a8fb67061d568d Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Sun, 15 Dec 2024 08:21:13 +0100 Subject: [PATCH] Make sure mean of the LLPR ensembles is consistent --- src/metatrain/utils/llpr.py | 19 +++++++++++++++++++ tests/utils/test_additive.py | 12 ++++++------ tests/utils/test_llpr.py | 9 +++++++++ 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/metatrain/utils/llpr.py b/src/metatrain/utils/llpr.py index f950d32f6..b0fe0b117 100644 --- a/src/metatrain/utils/llpr.py +++ b/src/metatrain/utils/llpr.py @@ -226,6 +226,25 @@ def forward( ll_features.block().values, ensemble_weights, ) + + # since we know the exact mean of the ensemble from the model's prediction, + # it should be mathematically correct to use it to re-center the ensemble. + # Besides making sure that the average is always correct (so that results + # will always be consistent between LLPR ensembles and the original model), + # this also takes care of additive contributions that are not present in the + # last layer, which can be composition, short-range models, a bias in the + # last layer, etc. + original_name = ( + name.replace("_ensemble", "").replace("aux::", "") + if name.replace("_ensemble", "").replace("aux::", "") in outputs + else name.replace("_ensemble", "").replace("mtt::aux::", "") + ) + ensemble_values = ( + ensemble_values + - ensemble_values.mean(dim=1, keepdim=True) + + return_dict[original_name].block().values + ) + property_name = "energy" if name == "energy_ensemble" else "ensemble_member" ensemble = TensorMap( keys=Labels( diff --git a/tests/utils/test_additive.py b/tests/utils/test_additive.py index 2e28b203a..19e5ccf56 100644 --- a/tests/utils/test_additive.py +++ b/tests/utils/test_additive.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path import metatensor.torch @@ -272,9 +273,9 @@ def test_remove_additive(): assert std_after < 100.0 * std_before -def test_composition_model_missing_types(): +def test_composition_model_missing_types(caplog): """ - Test the error when there are too many or too types in the dataset + Test the error when there are too many types in the dataset compared to those declared at initialization. """ @@ -355,11 +356,10 @@ def test_composition_model_missing_types(): targets={"energy": get_energy_target_info({"unit": "eV"})}, ), ) - with pytest.warns( - UserWarning, - match="do not contain atomic types", - ): + # need to capture the warning from the logger + with caplog.at_level(logging.WARNING): composition_model.train_model(dataset, []) + assert "do not contain atomic types" in caplog.text def test_composition_model_wrong_target(): diff --git a/tests/utils/test_llpr.py b/tests/utils/test_llpr.py index 46e98f21e..f6ca81731 100644 --- a/tests/utils/test_llpr.py +++ b/tests/utils/test_llpr.py @@ -278,11 +278,20 @@ def test_llpr_covariance_as_pseudo_hessian(tmpdir): assert "mtt::aux::energy_uncertainty" in outputs assert "energy_ensemble" in outputs + predictions = outputs["energy"].block().values analytical_uncertainty = outputs["mtt::aux::energy_uncertainty"].block().values + ensemble_mean = torch.mean( + outputs["energy_ensemble"].block().values, dim=1, keepdim=True + ) ensemble_uncertainty = torch.var( outputs["energy_ensemble"].block().values, dim=1, keepdim=True ) + print(predictions) + print(ensemble_mean) + print(predictions - ensemble_mean) + + torch.testing.assert_close(predictions, ensemble_mean, rtol=5e-3, atol=0.0) torch.testing.assert_close( analytical_uncertainty, ensemble_uncertainty, rtol=5e-3, atol=0.0 )