Skip to content

Commit

Permalink
Continue building out subplots tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Jul 16, 2021
1 parent 5f4b67d commit b0ce0e7
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 7 deletions.
8 changes: 3 additions & 5 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,8 @@ def pair(
#
# - Implementing this will require lots of downscale changes in figure setup,
# and especially the axis scaling, which will need to be pair specific
#
# - Do we want to allow lists of vectors to define the pairing? Everywhere
# else we have a variable specification, we accept Hashable | Vector
# - Ideally this SHOULD work without special handling now. But it does not
# because things downstream are not thought out clearly.

# TODO lists of vectors currently work, but I'm not sure where best to test

# TODO add data kwarg here? (it's everywhere else...)

Expand Down Expand Up @@ -197,6 +194,7 @@ def pair(
if keys:
pairspec["structure"][axis] = keys

# TODO raise here if cartesian is False and len(x) != len(y)?
pairspec["cartesian"] = cartesian
pairspec["wrap"] = wrap

Expand Down
9 changes: 8 additions & 1 deletion seaborn/_core/subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,15 @@ def init_figure(self, pyplot: bool): # TODO figsize param or figure_kws dict?

self._subplot_list.append(info)

return figure

# TODO moving the parts of existing code that depend on data or Plot, so they
# need to be implemented in a separate method (or in Plot._setup_figure)

def __iter__(self):
pass # TODO use this for looping over each subplot?

yield from self._subplot_list

def __len__(self):

return len(self._subplot_list)
129 changes: 128 additions & 1 deletion seaborn/tests/_core/test_subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,18 @@ def test_wrapped_x_pairing_and_facetd_rows(self, long_df):
Subplots({}, {}, {"x": ["x", "y", "z"], "wrap": 2}, data)


class TestGridDimensions:
class TestSubplotSpec:

def test_single_subplot(self, long_df):

data = PlotData(long_df, {"x": "x", "y": "y"})
s = Subplots({}, {}, {}, data)

assert s.n_subplots == 1
assert s.subplot_spec["ncols"] == 1
assert s.subplot_spec["nrows"] == 1
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is True

def test_single_facet(self, long_df):

Expand All @@ -55,6 +66,8 @@ def test_single_facet(self, long_df):
assert s.n_subplots == n_levels
assert s.subplot_spec["ncols"] == n_levels
assert s.subplot_spec["nrows"] == 1
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is True

def test_two_facets(self, long_df):

Expand All @@ -68,6 +81,8 @@ def test_two_facets(self, long_df):
assert s.n_subplots == n_cols * n_rows
assert s.subplot_spec["ncols"] == n_cols
assert s.subplot_spec["nrows"] == n_rows
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is True

def test_col_facet_wrapped(self, long_df):

Expand All @@ -80,6 +95,8 @@ def test_col_facet_wrapped(self, long_df):
assert s.n_subplots == n_levels
assert s.subplot_spec["ncols"] == wrap
assert s.subplot_spec["nrows"] == n_levels // wrap + 1
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is True

def test_row_facet_wrapped(self, long_df):

Expand All @@ -92,6 +109,8 @@ def test_row_facet_wrapped(self, long_df):
assert s.n_subplots == n_levels
assert s.subplot_spec["ncols"] == n_levels // wrap + 1
assert s.subplot_spec["nrows"] == wrap
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is True

def test_col_facet_wrapped_single_row(self, long_df):

Expand All @@ -104,6 +123,8 @@ def test_col_facet_wrapped_single_row(self, long_df):
assert s.n_subplots == n_levels
assert s.subplot_spec["ncols"] == n_levels
assert s.subplot_spec["nrows"] == 1
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is True

def test_x_and_y_paired(self, long_df):

Expand All @@ -115,6 +136,8 @@ def test_x_and_y_paired(self, long_df):
assert s.n_subplots == len(x) * len(y)
assert s.subplot_spec["ncols"] == len(x)
assert s.subplot_spec["nrows"] == len(y)
assert s.subplot_spec["sharex"] == "col"
assert s.subplot_spec["sharey"] == "row"

def test_x_paired(self, long_df):

Expand All @@ -125,6 +148,8 @@ def test_x_paired(self, long_df):
assert s.n_subplots == len(x)
assert s.subplot_spec["ncols"] == len(x)
assert s.subplot_spec["nrows"] == 1
assert s.subplot_spec["sharex"] == "col"
assert s.subplot_spec["sharey"] is True

def test_y_paired(self, long_df):

Expand All @@ -135,6 +160,8 @@ def test_y_paired(self, long_df):
assert s.n_subplots == len(y)
assert s.subplot_spec["ncols"] == 1
assert s.subplot_spec["nrows"] == len(y)
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] == "row"

def test_x_paired_and_wrapped(self, long_df):

Expand All @@ -146,6 +173,8 @@ def test_x_paired_and_wrapped(self, long_df):
assert s.n_subplots == len(x)
assert s.subplot_spec["ncols"] == wrap
assert s.subplot_spec["nrows"] == len(x) // wrap + 1
assert s.subplot_spec["sharex"] is False
assert s.subplot_spec["sharey"] is True

def test_y_paired_and_wrapped(self, long_df):

Expand All @@ -157,3 +186,101 @@ def test_y_paired_and_wrapped(self, long_df):
assert s.n_subplots == len(y)
assert s.subplot_spec["ncols"] == len(y) // wrap + 1
assert s.subplot_spec["nrows"] == wrap
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] is False

def test_col_faceted_y_paired(self, long_df):

y = ["x", "y", "z"]
key = "a"
data = PlotData(long_df, {"x": "f", "col": key})
s = Subplots({}, {}, {"y": y}, data)

n_levels = len(categorical_order(long_df[key]))
assert s.n_subplots == n_levels * len(y)
assert s.subplot_spec["ncols"] == n_levels
assert s.subplot_spec["nrows"] == len(y)
assert s.subplot_spec["sharex"] is True
assert s.subplot_spec["sharey"] == "row"

def test_row_faceted_x_paired(self, long_df):

x = ["f", "s"]
key = "a"
data = PlotData(long_df, {"y": "z", "row": key})
s = Subplots({}, {}, {"x": x}, data)

n_levels = len(categorical_order(long_df[key]))
assert s.n_subplots == n_levels * len(x)
assert s.subplot_spec["ncols"] == len(x)
assert s.subplot_spec["nrows"] == n_levels
assert s.subplot_spec["sharex"] == "col"
assert s.subplot_spec["sharey"] is True

def test_x_any_y_paired_non_cartesian(self, long_df):

x = ["a", "b", "c"]
y = ["x", "y", "z"]

data = PlotData(long_df, {})
s = Subplots({}, {}, {"x": x, "y": y, "cartesian": False}, data)

assert s.n_subplots == len(x)
assert s.subplot_spec["ncols"] == len(y)
assert s.subplot_spec["nrows"] == 1
assert s.subplot_spec["sharex"] is False
assert s.subplot_spec["sharey"] is False

def test_x_any_y_paired_non_cartesian_wrapped(self, long_df):

x = ["a", "b", "c"]
y = ["x", "y", "z"]
wrap = 2

data = PlotData(long_df, {})
s = Subplots({}, {}, {"x": x, "y": y, "cartesian": False, "wrap": wrap}, data)

assert s.n_subplots == len(x)
assert s.subplot_spec["ncols"] == wrap
assert s.subplot_spec["nrows"] == len(x) // wrap + 1
assert s.subplot_spec["sharex"] is False
assert s.subplot_spec["sharey"] is False


class TestSubplotElements:

def test_single_subplot(self, long_df):

data = PlotData(long_df, {"x": "x", "y": "y"})
s = Subplots({}, {}, {}, data)
f = s.init_figure(False)

assert len(s) == 1
for i, e in enumerate(s):
for side in ["left", "right", "bottom", "top"]:
assert e[side]
for dim in ["col", "row"]:
assert e[dim] is None
for axis in "xy":
assert e[axis] == axis
assert e["ax"] == f.axes[i]

@pytest.mark.parametrize("dim", ["col", "row"])
def test_single_facet_dim(self, long_df, dim):

key = "a"
data = PlotData(long_df, {"x": "x", "y": "y", dim: key})
s = Subplots({}, {}, {}, data)
s.init_figure(False)

levels = categorical_order(long_df[key])
assert len(s) == len(levels)

for i, e in enumerate(s):
assert e[dim] == levels[i]
for axis in "xy":
assert e[axis] == axis
assert e["top"] == (dim == "col" or i == 0)
assert e["bottom"] == (dim == "col" or i == len(levels) - 1)
assert e["left"] == (dim == "row" or i == 0)
assert e["right"] == (dim == "row" or i == len(levels) - 1)

0 comments on commit b0ce0e7

Please sign in to comment.