Skip to content

Commit

Permalink
Make sure mean of the LLPR ensembles is consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 15, 2024
1 parent 9277267 commit 53197a0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/metatrain/utils/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,25 @@ def forward(
ll_features.block().values,
ensemble_weights,
)

# since we know the exact mean of the ensemble from the model's prediction,
# it should be mathematically correct to use it to re-center the ensemble.
# Besides making sure that the average is always correct (so that results
# will always be consistent between LLPR ensembles and the original model),
# this also takes care of additive contributions that are not present in the
# last layer, which can be composition, short-range models, a bias in the
# last layer, etc.
original_name = (
name.replace("_ensemble", "").replace("aux::", "")
if name.replace("_ensemble", "").replace("aux::", "") in outputs
else name.replace("_ensemble", "").replace("mtt::aux::", "")
)
ensemble_values = (
ensemble_values
- ensemble_values.mean(dim=1, keepdim=True)
+ return_dict[original_name].block().values
)

property_name = "energy" if name == "energy_ensemble" else "ensemble_member"
ensemble = TensorMap(
keys=Labels(
Expand Down
9 changes: 9 additions & 0 deletions tests/utils/test_llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,20 @@ def test_llpr_covariance_as_pseudo_hessian(tmpdir):
assert "mtt::aux::energy_uncertainty" in outputs
assert "energy_ensemble" in outputs

predictions = outputs["energy"].block().values
analytical_uncertainty = outputs["mtt::aux::energy_uncertainty"].block().values
ensemble_mean = torch.mean(
outputs["energy_ensemble"].block().values, dim=1, keepdim=True
)
ensemble_uncertainty = torch.var(
outputs["energy_ensemble"].block().values, dim=1, keepdim=True
)

print(predictions)
print(ensemble_mean)
print(predictions - ensemble_mean)

torch.testing.assert_close(predictions, ensemble_mean, rtol=5e-3, atol=0.0)
torch.testing.assert_close(
analytical_uncertainty, ensemble_uncertainty, rtol=5e-3, atol=0.0
)

0 comments on commit 53197a0

Please sign in to comment.