diff --git a/CHANGELOG.md b/CHANGELOG.md index b62d97e6cd0..d942a583ab8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1605](https://github.com/Lightning-AI/metrics/pull/1605), [#1610](https://github.com/Lightning-AI/metrics/pull/1610), [#1609](https://github.com/Lightning-AI/metrics/pull/1609), + [#1621](https://github.com/Lightning-AI/metrics/pull/1621), ) diff --git a/src/torchmetrics/regression/pearson.py b/src/torchmetrics/regression/pearson.py index 8004c4f3b54..498f8ff0307 100644 --- a/src/torchmetrics/regression/pearson.py +++ b/src/torchmetrics/regression/pearson.py @@ -11,13 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor from torchmetrics.functional.regression.pearson import _pearson_corrcoef_compute, _pearson_corrcoef_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["PearsonCorrCoef.plot"] def _final_aggregation( @@ -159,3 +164,44 @@ def compute(self) -> Tensor: corr_xy = self.corr_xy n_total = self.n_total return _pearson_corrcoef_compute(var_x, var_y, corr_xy, n_total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import PearsonCorrCoef + >>> metric = PearsonCorrCoef() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import PearsonCorrCoef + >>> metric = PearsonCorrCoef() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/regression/r2.py b/src/torchmetrics/regression/r2.py index 5b2289354af..f041662d9a6 100644 --- a/src/torchmetrics/regression/r2.py +++ b/src/torchmetrics/regression/r2.py @@ -11,13 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor, tensor from torchmetrics.functional.regression.r2 import _r2_score_compute, _r2_score_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["R2Score.plot"] class R2Score(Metric): @@ -129,3 +134,44 @@ def compute(self) -> Tensor: return _r2_score_compute( self.sum_squared_error, self.sum_error, self.residual, self.total, self.adjusted, self.multioutput ) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import R2Score + >>> metric = R2Score() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import R2Score + >>> metric = R2Score() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/regression/spearman.py b/src/torchmetrics/regression/spearman.py index 6c09db268a5..2bb0263cb8e 100644 --- a/src/torchmetrics/regression/spearman.py +++ b/src/torchmetrics/regression/spearman.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, List, Optional, Sequence, Union from torch import Tensor @@ -19,6 +19,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SpearmanCorrCoef.plot"] class SpearmanCorrCoef(Metric): @@ -95,3 +100,44 @@ def compute(self) -> Tensor: preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return _spearman_corrcoef_compute(preds, target) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import SpearmanCorrCoef + >>> metric = SpearmanCorrCoef() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import SpearmanCorrCoef + >>> metric = SpearmanCorrCoef() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/regression/symmetric_mape.py b/src/torchmetrics/regression/symmetric_mape.py index faec0a517de..6c994d0ea38 100644 --- a/src/torchmetrics/regression/symmetric_mape.py +++ b/src/torchmetrics/regression/symmetric_mape.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional, Sequence, Union from torch import Tensor, tensor @@ -20,6 +20,11 @@ _symmetric_mean_absolute_percentage_error_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SymmetricMeanAbsolutePercentageError.plot"] class SymmetricMeanAbsolutePercentageError(Metric): @@ -74,3 +79,44 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute mean absolute percentage error over state.""" return _symmetric_mean_absolute_percentage_error_compute(self.sum_abs_per_error, self.total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import SymmetricMeanAbsolutePercentageError + >>> metric = SymmetricMeanAbsolutePercentageError() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import SymmetricMeanAbsolutePercentageError + >>> metric = SymmetricMeanAbsolutePercentageError() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/regression/tweedie_deviance.py b/src/torchmetrics/regression/tweedie_deviance.py index 70485af8e3c..53d5175c7fc 100644 --- a/src/torchmetrics/regression/tweedie_deviance.py +++ b/src/torchmetrics/regression/tweedie_deviance.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor @@ -21,6 +21,11 @@ _tweedie_deviance_score_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["TweedieDevianceScore.plot"] class TweedieDevianceScore(Metric): @@ -99,3 +104,44 @@ def update(self, preds: Tensor, targets: Tensor) -> None: def compute(self) -> Tensor: """Compute metric.""" return _tweedie_deviance_score_compute(self.sum_deviance_score, self.num_observations) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import TweedieDevianceScore + >>> metric = TweedieDevianceScore() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import TweedieDevianceScore + >>> metric = TweedieDevianceScore() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/regression/wmape.py b/src/torchmetrics/regression/wmape.py index 0df0d3e2d58..d8a55389f13 100644 --- a/src/torchmetrics/regression/wmape.py +++ b/src/torchmetrics/regression/wmape.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor @@ -21,6 +21,11 @@ _weighted_mean_absolute_percentage_error_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["WeightedMeanAbsolutePercentageError.plot"] class WeightedMeanAbsolutePercentageError(Metric): @@ -75,3 +80,44 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute weighted mean absolute percentage error over state.""" return _weighted_mean_absolute_percentage_error_compute(self.sum_abs_error, self.sum_scale) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import WeightedMeanAbsolutePercentageError + >>> metric = WeightedMeanAbsolutePercentageError() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import WeightedMeanAbsolutePercentageError + >>> metric = WeightedMeanAbsolutePercentageError() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index f9fb0c243d7..8f9f11ba0f3 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -80,6 +80,12 @@ MeanSquaredError, MeanSquaredLogError, MinkowskiDistance, + PearsonCorrCoef, + R2Score, + SpearmanCorrCoef, + SymmetricMeanAbsolutePercentageError, + TweedieDevianceScore, + WeightedMeanAbsolutePercentageError, ) from torchmetrics.retrieval import RetrievalMRR, RetrievalPrecision, RetrievalRecall, RetrievalRPrecision @@ -293,6 +299,12 @@ pytest.param(MeanAbsoluteError, _rand_input, _rand_input, id="mean absolute error"), pytest.param(MeanAbsolutePercentageError, _rand_input, _rand_input, id="mean absolute percentage error"), pytest.param(partial(MinkowskiDistance, p=3), _rand_input, _rand_input, id="minkowski distance"), + pytest.param(PearsonCorrCoef, _rand_input, _rand_input, id="pearson corr coef"), + pytest.param(R2Score, _rand_input, _rand_input, id="r2 score"), + pytest.param(SpearmanCorrCoef, _rand_input, _rand_input, id="spearman corr coef"), + pytest.param(SymmetricMeanAbsolutePercentageError, _rand_input, _rand_input, id="symmetric mape"), + pytest.param(TweedieDevianceScore, _rand_input, _rand_input, id="tweedie deviance score"), + pytest.param(WeightedMeanAbsolutePercentageError, _rand_input, _rand_input, id="weighted mape"), ], ) @pytest.mark.parametrize("num_vals", [1, 5])