Skip to content

Commit

Permalink
Add apply and pipe methods to Grid objects for fluent customization (#…
Browse files Browse the repository at this point in the history
…2928)

* Return self from tight_layout and refline

* Add apply and pipe methods to FacetGrid for fluent customization

* Move apply/pipe down to base class so JointGrid/PaiGrid get them too

* Tweak docstrings
  • Loading branch information
mwaskom authored Jul 31, 2022
1 parent 6460a21 commit 949dec3
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
4 changes: 3 additions & 1 deletion doc/whatsnew/v0.12.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ Other updates

- |Feature| It is now possible to aggregate / sort a :func:`lineplot` along the y axis using `orient="y"` (:pr:`2854`).

- |Feature| It is now possible to specify `estimator` as a string in :func:`barplot` and :func:`pointplot`, in addition to a callable (:pr:`2866`).
- |Feature| Made it easier to customize :class:`FacetGrid` / :class:`PairGrid` / :class:`JointGrid` with a fluent (method-chained) style by adding `apply`/ `pipe` methods. Additionally, fixed the `tight_layout` and `refline` methods so that they return `self` (:pr:`2926`).

- |Enhancement| Added a `width` parameter to :func:`barplot` (:pr:`2860`).

- |Enhancement| It is now possible to specify `estimator` as a string in :func:`barplot` and :func:`pointplot`, in addition to a callable (:pr:`2866`).

- |Enhancement| Error bars in :func:`regplot` now inherit the alpha value of the points they correspond to (:pr:`2540`).

- |Enhancement| When using :func:`pairplot` with `corner=True` and `diag_kind=None`, the top left y axis label is no longer hidden (:pr:`2850`).
Expand Down
32 changes: 32 additions & 0 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,35 @@ def figure(self):
"""Access the :class:`matplotlib.figure.Figure` object underlying the grid."""
return self._figure

def apply(self, func, *args, **kwargs):
"""
Pass the grid to a user-supplied function and return self.
The `func` must accept an object of this type for its first
positional argument. Additional arguments are passed through.
The return value of `func` is ignored; this method returns self.
See the `pipe` method if you want the return value.
Added in v0.12.0.
"""
func(self, *args, **kwargs)
return self

def pipe(self, func, *args, **kwargs):
"""
Pass the grid to a user-supplied function and return its value.
The `func` must accept an object of this type for its first
positional argument. Additional arguments are passed through.
The return value of `func` becomes the return value of this method.
See the `apply` method if you want to return self instead.
Added in v0.12.0.
"""
return func(self, *args, **kwargs)

def savefig(self, *args, **kwargs):
"""
Save an image of the plot.
Expand Down Expand Up @@ -86,6 +115,7 @@ def tight_layout(self, *args, **kwargs):
if self._tight_layout_pad is not None:
kwargs.setdefault("pad", self._tight_layout_pad)
self._figure.tight_layout(*args, **kwargs)
return self

def add_legend(self, legend_data=None, title=None, label_order=None,
adjust_subtitles=False, **kwargs):
Expand Down Expand Up @@ -1007,6 +1037,8 @@ def refline(self, *, x=None, y=None, color='.5', linestyle='--', **line_kws):
if y is not None:
self.map(plt.axhline, y=y, **line_kws)

return self

# ------ Properties that are part of the public API and documented by Sphinx

@property
Expand Down
23 changes: 23 additions & 0 deletions tests/test_axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,29 @@ def test_refline(self):
assert g.axes[0, 0].lines[-1].get_color() == color
assert g.axes[0, 0].lines[-1].get_linestyle() == linestyle

def test_apply(self, long_df):

def f(grid, color):
grid.figure.set_facecolor(color)

color = (.1, .6, .3, .9)
g = ag.FacetGrid(long_df)
res = g.apply(f, color)
assert res is g
assert g.figure.get_facecolor() == color

def test_pipe(self, long_df):

def f(grid, color):
grid.figure.set_facecolor(color)
return color

color = (.1, .6, .3, .9)
g = ag.FacetGrid(long_df)
res = g.pipe(f, color)
assert res == color
assert g.figure.get_facecolor() == color


class TestPairGrid:

Expand Down

0 comments on commit 949dec3

Please sign in to comment.