Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve legend for categorical scatterplots #2828

Merged
merged 7 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 57 additions & 31 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@
import matplotlib.patches as Patches
import matplotlib.pyplot as plt

from ._oldcore import (
VectorPlotter,
from seaborn._oldcore import (
variable_type,
infer_orient,
categorical_order,
)
from . import utils
from .utils import remove_na, _normal_quantile_func, _draw_figure, _default_color
from .algorithms import bootstrap
from .palettes import color_palette, husl_palette, light_palette, dark_palette
from .axisgrid import FacetGrid, _facet_docs
from seaborn.relational import _RelationalPlotter
from seaborn import utils
from seaborn.utils import remove_na, _normal_quantile_func, _draw_figure, _default_color
from seaborn.algorithms import bootstrap
from seaborn.palettes import color_palette, husl_palette, light_palette, dark_palette
from seaborn.axisgrid import FacetGrid, _facet_docs


__all__ = [
Expand All @@ -39,20 +39,26 @@
]


class _CategoricalPlotterNew(VectorPlotter):
# Subclassing _RelationalPlotter for the legend machinery,
# but probably should move that more centrally
class _CategoricalPlotterNew(_RelationalPlotter):

semantics = "x", "y", "hue", "units"

wide_structure = {"x": "@columns", "y": "@values", "hue": "@columns"}
flat_structure = {"x": "@index", "y": "@values"}

_legend_func = "scatter"
_legend_attributes = ["color"]

def __init__(
self,
data=None,
variables={},
order=None,
orient=None,
require_numeric=False,
legend="auto",
):

super().__init__(data=data, variables=variables)
Expand All @@ -75,6 +81,8 @@ def __init__(
require_numeric=require_numeric,
)

self.legend = legend

# Short-circuit in the case of an empty plot
if not self.has_xy_data:
return
Expand Down Expand Up @@ -172,6 +180,12 @@ def _adjust_cat_axis(self, ax, axis):
if self.var_types[axis] != "categorical":
return

# If both x/y data are empty, the correct way to set up the plot is
# somewhat undefined; because we don't add null category data to the plot in
# this case we don't *have* a categorical axis (yet), so best to just bail.
if self.plot_data[axis].empty:
return

# We can infer the total number of categories (including those from previous
# plots that are not part of the plot we are currently making) from the number
# of ticks, which matplotlib sets up while doing unit conversion. This feels
Expand Down Expand Up @@ -248,8 +262,7 @@ def plot_strips(
for sub_vars, sub_data in self.iter_data(iter_vars,
from_comp_data=True,
allow_empty=True):

if offsets is not None:
if offsets is not None and (offsets != 0).any():
dodge_move = offsets[sub_data["hue"].map(self._hue_map.levels.index)]

jitter_move = jitterer(size=len(sub_data)) if len(sub_data) > 1 else 0
Expand All @@ -272,13 +285,17 @@ def plot_strips(
else:
points.set_edgecolors(edgecolor)

# TODO XXX fully implement legend
show_legend = not self._redundant_hue and self.input_format != "wide"
if "hue" in self.variables and show_legend:
for level in self._hue_map.levels:
color = self._hue_map(level)
ax.scatter([], [], s=60, color=mpl.colors.rgb2hex(color), label=level)
ax.legend(loc="best", title=self.variables["hue"])
# Finalize the axes details
if self.legend == "auto":
show_legend = not self._redundant_hue and self.input_format != "wide"
else:
show_legend = bool(self.legend)

if show_legend:
self.add_legend_data(ax)
handles, _ = ax.get_legend_handles_labels()
if handles:
ax.legend(title=self.legend_title)

def plot_swarms(
self,
Expand Down Expand Up @@ -361,13 +378,17 @@ def draw(points, renderer, *, center=center):

_draw_figure(ax.figure)

# TODO XXX fully implement legend
show_legend = not self._redundant_hue and self.input_format != "wide"
if "hue" in self.variables and show_legend: # TODO and legend:
for level in self._hue_map.levels:
color = self._hue_map(level)
ax.scatter([], [], s=60, color=mpl.colors.rgb2hex(color), label=level)
ax.legend(loc="best", title=self.variables["hue"])
# Finalize the axes details
if self.legend == "auto":
show_legend = not self._redundant_hue and self.input_format != "wide"
else:
show_legend = bool(self.legend)

if show_legend:
self.add_legend_data(ax)
handles, _ = ax.get_legend_handles_labels()
if handles:
ax.legend(title=self.legend_title)


class _CategoricalFacetPlotter(_CategoricalPlotterNew):
Expand Down Expand Up @@ -2747,18 +2768,17 @@ def stripplot(
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
jitter=True, dodge=False, orient=None, color=None, palette=None,
size=5, edgecolor="gray", linewidth=0, ax=None,
hue_norm=None, native_scale=False, formatter=None,
hue_norm=None, native_scale=False, formatter=None, legend="auto",
**kwargs
):

# TODO XXX we need to add a legend= param!!!

p = _CategoricalPlotterNew(
data=data,
variables=_CategoricalPlotterNew.get_semantics(locals()),
order=order,
orient=orient,
require_numeric=False,
legend=legend,
)

if ax is None:
Expand Down Expand Up @@ -2869,7 +2889,7 @@ def swarmplot(
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
dodge=False, orient=None, color=None, palette=None,
size=5, edgecolor="gray", linewidth=0, ax=None,
hue_norm=None, native_scale=False, formatter=None, warn_thresh=.05,
hue_norm=None, native_scale=False, formatter=None, legend="auto", warn_thresh=.05,
**kwargs
):

Expand All @@ -2879,6 +2899,7 @@ def swarmplot(
order=order,
orient=orient,
require_numeric=False,
legend=legend,
)

if ax is None:
Expand Down Expand Up @@ -3548,7 +3569,7 @@ def catplot(
units=None, seed=None, order=None, hue_order=None, row_order=None,
col_order=None, kind="strip", height=5, aspect=1,
orient=None, color=None, palette=None,
legend=True, legend_out=True, sharex=True, sharey=True,
legend="auto", legend_out=True, sharex=True, sharey=True,
margin_titles=False, facet_kws=None,
hue_norm=None, native_scale=False, formatter=None,
**kwargs
Expand Down Expand Up @@ -3578,7 +3599,6 @@ def catplot(
refactored_kinds = [
"strip", "swarm",
]

if kind in refactored_kinds:

p = _CategoricalFacetPlotter(
Expand All @@ -3587,6 +3607,7 @@ def catplot(
order=order,
orient=orient,
require_numeric=False,
legend=legend,
)

# XXX Copying a fair amount from displot, which is not ideal
Expand Down Expand Up @@ -3615,12 +3636,17 @@ def catplot(
**facet_kws,
)

# Capture this here because scale_categorical is going to insert a (null)
# x variable even if it is empty. It's not clear whether that needs to
# happen or if disabling that is the cleaner solution.
has_xy_data = p.has_xy_data

if not native_scale or p.var_types[p.cat_axis] == "categorical":
p.scale_categorical(p.cat_axis, order=order, formatter=formatter)

p._attach(g)

if not p.has_xy_data:
if not has_xy_data:
return g

palette, hue_order = p._hue_backcompat(color, palette, hue_order)
Expand Down
2 changes: 1 addition & 1 deletion seaborn/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def __init__(
legend=None
):

# TODO this is messy, we want the mapping to be agnoistic about
# TODO this is messy, we want the mapping to be agnostic about
# the kind of plot to draw, but for the time being we need to set
# this information so the SizeMapping can use it
self._default_size_range = (
Expand Down
23 changes: 20 additions & 3 deletions seaborn/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2028,6 +2028,24 @@ def test_three_points(self):
for point_color in ax.collections[0].get_facecolor():
assert tuple(point_color) == to_rgba("C0")

def test_legend_categorical(self, long_df):

ax = self.func(data=long_df, x="y", y="a", hue="b")
legend_texts = [t.get_text() for t in ax.legend_.texts]
expected = categorical_order(long_df["b"])
assert legend_texts == expected

def test_legend_numeric(self, long_df):

ax = self.func(data=long_df, x="y", y="a", hue="z")
vals = [float(t.get_text()) for t in ax.legend_.texts]
assert (vals[1] - vals[0]) == pytest.approx(vals[2] - vals[1])

def test_legend_disabled(self, long_df):

ax = self.func(data=long_df, x="y", y="a", hue="b", legend=False)
assert ax.legend_ is None

def test_palette_from_color_deprecation(self, long_df):

color = (.9, .4, .5)
Expand Down Expand Up @@ -2085,9 +2103,8 @@ def test_log_scale(self):
dict(data="wide", orient="h"),
dict(data="long", x="x", color="C3"),
dict(data="long", y="y", hue="a", jitter=False),
# TODO XXX full numeric hue legend crashes pinned mpl, disabling for now
# dict(data="long", x="a", y="y", hue="z", edgecolor="w", linewidth=.5),
# dict(data="long", x="a_cat", y="y", hue="z"),
dict(data="long", x="a", y="y", hue="z", edgecolor="w", linewidth=.5),
dict(data="long", x="a_cat", y="y", hue="z"),
dict(data="long", x="y", y="s", hue="c", orient="h", dodge=True),
dict(data="long", x="s", y="y", hue="c", native_scale=True),
]
Expand Down