From b0ce0e7a9e3534fdad04ef9e287e4c6bb19fe684 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Thu, 15 Jul 2021 21:35:21 -0400 Subject: [PATCH] Continue building out subplots tests --- seaborn/_core/plot.py | 8 +- seaborn/_core/subplots.py | 9 +- seaborn/tests/_core/test_subplots.py | 129 ++++++++++++++++++++++++++- 3 files changed, 139 insertions(+), 7 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 8da54368a2..cacabcca86 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -136,11 +136,8 @@ def pair( # # - Implementing this will require lots of downscale changes in figure setup, # and especially the axis scaling, which will need to be pair specific - # - # - Do we want to allow lists of vectors to define the pairing? Everywhere - # else we have a variable specification, we accept Hashable | Vector - # - Ideally this SHOULD work without special handling now. But it does not - # because things downstream are not thought out clearly. + + # TODO lists of vectors currently work, but I'm not sure where best to test # TODO add data kwarg here? (it's everywhere else...) @@ -197,6 +194,7 @@ def pair( if keys: pairspec["structure"][axis] = keys + # TODO raise here if cartesian is False and len(x) != len(y)? pairspec["cartesian"] = cartesian pairspec["wrap"] = wrap diff --git a/seaborn/_core/subplots.py b/seaborn/_core/subplots.py index 642b26aa80..2fd5f657be 100644 --- a/seaborn/_core/subplots.py +++ b/seaborn/_core/subplots.py @@ -185,8 +185,15 @@ def init_figure(self, pyplot: bool): # TODO figsize param or figure_kws dict? self._subplot_list.append(info) + return figure + # TODO moving the parts of existing code that depend on data or Plot, so they # need to be implemented in a separate method (or in Plot._setup_figure) def __iter__(self): - pass # TODO use this for looping over each subplot? + + yield from self._subplot_list + + def __len__(self): + + return len(self._subplot_list) diff --git a/seaborn/tests/_core/test_subplots.py b/seaborn/tests/_core/test_subplots.py index 2b0a24a040..13d959939a 100644 --- a/seaborn/tests/_core/test_subplots.py +++ b/seaborn/tests/_core/test_subplots.py @@ -43,7 +43,18 @@ def test_wrapped_x_pairing_and_facetd_rows(self, long_df): Subplots({}, {}, {"x": ["x", "y", "z"], "wrap": 2}, data) -class TestGridDimensions: +class TestSubplotSpec: + + def test_single_subplot(self, long_df): + + data = PlotData(long_df, {"x": "x", "y": "y"}) + s = Subplots({}, {}, {}, data) + + assert s.n_subplots == 1 + assert s.subplot_spec["ncols"] == 1 + assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True def test_single_facet(self, long_df): @@ -55,6 +66,8 @@ def test_single_facet(self, long_df): assert s.n_subplots == n_levels assert s.subplot_spec["ncols"] == n_levels assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True def test_two_facets(self, long_df): @@ -68,6 +81,8 @@ def test_two_facets(self, long_df): assert s.n_subplots == n_cols * n_rows assert s.subplot_spec["ncols"] == n_cols assert s.subplot_spec["nrows"] == n_rows + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True def test_col_facet_wrapped(self, long_df): @@ -80,6 +95,8 @@ def test_col_facet_wrapped(self, long_df): assert s.n_subplots == n_levels assert s.subplot_spec["ncols"] == wrap assert s.subplot_spec["nrows"] == n_levels // wrap + 1 + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True def test_row_facet_wrapped(self, long_df): @@ -92,6 +109,8 @@ def test_row_facet_wrapped(self, long_df): assert s.n_subplots == n_levels assert s.subplot_spec["ncols"] == n_levels // wrap + 1 assert s.subplot_spec["nrows"] == wrap + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True def test_col_facet_wrapped_single_row(self, long_df): @@ -104,6 +123,8 @@ def test_col_facet_wrapped_single_row(self, long_df): assert s.n_subplots == n_levels assert s.subplot_spec["ncols"] == n_levels assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True def test_x_and_y_paired(self, long_df): @@ -115,6 +136,8 @@ def test_x_and_y_paired(self, long_df): assert s.n_subplots == len(x) * len(y) assert s.subplot_spec["ncols"] == len(x) assert s.subplot_spec["nrows"] == len(y) + assert s.subplot_spec["sharex"] == "col" + assert s.subplot_spec["sharey"] == "row" def test_x_paired(self, long_df): @@ -125,6 +148,8 @@ def test_x_paired(self, long_df): assert s.n_subplots == len(x) assert s.subplot_spec["ncols"] == len(x) assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] == "col" + assert s.subplot_spec["sharey"] is True def test_y_paired(self, long_df): @@ -135,6 +160,8 @@ def test_y_paired(self, long_df): assert s.n_subplots == len(y) assert s.subplot_spec["ncols"] == 1 assert s.subplot_spec["nrows"] == len(y) + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] == "row" def test_x_paired_and_wrapped(self, long_df): @@ -146,6 +173,8 @@ def test_x_paired_and_wrapped(self, long_df): assert s.n_subplots == len(x) assert s.subplot_spec["ncols"] == wrap assert s.subplot_spec["nrows"] == len(x) // wrap + 1 + assert s.subplot_spec["sharex"] is False + assert s.subplot_spec["sharey"] is True def test_y_paired_and_wrapped(self, long_df): @@ -157,3 +186,101 @@ def test_y_paired_and_wrapped(self, long_df): assert s.n_subplots == len(y) assert s.subplot_spec["ncols"] == len(y) // wrap + 1 assert s.subplot_spec["nrows"] == wrap + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is False + + def test_col_faceted_y_paired(self, long_df): + + y = ["x", "y", "z"] + key = "a" + data = PlotData(long_df, {"x": "f", "col": key}) + s = Subplots({}, {}, {"y": y}, data) + + n_levels = len(categorical_order(long_df[key])) + assert s.n_subplots == n_levels * len(y) + assert s.subplot_spec["ncols"] == n_levels + assert s.subplot_spec["nrows"] == len(y) + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] == "row" + + def test_row_faceted_x_paired(self, long_df): + + x = ["f", "s"] + key = "a" + data = PlotData(long_df, {"y": "z", "row": key}) + s = Subplots({}, {}, {"x": x}, data) + + n_levels = len(categorical_order(long_df[key])) + assert s.n_subplots == n_levels * len(x) + assert s.subplot_spec["ncols"] == len(x) + assert s.subplot_spec["nrows"] == n_levels + assert s.subplot_spec["sharex"] == "col" + assert s.subplot_spec["sharey"] is True + + def test_x_any_y_paired_non_cartesian(self, long_df): + + x = ["a", "b", "c"] + y = ["x", "y", "z"] + + data = PlotData(long_df, {}) + s = Subplots({}, {}, {"x": x, "y": y, "cartesian": False}, data) + + assert s.n_subplots == len(x) + assert s.subplot_spec["ncols"] == len(y) + assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] is False + assert s.subplot_spec["sharey"] is False + + def test_x_any_y_paired_non_cartesian_wrapped(self, long_df): + + x = ["a", "b", "c"] + y = ["x", "y", "z"] + wrap = 2 + + data = PlotData(long_df, {}) + s = Subplots({}, {}, {"x": x, "y": y, "cartesian": False, "wrap": wrap}, data) + + assert s.n_subplots == len(x) + assert s.subplot_spec["ncols"] == wrap + assert s.subplot_spec["nrows"] == len(x) // wrap + 1 + assert s.subplot_spec["sharex"] is False + assert s.subplot_spec["sharey"] is False + + +class TestSubplotElements: + + def test_single_subplot(self, long_df): + + data = PlotData(long_df, {"x": "x", "y": "y"}) + s = Subplots({}, {}, {}, data) + f = s.init_figure(False) + + assert len(s) == 1 + for i, e in enumerate(s): + for side in ["left", "right", "bottom", "top"]: + assert e[side] + for dim in ["col", "row"]: + assert e[dim] is None + for axis in "xy": + assert e[axis] == axis + assert e["ax"] == f.axes[i] + + @pytest.mark.parametrize("dim", ["col", "row"]) + def test_single_facet_dim(self, long_df, dim): + + key = "a" + data = PlotData(long_df, {"x": "x", "y": "y", dim: key}) + s = Subplots({}, {}, {}, data) + s.init_figure(False) + + levels = categorical_order(long_df[key]) + assert len(s) == len(levels) + + for i, e in enumerate(s): + assert e[dim] == levels[i] + for axis in "xy": + assert e[axis] == axis + assert e["top"] == (dim == "col" or i == 0) + assert e["bottom"] == (dim == "col" or i == len(levels) - 1) + assert e["left"] == (dim == "row" or i == 0) + assert e["right"] == (dim == "row" or i == len(levels) - 1)