Skip to content

Commit

Permalink
Adding support for plot() in image metrics (#1480)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
4 people authored Feb 23, 2023
1 parent 3df4e1b commit b3915a3
Show file tree
Hide file tree
Showing 12 changed files with 600 additions and 18 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for plotting of metrics through `.plot()` method (
[#1328](https://github.com/Lightning-AI/metrics/pull/1328),
[#1481](https://github.com/Lightning-AI/metrics/pull/1481)
[#1481](https://github.com/Lightning-AI/metrics/pull/1481),
[#1480](https://github.com/Lightning-AI/metrics/pull/1480)
)


Expand Down
148 changes: 147 additions & 1 deletion examples/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,147 @@ def confusion_matrix_example():
return fig, ax


if __name__ == "__main__":
def spectral_distortion_index_example():
"""Plot spectral distortion index example example."""
from torchmetrics.image.d_lambda import SpectralDistortionIndex

p = lambda: torch.rand([16, 3, 16, 16])
t = lambda: torch.rand([16, 3, 16, 16])

# plot single value
metric = SpectralDistortionIndex()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = SpectralDistortionIndex()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def error_relative_global_dimensionless_synthesis():
"""Plot error relative global dimensionless synthesis example."""
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis

p = lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
t = lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))

# plot single value
metric = ErrorRelativeGlobalDimensionlessSynthesis()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = ErrorRelativeGlobalDimensionlessSynthesis()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def peak_signal_noise_ratio():
"""Plot peak signal noise ratio example."""
from torchmetrics.image.psnr import PeakSignalNoiseRatio

p = lambda: torch.tensor([[0.0, 1.0], [2.0, 3.0]])
t = lambda: torch.tensor([[3.0, 2.0], [1.0, 0.0]])

# plot single value
metric = PeakSignalNoiseRatio()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = PeakSignalNoiseRatio()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def spectral_angle_mapper():
"""Plot spectral angle mapper example."""
from torchmetrics.image.sam import SpectralAngleMapper

p = lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
t = lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))

# plot single value
metric = SpectralAngleMapper()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = SpectralAngleMapper()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def structural_similarity_index_measure():
"""Plot structural similarity index measure example."""
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure

p = lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
t = lambda: p() * 0.75

# plot single value
metric = StructuralSimilarityIndexMeasure()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = StructuralSimilarityIndexMeasure()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def multiscale_structural_similarity_index_measure():
"""Plot multiscale structural similarity index measure example."""
from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure

p = lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
t = lambda: p() * 0.75

# plot single value
metric = MultiScaleStructuralSimilarityIndexMeasure()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = MultiScaleStructuralSimilarityIndexMeasure()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def universal_image_quality_index():
"""Plot universal image quality index example."""
from torchmetrics.image.uqi import UniversalImageQualityIndex

p = lambda: torch.rand([16, 1, 16, 16])
t = lambda: p() * 0.75

# plot single value
metric = UniversalImageQualityIndex()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = UniversalImageQualityIndex()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


if __name__ == "__main__":
metrics_func = {
"accuracy": accuracy_example,
"pesq": pesq_example,
Expand All @@ -235,6 +374,13 @@ def confusion_matrix_example():
"stoi": stoi_example,
"mean_squared_error": mean_squared_error_example,
"confusion_matrix": confusion_matrix_example,
"spectral_distortion_index": spectral_distortion_index_example,
"error_relative_global_dimensionless_synthesis": error_relative_global_dimensionless_synthesis,
"peak_signal_noise_ratio": peak_signal_noise_ratio,
"spectral_angle_mapper": spectral_angle_mapper,
"structural_similarity_index_measure": structural_similarity_index_measure,
"multiscale_structural_similarity_index_measure": multiscale_structural_similarity_index_measure,
"universal_image_quality_index": universal_image_quality_index,
}

parser = argparse.ArgumentParser(description="Example script for plotting metrics.")
Expand Down
3 changes: 1 addition & 2 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,7 @@ def plot(
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
fig: Figure object
ax: Axes object
Figure and Axes object
Raises:
ModuleNotFoundError:
Expand Down
3 changes: 1 addition & 2 deletions src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,7 @@ def plot(self, val: Optional[Tensor] = None) -> _PLOT_OUT_TYPE:
If no value is provided, will automatically call `metric.compute` and plot that result.
Returns:
fig: Figure object
ax: Axes object
Figure and Axes object
Raises:
ModuleNotFoundError:
Expand Down
63 changes: 62 additions & 1 deletion src/torchmetrics/image/d_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
from typing import Any, Dict, List, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal
Expand All @@ -20,6 +20,11 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SpectralDistortionIndex.plot"]


class SpectralDistortionIndex(Metric):
Expand Down Expand Up @@ -95,3 +100,59 @@ def compute(self) -> Tensor:
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _spectral_distortion_index_compute(preds, target, self.p, self.reduction)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> metric = SpectralDistortionIndex()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> metric = SpectralDistortionIndex()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val,
ax=ax,
higher_is_better=self.higher_is_better,
name=self.__class__.__name__,
lower_bound=0.0,
upper_bound=1.0,
)
return fig, ax
61 changes: 60 additions & 1 deletion src/torchmetrics/image/ergas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Union
from typing import Any, Dict, List, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal
Expand All @@ -21,6 +21,11 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["ErrorRelativeGlobalDimensionlessSynthesis.plot"]


class ErrorRelativeGlobalDimensionlessSynthesis(Metric):
Expand Down Expand Up @@ -94,3 +99,57 @@ def compute(self) -> Tensor:
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _ergas_compute(preds, target, self.ratio, self.reduction)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics import ErrorRelativeGlobalDimensionlessSynthesis
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> metric = ErrorRelativeGlobalDimensionlessSynthesis()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics import ErrorRelativeGlobalDimensionlessSynthesis
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> metric = ErrorRelativeGlobalDimensionlessSynthesis()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val,
ax=ax,
higher_is_better=self.higher_is_better,
name=self.__class__.__name__,
lower_bound=0.0,
upper_bound=1.0,
)
return fig, ax
Loading

0 comments on commit b3915a3

Please sign in to comment.