From 949dec3666ab12a366d2fc05ef18d6e90625b5fa Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 31 Jul 2022 17:15:51 -0400 Subject: [PATCH] Add apply and pipe methods to Grid objects for fluent customization (#2928) * Return self from tight_layout and refline * Add apply and pipe methods to FacetGrid for fluent customization * Move apply/pipe down to base class so JointGrid/PaiGrid get them too * Tweak docstrings --- doc/whatsnew/v0.12.0.rst | 4 +++- seaborn/axisgrid.py | 32 ++++++++++++++++++++++++++++++++ tests/test_axisgrid.py | 23 +++++++++++++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/doc/whatsnew/v0.12.0.rst b/doc/whatsnew/v0.12.0.rst index f26825f5ff..a54c1ff17c 100644 --- a/doc/whatsnew/v0.12.0.rst +++ b/doc/whatsnew/v0.12.0.rst @@ -60,10 +60,12 @@ Other updates - |Feature| It is now possible to aggregate / sort a :func:`lineplot` along the y axis using `orient="y"` (:pr:`2854`). -- |Feature| It is now possible to specify `estimator` as a string in :func:`barplot` and :func:`pointplot`, in addition to a callable (:pr:`2866`). +- |Feature| Made it easier to customize :class:`FacetGrid` / :class:`PairGrid` / :class:`JointGrid` with a fluent (method-chained) style by adding `apply`/ `pipe` methods. Additionally, fixed the `tight_layout` and `refline` methods so that they return `self` (:pr:`2926`). - |Enhancement| Added a `width` parameter to :func:`barplot` (:pr:`2860`). +- |Enhancement| It is now possible to specify `estimator` as a string in :func:`barplot` and :func:`pointplot`, in addition to a callable (:pr:`2866`). + - |Enhancement| Error bars in :func:`regplot` now inherit the alpha value of the points they correspond to (:pr:`2540`). - |Enhancement| When using :func:`pairplot` with `corner=True` and `diag_kind=None`, the top left y axis label is no longer hidden (:pr:`2850`). diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index 39fe48145e..4a7dafbd68 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -52,6 +52,35 @@ def figure(self): """Access the :class:`matplotlib.figure.Figure` object underlying the grid.""" return self._figure + def apply(self, func, *args, **kwargs): + """ + Pass the grid to a user-supplied function and return self. + + The `func` must accept an object of this type for its first + positional argument. Additional arguments are passed through. + The return value of `func` is ignored; this method returns self. + See the `pipe` method if you want the return value. + + Added in v0.12.0. + + """ + func(self, *args, **kwargs) + return self + + def pipe(self, func, *args, **kwargs): + """ + Pass the grid to a user-supplied function and return its value. + + The `func` must accept an object of this type for its first + positional argument. Additional arguments are passed through. + The return value of `func` becomes the return value of this method. + See the `apply` method if you want to return self instead. + + Added in v0.12.0. + + """ + return func(self, *args, **kwargs) + def savefig(self, *args, **kwargs): """ Save an image of the plot. @@ -86,6 +115,7 @@ def tight_layout(self, *args, **kwargs): if self._tight_layout_pad is not None: kwargs.setdefault("pad", self._tight_layout_pad) self._figure.tight_layout(*args, **kwargs) + return self def add_legend(self, legend_data=None, title=None, label_order=None, adjust_subtitles=False, **kwargs): @@ -1007,6 +1037,8 @@ def refline(self, *, x=None, y=None, color='.5', linestyle='--', **line_kws): if y is not None: self.map(plt.axhline, y=y, **line_kws) + return self + # ------ Properties that are part of the public API and documented by Sphinx @property diff --git a/tests/test_axisgrid.py b/tests/test_axisgrid.py index 4a33da94a0..3ade4ebac3 100644 --- a/tests/test_axisgrid.py +++ b/tests/test_axisgrid.py @@ -673,6 +673,29 @@ def test_refline(self): assert g.axes[0, 0].lines[-1].get_color() == color assert g.axes[0, 0].lines[-1].get_linestyle() == linestyle + def test_apply(self, long_df): + + def f(grid, color): + grid.figure.set_facecolor(color) + + color = (.1, .6, .3, .9) + g = ag.FacetGrid(long_df) + res = g.apply(f, color) + assert res is g + assert g.figure.get_facecolor() == color + + def test_pipe(self, long_df): + + def f(grid, color): + grid.figure.set_facecolor(color) + return color + + color = (.1, .6, .3, .9) + g = ag.FacetGrid(long_df) + res = g.pipe(f, color) + assert res == color + assert g.figure.get_facecolor() == color + class TestPairGrid: