diff --git a/doc/nextgen/api.rst b/doc/nextgen/api.rst index 4ede326ef1..41522ebc99 100644 --- a/doc/nextgen/api.rst +++ b/doc/nextgen/api.rst @@ -18,13 +18,14 @@ Plot interface Plot Plot.add - Plot.scale Plot.facet Plot.pair - Plot.configure + Plot.layout Plot.on Plot.plot Plot.save + Plot.scale + Plot.share Plot.show Marks diff --git a/doc/nextgen/demo.ipynb b/doc/nextgen/demo.ipynb index 08fde40e61..2924c9debe 100644 --- a/doc/nextgen/demo.ipynb +++ b/doc/nextgen/demo.ipynb @@ -717,7 +717,7 @@ " .facet(col=\"day\")\n", " .add(so.Dots(color=\".75\"), col=None)\n", " .add(so.Dots(), color=\"day\")\n", - " .configure(figsize=(7, 3))\n", + " .layout(size=(7, 3))\n", ")" ] }, @@ -808,7 +808,7 @@ "(\n", " so.Plot(tips)\n", " .pair(x=tips.columns, wrap=3)\n", - " .configure(sharey=False)\n", + " .share(y=False)\n", " .add(so.Bar(), so.Hist())\n", ")" ] diff --git a/doc/nextgen/index.ipynb b/doc/nextgen/index.ipynb index 6b95991dba..368ecffda8 100644 --- a/doc/nextgen/index.ipynb +++ b/doc/nextgen/index.ipynb @@ -20,7 +20,6 @@ "outputs": [], "source": [ "import seaborn as sns\n", - "sns.set_theme()\n", "tips = sns.load_dataset(\"tips\")\n", "\n", "import seaborn.objects as so\n", @@ -31,7 +30,7 @@ " )\n", " .facet(\"time\")\n", " .add(so.Dots())\n", - " .configure(figsize=(7, 4))\n", + " .layout(size=(7, 4))\n", ")" ] }, diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 1a96d3c6cb..0e1f924581 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -150,6 +150,7 @@ class Plot: _layers: list[Layer] _scales: dict[str, Scale] + _shares: dict[str, bool | str] _limits: dict[str, tuple[Any, Any]] _labels: dict[str, str | Callable[[str], str]] _theme: dict[str, Any] @@ -159,6 +160,7 @@ class Plot: _figure_spec: dict[str, Any] _subplot_spec: dict[str, Any] + _layout_spec: dict[str, Any] def __init__( self, @@ -180,6 +182,7 @@ def __init__( self._layers = [] self._scales = {} + self._shares = {} self._limits = {} self._labels = {} self._theme = {} @@ -189,6 +192,7 @@ def __init__( self._figure_spec = {} self._subplot_spec = {} + self._layout_spec = {} self._target = None @@ -250,6 +254,7 @@ def _clone(self) -> Plot: new._layers.extend(self._layers) new._scales.update(self._scales) + new._shares.update(self._shares) new._limits.update(self._limits) new._labels.update(self._labels) new._theme.update(self._theme) @@ -259,6 +264,7 @@ def _clone(self) -> Plot: new._figure_spec.update(self._figure_spec) new._subplot_spec.update(self._subplot_spec) + new._layout_spec.update(self._layout_spec) new._target = self._target @@ -272,7 +278,7 @@ def _theme_with_defaults(self) -> dict[str, Any]: "xaxis", "xtick", "yaxis", "ytick", ] base = { - k: v for k, v in mpl.rcParamsDefault.items() + k: mpl.rcParamsDefault[k] for k in mpl.rcParams if any(k.startswith(p) for p in style_groups) } theme = { @@ -584,6 +590,21 @@ def scale(self, **scales: Scale) -> Plot: new._scales.update(scales) return new + def share(self, **shares: bool | str) -> Plot: + """ + Control sharing of axis limits and ticks across subplots. + + Keywords correspond to variables defined in the plot, and values can be + boolean (to share across all subplots), or one of "row" or "col" (to share + more selectively across one dimension of a grid). + + Behavior for non-coordinate variables is currently undefined. + + """ + new = self._clone() + new._shares.update(shares) + return new + def limit(self, **limits: tuple[Any, Any]) -> Plot: """ Control the range of visible data. @@ -624,23 +645,22 @@ def label(self, *, title=None, **variables: str | Callable[[str], str]) -> Plot: new._labels.update(variables) return new - def configure( + def layout( self, - figsize: tuple[float, float] | None = None, - sharex: bool | str | None = None, - sharey: bool | str | None = None, + *, + size: tuple[float, float] | None = None, + algo: str | None = "tight", # TODO document ) -> Plot: """ Control the figure size and layout. Parameters ---------- - figsize: (width, height) - Size of the resulting figure, in inches. - sharex, sharey : bool, "row", or "col" - Whether axis limits should be shared across subplots. Boolean values apply - across the entire grid, whereas `"row"` or `"col"` have a smaller scope. - Shared axes will have tick labels disabled. + size : (width, height) + Size of the resulting figure, in inches. Size is inclusive of legend when + using pyplot, but not otherwise. + algo : {{"tight", "constrained", None}} + Name of algorithm for automatically adjusting the layout to remove overlap. """ # TODO add an "auto" mode for figsize that roughly scales with the rcParams @@ -650,12 +670,8 @@ def configure( new = self._clone() - new._figure_spec["figsize"] = figsize - - if sharex is not None: - new._subplot_spec["sharex"] = sharex - if sharey is not None: - new._subplot_spec["sharey"] = sharey + new._figure_spec["figsize"] = size + new._layout_spec["algo"] = algo return new @@ -881,6 +897,10 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: facet_spec = p._facet_spec.copy() pair_spec = p._pair_spec.copy() + for axis in "xy": + if axis in p._shares: + subplot_spec[f"share{axis}"] = p._shares[axis] + for dim in ["col", "row"]: if dim in common.frame and dim not in facet_spec["structure"]: order = categorical_order(common.frame[dim]) @@ -915,7 +935,7 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: # ~~ Decoration visibility - # TODO there should be some override (in Plot.configure?) so that + # TODO there should be some override (in Plot.layout?) so that # tick labels can be shown on interior shared axes axis_obj = getattr(ax, f"{axis}axis") visible_side = {"x": "bottom", "y": "left"}.get(axis) @@ -935,10 +955,7 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: for t in getattr(axis_obj, f"get_{group}ticklabels")(): t.set_visible(show_tick_labels) - # TODO title template should be configurable - # ---- Also we want right-side titles for row facets in most cases? - # ---- Or wrapped? That can get annoying too. - # TODO should configure() accept a title= kwarg (for single subplot plots)? + # TODO we want right-side titles for row facets in most cases? # Let's have what we currently call "margin titles" but properly using the # ax.set_title interface (see my gist) title_parts = [] @@ -1508,6 +1525,9 @@ def _make_legend(self, p: Plot) -> None: else: merged_contents[key] = artists.copy(), labels + # TODO explain + loc = "center right" if self._pyplot else "center left" + base_legend = None for (name, _), (handles, labels) in merged_contents.items(): @@ -1516,7 +1536,7 @@ def _make_legend(self, p: Plot) -> None: handles, labels, title=name, - loc="center left", + loc=loc, bbox_to_anchor=(.98, .55), ) @@ -1550,9 +1570,11 @@ def _finalize_figure(self, p: Plot) -> None: hi = cast(float, hi) + 0.5 ax.set(**{f"{axis}lim": (lo, hi)}) - # TODO this should be configurable - if not self._figure.get_constrained_layout(): + layout_algo = p._layout_spec.get("algo", "tight") + if layout_algo == "tight": self._figure.set_tight_layout(True) + elif layout_algo == "constrained": + self._figure.set_constrained_layout(True) @contextmanager diff --git a/tests/_core/test_plot.py b/tests/_core/test_plot.py index 14c40672fa..b83ffc6bc1 100644 --- a/tests/_core/test_plot.py +++ b/tests/_core/test_plot.py @@ -503,7 +503,7 @@ def test_facet_categories_unshared(self): p = ( Plot(x=["a", "b", "a", "c"]) .facet(col=["x", "x", "y", "y"]) - .configure(sharex=False) + .share(x=False) .add(m) .plot() ) @@ -527,7 +527,7 @@ def test_facet_categories_single_dim_shared(self): Plot(df, x="x") .facet(row="row", col="col") .add(m) - .configure(sharex="row") + .share(x="row") .plot() ) @@ -562,7 +562,7 @@ def test_pair_categories_shared(self): data = [("a", "a"), ("b", "c")] df = pd.DataFrame(data, columns=["x1", "x2"]).assign(y=1) m = MockMark() - p = Plot(df, y="y").pair(x=["x1", "x2"]).add(m).configure(sharex=True).plot() + p = Plot(df, y="y").pair(x=["x1", "x2"]).add(m).share(x=True).plot() for ax in p._figure.axes: assert ax.get_xticks() == [0, 1, 2] @@ -994,6 +994,12 @@ def test_save(self): tag = xml.etree.ElementTree.fromstring(buf.getvalue()).tag assert tag == "{http://www.w3.org/2000/svg}svg" + def test_layout_size(self): + + size = (4, 2) + p = Plot().layout(size=size).plot() + assert tuple(p._figure.get_size_inches()) == size + def test_on_axes(self): ax = mpl.figure.Figure().subplots() @@ -1239,11 +1245,27 @@ def test_2d_with_order(self, long_df, reorder): p = Plot(long_df).facet(**variables, order=order) self.check_facet_results_2d(p, long_df, variables, order) - def test_figsize(self): + @pytest.mark.parametrize("algo", ["tight", "constrained"]) + def test_layout_algo(self, algo): + + if algo == "constrained" and Version(mpl.__version__) < Version("3.3.0"): + pytest.skip("constrained_layout requires matplotlib>=3.3") + + p = Plot().facet(["a", "b"]).limit(x=(.1, .9)) + + p1 = p.layout(algo=algo).plot() + p2 = p.layout(algo=None).plot() + + # Force a draw (we probably need a method for this) + p1.save(io.BytesIO()) + p2.save(io.BytesIO()) + + bb11, bb12 = [ax.get_position() for ax in p1._figure.axes] + bb21, bb22 = [ax.get_position() for ax in p2._figure.axes] - figsize = (4, 2) - p = Plot().configure(figsize=figsize).plot() - assert tuple(p._figure.get_size_inches()) == figsize + sep1 = bb12.corners()[0, 0] - bb11.corners()[2, 0] + sep2 = bb22.corners()[0, 0] - bb21.corners()[2, 0] + assert sep1 < sep2 def test_axis_sharing(self, long_df): @@ -1257,13 +1279,13 @@ def test_axis_sharing(self, long_df): shareset = getattr(root, f"get_shared_{axis}_axes")() assert all(shareset.joined(root, ax) for ax in other) - p2 = p.configure(sharex=False, sharey=False).plot() + p2 = p.share(x=False, y=False).plot() root, *other = p2._figure.axes for axis in "xy": shareset = getattr(root, f"get_shared_{axis}_axes")() assert not any(shareset.joined(root, ax) for ax in other) - p3 = p.configure(sharex="col", sharey="row").plot() + p3 = p.share(x="col", y="row").plot() shape = ( len(categorical_order(long_df[variables["row"]])), len(categorical_order(long_df[variables["col"]])), @@ -1448,7 +1470,7 @@ def test_axis_sharing(self, long_df): y_shareset = getattr(root, "get_shared_y_axes")() assert not any(y_shareset.joined(root, ax) for ax in other) - p2 = p.configure(sharex=False, sharey=False).plot() + p2 = p.share(x=False, y=False).plot() root, *other = p2._figure.axes for axis in "xy": shareset = getattr(root, f"get_shared_{axis}_axes")() @@ -1712,7 +1734,7 @@ def test_2d_unshared(self): p = ( Plot() .facet(col=["a", "b"], row=["x", "y"]) - .configure(sharex=False, sharey=False) + .share(x=False, y=False) .plot() ) subplots = list(p._subplots)