From e810c86c26d8ae78d6c295e975f303f526d98057 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 17 Jan 2021 17:49:29 -0500 Subject: [PATCH] Improve test coverage --- seaborn/categorical.py | 13 ++----- seaborn/tests/test_categorical.py | 56 +++++++++++++++++++++++++++---- 2 files changed, 53 insertions(+), 16 deletions(-) diff --git a/seaborn/categorical.py b/seaborn/categorical.py index 8131d6d859..96edf42157 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -229,6 +229,8 @@ def _adjust_cat_axis(self, ax): # But two reasons not to do that: # - If it happens before plotting, autoscaling messes up the plot limits # - It would change existing plots from other seaborn functions + if self.var_types[self.cat_axis] != "categorical": + return data = self.plot_data[self.cat_axis] if self.facets is not None: @@ -4041,7 +4043,7 @@ def catplot( # Check for attempt to plot onto specific axes and warn if "ax" in kwargs: msg = ("catplot is a figure-level function and does not accept " - "target axes. You may wish to try {}".format(kind + "plot")) + f"target axes. You may wish to try {kind}plot") warnings.warn(msg, UserWarning) kwargs.pop("ax") @@ -4058,15 +4060,6 @@ def catplot( # XXX Copying a fair amount from displot, which is not ideal - # Check for attempt to plot onto specific axes and warn - if "ax" in kwargs: - msg = ( - "`catplot` is a figure-level function and does not accept " - "the ax= paramter. You may wish to try {}plot.".format(kind) - ) - warnings.warn(msg, UserWarning) - kwargs.pop("ax") - for var in ["row", "col"]: # Handle faceting variables that lack name information if var in p.variables and p.variables[var] is None: diff --git a/seaborn/tests/test_categorical.py b/seaborn/tests/test_categorical.py index eb147c6def..3e2060d66a 100644 --- a/seaborn/tests/test_categorical.py +++ b/seaborn/tests/test_categorical.py @@ -27,6 +27,46 @@ from .._testing import assert_plots_equal +PLOT_FUNCS = [ + catplot, + stripplot, +] + + +class TestCategoricalPlotterNew: + + @pytest.mark.parametrize( + "func,kwargs", + itertools.product( + PLOT_FUNCS, + [ + {"x": "x", "y": "a"}, + {"x": "a", "y": "y"}, + {"x": "y"}, + {"y": "x"}, + ], + ), + ) + def test_axis_labels(self, long_df, func, kwargs): + + func(data=long_df, **kwargs) + + ax = plt.gca() + for axis in "xy": + val = kwargs.get(axis, "") + label_func = getattr(ax, f"get_{axis}label") + assert label_func() == val + + @pytest.mark.parametrize("func", PLOT_FUNCS) + def test_empty(self, func): + + func() + ax = plt.gca() + assert not ax.collections + assert not ax.patches + assert not ax.lines + + class CategoricalFixture: """Test boxplot (also base class for things like violinplots).""" rs = np.random.RandomState(30) @@ -1964,11 +2004,6 @@ def test_attributes(self, long_df): assert points.get_linewidths().item() == kwargs["linewidth"] assert tuple(points.get_edgecolors().squeeze()) == to_rgba(kwargs["edgecolor"]) - def test_empty(self): - - ax = stripplot() - assert not ax.collections - def test_three_strip_points(self): x = np.arange(3) @@ -2892,7 +2927,7 @@ def test_plot_colors(self): def test_ax_kwarg_removal(self): f, ax = plt.subplots() - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match="catplot is a figure-level"): g = cat.catplot(x="g", y="y", data=self.df, ax=ax) assert len(ax.collections) == 0 assert len(g.ax.collections) > 0 @@ -2959,6 +2994,15 @@ def test_share_xy(self): for ax in g.axes.flat: assert len(ax.collections) == len(self.df.g.unique()) + @pytest.mark.parametrize("var", ["col", "row"]) + def test_array_faceter(self, long_df, var): + + g1 = catplot(data=long_df, x="y", **{var: "a"}) + g2 = catplot(data=long_df, x="y", **{var: long_df["a"].to_numpy()}) + + for ax1, ax2 in zip(g1.axes.flat, g2.axes.flat): + assert_plots_equal(ax1, ax2) + class TestBoxenPlotter(CategoricalFixture):