Skip to content

Commit

Permalink
[Refactor] Classification 6/n (#1163)
Browse files Browse the repository at this point in the history
* some idea

* implementation stuff

* error message

* somethings working

* working

* working multilabel precision recall

* docstrings

* working average precision

* working roc

* working auroc

* working precision at recall

* docs

* init files

* beginning doctest

* more docs

* more docs

* more docs

* more docs

* more docs

* more docs

* correction to math

* fix mypy

* add suggestions

* change default from 100 to None

* fix

* fix

* some fixes

* suggestions for stancld

* try fixing

* try fix

* try again

* more fixing

* another fix

* skip half + cpu test for old versions

* fix link

* another fix

* nan safety

* another fix

* skip non working cpu half tests

* skip non working cpu + half tests
  • Loading branch information
SkafteNicki authored and Borda committed Sep 13, 2022
1 parent c25680c commit 0f7a102
Show file tree
Hide file tree
Showing 27 changed files with 6,523 additions and 680 deletions.
36 changes: 36 additions & 0 deletions docs/source/classification/auroc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,44 @@ ________________
.. autoclass:: torchmetrics.AUROC
:noindex:

BinaryAUROC
^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryAUROC
:noindex:

MulticlassAUROC
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassAUROC
:noindex:

MultilabelAUROC
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelAUROC
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.auroc
:noindex:

binary_auroc
^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_auroc
:noindex:

multiclass_auroc
^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_auroc
:noindex:

multilabel_auroc
^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_auroc
:noindex:
36 changes: 36 additions & 0 deletions docs/source/classification/average_precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,44 @@ ________________
.. autoclass:: torchmetrics.AveragePrecision
:noindex:

BinaryAveragePrecision
^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryAveragePrecision
:noindex:

MulticlassAveragePrecision
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassAveragePrecision
:noindex:

MultilabelAveragePrecision
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelAveragePrecision
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.average_precision
:noindex:

binary_average_precision
^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_average_precision
:noindex:

multiclass_average_precision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_average_precision
:noindex:

multilabel_average_precision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_average_precision
:noindex:
36 changes: 36 additions & 0 deletions docs/source/classification/precision_recall_curve.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,44 @@ ________________
.. autoclass:: torchmetrics.PrecisionRecallCurve
:noindex:

BinaryPrecisionRecallCurve
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryPrecisionRecallCurve
:noindex:

MulticlassPrecisionRecallCurve
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassPrecisionRecallCurve
:noindex:

MultilabelPrecisionRecallCurve
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelPrecisionRecallCurve
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.precision_recall_curve
:noindex:

binary_precision_recall_curve
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_precision_recall_curve
:noindex:

multiclass_precision_recall_curve
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_precision_recall_curve
:noindex:

multilabel_precision_recall_curve
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_precision_recall_curve
:noindex:
50 changes: 50 additions & 0 deletions docs/source/classification/recall_at_fixed_precision.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
.. customcarditem::
:header: Recall At Fixed Precision
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Classification

#########################
Recall At Fixed Precision
#########################

Module Interface
________________

BinaryRecallAtFixedPrecision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryRecallAtFixedPrecision
:noindex:

MulticlassRecallAtFixedPrecision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassRecallAtFixedPrecision
:noindex:

MultilabelRecallAtFixedPrecision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelRecallAtFixedPrecision
:noindex:

Functional Interface
____________________

binary_recall_at_fixed_precision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_recall_at_fixed_precision
:noindex:

multiclass_recall_at_fixed_precision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_recall_at_fixed_precision
:noindex:

multilabel_recall_at_fixed_precision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_recall_at_fixed_precision
:noindex:
36 changes: 36 additions & 0 deletions docs/source/classification/roc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,44 @@ ________________
.. autoclass:: torchmetrics.ROC
:noindex:

BinaryROC
^^^^^^^^^

.. autoclass:: torchmetrics.BinaryROC
:noindex:

MulticlassROC
^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassROC
:noindex:

MultilabelROC
^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelROC
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.roc
:noindex:

binary_roc
^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_roc
:noindex:

multiclass_roc
^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_roc
:noindex:

multilabel_roc
^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_roc
:noindex:
30 changes: 30 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
Accuracy,
AveragePrecision,
BinaryAccuracy,
BinaryAUROC,
BinaryAveragePrecision,
BinaryCohenKappa,
BinaryConfusionMatrix,
BinaryF1Score,
Expand All @@ -35,7 +37,10 @@
BinaryJaccardIndex,
BinaryMatthewsCorrCoef,
BinaryPrecision,
BinaryPrecisionRecallCurve,
BinaryRecall,
BinaryRecallAtFixedPrecision,
BinaryROC,
BinarySpecificity,
BinaryStatScores,
BinnedAveragePrecision,
Expand All @@ -56,6 +61,8 @@
LabelRankingLoss,
MatthewsCorrCoef,
MulticlassAccuracy,
MulticlassAUROC,
MulticlassAveragePrecision,
MulticlassCohenKappa,
MulticlassConfusionMatrix,
MulticlassF1Score,
Expand All @@ -64,10 +71,15 @@
MulticlassJaccardIndex,
MulticlassMatthewsCorrCoef,
MulticlassPrecision,
MulticlassPrecisionRecallCurve,
MulticlassRecall,
MulticlassRecallAtFixedPrecision,
MulticlassROC,
MulticlassSpecificity,
MulticlassStatScores,
MultilabelAccuracy,
MultilabelAUROC,
MultilabelAveragePrecision,
MultilabelConfusionMatrix,
MultilabelCoverageError,
MultilabelExactMatch,
Expand All @@ -77,9 +89,12 @@
MultilabelJaccardIndex,
MultilabelMatthewsCorrCoef,
MultilabelPrecision,
MultilabelPrecisionRecallCurve,
MultilabelRankingAveragePrecision,
MultilabelRankingLoss,
MultilabelRecall,
MultilabelRecallAtFixedPrecision,
MultilabelROC,
MultilabelSpecificity,
MultilabelStatScores,
Precision,
Expand Down Expand Up @@ -155,6 +170,21 @@
"MultilabelAccuracy",
"AUC",
"AUROC",
"BinaryAUROC",
"BinaryAveragePrecision",
"BinaryPrecisionRecallCurve",
"BinaryRecallAtFixedPrecision",
"BinaryROC",
"MultilabelROC",
"MulticlassAUROC",
"MulticlassAveragePrecision",
"MulticlassPrecisionRecallCurve",
"MulticlassRecallAtFixedPrecision",
"MulticlassROC",
"MultilabelAUROC",
"MultilabelAveragePrecision",
"MultilabelPrecisionRecallCurve",
"MultilabelRecallAtFixedPrecision",
"AveragePrecision",
"BinnedAveragePrecision",
"BinnedPrecisionRecallCurve",
Expand Down
23 changes: 19 additions & 4 deletions src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
MulticlassConfusionMatrix,
MultilabelConfusionMatrix,
)
from torchmetrics.classification.precision_recall_curve import ( # noqa: F401 isort:skip
PrecisionRecallCurve,
BinaryPrecisionRecallCurve,
MulticlassPrecisionRecallCurve,
MultilabelPrecisionRecallCurve,
)
from torchmetrics.classification.stat_scores import ( # noqa: F401 isort:skip
BinaryStatScores,
MulticlassStatScores,
Expand All @@ -31,8 +37,13 @@
MultilabelAccuracy,
)
from torchmetrics.classification.auc import AUC # noqa: F401
from torchmetrics.classification.auroc import AUROC # noqa: F401
from torchmetrics.classification.avg_precision import AveragePrecision # noqa: F401
from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MulticlassAUROC, MultilabelAUROC # noqa: F401
from torchmetrics.classification.average_precision import ( # noqa: F401
AveragePrecision,
BinaryAveragePrecision,
MulticlassAveragePrecision,
MultilabelAveragePrecision,
)
from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision # noqa: F401
from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401
from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401
Expand Down Expand Up @@ -80,7 +91,6 @@
Precision,
Recall,
)
from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
from torchmetrics.classification.ranking import ( # noqa: F401
CoverageError,
LabelRankingAveragePrecision,
Expand All @@ -89,7 +99,12 @@
MultilabelRankingAveragePrecision,
MultilabelRankingLoss,
)
from torchmetrics.classification.roc import ROC # noqa: F401
from torchmetrics.classification.recall_at_fixed_precision import ( # noqa: F401
BinaryRecallAtFixedPrecision,
MulticlassRecallAtFixedPrecision,
MultilabelRecallAtFixedPrecision,
)
from torchmetrics.classification.roc import ROC, BinaryROC, MulticlassROC, MultilabelROC # noqa: F401
from torchmetrics.classification.specificity import ( # noqa: F401
BinarySpecificity,
MulticlassSpecificity,
Expand Down
Loading

0 comments on commit 0f7a102

Please sign in to comment.