Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes metric hashing #478

Merged
merged 11 commits into from
Aug 24, 2021
23 changes: 23 additions & 0 deletions tests/bases/test_hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from tests.helpers.testers import DummyListMetric, DummyMetric

pytest.mark.parametrize(
"metric_cls",
[
(DummyMetric,),
(DummyListMetric,),
],
)


def test_metric_hashing(metric_cls):
"""Tests that hases are different.

See the Metric's hash function for details on why this is required.
"""
instance_1 = metric_cls()
instance_2 = metric_cls()

assert hash(instance_1) != hash(instance_2)
assert id(instance_1) != id(instance_2)
7 changes: 6 additions & 1 deletion torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,12 @@ def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
return filtered_kwargs

def __hash__(self) -> int:
hash_vals = [self.__class__.__name__]
# we need to add the id here, since PyTorch requires a module hash to be unique.
# Internally, PyTorch nn.Module relies on that for children discovery
# (see https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1544)
# For metrics that include tensors it is not a problem,
# since their hash is unique based on the memory location but we cannot rely on that for every metric.
hash_vals = [self.__class__.__name__, id(self)]

for key in self._defaults:
val = getattr(self, key)
Expand Down