From 6144eb3b248b7fa315bb9afeb96c690a5d747001 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 3 May 2022 16:57:41 +0200 Subject: [PATCH] Update docs on nested metrics (#1002) * Apply suggestions from code review Co-authored-by: Ethan Harris --- docs/source/pages/overview.rst | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index a2811d0e0e1..9436fc51a13 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -67,6 +67,12 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us To change this, after initializing the metric, the method ``.persistent(mode)`` can be used to enable (``mode=True``) or disable (``mode=False``) this behaviour. +.. note:: + + Due to specialized logic around metric states, we in general do **not** recommend + that metrics are initialized inside other metrics (nested metrics), as this can lead + to weird behaviour. Instead consider subclassing a metric or use + ``torchmetrics.MetricCollection``. ******************* Metrics and devices @@ -333,8 +339,8 @@ Metrics and differentiability ***************************** Metrics support backpropagation, if all computations involved in the metric calculation -are differentiable. All modular metrics have a property that determines if a metric is -differentiable or not. +are differentiable. All modular metric classes have the property ``is_differentiable`` that determines +if a metric is differentiable or not. However, note that the cached state is detached from the computational graph and cannot be back-propagated. Not doing this would mean storing the computational @@ -343,12 +349,30 @@ In practise this means that: .. code-block:: python + MyMetric.is_differentiable # returns True if metric is differentiable metric = MyMetric() - val = metric(pred, target) # this value can be back-propagated - val = metric.compute() # this value cannot be back-propagated + val = metric(pred, target) # this value can be back-propagated + val = metric.compute() # this value cannot be back-propagated A functional metric is differentiable if its corresponding modular metric is differentiable. +*************************************** +Metrics and hyperparameter optimization +*************************************** + +If you want to directly optimize a metric it needs to support backpropagation (see section above). +However, if you are just interested in using a metric for hyperparameter tuning and are not sure +if the metric should be maximized or minimized, all modular metric classes have the ``higher_is_better`` +property that can be used to determine this: + +.. code-block:: python + + # returns True because accuracy is optimal when it is maximized + torchmetrics.Accuracy.higher_is_better + + # returns False because the mean squared error is optimal when it is minimized + torchmetrics.MeanSquaredError.higher_is_better + .. _Metric kwargs: ************************