Skip to content

Commit

Permalink
Improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Jan 17, 2021
1 parent 3d918ef commit 6c3cad3
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 18 deletions.
13 changes: 3 additions & 10 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def _adjust_cat_axis(self, ax):
# But two reasons not to do that:
# - If it happens before plotting, autoscaling messes up the plot limits
# - It would change existing plots from other seaborn functions
if self.var_types[self.cat_axis] != "categorical":
return

data = self.plot_data[self.cat_axis]
if self.facets is not None:
Expand Down Expand Up @@ -4041,7 +4043,7 @@ def catplot(
# Check for attempt to plot onto specific axes and warn
if "ax" in kwargs:
msg = ("catplot is a figure-level function and does not accept "
"target axes. You may wish to try {}".format(kind + "plot"))
f"target axes. You may wish to try {kind}plot")
warnings.warn(msg, UserWarning)
kwargs.pop("ax")

Expand All @@ -4058,15 +4060,6 @@ def catplot(

# XXX Copying a fair amount from displot, which is not ideal

# Check for attempt to plot onto specific axes and warn
if "ax" in kwargs:
msg = (
"`catplot` is a figure-level function and does not accept "
"the ax= paramter. You may wish to try {}plot.".format(kind)
)
warnings.warn(msg, UserWarning)
kwargs.pop("ax")

for var in ["row", "col"]:
# Handle faceting variables that lack name information
if var in p.variables and p.variables[var] is None:
Expand Down
79 changes: 71 additions & 8 deletions seaborn/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from .._core import categorical_order
from ..categorical import (
_CategoricalPlotterNew,
catplot,
stripplot,
)
Expand All @@ -27,6 +28,63 @@
from .._testing import assert_plots_equal


PLOT_FUNCS = [
catplot,
stripplot,
]


class TestCategoricalPlotterNew:

@pytest.mark.parametrize(
"func,kwargs",
itertools.product(
PLOT_FUNCS,
[
{"x": "x", "y": "a"},
{"x": "a", "y": "y"},
{"x": "y"},
{"y": "x"},
],
),
)
def test_axis_labels(self, long_df, func, kwargs):

func(data=long_df, **kwargs)

ax = plt.gca()
for axis in "xy":
val = kwargs.get(axis, "")
label_func = getattr(ax, f"get_{axis}label")
assert label_func() == val

@pytest.mark.parametrize("func", PLOT_FUNCS)
def test_empty(self, func):

func()
ax = plt.gca()
assert not ax.collections
assert not ax.patches
assert not ax.lines

def test_redundant_hue_backcompat(self, long_df):

p = _CategoricalPlotterNew(
data=long_df,
variables={"x": "s", "y": "y"},
)

color = None
palette = dict(zip(long_df["s"].unique(), color_palette()))
hue_order = None

palette, _ = p._hue_backcompat(color, palette, hue_order, force_hue=True)

assert p.variables["hue"] == "s"
assert_array_equal(p.plot_data["hue"], p.plot_data["x"])
assert all(isinstance(k, str) for k in palette)


class CategoricalFixture:
"""Test boxplot (also base class for things like violinplots)."""
rs = np.random.RandomState(30)
Expand Down Expand Up @@ -1964,11 +2022,6 @@ def test_attributes(self, long_df):
assert points.get_linewidths().item() == kwargs["linewidth"]
assert tuple(points.get_edgecolors().squeeze()) == to_rgba(kwargs["edgecolor"])

def test_empty(self):

ax = stripplot()
assert not ax.collections

def test_three_strip_points(self):

x = np.arange(3)
Expand All @@ -1978,11 +2031,12 @@ def test_three_strip_points(self):

def test_palette_from_color_deprecation(self, long_df):

color = "C3"
color = (.9, .4, .5)
hex_color = mpl.colors.to_hex(color)

hue_var = "a"
n_hue = long_df[hue_var].nunique()
palette = color_palette(f"dark:{color}", n_hue)
palette = color_palette(f"dark:{hex_color}", n_hue)

with pytest.warns(FutureWarning, match="Setting a gradient palette"):
ax = stripplot(data=long_df, x="z", hue=hue_var, color=color)
Expand Down Expand Up @@ -2892,7 +2946,7 @@ def test_plot_colors(self):
def test_ax_kwarg_removal(self):

f, ax = plt.subplots()
with pytest.warns(UserWarning):
with pytest.warns(UserWarning, match="catplot is a figure-level"):
g = cat.catplot(x="g", y="y", data=self.df, ax=ax)
assert len(ax.collections) == 0
assert len(g.ax.collections) > 0
Expand Down Expand Up @@ -2959,6 +3013,15 @@ def test_share_xy(self):
for ax in g.axes.flat:
assert len(ax.collections) == len(self.df.g.unique())

@pytest.mark.parametrize("var", ["col", "row"])
def test_array_faceter(self, long_df, var):

g1 = catplot(data=long_df, x="y", **{var: "a"})
g2 = catplot(data=long_df, x="y", **{var: long_df["a"].to_numpy()})

for ax1, ax2 in zip(g1.axes.flat, g2.axes.flat):
assert_plots_equal(ax1, ax2)


class TestBoxenPlotter(CategoricalFixture):

Expand Down

0 comments on commit 6c3cad3

Please sign in to comment.