Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update docs on nested metrics #1002

Merged
merged 2 commits into from
May 3, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
To change this, after initializing the metric, the method ``.persistent(mode)`` can
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.

.. note::

Due to specialized logic around metric states, we in general do **not** recommend
that metrics are initialized inside other metrics (nested metrics), as this can lead
to weird behaviour. Instead consider subclassing a metric or use
``torchmetrics.MetricCollection``.

*******************
Metrics and devices
Expand Down Expand Up @@ -333,8 +339,8 @@ Metrics and differentiability
*****************************

Metrics support backpropagation, if all computations involved in the metric calculation
are differentiable. All modular metrics have a property that determines if a metric is
differentiable or not.
are differentiable. All modular metric classes have the property ``is_differentiable`` that determines
if a metric is differentiable or not.

However, note that the cached state is detached from the computational
graph and cannot be back-propagated. Not doing this would mean storing the computational
Expand All @@ -343,12 +349,30 @@ In practise this means that:

.. code-block:: python

MyMetric.is_differentiable # returns True if metric is differentiable
metric = MyMetric()
val = metric(pred, target) # this value can be back-propagated
val = metric.compute() # this value cannot be back-propagated
val = metric(pred, target) # this value can be back-propagated
val = metric.compute() # this value cannot be back-propagated

A functional metric is differentiable if its corresponding modular metric is differentiable.

***************************************
Metrics and hyperparameter optimization
***************************************

If you want to directly optimize a metric it needs to support backpropagation (see section above).
However, if you are just interested in using a metric for hyperparameter tuning and are not sure
if the metric should be maximized or minimized, all modular metric classes have the ``higher_is_better``
property that can be used to determine this:

.. code-block:: python

# returns True because accuracy is optimal when it is maximized
torchmetrics.Accuracy.higher_is_better

# returns False because the mean squared error is optimal when it is minimized
torchmetrics.MeanSquaredError.higher_is_better

.. _Metric kwargs:

************************
Expand Down