Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Mar 14, 2021
2 parents 28f2d6c + 5d48e90 commit 4dfae65
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 65 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added prefix arg to metric collection ([#70](https://github.com/PyTorchLightning/metrics/pull/70))


### Changed

Expand Down
70 changes: 35 additions & 35 deletions docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,35 +55,35 @@ In practise this means that:
Internal implementation details
-------------------------------

This section briefly describe how metrics work internally. We encourage looking at the source code for more info.
This section briefly describes how metrics work internally. We encourage looking at the source code for more info.
Internally, Lightning wraps the user defined ``update()`` and ``compute()`` method. We do this to automatically
synchronize and reduce metric states across multiple devices. More precisely, calling ``update()`` does the
following internally:

1. Clears computed cache
2. Calls user-defined ``update()``
1. Clears computed cache.
2. Calls user-defined ``update()``.

Simiarly, calling ``compute()`` does the following internally
Similarly, calling ``compute()`` does the following internally:

1. Syncs metric states between processes
2. Reduce gathered metric states
3. Calls the user defined ``compute()`` method on the gathered metric states
4. Cache computed result
1. Syncs metric states between processes.
2. Reduce gathered metric states.
3. Calls the user defined ``compute()`` method on the gathered metric states.
4. Cache computed result.

From a user's standpoint this has one important side-effect: computed results are cached. This means that no
matter how many times ``compute`` is called after one and another, it will continue to return the same result.
The cache is first emptied on the next call to ``update``.

``forward`` serves the dual purpose of both returning the metric on the current data and updating the internal
metric state for accumulating over multiple batches. The ``forward()`` method achives this by combining calls
metric state for accumulating over multiple batches. The ``forward()`` method achieves this by combining calls
to ``update`` and ``compute`` in the following way (assuming metric is initialized with ``compute_on_step=True``):

1. Calls ``update()`` to update the global metric states (for accumulation over multiple batches)
2. Caches the global state
3. Calls ``reset()`` to clear global metric state
4. Calls ``update()`` to update local metric state
5. Calls ``compute()`` to calculate metric for current batch
6. Restores the global state
1. Calls ``update()`` to update the global metric state (for accumulation over multiple batches)
2. Caches the global state.
3. Calls ``reset()`` to clear global metric state.
4. Calls ``update()`` to update local metric state.
5. Calls ``compute()`` to calculate metric for current batch.
6. Restores the global state.

This procedure has the consequence of calling the user defined ``update`` **twice** during a single
forward call (one to update global statistics and one for getting the batch statistics).
Expand All @@ -97,18 +97,18 @@ forward call (one to update global statistics and one for getting the batch stat
Contributing your metric to Torchmetrics
----------------------------------------

Wanting to contribute the metric you have implement? Great, we are always open to adding more metrics to Torchmetrics
Wanting to contribute the metric you have implemented? Great, we are always open to adding more metrics to ``torchmetrics``
as long as they serve a general purpose. However, to keep all our metrics consistent we request that the implementation
and tests gets formatted in the following way:

1. Start by reading our `contribution guidelines <https://torchmetrics.readthedocs.io//en/latest/generated/CONTRIBUTING.html>`_
2. First implement the functional backend. This takes cares of all logic that does into the metric. The code should
to put into single file placed under ``torchmetrics/functional/"domain"/"new_metric".py`` where ``domain`` is the type of
metric (classification, regression, nlp ect) and ``new_metric`` is the name of the metric. In this file should be the
1. Start by reading our `contribution guidelines <https://torchmetrics.readthedocs.io//en/latest/generated/CONTRIBUTING.html>`_.
2. First implement the functional backend. This takes cares of all the logic that goes into the metric. The code should
be put into a single file placed under ``torchmetrics/functional/"domain"/"new_metric".py`` where ``domain`` is the type of
metric (classification, regression, nlp etc) and ``new_metric`` is the name of the metric. In this file, there should be the
following three functions:

1. ``_new_metric_update(...)``: everything that has to do with type/shape checking and all logic required before distributed syncing need to go here.
2. ``_new_metric_compute(...)``: all remaining logic
2. ``_new_metric_compute(...)``: all remaining logic.
3. ``new_metric(...)``: essentially wraps the ``_update`` and ``_compute`` private functions into one public function that
makes up the functional interface for the metric.

Expand All @@ -118,33 +118,33 @@ and tests gets formatted in the following way:

3. In a corresponding file placed in ``torchmetrics/"domain"/"new_metric".py`` create the module interface:

1. Create a new module metric by subclassing ``torchmetrics.Metric``
1. Create a new module metric by subclassing ``torchmetrics.Metric``.
2. In the ``__init__`` of the module call ``self.add_state`` for as many metric states are needed for the metric to
proper accumulate metric statistics
3. The module interface should essentially call the private ``_new_metric_update(...)`` in its `update` method and simiarly the
proper accumulate metric statistics.
3. The module interface should essentially call the private ``_new_metric_update(...)`` in its `update` method and similarly the
``_new_metric_compute(...)`` function in its ``compute``. No logic should really be implemented in the module interface.
We do this to not have duplicate code to maintain.

.. note::
The module `Accuracy <https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/classification/accuracy.py>`_
metric that correspond to the above functional example showcases these steps
metric that corresponds to the above functional example showcases these steps.

4. Remember to add binding to the different relevant ``__init__`` files
4. Remember to add binding to the different relevant ``__init__`` files.

5. Testing is key to keeping torchmetrics trustworty. This is why we have a very rigid testing protocol. This means
that we in most cases require the metric to be tested against some other commen framework (``sklearn``, ``scipy`` ect).
5. Testing is key to keeping ``torchmetrics`` trustworty. This is why we have a very rigid testing protocol. This means
that we in most cases require the metric to be tested against some other common framework (``sklearn``, ``scipy`` etc).

1. Create a testing file in ``tests/"domain"/test_"new_metric".py``. Only one file is needed as it is intended to test
both the functional and module interface
2. In that file, start by defining a number of test inputs that your metric should be evaluated on
both the functional and module interface.
2. In that file, start by defining a number of test inputs that your metric should be evaluated on.
3. Create a testclass ``class NewMetric(MetricTester)`` that inherits from ``tests.helpers.testers.MetricTester``.
This testclass should essentially implement the ``test_"new_metric"_class`` and ``test_"new_metric"_fn`` methods that
respectively tests the module interface and the functional interface.
4. The testclass should be parametrized (using ``@pytest.mark.parametrize``) by the different test inputs defined initiallly.
Additionally, the ``test_"new_metric"_class`` method should also be parametrized with an `ddp` parameter such that it gets
tested in a distributed setting. If your metric has additionally parameters, then make sure to also parametrize these
such that different combinations of input and parameters gets tested.
5. (optional) Ff your metrics raises any exceptions, please add tests that showcases this
4. The testclass should be parameterized (using ``@pytest.mark.parametrize``) by the different test inputs defined initially.
Additionally, the ``test_"new_metric"_class`` method should also be parameterized with an ``ddp`` parameter such that it gets
tested in a distributed setting. If your metric has additional parameters, then make sure to also parameterize these
such that different combinations of inputs and parameters gets tested.
5. (optional) If your metric raises any exception, please add tests that showcase this.

.. note::
The `test file for accuracy <https://github.com/PyTorchLightning/metrics/blob/master/tests/classification/test_accuracy.py>`_ metric
Expand Down
44 changes: 24 additions & 20 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,27 +195,31 @@ Example:
Similarly it can also reduce the amount of code required to log multiple metrics
inside your LightningModule

.. code-block:: python
.. testcode::

from torchmetrics import Accuracy, MetricCollection, Precision, Recall

def __init__(self):
...
metrics = pl.metrics.MetricCollection(...)
self.train_metrics = metrics.clone()
self.valid_metrics = metrics.clone()
def training_step(self, batch, batch_idx):
logits = self(x)
...
self.train_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train')
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val')
class MyModule():
def __init__(self):
metrics = MetricCollection(Accuracy(), Precision(), Recall())
self.train_metrics = metrics.clone(prefix='train_')
self.valid_metrics = metrics.clone(prefix='val_')

def training_step(self, batch, batch_idx):
logits = self(x)
# ...
output = self.train_metrics(logits, y)
# use log_dict instead of log
# metrics are logged with keys: train_Accuracy, train_Precision and train_Recall
self.log_dict(output)

def validation_step(self, batch, batch_idx):
logits = self(x)
# ...
output = self.valid_metrics(logits, y)
# use log_dict instead of log
# metrics are logged with keys: val_Accuracy, val_Precision and val_Recall
self.log_dict(output)

.. note::

Expand Down
25 changes: 25 additions & 0 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,28 @@ def test_metric_collection_args_kwargs(tmpdir):
_ = metric_collection(x=10, y=20)
assert metric_collection['DummyMetricSum'].x == 10
assert metric_collection['DummyMetricDiff'].x == -20


def test_metric_collection_prefix_arg(tmpdir):
""" Test that the prefix arg alters the keywords in the output"""
m1 = DummyMetricSum()
m2 = DummyMetricDiff()
names = ['DummyMetricSum', 'DummyMetricDiff']

metric_collection = MetricCollection([m1, m2], prefix='prefix_')

# test forward
out = metric_collection(5)
for name in names:
assert f"prefix_{name}" in out, 'prefix argument not working as intended with forward method'

# test compute
out = metric_collection.compute()
for name in names:
assert f"prefix_{name}" in out, 'prefix argument not working as intended with compute method'

# test clone
new_metric_collection = metric_collection.clone(prefix='new_prefix_')
out = new_metric_collection(5)
for name in names:
assert f"new_prefix_{name}" in out, 'prefix argument not working as intended with clone method'
43 changes: 33 additions & 10 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from copy import deepcopy
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from torch import nn

Expand All @@ -36,6 +36,8 @@ class MetricCollection(nn.ModuleDict):
dict as key for output dict. Use this format if you want to chain
together multiple of the same metric with different parameters.
prefix: a string to append in front of the keys of the output dict
Example (input as list):
>>> import torch
>>> from torchmetrics import MetricCollection, Accuracy, Precision, Recall
Expand All @@ -58,8 +60,11 @@ class MetricCollection(nn.ModuleDict):
>>> metrics.persistent()
"""

def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]):
def __init__(
self,
metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]],
prefix: Optional[str] = None
):
super().__init__()
if isinstance(metrics, dict):
# Check all values are metrics
Expand All @@ -84,13 +89,15 @@ def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]
else:
raise ValueError("Unknown input to MetricCollection.")

self.prefix = self._check_prefix_arg(prefix)

def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
"""
Iteratively call forward for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
return {self._set_prefix(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}

def update(self, *args, **kwargs): # pylint: disable=E0202
"""
Expand All @@ -103,20 +110,36 @@ def update(self, *args, **kwargs): # pylint: disable=E0202
m.update(*args, **m_kwargs)

def compute(self) -> Dict[str, Any]:
return {k: m.compute() for k, m in self.items()}
return {self._set_prefix(k): m.compute() for k, m in self.items()}

def reset(self):
def reset(self) -> None:
""" Iteratively call reset for each metric """
for _, m in self.items():
m.reset()

def clone(self):
""" Make a copy of the metric collection """
return deepcopy(self)
def clone(self, prefix: Optional[str] = None) -> 'MetricCollection':
""" Make a copy of the metric collection
Args:
prefix: a string to append in front of the metric keys
"""
mc = deepcopy(self)
mc.prefix = self._check_prefix_arg(prefix)
return mc

def persistent(self, mode: bool = True):
def persistent(self, mode: bool = True) -> None:
"""Method for post-init to change if metric states should be saved to
its state_dict
"""
for _, m in self.items():
m.persistent(mode)

def _set_prefix(self, k: str) -> str:
return k if self.prefix is None else self.prefix + k

def _check_prefix_arg(self, prefix: str) -> Optional[str]:
if prefix is not None:
if isinstance(prefix, str):
return prefix
else:
raise ValueError('Expected input `prefix` to be a string')
return None

0 comments on commit 4dfae65

Please sign in to comment.