Skip to content

Commit

Permalink
Add Plot.label method (#2902)
Browse files Browse the repository at this point in the history
* Add Plot.label method

* Satisfy mypy (I'm not sure I understand why it's confused here)

* Test legend title customization

* Refactor label resolution
  • Loading branch information
mwaskom authored Jul 14, 2022
1 parent 189577e commit 022e4bd
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 12 deletions.
65 changes: 55 additions & 10 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import textwrap
from collections import abc
from collections.abc import Callable, Generator, Hashable
from typing import Any, cast
from typing import Any, Optional, cast

import pandas as pd
from pandas import DataFrame, Series, Index
Expand Down Expand Up @@ -147,6 +147,7 @@ class Plot:

_scales: dict[str, Scale]
_limits: dict[str, tuple[Any, Any]]
_labels: dict[str, str | Callable[[str], str] | None]

_subplot_spec: dict[str, Any] # TODO values type
_facet_spec: FacetSpec
Expand All @@ -172,6 +173,7 @@ def __init__(

self._scales = {}
self._limits = {}
self._labels = {}

self._subplot_spec = {}
self._facet_spec = {}
Expand Down Expand Up @@ -552,8 +554,8 @@ def limit(self, **limits: tuple[Any, Any]) -> Plot:
Keywords correspond to variables defined in the plot, and values are a
(min, max) tuple (where either can be `None` to leave unset).
Limits apply only to the axis scale; data outside the limits are still
used in any stat transforms and added to the plot.
Limits apply only to the axis; data outside the visible range are
still used for any stat transforms and added to the plot.
Behavior for non-coordinate variables is currently undefined.
Expand All @@ -562,6 +564,25 @@ def limit(self, **limits: tuple[Any, Any]) -> Plot:
new._limits.update(limits)
return new

def label(self, **labels: str | Callable[[str], str] | None) -> Plot:
"""
Control the labels used for variables in the plot.
For coordinate variables, this sets the axis label.
For semantic variables, it sets the legend title.
Keywords correspond to variables defined in the plot.
Values can be one of the following types::
- string (used literally)
- function (called on the default label)
- None (disables the label for this variable)
"""
new = self._clone()
new._labels.update(labels)
return new

def configure(
self,
figsize: tuple[float, float] | None = None,
Expand Down Expand Up @@ -768,6 +789,20 @@ def _extract_data(self, p: Plot) -> tuple[PlotData, list[Layer]]:

return common_data, layers

def _resolve_label(self, p: Plot, var: str, auto_label: str | None) -> str | None:

label: str | None
if var in p._labels:
manual_label = p._labels[var]
if callable(manual_label) and auto_label is not None:
label = manual_label(auto_label)
else:
# mypy needs a lot of help here, I'm not sure why
label = cast(Optional[str], manual_label)
else:
label = auto_label
return label

def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:

# --- Parsing the faceting/pairing parameterization to specify figure grid
Expand Down Expand Up @@ -797,6 +832,9 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
ax = sub["ax"]
for axis in "xy":
axis_key = sub[axis]

# ~~ Axis labels

# TODO Should we make it possible to use only one x/y label for
# all rows/columns in a faceted plot? Maybe using sub{axis}label,
# although the alignments of the labels from that method leaves
Expand All @@ -805,9 +843,12 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
common.names.get(axis_key),
*(layer["data"].names.get(axis_key) for layer in layers)
]
label = next((name for name in names if name is not None), None)
auto_label = next((name for name in names if name is not None), None)
label = self._resolve_label(p, axis_key, auto_label)
ax.set(**{f"{axis}label": label})

# ~~ Decoration visibility

# TODO there should be some override (in Plot.configure?) so that
# tick labels can be shown on interior shared axes
axis_obj = getattr(ax, f"{axis}axis")
Expand Down Expand Up @@ -1151,9 +1192,7 @@ def get_order(var):
df = self._unscale_coords(subplots, df, orient)

grouping_vars = mark._grouping_props + default_grouping_vars
split_generator = self._setup_split_generator(
grouping_vars, df, subplots
)
split_generator = self._setup_split_generator(grouping_vars, df, subplots)

mark._plot(split_generator, scales, orient)

Expand All @@ -1162,7 +1201,7 @@ def get_order(var):
view["ax"].autoscale_view()

if layer["legend"]:
self._update_legend_contents(mark, data, scales)
self._update_legend_contents(p, mark, data, scales)

def _scale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame:
# TODO stricter type on subplots
Expand Down Expand Up @@ -1357,7 +1396,11 @@ def split_generator(keep_na=False) -> Generator:
return split_generator

def _update_legend_contents(
self, mark: Mark, data: PlotData, scales: dict[str, Scale]
self,
p: Plot,
mark: Mark,
data: PlotData,
scales: dict[str, Scale],
) -> None:
"""Add legend artists / labels for one layer in the plot."""
if data.frame.empty and data.frames:
Expand All @@ -1382,7 +1425,9 @@ def _update_legend_contents(
part_vars.append(var)
break
else:
entry = (data.names[var], data.ids[var]), [var], (values, labels)
auto_title = data.names[var]
title = self._resolve_label(p, var, auto_title)
entry = (title, data.ids[var]), [var], (values, labels)
schema.append(entry)

# Second pass, generate an artist corresponding to each value
Expand Down
35 changes: 33 additions & 2 deletions tests/_core/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,8 +999,8 @@ def test_limits(self, long_df):

limit = (-2, 24)
p = Plot(long_df, x="x", y="y").limit(x=limit).plot()
ax1 = p._figure.axes[0]
assert ax1.get_xlim() == limit
ax = p._figure.axes[0]
assert ax.get_xlim() == limit

limit = (np.datetime64("2005-01-01"), np.datetime64("2008-01-01"))
p = Plot(long_df, x="d", y="y").limit(x=limit).plot()
Expand All @@ -1012,6 +1012,30 @@ def test_limits(self, long_df):
ax = p._figure.axes[0]
assert ax.get_xlim() == (0.5, 2.5)

def test_labels_axis(self, long_df):

label = "Y axis"
p = Plot(long_df, x="x", y="y").label(y=label).plot()
ax = p._figure.axes[0]
assert ax.get_ylabel() == label

label = str.capitalize
p = Plot(long_df, x="x", y="y").label(y=label).plot()
ax = p._figure.axes[0]
assert ax.get_ylabel() == "Y"

def test_labels_legend(self, long_df):

m = MockMark()

label = "A"
p = Plot(long_df, x="x", y="y", color="a").add(m).label(color=label).plot()
assert p._figure.legends[0].get_title().get_text() == label

func = str.capitalize
p = Plot(long_df, x="x", y="y", color="a").add(m).label(color=func).plot()
assert p._figure.legends[0].get_title().get_text() == label


class TestFacetInterface:

Expand Down Expand Up @@ -1406,6 +1430,13 @@ def test_limits(self, long_df):
ax1 = p._figure.axes[1]
assert ax1.get_xlim() == limit

def test_labels(self, long_df):

label = "Z"
p = Plot(long_df, y="y").pair(x=["x", "z"]).label(x1=label).plot()
ax1 = p._figure.axes[1]
assert ax1.get_xlabel() == label


class TestLabelVisibility:

Expand Down

0 comments on commit 022e4bd

Please sign in to comment.