Skip to content

Commit

Permalink
Add optional color map parameter for confusion matrix (#2512)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 23, 2024
1 parent af32fd0 commit 3d52192
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
17 changes: 13 additions & 4 deletions src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_confusion_matrix
from torchmetrics.utilities.plot import _AX_TYPE, _CMAP_TYPE, _PLOT_OUT_TYPE, plot_confusion_matrix

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = [
Expand Down Expand Up @@ -151,6 +151,7 @@ def plot(
ax: Optional[_AX_TYPE] = None,
add_text: bool = True,
labels: Optional[List[str]] = None,
cmap: Optional[_CMAP_TYPE] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand All @@ -160,6 +161,8 @@ def plot(
ax: An matplotlib axis object. If provided will add plot to that axis
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes
cmap: matplotlib colormap to use for the confusion matrix
https://matplotlib.org/stable/users/explain/colors/colormaps.html
Returns:
Figure and Axes object
Expand All @@ -181,7 +184,7 @@ def plot(
val = val if val is not None else self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels)
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
return fig, ax


Expand Down Expand Up @@ -292,6 +295,7 @@ def plot(
ax: Optional[_AX_TYPE] = None,
add_text: bool = True,
labels: Optional[List[str]] = None,
cmap: Optional[_CMAP_TYPE] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand All @@ -301,6 +305,8 @@ def plot(
ax: An matplotlib axis object. If provided will add plot to that axis
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes
cmap: matplotlib colormap to use for the confusion matrix
https://matplotlib.org/stable/users/explain/colors/colormaps.html
Returns:
Figure and Axes object
Expand All @@ -322,7 +328,7 @@ def plot(
val = val if val is not None else self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels)
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
return fig, ax


Expand Down Expand Up @@ -436,6 +442,7 @@ def plot(
ax: Optional[_AX_TYPE] = None,
add_text: bool = True,
labels: Optional[List[str]] = None,
cmap: Optional[_CMAP_TYPE] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand All @@ -445,6 +452,8 @@ def plot(
ax: An matplotlib axis object. If provided will add plot to that axis
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes
cmap: matplotlib colormap to use for the confusion matrix
https://matplotlib.org/stable/users/explain/colors/colormaps.html
Returns:
Figure and Axes object
Expand All @@ -466,7 +475,7 @@ def plot(
val = val if val is not None else self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels)
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
return fig, ax


Expand Down
7 changes: 6 additions & 1 deletion src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@

_PLOT_OUT_TYPE = Tuple[plt.Figure, Union[matplotlib.axes.Axes, np.ndarray]]
_AX_TYPE = matplotlib.axes.Axes
_CMAP_TYPE = Union[matplotlib.colors.Colormap, str]

style_change = plt.style.context
else:
_PLOT_OUT_TYPE = Tuple[object, object] # type: ignore[misc]
_AX_TYPE = object
_CMAP_TYPE = object # type: ignore[misc]

from contextlib import contextmanager

Expand Down Expand Up @@ -201,6 +203,7 @@ def plot_confusion_matrix(
ax: Optional[_AX_TYPE] = None,
add_text: bool = True,
labels: Optional[List[Union[int, str]]] = None,
cmap: Optional[_CMAP_TYPE] = None,
) -> _PLOT_OUT_TYPE:
"""Plot an confusion matrix.
Expand All @@ -213,6 +216,8 @@ def plot_confusion_matrix(
ax: Axis from a figure. If not provided, a new figure and axis will be created
add_text: if text should be added to each cell with the given value
labels: labels to add the x- and y-axis
cmap: matplotlib colormap to use for the confusion matrix
https://matplotlib.org/stable/users/explain/colors/colormaps.html
Returns:
A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure
Expand Down Expand Up @@ -248,7 +253,7 @@ def plot_confusion_matrix(
ax = axs[i] if rows != 1 and cols != 1 else axs
if fig_label is not None:
ax.set_title(f"Label {fig_label[i]}", fontsize=15)
ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach())
ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach(), cmap=cmap)
if i // cols == rows - 1: # bottom row only
ax.set_xlabel("Predicted class", fontsize=15)
if i % cols == 0: # leftmost column only
Expand Down

0 comments on commit 3d52192

Please sign in to comment.