-
Notifications
You must be signed in to change notification settings - Fork 412
/
multioutput.py
203 lines (172 loc) · 9.27 KB
/
multioutput.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from copy import deepcopy
from typing import Any, Optional, Union
import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import ModuleList
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchmetrics.wrappers.abstract import WrapperMetric
if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["MultioutputWrapper.plot"]
def _get_nan_indices(*tensors: Tensor) -> Tensor:
"""Get indices of rows along dim 0 which have NaN values."""
if len(tensors) == 0:
raise ValueError("Must pass at least one tensor as argument")
sentinel = tensors[0]
nan_idxs = torch.zeros(len(sentinel), dtype=torch.bool, device=sentinel.device)
for tensor in tensors:
permuted_tensor = tensor.flatten(start_dim=1)
nan_idxs |= torch.any(torch.isnan(permuted_tensor), dim=1)
return nan_idxs
class MultioutputWrapper(WrapperMetric):
"""Wrap a base metric to enable it to support multiple outputs.
Several torchmetrics metrics, such as :class:`~torchmetrics.regression.spearman.SpearmanCorrCoef` lack support for
multioutput mode. This class wraps such metrics to support computing one metric per output.
Unlike specific torchmetric metrics, it doesn't support any aggregation across outputs.
This means if you set ``num_outputs`` to 2, ``.compute()`` will return a Tensor of dimension
``(2, ...)`` where ``...`` represents the dimensions the metric returns when not wrapped.
In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude
fashion, dealing with missing labels (or other data). When ``remove_nans`` is passed, the class will remove the
intersection of NaN containing "rows" upon each update for each output. For example, suppose a user uses
`MultioutputWrapper` to wrap :class:`torchmetrics.regression.r2.R2Score` with 2 outputs, one of which occasionally
has missing labels for classes like ``R2Score`` is that this class supports removing ``NaN`` values
(parameter ``remove_nans``) on a per-output basis. When ``remove_nans`` is passed the wrapper will remove all rows
Args:
base_metric: Metric being wrapped.
num_outputs: Expected dimensionality of the output dimension.
This parameter is used to determine the number of distinct metrics we need to track.
output_dim:
Dimension on which output is expected. Note that while this provides some flexibility, the output dimension
must be the same for all inputs to update. This applies even for metrics such as `Accuracy` where the labels
can have a different number of dimensions than the predictions. This can be worked around if the output
dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs.
remove_nans:
Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying
metric. Proper operation requires all tensors passed to update to have dimension ``(N, ...)`` where N
represents the length of the batch or dataset being passed in.
squeeze_outputs:
If ``True``, will squeeze the 1-item dimensions left after ``index_select`` is applied.
This is sometimes unnecessary but harmless for metrics such as `R2Score` but useful
for certain classification metrics that can't handle additional 1-item dimensions.
Example:
>>> # Mimic R2Score in `multioutput`, `raw_values` mode:
>>> import torch
>>> from torchmetrics.wrappers import MultioutputWrapper
>>> from torchmetrics.regression import R2Score
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = MultioutputWrapper(R2Score(), 2)
>>> r2score(preds, target)
tensor([0.9654, 0.9082])
"""
is_differentiable = False
def __init__(
self,
base_metric: Metric,
num_outputs: int,
output_dim: int = -1,
remove_nans: bool = True,
squeeze_outputs: bool = True,
) -> None:
super().__init__()
self.metrics = ModuleList([deepcopy(base_metric) for _ in range(num_outputs)])
self.output_dim = output_dim
self.remove_nans = remove_nans
self.squeeze_outputs = squeeze_outputs
def _get_args_kwargs_by_output(self, *args: Tensor, **kwargs: Tensor) -> list[tuple[Tensor, Tensor]]:
"""Get args and kwargs reshaped to be output-specific and (maybe) having NaNs stripped out."""
args_kwargs_by_output = []
for i in range(len(self.metrics)):
selected_args = apply_to_collection(
args, Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device)
)
selected_kwargs = apply_to_collection(
kwargs, Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device)
)
if self.remove_nans:
args_kwargs = selected_args + tuple(selected_kwargs.values())
nan_idxs = _get_nan_indices(*args_kwargs)
selected_args = [arg[~nan_idxs] for arg in selected_args]
selected_kwargs = {k: v[~nan_idxs] for k, v in selected_kwargs.items()}
if self.squeeze_outputs:
selected_args = [arg.squeeze(self.output_dim) for arg in selected_args]
selected_kwargs = {k: v.squeeze(self.output_dim) for k, v in selected_kwargs.items()}
args_kwargs_by_output.append((selected_args, selected_kwargs))
return args_kwargs_by_output
def update(self, *args: Any, **kwargs: Any) -> None:
"""Update each underlying metric with the corresponding output."""
reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs)
for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs):
metric.update(*selected_args, **selected_kwargs)
def compute(self) -> Tensor:
"""Compute metrics."""
return torch.stack([m.compute() for m in self.metrics], 0)
@torch.jit.unused
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Call underlying forward methods and aggregate the results if they're non-null.
We override this method to ensure that state variables get copied over on the underlying metrics.
"""
reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs)
results = [
metric(*selected_args, **selected_kwargs)
for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs)
]
if results[0] is None:
return None
return torch.stack(results, 0)
def reset(self) -> None:
"""Reset all underlying metrics."""
for metric in self.metrics:
metric.reset()
super().reset()
def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.wrappers import MultioutputWrapper
>>> from torchmetrics.regression import R2Score
>>> metric = MultioutputWrapper(R2Score(), 2)
>>> metric.update(torch.randn(20, 2), torch.randn(20, 2))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.wrappers import MultioutputWrapper
>>> from torchmetrics.regression import R2Score
>>> metric = MultioutputWrapper(R2Score(), 2)
>>> values = [ ]
>>> for _ in range(3):
... values.append(metric(torch.randn(20, 2), torch.randn(20, 2)))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)