Skip to content

Commit

Permalink
Improve legend for categorical scatterplots (#2828)
Browse files Browse the repository at this point in the history
* Improve legend for categorical scatterplots

* Move legend attribute assignment to fix empty plot

* Don't create axis labels inside plotting functions

* Add slight hack to enable catplot with empty x/y vectors

* Don't set axis limits for empty categorical plot

* Avoid expensive and uncessary computation when stripplot is not dodged

* Add tests
  • Loading branch information
mwaskom authored Jun 2, 2022
1 parent 1e8e843 commit dce3150
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 35 deletions.
88 changes: 57 additions & 31 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@
import matplotlib.patches as Patches
import matplotlib.pyplot as plt

from ._oldcore import (
VectorPlotter,
from seaborn._oldcore import (
variable_type,
infer_orient,
categorical_order,
)
from . import utils
from .utils import remove_na, _normal_quantile_func, _draw_figure, _default_color
from .algorithms import bootstrap
from .palettes import color_palette, husl_palette, light_palette, dark_palette
from .axisgrid import FacetGrid, _facet_docs
from seaborn.relational import _RelationalPlotter
from seaborn import utils
from seaborn.utils import remove_na, _normal_quantile_func, _draw_figure, _default_color
from seaborn.algorithms import bootstrap
from seaborn.palettes import color_palette, husl_palette, light_palette, dark_palette
from seaborn.axisgrid import FacetGrid, _facet_docs


__all__ = [
Expand All @@ -39,20 +39,26 @@
]


class _CategoricalPlotterNew(VectorPlotter):
# Subclassing _RelationalPlotter for the legend machinery,
# but probably should move that more centrally
class _CategoricalPlotterNew(_RelationalPlotter):

semantics = "x", "y", "hue", "units"

wide_structure = {"x": "@columns", "y": "@values", "hue": "@columns"}
flat_structure = {"x": "@index", "y": "@values"}

_legend_func = "scatter"
_legend_attributes = ["color"]

def __init__(
self,
data=None,
variables={},
order=None,
orient=None,
require_numeric=False,
legend="auto",
):

super().__init__(data=data, variables=variables)
Expand All @@ -75,6 +81,8 @@ def __init__(
require_numeric=require_numeric,
)

self.legend = legend

# Short-circuit in the case of an empty plot
if not self.has_xy_data:
return
Expand Down Expand Up @@ -172,6 +180,12 @@ def _adjust_cat_axis(self, ax, axis):
if self.var_types[axis] != "categorical":
return

# If both x/y data are empty, the correct way to set up the plot is
# somewhat undefined; because we don't add null category data to the plot in
# this case we don't *have* a categorical axis (yet), so best to just bail.
if self.plot_data[axis].empty:
return

# We can infer the total number of categories (including those from previous
# plots that are not part of the plot we are currently making) from the number
# of ticks, which matplotlib sets up while doing unit conversion. This feels
Expand Down Expand Up @@ -248,8 +262,7 @@ def plot_strips(
for sub_vars, sub_data in self.iter_data(iter_vars,
from_comp_data=True,
allow_empty=True):

if offsets is not None:
if offsets is not None and (offsets != 0).any():
dodge_move = offsets[sub_data["hue"].map(self._hue_map.levels.index)]

jitter_move = jitterer(size=len(sub_data)) if len(sub_data) > 1 else 0
Expand All @@ -272,13 +285,17 @@ def plot_strips(
else:
points.set_edgecolors(edgecolor)

# TODO XXX fully implement legend
show_legend = not self._redundant_hue and self.input_format != "wide"
if "hue" in self.variables and show_legend:
for level in self._hue_map.levels:
color = self._hue_map(level)
ax.scatter([], [], s=60, color=mpl.colors.rgb2hex(color), label=level)
ax.legend(loc="best", title=self.variables["hue"])
# Finalize the axes details
if self.legend == "auto":
show_legend = not self._redundant_hue and self.input_format != "wide"
else:
show_legend = bool(self.legend)

if show_legend:
self.add_legend_data(ax)
handles, _ = ax.get_legend_handles_labels()
if handles:
ax.legend(title=self.legend_title)

def plot_swarms(
self,
Expand Down Expand Up @@ -361,13 +378,17 @@ def draw(points, renderer, *, center=center):

_draw_figure(ax.figure)

# TODO XXX fully implement legend
show_legend = not self._redundant_hue and self.input_format != "wide"
if "hue" in self.variables and show_legend: # TODO and legend:
for level in self._hue_map.levels:
color = self._hue_map(level)
ax.scatter([], [], s=60, color=mpl.colors.rgb2hex(color), label=level)
ax.legend(loc="best", title=self.variables["hue"])
# Finalize the axes details
if self.legend == "auto":
show_legend = not self._redundant_hue and self.input_format != "wide"
else:
show_legend = bool(self.legend)

if show_legend:
self.add_legend_data(ax)
handles, _ = ax.get_legend_handles_labels()
if handles:
ax.legend(title=self.legend_title)


class _CategoricalFacetPlotter(_CategoricalPlotterNew):
Expand Down Expand Up @@ -2747,18 +2768,17 @@ def stripplot(
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
jitter=True, dodge=False, orient=None, color=None, palette=None,
size=5, edgecolor="gray", linewidth=0, ax=None,
hue_norm=None, native_scale=False, formatter=None,
hue_norm=None, native_scale=False, formatter=None, legend="auto",
**kwargs
):

# TODO XXX we need to add a legend= param!!!

p = _CategoricalPlotterNew(
data=data,
variables=_CategoricalPlotterNew.get_semantics(locals()),
order=order,
orient=orient,
require_numeric=False,
legend=legend,
)

if ax is None:
Expand Down Expand Up @@ -2869,7 +2889,7 @@ def swarmplot(
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
dodge=False, orient=None, color=None, palette=None,
size=5, edgecolor="gray", linewidth=0, ax=None,
hue_norm=None, native_scale=False, formatter=None, warn_thresh=.05,
hue_norm=None, native_scale=False, formatter=None, legend="auto", warn_thresh=.05,
**kwargs
):

Expand All @@ -2879,6 +2899,7 @@ def swarmplot(
order=order,
orient=orient,
require_numeric=False,
legend=legend,
)

if ax is None:
Expand Down Expand Up @@ -3548,7 +3569,7 @@ def catplot(
units=None, seed=None, order=None, hue_order=None, row_order=None,
col_order=None, kind="strip", height=5, aspect=1,
orient=None, color=None, palette=None,
legend=True, legend_out=True, sharex=True, sharey=True,
legend="auto", legend_out=True, sharex=True, sharey=True,
margin_titles=False, facet_kws=None,
hue_norm=None, native_scale=False, formatter=None,
**kwargs
Expand Down Expand Up @@ -3578,7 +3599,6 @@ def catplot(
refactored_kinds = [
"strip", "swarm",
]

if kind in refactored_kinds:

p = _CategoricalFacetPlotter(
Expand All @@ -3587,6 +3607,7 @@ def catplot(
order=order,
orient=orient,
require_numeric=False,
legend=legend,
)

# XXX Copying a fair amount from displot, which is not ideal
Expand Down Expand Up @@ -3615,12 +3636,17 @@ def catplot(
**facet_kws,
)

# Capture this here because scale_categorical is going to insert a (null)
# x variable even if it is empty. It's not clear whether that needs to
# happen or if disabling that is the cleaner solution.
has_xy_data = p.has_xy_data

if not native_scale or p.var_types[p.cat_axis] == "categorical":
p.scale_categorical(p.cat_axis, order=order, formatter=formatter)

p._attach(g)

if not p.has_xy_data:
if not has_xy_data:
return g

palette, hue_order = p._hue_backcompat(color, palette, hue_order)
Expand Down
2 changes: 1 addition & 1 deletion seaborn/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def __init__(
legend=None
):

# TODO this is messy, we want the mapping to be agnoistic about
# TODO this is messy, we want the mapping to be agnostic about
# the kind of plot to draw, but for the time being we need to set
# this information so the SizeMapping can use it
self._default_size_range = (
Expand Down
23 changes: 20 additions & 3 deletions seaborn/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2028,6 +2028,24 @@ def test_three_points(self):
for point_color in ax.collections[0].get_facecolor():
assert tuple(point_color) == to_rgba("C0")

def test_legend_categorical(self, long_df):

ax = self.func(data=long_df, x="y", y="a", hue="b")
legend_texts = [t.get_text() for t in ax.legend_.texts]
expected = categorical_order(long_df["b"])
assert legend_texts == expected

def test_legend_numeric(self, long_df):

ax = self.func(data=long_df, x="y", y="a", hue="z")
vals = [float(t.get_text()) for t in ax.legend_.texts]
assert (vals[1] - vals[0]) == pytest.approx(vals[2] - vals[1])

def test_legend_disabled(self, long_df):

ax = self.func(data=long_df, x="y", y="a", hue="b", legend=False)
assert ax.legend_ is None

def test_palette_from_color_deprecation(self, long_df):

color = (.9, .4, .5)
Expand Down Expand Up @@ -2085,9 +2103,8 @@ def test_log_scale(self):
dict(data="wide", orient="h"),
dict(data="long", x="x", color="C3"),
dict(data="long", y="y", hue="a", jitter=False),
# TODO XXX full numeric hue legend crashes pinned mpl, disabling for now
# dict(data="long", x="a", y="y", hue="z", edgecolor="w", linewidth=.5),
# dict(data="long", x="a_cat", y="y", hue="z"),
dict(data="long", x="a", y="y", hue="z", edgecolor="w", linewidth=.5),
dict(data="long", x="a_cat", y="y", hue="z"),
dict(data="long", x="y", y="s", hue="c", orient="h", dodge=True),
dict(data="long", x="s", y="y", hue="c", native_scale=True),
]
Expand Down

0 comments on commit dce3150

Please sign in to comment.