Skip to content

Commit

Permalink
Complete subplots tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Jul 17, 2021
1 parent 8ceb7e6 commit 841a3c9
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 9 deletions.
14 changes: 8 additions & 6 deletions seaborn/_core/subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Generator
from matplotlib.figure import Figure
from seaborn._core.data import PlotData


Expand All @@ -33,6 +35,7 @@ def __init__(

def _check_dimension_uniqueness(self, data: PlotData) -> None:
"""Reject specs that pair and facet on (or wrap to) same figure dimension."""
err = None
collisions = {"x": ["columns", "rows"], "y": ["rows", "columns"]}
for pair_axis, (multi_dim, wrap_dim) in collisions.items():
if self.facet_spec.get("wrap") and "col" in data and "row" in data:
Expand All @@ -52,12 +55,11 @@ def _check_dimension_uniqueness(self, data: PlotData) -> None:
err = f"Cannot wrap the {wrap_dim} while pairing on `{pair_axis}``."
elif wrap_dim[:3] in data and self.pair_spec.get("wrap"):
err = f"Cannot wrap the {multi_dim} while faceting the {wrap_dim}."
else:
continue

if err is not None:
raise RuntimeError(err) # TODO what err class? Define PlotSpecError?

def _determine_grid_dimensions(self, data: PlotData):
def _determine_grid_dimensions(self, data: PlotData) -> None:

self.grid_dimensions = {}
for dim, axis in zip(["col", "row"], ["x", "y"]):
Expand Down Expand Up @@ -116,7 +118,7 @@ def _determine_axis_sharing(self) -> None:
val = True
self.subplot_spec[key] = val

def init_figure(self, pyplot: bool): # TODO figsize param or figure_kws dict?
def init_figure(self, pyplot: bool) -> Figure: # TODO figure_kws dict?

figure_kws = {"constrained_layout": True} # TODO get from configure?

Expand Down Expand Up @@ -190,10 +192,10 @@ def init_figure(self, pyplot: bool): # TODO figsize param or figure_kws dict?
# TODO moving the parts of existing code that depend on data or Plot, so they
# need to be implemented in a separate method (or in Plot._setup_figure)

def __iter__(self):
def __iter__(self) -> Generator[dict, None, None]: # TODO TypedDict?

yield from self._subplot_list

def __len__(self):
def __len__(self) -> int:

return len(self._subplot_list)
93 changes: 90 additions & 3 deletions seaborn/tests/_core/test_subplots.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import numpy as np
import pytest

from seaborn._core.data import PlotData
Expand Down Expand Up @@ -247,6 +248,13 @@ def test_x_any_y_paired_non_cartesian_wrapped(self, long_df):
assert s.subplot_spec["sharex"] is False
assert s.subplot_spec["sharey"] is False

def test_forced_unshared_facets(self, long_df):

data = PlotData(long_df, {"col": "a", "row": "f"})
s = Subplots({"sharex": False, "sharey": "row"}, {}, {}, data)
assert s.subplot_spec["sharex"] is False
assert s.subplot_spec["sharey"] == "row"


class TestSubplotElements:

Expand Down Expand Up @@ -381,7 +389,34 @@ def test_single_paired_var(self, long_df, var):
@pytest.mark.parametrize("var", ["x", "y"])
def test_single_paired_var_wrapped(self, long_df, var):

... # TODO
other_var = {"x": "y", "y": "x"}[var]
variables = {other_var: "a"}
pairings = ["x", "y", "z", "a", "b"]
wrap = len(pairings) - 2
pair_spec = {var: pairings, "wrap": wrap}
data = PlotData(long_df, variables)
s = Subplots({}, {}, pair_spec, data)
s.init_figure(False)

assert len(s) == len(pairings)

for i, e in enumerate(s):
assert e[var] == f"{var}{i}"
assert e[other_var] == other_var
assert e["col"] is e["row"] is None

tests = (
i < wrap,
i >= wrap or i >= len(s) % wrap,
i % wrap == 0,
i % wrap == wrap - 1 or i + 1 == len(s),
)
sides = {
"x": ["top", "bottom", "left", "right"],
"y": ["left", "right", "top", "bottom"],
}
for side, expected in zip(sides[var], tests):
assert e[side] == expected

def test_both_paired_variables(self, long_df):

Expand Down Expand Up @@ -414,6 +449,58 @@ def test_both_paired_variables(self, long_df):
assert e["x"] == f"x{j}"
assert e["y"] == f"y{i}"

def test_one_facet_one_paired(self, long_df):
def test_both_paired_non_cartesian(self, long_df):

pair_spec = {"x": ["a", "b", "c"], "y": ["x", "y", "z"], "cartesian": False}
data = PlotData(long_df, {})
s = Subplots({}, {}, pair_spec, data)
s.init_figure(False)

... # TODO
for i, e in enumerate(s):
assert e["x"] == f"x{i}"
assert e["y"] == f"y{i}"
assert e["col"] is e["row"] is None
assert e["left"] == (i == 0)
assert e["right"] == (i == (len(s) - 1))
assert e["top"]
assert e["bottom"]

@pytest.mark.parametrize("dim,var", [("col", "y"), ("row", "x")])
def test_one_facet_one_paired(self, long_df, dim, var):

other_var = {"x": "y", "y": "x"}[var]
other_dim = {"col": "row", "row": "col"}[dim]

variables = {other_var: "z", dim: "s"}
pairings = ["x", "y", "t"]
pair_spec = {var: pairings}

data = PlotData(long_df, variables)
s = Subplots({}, {}, pair_spec, data)
s.init_figure(False)

levels = categorical_order(long_df[variables[dim]])
n_cols = len(levels) if dim == "col" else len(pairings)
n_rows = len(levels) if dim == "row" else len(pairings)

assert len(s) == len(levels) * len(pairings)

es = list(s)

for e in es[:n_cols]:
assert e["top"]
for e in es[::n_cols]:
assert e["left"]
for e in es[n_cols - 1::n_cols]:
assert e["right"]
for e in es[-n_cols:]:
assert e["bottom"]

if dim == "row":
es = np.reshape(es, (n_rows, n_cols)).T.ravel()

for i, e in enumerate(es):
assert e[dim] == levels[i % len(pairings)]
assert e[other_dim] is None
assert e[var] == f"{var}{i // len(levels)}"
assert e[other_var] == other_var

0 comments on commit 841a3c9

Please sign in to comment.