Skip to content

Commit

Permalink
Integrate Subplots into Plot
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Jul 17, 2021
1 parent 841a3c9 commit 3c07f98
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 163 deletions.
197 changes: 51 additions & 146 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import itertools
from copy import deepcopy

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

from seaborn._core.rules import categorical_order, variable_type
from seaborn._core.data import PlotData
from seaborn._core.subplots import Subplots
from seaborn._core.mappings import GroupMapping, HueMapping
from seaborn._core.scales import (
ScaleWrapper,
Expand Down Expand Up @@ -205,7 +205,7 @@ def facet(
self,
col: VariableSpec = None,
row: VariableSpec = None,
col_order: OrderSpec = None,
col_order: OrderSpec = None, # TODO single order param
row_order: OrderSpec = None,
wrap: int | None = None,
data: DataSource = None,
Expand Down Expand Up @@ -448,155 +448,60 @@ def _setup_figure(self, pyplot: bool = False) -> None:
)
)

# Reject specs that pair and facet on (or wrap to) the same figure dimension
overlaps = {"x": ["columns", "rows"], "y": ["rows", "columns"]}
for pair_axis, (facet_dim, wrap_dim) in overlaps.items():

if pair_axis not in self._pairspec:
continue
elif facet_dim[:3] in setup_data:
err = f"Cannot facet on the {facet_dim} while pairing on {pair_axis}."
elif wrap_dim[:3] in setup_data and self._facetspec.get("wrap"):
err = f"Cannot wrap the {wrap_dim} while pairing on {pair_axis}."
elif wrap_dim[:3] in setup_data and self._pairspec.get("wrap"):
err = f"Cannot wrap the {facet_dim} while faceting on {wrap_dim}."
else:
continue
raise RuntimeError(err) # TODO what err class? Define PlotSpecError?

# --- Subplot grid parameterization

# TODO this method is getting quite long and complicated.
# I'd like to break it up, although adding a bunch more methods
# on Plot will make it hard to navigate. Should there be some sort of
# container class for the figure/subplots where that logic lives?

# TODO build this from self._subplotspec?
subplot_spec = {}

figure_dimensions = {}
for dim, axis in zip(["col", "row"], ["x", "y"]):

if dim in setup_data:
figure_dimensions[dim] = categorical_order(
setup_data.frame[dim], self._facetspec.get(f"{dim}_order"),
)
elif axis in self._pairspec:
figure_dimensions[dim] = self._pairspec[axis]
else:
figure_dimensions[dim] = [None]

subplot_spec[f"n{dim}s"] = len(figure_dimensions[dim])

if not self._pairspec.get("cartesian", True):
# TODO we need to re-enable axis/tick labels even when sharing
subplot_spec["nrows"] = 1

wrap = self._facetspec.get("wrap", self._pairspec.get("wrap"))
if wrap is not None:
wrap_dim = "row" if subplot_spec["nrows"] > 1 else "col"
flow_dim = {"row": "col", "col": "row"}[wrap_dim]
n_subplots = subplot_spec[f"n{wrap_dim}s"]
flow = int(np.ceil(n_subplots / wrap))
subplot_spec[f"n{wrap_dim}s"] = wrap
subplot_spec[f"n{flow_dim}s"] = flow
else:
n_subplots = subplot_spec["ncols"] * subplot_spec["nrows"]

# Work out the defaults for sharex/sharey
axis_to_dim = {"x": "col", "y": "row"}
for axis in "xy":
key = f"share{axis}"
if key in self._subplotspec: # Should we just be updating this?
val = self._subplotspec[key]
else:
if axis in self._pairspec:
if wrap in [None, 1] and self._pairspec.get("cartesian", True):
val = axis_to_dim[axis]
else:
val = False
else:
val = True
subplot_spec[key] = val
self._subplots = subplots = Subplots(
self._subplotspec, self._facetspec, self._pairspec, setup_data
)

# --- Figure initialization
figure_kws = {"figsize": getattr(self, "_figsize", None)} # TODO
self._figure = subplots.init_figure(pyplot, figure_kws)

figsize = getattr(self, "_figsize", None)

if pyplot:
self._figure = plt.figure(figsize=figsize)
else:
self._figure = mpl.figure.Figure(figsize=figsize)

subplots = self._figure.subplots(**subplot_spec, squeeze=False)

# --- Building the internal subplot list and add default decorations

self._subplot_list = []

if wrap is not None:
ravel_order = {"col": "C", "row": "F"}[wrap_dim]
subplots_flat = subplots.ravel(ravel_order)
subplots, extra = np.split(subplots_flat, [n_subplots])
for ax in extra:
ax.remove()
if wrap_dim == "col":
subplots = subplots[np.newaxis, :]
else:
subplots = subplots[:, np.newaxis]
if not self._pairspec or self._pairspec["cartesian"]:
iterplots = np.ndenumerate(subplots)
else:
indices = np.arange(n_subplots)
iterplots = zip(zip(indices, indices), subplots.flat)

for (i, j), ax in iterplots:

info = {"ax": ax}

for dim in ["row", "col"]:
idx = {"row": i, "col": j}[dim]
if dim in setup_data:
info[dim] = figure_dimensions[dim][idx]
else:
info[dim] = None

# --- Figure annotation
for sub in subplots:
ax = sub["ax"]
for axis in "xy":

idx = {"x": j, "y": i}[axis]
if axis in self._pairspec:
key = f"{axis}{idx}"
else:
key = axis
info[axis] = key

label = setup_data.names.get(key)
axis_key = sub[axis]
ax.set(**{
f"{axis}scale": self._scales[key]._scale,
f"{axis}label": label, # TODO we should do this elsewhere
# TODO this is the only non "annotation" part of this code
# everything else can happen after .plot(), but we need this first
# Should perhaps separate it out to make that more clear
# (or pass scales into Subplots)
f"{axis}scale": self._scales[axis_key]._scale,
# TODO we still need a way to label axes with names passed in layers
f"{axis}label": setup_data.names.get(axis_key)
})

self._subplot_list.append(info)

# Now do some individual subplot configuration
# TODO this could be moved to a different loop, here or in a subroutine

# TODO need to account for wrap, non-cartesian
if subplot_spec["sharex"] in (True, "col") and subplots.shape[0] - i > 1:
ax.xaxis.label.set_visible(False)
if subplot_spec["sharey"] in (True, "row") and j > 0:
ax.yaxis.label.set_visible(False)

# TODO should titles be set for each position along the pair dimension?
# (e.g., pair on y, facet on cols, should facet titles only go on top row?)
axis_obj = getattr(ax, f"{axis}axis")
if self._subplotspec.get("cartesian", True):
label_side = {"x": "bottom", "y": "left"}.get(axis)
visible = sub[label_side]
axis_obj.get_label().set_visible(visible)
# TODO check that this is the right way to set these attributes
plt.setp(axis_obj.get_majorticklabels(), visible=visible)
plt.setp(axis_obj.get_minorticklabels(), visible=visible)

# TODO title template should be configurable
# TODO Also we want right-side titles for row facets in most cases
# TODO should configure() accept a title= kwarg (for single subplot plots)?
title_parts = []
for idx, dim in zip([i, j], ["row", "col"]):
if dim in setup_data:
for dim in ["row", "col"]:
if sub[dim] is not None:
name = setup_data.names.get(dim, f"_{dim}_")
level = figure_dimensions[dim][idx]
title_parts.append(f"{name} = {level}")
title = " | ".join(title_parts)
ax.set_title(title)
title_parts.append(f"{name} = {sub[dim]}")

has_col = sub["col"] is not None
has_row = sub["row"] is not None
show_title = (
has_col and has_row
or (has_col or has_row) and self._facetspec.get("wrap")
or (has_col and sub["top"])
# TODO or has_row and sub["right"] and <right titles>
or has_row # TODO and not <right titles>
)
if title_parts:
title = " | ".join(title_parts)
title_text = ax.set_title(title)
title_text.set_visible(show_title)

def _setup_mappings(self) -> None:

Expand Down Expand Up @@ -757,7 +662,7 @@ def _generate_pairings(
pair_variables = self._pairspec.get("structure", {})

if not pair_variables:
yield self._subplot_list, df
yield list(self._subplots), df
return

iter_axes = itertools.product(*[
Expand All @@ -776,9 +681,9 @@ def _generate_pairings(
})

subplots = []
for s in self._subplot_list:
if (x is None or s["x"] == x) and (y is None or s["y"] == y):
subplots.append(s)
for sub in self._subplots:
if (x is None or sub["x"] == x) and (y is None or sub["y"] == y):
subplots.append(sub)

yield subplots, df.assign(**reassignments)

Expand Down
9 changes: 4 additions & 5 deletions seaborn/_core/subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,11 @@ def _determine_axis_sharing(self) -> None:
val = True
self.subplot_spec[key] = val

def init_figure(self, pyplot: bool) -> Figure: # TODO figure_kws dict?
def init_figure(self, pyplot: bool, figure_kws: dict | None = {}) -> Figure:
# TODO other methods don't have defaults, maybe don't have one here either

figure_kws = {"constrained_layout": True} # TODO get from configure?
if figure_kws is None:
figure_kws = {}

if pyplot:
figure = plt.figure(**figure_kws)
Expand Down Expand Up @@ -189,9 +191,6 @@ def init_figure(self, pyplot: bool) -> Figure: # TODO figure_kws dict?

return figure

# TODO moving the parts of existing code that depend on data or Plot, so they
# need to be implemented in a separate method (or in Plot._setup_figure)

def __iter__(self) -> Generator[dict, None, None]: # TODO TypedDict?

yield from self._subplot_list
Expand Down
26 changes: 14 additions & 12 deletions seaborn/tests/_core/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def test_matplotlib_object_creation(self):
p = Plot()
p._setup_figure()
assert isinstance(p._figure, mpl.figure.Figure)
for sub in p._subplot_list:
for sub in p._subplots:
assert isinstance(sub["ax"], mpl.axes.Axes)

def test_empty(self):
Expand All @@ -375,7 +375,7 @@ def test_single_split_single_layer(self, long_df):
assert m.n_splits == 1

assert m.passed_keys[0] == {}
assert m.passed_axes[0] is p._subplot_list[0]["ax"]
assert m.passed_axes == [sub["ax"] for sub in p._subplots]
assert_frame_equal(m.passed_data[0], p._data.frame)

def test_single_split_multi_layer(self, long_df):
Expand Down Expand Up @@ -435,7 +435,8 @@ def test_one_grouping_variable(self, long_df, split_var):
p = Plot(long_df, x="f", y="z", **{split_var: split_col}).add(m).plot()

split_keys = categorical_order(long_df[split_col])
assert m.passed_axes == [p._subplot_list[0]["ax"] for _ in split_keys]
sub, *_ = p._subplots
assert m.passed_axes == [sub["ax"] for _ in split_keys]
self.check_splits_single_var(p, m, split_var, split_keys)

def test_two_grouping_variables(self, long_df):
Expand All @@ -448,8 +449,9 @@ def test_two_grouping_variables(self, long_df):
p = Plot(long_df, y="z", **variables).add(m).plot()

split_keys = [categorical_order(long_df[col]) for col in split_cols]
sub, *_ = p._subplots
assert m.passed_axes == [
p._subplot_list[0]["ax"] for _ in itertools.product(*split_keys)
sub["ax"] for _ in itertools.product(*split_keys)
]
self.check_splits_multi_vars(p, m, split_vars, split_keys)

Expand Down Expand Up @@ -670,7 +672,7 @@ def check_facet_results_1d(self, p, df, dim, key, order=None):

other_dim = {"row": "col", "col": "row"}[dim]

for subplot, level in zip(p._subplot_list, order):
for subplot, level in zip(p._subplots, order):
assert subplot[dim] == level
assert subplot[other_dim] is None
assert subplot["ax"].get_title() == f"{key} = {level}"
Expand Down Expand Up @@ -722,9 +724,9 @@ def check_facet_results_2d(self, p, df, variables, order=None):
order = {dim: categorical_order(df[key]) for dim, key in variables.items()}

levels = itertools.product(*[order[dim] for dim in ["row", "col"]])
assert len(p._subplot_list) == len(list(levels))
assert len(p._subplots) == len(list(levels))

for subplot, (row_level, col_level) in zip(p._subplot_list, levels):
for subplot, (row_level, col_level) in zip(p._subplots, levels):
assert subplot["row"] == row_level
assert subplot["col"] == col_level
assert subplot["axes"].get_title() == (
Expand Down Expand Up @@ -838,7 +840,7 @@ def check_pair_grid(self, p, x, y):

xys = itertools.product(y, x)

for (y_i, x_j), subplot in zip(xys, p._subplot_list):
for (y_i, x_j), subplot in zip(xys, p._subplots):

ax = subplot["ax"]
assert ax.get_xlabel() == "" if x_j is None else x_j
Expand Down Expand Up @@ -880,7 +882,7 @@ def test_non_cartesian(self, long_df):

p = Plot(long_df).pair(x, y, cartesian=False).plot()

for i, subplot in enumerate(p._subplot_list):
for i, subplot in enumerate(p._subplots):
ax = subplot["ax"]
assert ax.get_xlabel() == x[i]
assert ax.get_ylabel() == y[i]
Expand Down Expand Up @@ -922,7 +924,7 @@ def test_with_facets(self, long_df):
facet_levels = categorical_order(long_df[col])
dims = itertools.product(y, facet_levels)

for (y_i, col_i), subplot in zip(dims, p._subplot_list):
for (y_i, col_i), subplot in zip(dims, p._subplots):

ax = subplot["ax"]
assert ax.get_xlabel() == x
Expand All @@ -938,7 +940,7 @@ def test_error_on_facet_overlap(self, long_df, variables):

facet_dim, pair_axis = variables
p = Plot(long_df, **{facet_dim[:3]: "a"}).pair(**{pair_axis: ["x", "y"]})
expected = f"Cannot facet on the {facet_dim} while pairing on {pair_axis}."
expected = f"Cannot facet the {facet_dim} while pairing on `{pair_axis}`."
with pytest.raises(RuntimeError, match=expected):
p.plot()

Expand All @@ -951,7 +953,7 @@ def test_error_on_wrap_overlap(self, long_df, variables):
.facet(wrap=2)
.pair(**{pair_axis: ["x", "y"]})
)
expected = f"Cannot wrap the {facet_dim} while pairing on {pair_axis}."
expected = f"Cannot wrap the {facet_dim} while pairing on `{pair_axis}``."
with pytest.raises(RuntimeError, match=expected):
p.plot()

Expand Down

0 comments on commit 3c07f98

Please sign in to comment.