Skip to content

Commit

Permalink
Debug CI plot (#1878)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Jul 4, 2023
1 parent 13f85d0 commit 8c5fab8
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions tests/unittests/utilities/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ def test_plot_methods(metric_class: object, preds: Callable, target: Callable, n

assert isinstance(fig, plt.Figure)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close(fig)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -699,17 +700,17 @@ def test_plot_methods_special_image_metrics(metric_class, preds, target, index_0

assert isinstance(fig, plt.Figure)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close(fig)


@pytest.mark.skipif(not hasattr(torch, "inference_mode"), reason="`inference_mode` is not supported")
def test_plot_methods_special_text_metrics():
"""Test the plot method for text metrics that does not fit the default testing format."""
metric = BERTScore()
with torch.inference_mode():
metric.update(_text_input_1(), _text_input_2())
fig, ax = metric.plot()
metric.update(_text_input_1(), _text_input_2())
fig, ax = metric.plot()
assert isinstance(fig, plt.Figure)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close(fig)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -782,6 +783,7 @@ def test_plot_methods_retrieval(metric_class, preds, target, indexes, num_vals):

assert isinstance(fig, plt.Figure)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close(fig)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -821,6 +823,7 @@ def test_confusion_matrix_plotter(metric_class, preds, target, labels, use_label
cond1 = isinstance(axs, matplotlib.axes.Axes)
cond2 = isinstance(axs, np.ndarray) and all(isinstance(a, matplotlib.axes.Axes) for a in axs)
assert cond1 or cond2
plt.close(fig)


@pytest.mark.parametrize("together", [True, False])
Expand Down Expand Up @@ -859,6 +862,7 @@ def test_plot_method_collection(together, num_vals):
fig, ax = plt.subplots(nrows=len(m_collection) + 1, ncols=1)
with pytest.raises(ValueError, match="Expected argument `ax` to be a sequence of matplotlib axis objects with.*"):
m_collection.plot(ax=ax.tolist())
plt.close(fig)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -915,6 +919,7 @@ def test_plot_method_curve_metrics(metric_class, preds, target, thresholds, scor
fig, ax = metric.plot(score=score)
assert isinstance(fig, plt.Figure)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close(fig)


def test_tracker_plotter():
Expand All @@ -927,3 +932,4 @@ def test_tracker_plotter():
fig, ax = tracker.plot() # plot all epochs
assert isinstance(fig, plt.Figure)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close(fig)

0 comments on commit 8c5fab8

Please sign in to comment.