Skip to content

Commit

Permalink
Refactor label resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Jul 13, 2022
1 parent 9493b0c commit 932b8c4
Showing 1 changed file with 18 additions and 23 deletions.
41 changes: 18 additions & 23 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,20 @@ def _extract_data(self, p: Plot) -> tuple[PlotData, list[Layer]]:

return common_data, layers

def _resolve_label(self, p: Plot, var: str, auto_label: str | None) -> str | None:

label: str | None
if var in p._labels:
manual_label = p._labels[var]
if callable(manual_label) and auto_label is not None:
label = manual_label(auto_label)
else:
# mypy needs a lot of help here, I'm not sure why
label = cast(Optional[str], manual_label)
else:
label = auto_label
return label

def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:

# --- Parsing the faceting/pairing parameterization to specify figure grid
Expand Down Expand Up @@ -830,16 +844,7 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
*(layer["data"].names.get(axis_key) for layer in layers)
]
auto_label = next((name for name in names if name is not None), None)
if axis_key in p._labels:
manual_label = p._labels[axis_key]
label: str | None
if callable(manual_label) and auto_label is not None:
label = manual_label(auto_label)
else:
# mypy needs a lot of help here, I'm not sure why
label = cast(Optional[str], manual_label)
else:
label = auto_label
label = self._resolve_label(p, axis_key, auto_label)
ax.set(**{f"{axis}label": label})

# ~~ Decoration visibility
Expand Down Expand Up @@ -1196,7 +1201,7 @@ def get_order(var):
view["ax"].autoscale_view()

if layer["legend"]:
self._update_legend_contents(mark, data, scales, p._labels)
self._update_legend_contents(p, mark, data, scales)

def _scale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame:
# TODO stricter type on subplots
Expand Down Expand Up @@ -1392,10 +1397,10 @@ def split_generator(keep_na=False) -> Generator:

def _update_legend_contents(
self,
p: Plot,
mark: Mark,
data: PlotData,
scales: dict[str, Scale],
titles: dict[str, str | Callable[[str], str] | None],
) -> None:
"""Add legend artists / labels for one layer in the plot."""
if data.frame.empty and data.frames:
Expand All @@ -1420,18 +1425,8 @@ def _update_legend_contents(
part_vars.append(var)
break
else:
# TODO copied from _setup_figure
auto_title = data.names[var]
if var in titles:
manual_title = titles[var]
title: str | None
if callable(manual_title) and auto_title is not None:
title = manual_title(auto_title)
else:
# mypy needs a lot of help here, I'm not sure why
title = cast(Optional[str], manual_title)
else:
title = auto_title
title = self._resolve_label(p, var, auto_title)
entry = (title, data.ids[var]), [var], (values, labels)
schema.append(entry)

Expand Down

0 comments on commit 932b8c4

Please sign in to comment.