Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/lpips_backprop
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Feb 12, 2024
2 parents 7aaf81f + 71089f0 commit a7c233c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed initialize aggregation metrics with default floating type ([#2366](https://github.com/Lightning-AI/torchmetrics/pull/2366))


- Fixed plotting of confusion matrices ([#2358](https://github.com/Lightning-AI/torchmetrics/pull/2358))

---

## [1.3.0] - 2024-01-10
Expand Down
10 changes: 6 additions & 4 deletions src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,17 @@ def plot_confusion_matrix(
fig_label = None
labels = labels or np.arange(n_classes).tolist()

fig, axs = plt.subplots(nrows=rows, ncols=cols) if ax is None else (ax.get_figure(), ax)
fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax)
axs = trim_axs(axs, nb)
for i in range(nb):
ax = axs[i] if rows != 1 and cols != 1 else axs
if fig_label is not None:
ax.set_title(f"Label {fig_label[i]}", fontsize=15)
ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach())
ax.set_xlabel("Predicted class", fontsize=15)
ax.set_ylabel("True class", fontsize=15)
if i // cols == rows - 1: # bottom row only
ax.set_xlabel("Predicted class", fontsize=15)
if i % cols == 0: # leftmost column only
ax.set_ylabel("True class", fontsize=15)
ax.set_xticks(list(range(n_classes)))
ax.set_yticks(list(range(n_classes)))
ax.set_xticklabels(labels, rotation=45, fontsize=10)
Expand All @@ -259,7 +261,7 @@ def plot_confusion_matrix(
if add_text:
for ii, jj in product(range(n_classes), range(n_classes)):
val = confmat[i, ii, jj] if confmat.ndim == 3 else confmat[ii, jj]
ax.text(jj, ii, str(val.item()), ha="center", va="center", fontsize=15)
ax.text(jj, ii, str(round(val.item(), 2)), ha="center", va="center", fontsize=15)

return fig, axs

Expand Down

0 comments on commit a7c233c

Please sign in to comment.