From dce315003b5bfba69e4bfe0a9d8e3ed9c46ea7fa Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Wed, 1 Jun 2022 21:09:10 -0400 Subject: [PATCH] Improve legend for categorical scatterplots (#2828) * Improve legend for categorical scatterplots * Move legend attribute assignment to fix empty plot * Don't create axis labels inside plotting functions * Add slight hack to enable catplot with empty x/y vectors * Don't set axis limits for empty categorical plot * Avoid expensive and uncessary computation when stripplot is not dodged * Add tests --- seaborn/categorical.py | 88 ++++++++++++++++++++----------- seaborn/relational.py | 2 +- seaborn/tests/test_categorical.py | 23 ++++++-- 3 files changed, 78 insertions(+), 35 deletions(-) diff --git a/seaborn/categorical.py b/seaborn/categorical.py index 15c3e8b09e..c32acd0ad1 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -18,17 +18,17 @@ import matplotlib.patches as Patches import matplotlib.pyplot as plt -from ._oldcore import ( - VectorPlotter, +from seaborn._oldcore import ( variable_type, infer_orient, categorical_order, ) -from . import utils -from .utils import remove_na, _normal_quantile_func, _draw_figure, _default_color -from .algorithms import bootstrap -from .palettes import color_palette, husl_palette, light_palette, dark_palette -from .axisgrid import FacetGrid, _facet_docs +from seaborn.relational import _RelationalPlotter +from seaborn import utils +from seaborn.utils import remove_na, _normal_quantile_func, _draw_figure, _default_color +from seaborn.algorithms import bootstrap +from seaborn.palettes import color_palette, husl_palette, light_palette, dark_palette +from seaborn.axisgrid import FacetGrid, _facet_docs __all__ = [ @@ -39,13 +39,18 @@ ] -class _CategoricalPlotterNew(VectorPlotter): +# Subclassing _RelationalPlotter for the legend machinery, +# but probably should move that more centrally +class _CategoricalPlotterNew(_RelationalPlotter): semantics = "x", "y", "hue", "units" wide_structure = {"x": "@columns", "y": "@values", "hue": "@columns"} flat_structure = {"x": "@index", "y": "@values"} + _legend_func = "scatter" + _legend_attributes = ["color"] + def __init__( self, data=None, @@ -53,6 +58,7 @@ def __init__( order=None, orient=None, require_numeric=False, + legend="auto", ): super().__init__(data=data, variables=variables) @@ -75,6 +81,8 @@ def __init__( require_numeric=require_numeric, ) + self.legend = legend + # Short-circuit in the case of an empty plot if not self.has_xy_data: return @@ -172,6 +180,12 @@ def _adjust_cat_axis(self, ax, axis): if self.var_types[axis] != "categorical": return + # If both x/y data are empty, the correct way to set up the plot is + # somewhat undefined; because we don't add null category data to the plot in + # this case we don't *have* a categorical axis (yet), so best to just bail. + if self.plot_data[axis].empty: + return + # We can infer the total number of categories (including those from previous # plots that are not part of the plot we are currently making) from the number # of ticks, which matplotlib sets up while doing unit conversion. This feels @@ -248,8 +262,7 @@ def plot_strips( for sub_vars, sub_data in self.iter_data(iter_vars, from_comp_data=True, allow_empty=True): - - if offsets is not None: + if offsets is not None and (offsets != 0).any(): dodge_move = offsets[sub_data["hue"].map(self._hue_map.levels.index)] jitter_move = jitterer(size=len(sub_data)) if len(sub_data) > 1 else 0 @@ -272,13 +285,17 @@ def plot_strips( else: points.set_edgecolors(edgecolor) - # TODO XXX fully implement legend - show_legend = not self._redundant_hue and self.input_format != "wide" - if "hue" in self.variables and show_legend: - for level in self._hue_map.levels: - color = self._hue_map(level) - ax.scatter([], [], s=60, color=mpl.colors.rgb2hex(color), label=level) - ax.legend(loc="best", title=self.variables["hue"]) + # Finalize the axes details + if self.legend == "auto": + show_legend = not self._redundant_hue and self.input_format != "wide" + else: + show_legend = bool(self.legend) + + if show_legend: + self.add_legend_data(ax) + handles, _ = ax.get_legend_handles_labels() + if handles: + ax.legend(title=self.legend_title) def plot_swarms( self, @@ -361,13 +378,17 @@ def draw(points, renderer, *, center=center): _draw_figure(ax.figure) - # TODO XXX fully implement legend - show_legend = not self._redundant_hue and self.input_format != "wide" - if "hue" in self.variables and show_legend: # TODO and legend: - for level in self._hue_map.levels: - color = self._hue_map(level) - ax.scatter([], [], s=60, color=mpl.colors.rgb2hex(color), label=level) - ax.legend(loc="best", title=self.variables["hue"]) + # Finalize the axes details + if self.legend == "auto": + show_legend = not self._redundant_hue and self.input_format != "wide" + else: + show_legend = bool(self.legend) + + if show_legend: + self.add_legend_data(ax) + handles, _ = ax.get_legend_handles_labels() + if handles: + ax.legend(title=self.legend_title) class _CategoricalFacetPlotter(_CategoricalPlotterNew): @@ -2747,18 +2768,17 @@ def stripplot( data=None, *, x=None, y=None, hue=None, order=None, hue_order=None, jitter=True, dodge=False, orient=None, color=None, palette=None, size=5, edgecolor="gray", linewidth=0, ax=None, - hue_norm=None, native_scale=False, formatter=None, + hue_norm=None, native_scale=False, formatter=None, legend="auto", **kwargs ): - # TODO XXX we need to add a legend= param!!! - p = _CategoricalPlotterNew( data=data, variables=_CategoricalPlotterNew.get_semantics(locals()), order=order, orient=orient, require_numeric=False, + legend=legend, ) if ax is None: @@ -2869,7 +2889,7 @@ def swarmplot( data=None, *, x=None, y=None, hue=None, order=None, hue_order=None, dodge=False, orient=None, color=None, palette=None, size=5, edgecolor="gray", linewidth=0, ax=None, - hue_norm=None, native_scale=False, formatter=None, warn_thresh=.05, + hue_norm=None, native_scale=False, formatter=None, legend="auto", warn_thresh=.05, **kwargs ): @@ -2879,6 +2899,7 @@ def swarmplot( order=order, orient=orient, require_numeric=False, + legend=legend, ) if ax is None: @@ -3548,7 +3569,7 @@ def catplot( units=None, seed=None, order=None, hue_order=None, row_order=None, col_order=None, kind="strip", height=5, aspect=1, orient=None, color=None, palette=None, - legend=True, legend_out=True, sharex=True, sharey=True, + legend="auto", legend_out=True, sharex=True, sharey=True, margin_titles=False, facet_kws=None, hue_norm=None, native_scale=False, formatter=None, **kwargs @@ -3578,7 +3599,6 @@ def catplot( refactored_kinds = [ "strip", "swarm", ] - if kind in refactored_kinds: p = _CategoricalFacetPlotter( @@ -3587,6 +3607,7 @@ def catplot( order=order, orient=orient, require_numeric=False, + legend=legend, ) # XXX Copying a fair amount from displot, which is not ideal @@ -3615,12 +3636,17 @@ def catplot( **facet_kws, ) + # Capture this here because scale_categorical is going to insert a (null) + # x variable even if it is empty. It's not clear whether that needs to + # happen or if disabling that is the cleaner solution. + has_xy_data = p.has_xy_data + if not native_scale or p.var_types[p.cat_axis] == "categorical": p.scale_categorical(p.cat_axis, order=order, formatter=formatter) p._attach(g) - if not p.has_xy_data: + if not has_xy_data: return g palette, hue_order = p._hue_backcompat(color, palette, hue_order) diff --git a/seaborn/relational.py b/seaborn/relational.py index 1ac2f3c93a..f6e376d8b3 100644 --- a/seaborn/relational.py +++ b/seaborn/relational.py @@ -524,7 +524,7 @@ def __init__( legend=None ): - # TODO this is messy, we want the mapping to be agnoistic about + # TODO this is messy, we want the mapping to be agnostic about # the kind of plot to draw, but for the time being we need to set # this information so the SizeMapping can use it self._default_size_range = ( diff --git a/seaborn/tests/test_categorical.py b/seaborn/tests/test_categorical.py index fc7fc5e571..aa7317525c 100644 --- a/seaborn/tests/test_categorical.py +++ b/seaborn/tests/test_categorical.py @@ -2028,6 +2028,24 @@ def test_three_points(self): for point_color in ax.collections[0].get_facecolor(): assert tuple(point_color) == to_rgba("C0") + def test_legend_categorical(self, long_df): + + ax = self.func(data=long_df, x="y", y="a", hue="b") + legend_texts = [t.get_text() for t in ax.legend_.texts] + expected = categorical_order(long_df["b"]) + assert legend_texts == expected + + def test_legend_numeric(self, long_df): + + ax = self.func(data=long_df, x="y", y="a", hue="z") + vals = [float(t.get_text()) for t in ax.legend_.texts] + assert (vals[1] - vals[0]) == pytest.approx(vals[2] - vals[1]) + + def test_legend_disabled(self, long_df): + + ax = self.func(data=long_df, x="y", y="a", hue="b", legend=False) + assert ax.legend_ is None + def test_palette_from_color_deprecation(self, long_df): color = (.9, .4, .5) @@ -2085,9 +2103,8 @@ def test_log_scale(self): dict(data="wide", orient="h"), dict(data="long", x="x", color="C3"), dict(data="long", y="y", hue="a", jitter=False), - # TODO XXX full numeric hue legend crashes pinned mpl, disabling for now - # dict(data="long", x="a", y="y", hue="z", edgecolor="w", linewidth=.5), - # dict(data="long", x="a_cat", y="y", hue="z"), + dict(data="long", x="a", y="y", hue="z", edgecolor="w", linewidth=.5), + dict(data="long", x="a_cat", y="y", hue="z"), dict(data="long", x="y", y="s", hue="c", orient="h", dodge=True), dict(data="long", x="s", y="y", hue="c", native_scale=True), ]