Skip to content

Commit

Permalink
Add a move_legend convenience function
Browse files Browse the repository at this point in the history
This addresses issues discussed in #2280, along with some of the issues in #2231

It is a somewhat hack-ish solution. Because matplotlib legends don't offer public
control over their location, this copies data from an existing legend to a new
object, and then removes the original legend. I am hopeful that there will be
upstream changes that make legend repositioning more natural, but this is
a reasonable stopgap measure to alleviate a common seaborn pain-point.
  • Loading branch information
mwaskom committed Aug 15, 2021
1 parent 091f4c0 commit 57a08c9
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 5 deletions.
9 changes: 5 additions & 4 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,11 @@ Utility functions
:toctree: generated/
:nosignatures:

load_dataset
get_dataset_names
get_data_home
despine
desaturate
move_legend
saturate
desaturate
set_hls_values
load_dataset
get_dataset_names
get_data_home
156 changes: 156 additions & 0 deletions doc/docstrings/move_legend.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "8ec46ad8-bc4c-4ee0-9626-271088c702f9",
"metadata": {
"tags": [
"hide"
]
},
"outputs": [],
"source": [
"import seaborn as sns\n",
"sns.set_theme()\n",
"penguins = sns.load_dataset(\"penguins\")"
]
},
{
"cell_type": "raw",
"id": "008bdd98-88cb-4a81-9f50-9b0e5a357305",
"metadata": {},
"source": [
"For axes-level functions, pass the :class:`matplotlib.axes.Axes` object and provide a new location."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b82e58f9-b15d-4554-bee5-de6a689344a6",
"metadata": {},
"outputs": [],
"source": [
"ax = sns.histplot(penguins, x=\"bill_length_mm\", hue=\"species\")\n",
"sns.move_legend(ax, \"center right\")"
]
},
{
"cell_type": "raw",
"id": "4f2a7f5d-ab39-46c7-87f4-532e607adf0b",
"metadata": {},
"source": [
"Use the `bbox_to_anchor` parameter for more fine-grained control, including moving the legend outside of the axes:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed610a98-447a-4459-8342-48abc80330f0",
"metadata": {},
"outputs": [],
"source": [
"ax = sns.histplot(penguins, x=\"bill_length_mm\", hue=\"species\")\n",
"sns.move_legend(ax, \"upper left\", bbox_to_anchor=(1, 1))"
]
},
{
"cell_type": "raw",
"id": "9d2fd766-a806-45d9-949d-1572991cf512",
"metadata": {},
"source": [
"Pass additional :meth:`matplotlib.axes.Axes.legend` parameters to update other properties:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ad4342c-c46e-49e9-98a2-6c88c6fb4c54",
"metadata": {},
"outputs": [],
"source": [
"ax = sns.histplot(penguins, x=\"bill_length_mm\", hue=\"species\")\n",
"sns.move_legend(\n",
" ax, \"lower center\",\n",
" bbox_to_anchor=(.5, 1), ncol=3, title=None, frameon=False,\n",
")"
]
},
{
"cell_type": "raw",
"id": "0d573092-46fd-4a95-b7ed-7e6833823adc",
"metadata": {},
"source": [
"It's also possible to move the legend created by a figure-level function. But when fine-tuning the position, you must bear in mind that the figure will have extra blank space on the right:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b258a9b8-69e5-4d4a-94cb-5b6baddc402b",
"metadata": {},
"outputs": [],
"source": [
"g = sns.displot(\n",
" penguins,\n",
" x=\"bill_length_mm\", hue=\"species\",\n",
" col=\"island\", col_wrap=2, height=3,\n",
")\n",
"sns.move_legend(g, \"upper left\", bbox_to_anchor=(.55, .45))"
]
},
{
"cell_type": "raw",
"id": "c9dc54e2-2c66-412f-ab2a-4f2bc2cb5782",
"metadata": {},
"source": [
"One way to avoid this would be to set `legend_out=False` on the :class:`seaborn.FacetGrid`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06cff408-4cdf-47af-8def-176f3e70ec5a",
"metadata": {},
"outputs": [],
"source": [
"g = sns.displot(\n",
" penguins,\n",
" x=\"bill_length_mm\", hue=\"species\",\n",
" col=\"island\", col_wrap=2, height=3,\n",
" facet_kws=dict(legend_out=False),\n",
")\n",
"sns.move_legend(g, \"upper left\", bbox_to_anchor=(.55, .45), frameon=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b170f20d-22a9-4f7d-917a-d09e10b1f08c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "seaborn-py38-latest",
"language": "python",
"name": "seaborn-py38-latest"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
90 changes: 90 additions & 0 deletions seaborn/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,96 @@ def test_locator_to_legend_entries():
assert str_levels == ['1e-04', '1e-03', '1e-02']


def test_move_legend_matplotlib_objects():

fig, ax = plt.subplots()

colors = "C2", "C5"
labels = "first label", "second label"
title = "the legend"

for color, label in zip(colors, labels):
ax.plot([0, 1], color=color, label=label)
ax.legend(loc="upper right", title=title)
utils._draw_figure(fig)
xfm = ax.transAxes.inverted().transform

# --- Test axes legend

old_pos = xfm(ax.legend_.legendPatch.get_extents())

new_fontsize = 14
utils.move_legend(ax, "lower left", title_fontsize=new_fontsize)
utils._draw_figure(fig)
new_pos = xfm(ax.legend_.legendPatch.get_extents())

assert (new_pos < old_pos).all()
assert ax.legend_.get_title().get_text() == title
assert ax.legend_.get_title().get_size() == new_fontsize

# --- Test title replacement

new_title = "new title"
utils.move_legend(ax, "lower left", title=new_title)
utils._draw_figure(fig)
assert ax.legend_.get_title().get_text() == new_title

# --- Test figure legend

fig.legend(loc="upper right", title=title)
_draw_figure(fig)
xfm = fig.transFigure.inverted().transform
old_pos = xfm(fig.legends[0].legendPatch.get_extents())

utils.move_legend(fig, "lower left", title=new_title)
_draw_figure(fig)

new_pos = xfm(fig.legends[0].legendPatch.get_extents())
assert (new_pos < old_pos).all()
assert fig.legends[0].get_title().get_text() == new_title


def test_move_legend_grid_object(long_df):

from seaborn.axisgrid import FacetGrid

hue_var = "a"
g = FacetGrid(long_df, hue=hue_var)
g.map(plt.plot, "x", "y")

g.add_legend()
_draw_figure(g.figure)

xfm = g.figure.transFigure.inverted().transform
old_pos = xfm(g.legend.legendPatch.get_extents())

fontsize = 20
utils.move_legend(g, "lower left", title_fontsize=fontsize)
_draw_figure(g.figure)

new_pos = xfm(g.legend.legendPatch.get_extents())
assert (new_pos < old_pos).all()
assert g.legend.get_title().get_text() == hue_var
assert g.legend.get_title().get_size() == fontsize

assert g.legend.legendHandles
for i, h in enumerate(g.legend.legendHandles):
assert mpl.colors.to_rgb(h.get_color()) == mpl.colors.to_rgb(f"C{i}")


def test_move_legend_input_checks():

ax = plt.figure().subplots()
with pytest.raises(TypeError):
utils.move_legend(ax.xaxis, "best")

with pytest.raises(ValueError):
utils.move_legend(ax, "best")

with pytest.raises(ValueError):
utils.move_legend(ax.figure, "best")


def check_load_dataset(name):
ds = load_dataset(name, cache=False)
assert(isinstance(ds, pd.DataFrame))
Expand Down
92 changes: 91 additions & 1 deletion seaborn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from matplotlib.cbook import normalize_kwargs


__all__ = ["desaturate", "saturate", "set_hls_values",
__all__ = ["desaturate", "saturate", "set_hls_values", "move_legend",
"despine", "get_dataset_names", "get_data_home", "load_dataset"]


Expand Down Expand Up @@ -390,6 +390,96 @@ def despine(fig=None, ax=None, top=True, right=True, left=False,
ax_i.set_yticks(newticks)


def move_legend(obj, loc, **kwargs):
"""
Recreate a plot's legend at a new location.
The name is a slight misnomer. Matplotlib legends do not expose public
control over their position parameters. So this function creates a new legend,
copying over the data from the original object, which is then removed.
Parameters
----------
obj : the object with the plot
This argument can be either a seaborn or matplotlib object:
- :class:`seaborn.FacetGrid` or :class:`seaborn.PairGrid`
- :class:`matplotlib.axes.Axes` or :class:`matplotlib.figure.Figure`
loc : str or int
Location argument, as in :meth:`matplotlib.axes.Axes.legend`.
kwargs
Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.legend`.
Examples
--------
.. include:: ../docstrings/move_legend.rst
"""
# This is a somewhat hackish solution that will hopefully be obviated by
# upstream improvements to matplotlib legends that make them easier to
# modify after creation.

from seaborn.axisgrid import Grid # Avoid circular import

# Locate the legend object and a method to recreate the legend
if isinstance(obj, Grid):
old_legend = obj.legend
legend_func = obj.figure.legend
elif isinstance(obj, mpl.axes.Axes):
old_legend = obj.legend_
legend_func = obj.legend
elif isinstance(obj, mpl.figure.Figure):
if obj.legends:
old_legend = obj.legends[-1]
else:
old_legend = None
legend_func = obj.legend
else:
err = "`obj` must be a seaborn Grid or matplotlib Axes or Figure instance."
raise TypeError(err)

if old_legend is None:
err = f"{obj} has no legend attached."
raise ValueError(err)

# Extract the components of the legend we need to reuse
handles = old_legend.legendHandles
labels = [t.get_text() for t in old_legend.get_texts()]

# Extract legend properties that can be passed to the recreation method
# (Vexingly, these don't all round-trip)
legend_kws = inspect.signature(mpl.legend.Legend).parameters
props = {k: v for k, v in old_legend.properties().items() if k in legend_kws}

# Delegate default bbox_to_anchor rules to matplotlib
props.pop("bbox_to_anchor")

# Try to propagate the existing title and font properties; respect new ones too
title = props.pop("title")
if "title" in kwargs:
title.set_text(kwargs.pop("title"))
title_kwargs = {k: v for k, v in kwargs.items() if k.startswith("title_")}
for key, val in title_kwargs.items():
title.set(**{key[6:]: val})
kwargs.pop(key)

# Try to respect the frame visibility
kwargs.setdefault("frameon", old_legend.legendPatch.get_visible())

# Remove the old legend and create the new one
props.update(kwargs)
old_legend.remove()
new_legend = legend_func(handles, labels, loc=loc, **props)
new_legend.set_title(title.get_text(), title.get_fontproperties())

# Let the Grid object continue to track the correct legend object
if isinstance(obj, Grid):
obj._legend = new_legend


def _kde_support(data, bw, gridsize, cut, clip):
"""Establish support for a kernel density estimate."""
support_min = max(data.min() - bw * cut, clip[0])
Expand Down

0 comments on commit 57a08c9

Please sign in to comment.