Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds indexing operation to Metric class #142

Merged
merged 9 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `BootStrapper` to easely calculate confidence intervals for metrics ([#101](https://github.com/PyTorchLightning/metrics/pull/101))


- Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142))


### Changed

- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))
Expand Down
1 change: 1 addition & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ This pattern is implemented for the following operators (with ``a`` being metric
* Inversion (``~a``)
* Negative Value (``neg(a)``)
* Positive Value (``pos(a)``)
* Indexing (``a[0]``)

.. note::

Expand Down
24 changes: 20 additions & 4 deletions tests/bases/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,10 @@ def test_metrics_rfloordiv(first_operand, expected_result):
assert torch.allclose(expected_result, final_rfloordiv.compute())


@pytest.mark.parametrize(["first_operand", "expected_result"],
[pytest.param(tensor([2, 2, 2]), tensor(12), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))])
@pytest.mark.parametrize(
["first_operand", "expected_result"],
[pytest.param(tensor([2, 2, 2]), tensor(12), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))],
)
def test_metrics_rmatmul(first_operand, expected_result):
second_operand = DummyMetric([2, 2, 2])

Expand All @@ -337,8 +339,10 @@ def test_metrics_rmatmul(first_operand, expected_result):
assert torch.allclose(expected_result, final_rmatmul.compute())


@pytest.mark.parametrize(["first_operand", "expected_result"],
[pytest.param(tensor(2), tensor(2), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))])
@pytest.mark.parametrize(
["first_operand", "expected_result"],
[pytest.param(tensor(2), tensor(2), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))],
)
def test_metrics_rmod(first_operand, expected_result):
second_operand = DummyMetric(5)

Expand Down Expand Up @@ -497,6 +501,18 @@ def test_metrics_pos():
assert torch.allclose(tensor(1), final_pos.compute())


@pytest.mark.parametrize(
["value", "idx", "expected_result"],
[([1, 2, 3], 1, tensor(2)), ([[0, 1], [2, 3]], (1, 0), tensor(2)), ([[0, 1], [2, 3]], 1, tensor([2, 3]))],
)
def test_metrics_getitem(value, idx, expected_result):
first_metric = DummyMetric(value)

final_getitem = first_metric[idx]
assert isinstance(final_getitem, CompositionalMetric)
assert torch.allclose(expected_result, final_getitem.compute())


def test_compositional_metrics_update():

compos = DummyMetric(5) + DummyMetric(4)
Expand Down
1 change: 1 addition & 0 deletions tests/wrappers/test_bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class TestBootStrapper(BootStrapper):
""" For testing purpose, we subclass the bootstrapper class so we can get the exact permutation
the class is creating
"""

def update(self, *args) -> None:
self.out = []
for idx in range(self.num_bootstraps):
Expand Down
14 changes: 7 additions & 7 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,7 @@ def add_state(
ValueError:
If ``dist_reduce_fx`` is not callable or one of ``"mean"``, ``"sum"``, ``"cat"``, ``None``.
"""
if (
not isinstance(default, Tensor) and not isinstance(default, list) # noqa: W503
or (isinstance(default, list) and len(default) != 0) # noqa: W503
):
if (not isinstance(default, (Tensor, list)) or (isinstance(default, list) and default)):
raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)")

if dist_reduce_fx == "sum":
Expand Down Expand Up @@ -304,7 +301,7 @@ def persistent(self, mode: bool = False):
for key in self._persistent.keys():
self._persistent[key] = mode

def state_dict(self, destination=None, prefix='', keep_vars=False):
def state_dict(self, destination=None, prefix="", keep_vars=False):
destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
# Register metric states to be part of the state_dict
for key in self._defaults.keys():
Expand Down Expand Up @@ -342,7 +339,7 @@ def __hash__(self):
val = getattr(self, key)
# Special case: allow list values, so long
# as their elements are hashable
if hasattr(val, '__iter__') and not isinstance(val, Tensor):
if hasattr(val, "__iter__") and not isinstance(val, Tensor):
hash_vals.extend(val)
else:
hash_vals.append(val)
Expand Down Expand Up @@ -449,6 +446,9 @@ def __neg__(self):
def __pos__(self):
return CompositionalMetric(torch.abs, self, None)

def __getitem__(self, idx):
return CompositionalMetric(lambda x: x[idx], self, None)


def _neg(tensor: Tensor):
return -torch.abs(tensor)
Expand Down Expand Up @@ -530,6 +530,6 @@ def persistent(self, mode: bool = False) -> None:

def __repr__(self) -> str:
_op_metrics = f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)"
repr_str = (self.__class__.__name__ + _op_metrics)
repr_str = self.__class__.__name__ + _op_metrics

return repr_str
13 changes: 3 additions & 10 deletions torchmetrics/wrappers/bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7


def _bootstrap_sampler(
size: int,
sampling_strategy: str = 'poisson'
) -> Tensor:
def _bootstrap_sampler(size: int, sampling_strategy: str = 'poisson') -> Tensor:
""" Resample a tensor along its first dimension with replacement
Args:
size: number of samples
Expand All @@ -38,14 +35,10 @@ def _bootstrap_sampler(
"""
if sampling_strategy == 'poisson':
p = torch.distributions.Poisson(1)
n = p.sample((size,))
n = p.sample((size, ))
return torch.arange(size).repeat_interleave(n.long(), dim=0)
elif sampling_strategy == 'multinomial':
idx = torch.multinomial(
torch.ones(size),
num_samples=size,
replacement=True
)
idx = torch.multinomial(torch.ones(size), num_samples=size, replacement=True)
return idx
raise ValueError('Unknown sampling strategy')

Expand Down