From 06668c0ddf745a712c0acc142c0c93de6ccd3bdc Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 8 Jan 2021 11:09:07 +0100 Subject: [PATCH] [Metrics] MetricCollection (#4318) * docs + precision + recall + f_beta + refactor Co-authored-by: Teddy Koker * rebase Co-authored-by: Teddy Koker * fixes Co-authored-by: Teddy Koker * added missing file * docs * docs * extra import * add metric collection * add docs + integration with log_dict * add test * update * update * more test * more test * pep8 * fix doctest * pep8 * add clone method * add clone method * merge-2 * changelog * kwargs filtering and tests * pep8 * fix test * update docs * Update docs/source/metrics.rst Co-authored-by: Roger Shieh * fix docs * fix tests * Apply suggestions from code review Co-authored-by: Jirka Borovec * fix docs * fix doctest * fix doctest * fix doctest * fix doctest Co-authored-by: ananyahjha93 Co-authored-by: Teddy Koker Co-authored-by: Nicki Skafte Co-authored-by: Roger Shieh Co-authored-by: Jirka Borovec --- CHANGELOG.md | 6 ++ docs/source/metrics.rst | 77 ++++++++++++- pytorch_lightning/metrics/__init__.py | 2 +- pytorch_lightning/metrics/metric.py | 117 +++++++++++++++++++- tests/metrics/test_metric.py | 144 ++++++++++++++++++++++++- tests/metrics/test_metric_lightning.py | 52 ++++++++- 6 files changed, 384 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9209d1363326f..3db41991aaed6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `image_gradients` functional metric to compute the image gradients of a given input image. ([#5056](https://github.com/PyTorchLightning/pytorch-lightning/pull/5056)) +- Added `MetricCollection` ([#4318](https://github.com/PyTorchLightning/pytorch-lightning/pull/4318)) + + +- Added `.clone()` method to metrics ([#4318](https://github.com/PyTorchLightning/pytorch-lightning/pull/4318)) + + ### Changed - Changed `automatic casting` for LoggerConnector `metrics` ([#5218](https://github.com/PyTorchLightning/pytorch-lightning/pull/5218)) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 6a2033d8d63a1..ed6b847f72fa3 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -81,6 +81,7 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True) .. note:: + If using metrics in data parallel mode (dp), the metric update/logging should be done in the ``_step_end`` method (where ```` is either ``training``, ``validation`` or ``test``). This is due to metric states else being destroyed after each forward pass, @@ -99,7 +100,6 @@ If ``on_epoch`` is True, the logger automatically logs the end of epoch metric v self.metric(outputs['preds'], outputs['target']) self.log('metric', self.metric) - This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example: .. code-block:: python @@ -131,7 +131,17 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us Metrics contain internal states that keep track of the data seen so far. Do not mix metric states across training, validation and testing. It is highly recommended to re-initialize the metric per mode as - shown in the examples above. + shown in the examples above. For easy initializing the same metric multiple + times, the ``.clone()`` method can be used: + + .. testcode:: + + def __init__(self): + ... + metric = pl.metrics.Accuracy() + self.train_acc = metric.clone() + self.val_acc = metric.clone() + self.test_acc = metric.clone() .. note:: @@ -240,6 +250,69 @@ In practise this means that: val = metric(pred, target) # this value can be backpropagated val = metric.compute() # this value cannot be backpropagated +**************** +MetricCollection +**************** + +In many cases it is beneficial to evaluate the model output by multiple metrics. +In this case the `MetricCollection` class may come in handy. It accepts a sequence +of metrics and wraps theses into a single callable metric class, with the same +interface as any other metric. + +Example: + +.. testcode:: + + from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall + target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) + preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) + metric_collection = MetricCollection([ + Accuracy(), + Precision(num_classes=3, average='macro'), + Recall(num_classes=3, average='macro') + ]) + print(metric_collection(preds, target)) + +.. testoutput:: + :options: +NORMALIZE_WHITESPACE + + {'Accuracy': tensor(0.1250), + 'Precision': tensor(0.0667), + 'Recall': tensor(0.1111)} + +Similarly it can also reduce the amount of code required to log multiple metrics +inside your LightningModule + +.. code-block:: python + + def __init__(self): + ... + metrics = pl.metrics.MetricCollection(...) + self.train_metrics = metrics.clone() + self.valid_metrics = metrics.clone() + + def training_step(self, batch, batch_idx): + logits = self(x) + ... + self.train_metrics(logits, y) + # use log_dict instead of log + self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train') + + def validation_step(self, batch, batch_idx): + logits = self(x) + ... + self.valid_metrics(logits, y) + # use log_dict instead of log + self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val') + +.. note:: + + `MetricCollection` as default assumes that all the metrics in the collection + have the same call signature. If this is not the case, input that should be + given to different metrics can given as keyword arguments to the collection. + +.. autoclass:: pytorch_lightning.metrics.MetricCollection + :noindex: ********** Metric API diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 68268c6a3e1d6..72e04a0b987a0 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.metrics.metric import Metric # noqa: F401 +from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F401 from pytorch_lightning.metrics.classification import ( # noqa: F401 Accuracy, diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index a21242c3bdc7e..05b719e8a0610 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import inspect from abc import ABC, abstractmethod from collections.abc import Sequence from copy import deepcopy -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -57,6 +58,7 @@ class Metric(nn.Module, ABC): Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. default: None """ + def __init__( self, compute_on_step: bool = True, @@ -72,6 +74,7 @@ def __init__( self.dist_sync_fn = dist_sync_fn self._to_sync = True + self._update_signature = inspect.signature(self.update) self.update = self._wrap_update(self.update) self.compute = self._wrap_compute(self.compute) self._computed = None @@ -120,7 +123,7 @@ def add_state( """ if ( not isinstance(default, torch.Tensor) - and not isinstance(default, list) # noqa: W503 + and not isinstance(default, list) # noqa: W503 or (isinstance(default, list) and len(default) != 0) # noqa: W503 ): raise ValueError( @@ -208,9 +211,11 @@ def wrapped_func(*args, **kwargs): return self._computed dist_sync_fn = self.dist_sync_fn - if (dist_sync_fn is None - and torch.distributed.is_available() - and torch.distributed.is_initialized()): + if ( + dist_sync_fn is None + and torch.distributed.is_available() + and torch.distributed.is_initialized() + ): # User provided a bool, so we assume DDP if available dist_sync_fn = gather_all_tensors @@ -250,6 +255,10 @@ def reset(self): else: setattr(self, attr, deepcopy(default)) + def clone(self): + """ Make a copy of the metric """ + return deepcopy(self) + def __getstate__(self): # ignore update and compute functions for pickling return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]} @@ -292,3 +301,101 @@ def state_dict(self, *args, **kwargs): current_val = getattr(self, key) state_dict.update({key: current_val}) return state_dict + + +class MetricCollection(nn.ModuleDict): + """ + MetricCollection class can be used to chain metrics that have the same + call pattern into one single class. + + Args: + metrics: One of the following + + * list or tuple: if metrics are passed in as a list, will use the + metrics class name as key for output dict. Therefore, two metrics + of the same class cannot be chained this way. + + * dict: if metrics are passed in as a dict, will use each key in the + dict as key for output dict. Use this format if you want to chain + together multiple of the same metric with different parameters. + + Example (input as list): + + >>> from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall + >>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) + >>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) + >>> metrics = MetricCollection([Accuracy(), + ... Precision(num_classes=3, average='macro'), + ... Recall(num_classes=3, average='macro')]) + >>> metrics(preds, target) + {'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} + + Example (input as dict): + + >>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'), + ... 'macro_recall': Recall(num_classes=3, average='macro')}) + >>> metrics(preds, target) + {'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)} + + """ + def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): + super().__init__() + if isinstance(metrics, dict): + # Check all values are metrics + for name, metric in metrics.items(): + if not isinstance(metric, Metric): + raise ValueError(f'Value {metric} belonging to key {name}' + ' is not an instance of `pl.metrics.Metric`') + self[name] = metric + elif isinstance(metrics, (tuple, list)): + for metric in metrics: + if not isinstance(metric, Metric): + raise ValueError(f'Input {metric} to `MetricCollection` is not a instance' + ' of `pl.metrics.Metric`') + name = metric.__class__.__name__ + if name in self: + raise ValueError(f'Encountered two metrics both named {name}') + self[name] = metric + else: + raise ValueError('Unknown input to MetricCollection.') + + def _filter_kwargs(self, metric: Metric, **kwargs): + """ filter kwargs such that they match the update signature of the metric """ + return {k: v for k, v in kwargs.items() if k in metric._update_signature.parameters.keys()} + + def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202 + """ + Iteratively call forward for each metric. Positional arguments (args) will + be passed to every metric in the collection, while keyword arguments (kwargs) + will be filtered based on the signature of the individual metric. + """ + return {k: m(*args, **self._filter_kwargs(m, **kwargs)) for k, m in self.items()} + + def update(self, *args, **kwargs): # pylint: disable=E0202 + """ + Iteratively call update for each metric. Positional arguments (args) will + be passed to every metric in the collection, while keyword arguments (kwargs) + will be filtered based on the signature of the individual metric. + """ + for _, m in self.items(): + m_kwargs = self._filter_kwargs(m, **kwargs) + m.update(*args, **m_kwargs) + + def compute(self) -> Dict[str, Any]: + return {k: m.compute() for k, m in self.items()} + + def reset(self): + """ Iteratively call reset for each metric """ + for _, m in self.items(): + m.reset() + + def clone(self): + """ Make a copy of the metric collection """ + return deepcopy(self) + + def persistent(self, mode: bool = True): + """ Method for post-init to change if metric states should be saved to + its state_dict + """ + for _, m in self.items(): + m.persistent(mode) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 33948204cb054..c3cafa2365267 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -6,8 +6,7 @@ import numpy as np import pytest import torch - -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.metric import Metric, MetricCollection torch.manual_seed(42) @@ -17,7 +16,7 @@ class Dummy(Metric): def __init__(self): super().__init__() - self.add_state("x", torch.tensor(0), dist_reduce_fx=None) + self.add_state("x", torch.tensor(0.0), dist_reduce_fx=None) def update(self): pass @@ -166,7 +165,7 @@ def compute(self): assert a.compute() == 13 -class ToPickle(Dummy): +class DummyMetric1(Dummy): def update(self, x): self.x += x @@ -174,9 +173,17 @@ def compute(self): return self.x +class DummyMetric2(Dummy): + def update(self, y): + self.x -= y + + def compute(self): + return self.x + + def test_pickle(tmpdir): # doesn't tests for DDP - a = ToPickle() + a = DummyMetric1() a.update(1) metric_pickled = pickle.dumps(a) @@ -201,3 +208,130 @@ def test_state_dict(tmpdir): assert metric.state_dict() == OrderedDict(x=0) metric.persistent(False) assert metric.state_dict() == OrderedDict() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.") +def test_device_and_dtype_transfer(tmpdir): + metric = DummyMetric1() + assert metric.x.is_cuda is False + assert metric.x.dtype == torch.float32 + + metric = metric.to(device='cuda') + assert metric.x.is_cuda + + metric = metric.double() + assert metric.x.dtype == torch.float64 + + metric = metric.half() + assert metric.x.dtype == torch.float16 + + +def test_metric_collection(tmpdir): + m1 = DummyMetric1() + m2 = DummyMetric2() + + metric_collection = MetricCollection([m1, m2]) + + # Test correct dict structure + assert len(metric_collection) == 2 + assert metric_collection['DummyMetric1'] == m1 + assert metric_collection['DummyMetric2'] == m2 + + # Test correct initialization + for name, metric in metric_collection.items(): + assert metric.x == 0, f'Metric {name} not initialized correctly' + + # Test every metric gets updated + metric_collection.update(5) + for name, metric in metric_collection.items(): + assert metric.x.abs() == 5, f'Metric {name} not updated correctly' + + # Test compute on each metric + metric_collection.update(-5) + metric_vals = metric_collection.compute() + assert len(metric_vals) == 2 + for name, metric_val in metric_vals.items(): + assert metric_val == 0, f'Metric {name}.compute not called correctly' + + # Test that everything is reset + for name, metric in metric_collection.items(): + assert metric.x == 0, f'Metric {name} not reset correctly' + + # Test pickable + metric_pickled = pickle.dumps(metric_collection) + metric_loaded = pickle.loads(metric_pickled) + assert isinstance(metric_loaded, MetricCollection) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.") +def test_device_and_dtype_transfer_metriccollection(tmpdir): + m1 = DummyMetric1() + m2 = DummyMetric2() + + metric_collection = MetricCollection([m1, m2]) + for _, metric in metric_collection.items(): + assert metric.x.is_cuda is False + assert metric.x.dtype == torch.float32 + + metric_collection = metric_collection.to(device='cuda') + for _, metric in metric_collection.items(): + assert metric.x.is_cuda + + metric_collection = metric_collection.double() + for _, metric in metric_collection.items(): + assert metric.x.dtype == torch.float64 + + metric_collection = metric_collection.half() + for _, metric in metric_collection.items(): + assert metric.x.dtype == torch.float16 + + +def test_metric_collection_wrong_input(tmpdir): + """ Check that errors are raised on wrong input """ + m1 = DummyMetric1() + + # Not all input are metrics (list) + with pytest.raises(ValueError): + _ = MetricCollection([m1, 5]) + + # Not all input are metrics (dict) + with pytest.raises(ValueError): + _ = MetricCollection({'metric1': m1, + 'metric2': 5}) + + # Same metric passed in multiple times + with pytest.raises(ValueError, match='Encountered two metrics both named *.'): + _ = MetricCollection([m1, m1]) + + # Not a list or dict passed in + with pytest.raises(ValueError, match='Unknown input to MetricCollection.'): + _ = MetricCollection(m1) + + +def test_metric_collection_args_kwargs(tmpdir): + """ Check that args and kwargs gets passed correctly in metric collection, + Checks both update and forward method + """ + m1 = DummyMetric1() + m2 = DummyMetric2() + + metric_collection = MetricCollection([m1, m2]) + + # args gets passed to all metrics + metric_collection.update(5) + assert metric_collection['DummyMetric1'].x == 5 + assert metric_collection['DummyMetric2'].x == -5 + metric_collection.reset() + _ = metric_collection(5) + assert metric_collection['DummyMetric1'].x == 5 + assert metric_collection['DummyMetric2'].x == -5 + metric_collection.reset() + + # kwargs gets only passed to metrics that it matches + metric_collection.update(x=10, y=20) + assert metric_collection['DummyMetric1'].x == 10 + assert metric_collection['DummyMetric2'].x == -20 + metric_collection.reset() + _ = metric_collection(x=10, y=20) + assert metric_collection['DummyMetric1'].x == 10 + assert metric_collection['DummyMetric2'].x == -20 diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index ed809c5e8527e..2347cc65f8293 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,7 +1,7 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.metrics import Metric +from pytorch_lightning.metrics import Metric, MetricCollection from tests.base.boring_model import BoringModel @@ -17,6 +17,18 @@ def compute(self): return self.x +class DiffMetric(Metric): + def __init__(self): + super().__init__() + self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, x): + self.x -= x + + def compute(self): + return self.x + + def test_metric_lightning(tmpdir): class TestModel(BoringModel): def __init__(self): @@ -125,3 +137,41 @@ def training_step(self, batch, batch_idx): output = model(rand_input) script_output = script_model(rand_input) assert torch.allclose(output, script_output) + + +def test_metric_collection_lightning_log(tmpdir): + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.metric = MetricCollection([SumMetric(), DiffMetric()]) + self.sum = 0.0 + self.diff = 0.0 + + def training_step(self, batch, batch_idx): + x = batch + metric_vals = self.metric(x.sum()) + self.sum += x.sum() + self.diff -= x.sum() + self.log_dict({f'{k}_step': v for k, v in metric_vals.items()}) + return self.step(x) + + def training_epoch_end(self, outputs): + metric_vals = self.metric.compute() + self.log_dict({f'{k}_epoch': v for k, v in metric_vals.items()}) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + logged = trainer.logged_metrics + assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum) + assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff)