Skip to content

Commit

Permalink
Classification metrics overhaul: input formatting standardization (1/…
Browse files Browse the repository at this point in the history
…n) (#4837)

* Add stuff

* Change metrics documentation layout

* Change testing utils

* Replace len(*.shape) with *.ndim

* More descriptive error message for input formatting

* Replace movedim with permute

* Style changes in error messages

* More error message style improvements

* Fix typo in docs

* Add more descriptive variable names in utils

* Change internal var names

* Break down error checking for inputs into separate functions

* Remove the (N, ..., C) option in MD-MC

* Simplify select_topk

* Remove detach for inputs

* Fix typos

* Update pytorch_lightning/metrics/classification/utils.py

Co-authored-by: Jirka Borovec <[email protected]>

* Update docs/source/metrics.rst

Co-authored-by: Jirka Borovec <[email protected]>

* Minor error message changes

* Update pytorch_lightning/metrics/utils.py

Co-authored-by: Jirka Borovec <[email protected]>

* Reuse case from validation in formatting

* Refactor code in _input_format_classification

* Small improvements

* PEP 8

* Update pytorch_lightning/metrics/classification/utils.py

Co-authored-by: Rohit Gupta <[email protected]>

* Update pytorch_lightning/metrics/classification/utils.py

Co-authored-by: Rohit Gupta <[email protected]>

* Update docs/source/metrics.rst

Co-authored-by: Rohit Gupta <[email protected]>

* Update pytorch_lightning/metrics/classification/utils.py

Co-authored-by: Rohit Gupta <[email protected]>

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <[email protected]>

* Alphabetical reordering of regression metrics

* Change default value of top_k and add error checking

* Extract basic validation into separate function

* Update desciption of parameters in input formatting

* Apply suggestions from code review

Co-authored-by: Nicki Skafte <[email protected]>

* Check that probabilities in preds sum to 1 (for MC)

* Fix coverage

* Minor changes

* Fix edge case and simplify testing

Co-authored-by: Teddy Koker <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: chaton <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
Co-authored-by: Nicki Skafte <[email protected]>
  • Loading branch information
6 people authored Dec 7, 2020
1 parent 02152c1 commit fedc0d1
Show file tree
Hide file tree
Showing 6 changed files with 1,058 additions and 220 deletions.
229 changes: 145 additions & 84 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,53 +196,76 @@ Metric API
.. autoclass:: pytorch_lightning.metrics.Metric
:noindex:

*************
Class metrics
*************
***************************
Class vs Functional Metrics
***************************

The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs.

Also, the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface.
If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the class interface.

**********************
Classification Metrics
----------------------
**********************

Accuracy
~~~~~~~~
Input types
-----------

.. autoclass:: pytorch_lightning.metrics.classification.Accuracy
:noindex:
For the purposes of classification metrics, inputs (predictions and targets) are split
into these categories (``N`` stands for the batch size and ``C`` for number of classes):

Precision
~~~~~~~~~
.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1
:header: "Type", "preds shape", "preds dtype", "target shape", "target dtype"
:widths: 20, 10, 10, 10, 10

.. autoclass:: pytorch_lightning.metrics.classification.Precision
:noindex:
"Binary", "(N,)", "``float``", "(N,)", "``binary``\*"
"Multi-class", "(N,)", "``int``", "(N,)", "``int``"
"Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``"
"Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*"
"Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``"
"Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``"

Recall
~~~~~~
.. note::
All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so
that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``.

.. autoclass:: pytorch_lightning.metrics.classification.Recall
:noindex:
When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types

FBeta
~~~~~
.. testcode::

.. autoclass:: pytorch_lightning.metrics.classification.FBeta
:noindex:
# Binary inputs
binary_preds = torch.tensor([0.6, 0.1, 0.9])
binary_target = torch.tensor([1, 0, 2])

F1
~~
# Multi-class inputs
mc_preds = torch.tensor([0, 2, 1])
mc_target = torch.tensor([0, 1, 2])

.. autoclass:: pytorch_lightning.metrics.classification.F1
:noindex:
# Multi-class inputs with probabilities
mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]])
mc_target_probs = torch.tensor([0, 1, 2])

ConfusionMatrix
~~~~~~~~~~~~~~~
# Multi-label inputs
ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])

.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
:noindex:
In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class
but are actually binary/multi-label. For example, if both predictions and targets are 1d
binary tensors. Or it could be the other way around, you want to treat binary/multi-label
inputs as 2-class (multi-dimensional) multi-class inputs.

PrecisionRecallCurve
~~~~~~~~~~~~~~~~~~~~
For these cases, the metrics where this distinction would make a difference, expose the
``is_multiclass`` argument.

.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecallCurve
Class Metrics (Classification)
------------------------------

Accuracy
~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.Accuracy
:noindex:

AveragePrecision
Expand All @@ -251,67 +274,51 @@ AveragePrecision
.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision
:noindex:

ROC
~~~
ConfusionMatrix
~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.ROC
.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
:noindex:

Regression Metrics
------------------

MeanSquaredError
~~~~~~~~~~~~~~~~
F1
~~

.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError
.. autoclass:: pytorch_lightning.metrics.classification.F1
:noindex:

FBeta
~~~~~

MeanAbsoluteError
~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError
.. autoclass:: pytorch_lightning.metrics.classification.FBeta
:noindex:

Precision
~~~~~~~~~

MeanSquaredLogError
~~~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError
.. autoclass:: pytorch_lightning.metrics.classification.Precision
:noindex:

PrecisionRecallCurve
~~~~~~~~~~~~~~~~~~~~

ExplainedVariance
~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance
.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecallCurve
:noindex:

Recall
~~~~~~

PSNR
~~~~

.. autoclass:: pytorch_lightning.metrics.regression.PSNR
.. autoclass:: pytorch_lightning.metrics.classification.Recall
:noindex:

ROC
~~~

SSIM
~~~~

.. autoclass:: pytorch_lightning.metrics.regression.SSIM
.. autoclass:: pytorch_lightning.metrics.classification.ROC
:noindex:

******************
Functional Metrics
******************

The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs.

Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface.
If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface.

Classification
--------------
Functional Metrics (Classification)
-----------------------------------

accuracy [func]
~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -417,6 +424,12 @@ recall [func]
.. autofunction:: pytorch_lightning.metrics.functional.classification.recall
:noindex:

select_topk [func]
~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.utils.select_topk
:noindex:


stat_scores [func]
~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -445,9 +458,57 @@ to_onehot [func]
.. autofunction:: pytorch_lightning.metrics.utils.to_onehot
:noindex:

******************
Regression Metrics
******************

Class Metrics (Regression)
--------------------------

Regression
----------
ExplainedVariance
~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance
:noindex:


MeanAbsoluteError
~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError
:noindex:


MeanSquaredError
~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError
:noindex:


MeanSquaredLogError
~~~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError
:noindex:


PSNR
~~~~

.. autoclass:: pytorch_lightning.metrics.regression.PSNR
:noindex:


SSIM
~~~~

.. autoclass:: pytorch_lightning.metrics.regression.SSIM
:noindex:


Functional Metrics (Regression)
-------------------------------

explained_variance [func]
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -470,17 +531,17 @@ mean_squared_error [func]
:noindex:


psnr [func]
~~~~~~~~~~~
mean_squared_log_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.psnr
.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error
:noindex:


mean_squared_log_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
psnr [func]
~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error
.. autofunction:: pytorch_lightning.metrics.functional.psnr
:noindex:


Expand All @@ -490,22 +551,22 @@ ssim [func]
.. autofunction:: pytorch_lightning.metrics.functional.ssim
:noindex:


***
NLP
---
***

bleu_score [func]
~~~~~~~~~~~~~~~~~
-----------------

.. autofunction:: pytorch_lightning.metrics.functional.nlp.bleu_score
:noindex:


********
Pairwise
--------
********

embedding_similarity [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~
---------------------------

.. autofunction:: pytorch_lightning.metrics.functional.self_supervised.embedding_similarity
:noindex:
Loading

0 comments on commit fedc0d1

Please sign in to comment.