Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add apply and pipe methods to Grid objects for fluent customization #2928

Merged
merged 4 commits into from
Jul 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/whatsnew/v0.12.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ Other updates

- |Feature| It is now possible to aggregate / sort a :func:`lineplot` along the y axis using `orient="y"` (:pr:`2854`).

- |Feature| It is now possible to specify `estimator` as a string in :func:`barplot` and :func:`pointplot`, in addition to a callable (:pr:`2866`).
- |Feature| Made it easier to customize :class:`FacetGrid` / :class:`PairGrid` / :class:`JointGrid` with a fluent (method-chained) style by adding `apply`/ `pipe` methods. Additionally, fixed the `tight_layout` and `refline` methods so that they return `self` (:pr:`2926`).

- |Enhancement| Added a `width` parameter to :func:`barplot` (:pr:`2860`).

- |Enhancement| It is now possible to specify `estimator` as a string in :func:`barplot` and :func:`pointplot`, in addition to a callable (:pr:`2866`).

- |Enhancement| Error bars in :func:`regplot` now inherit the alpha value of the points they correspond to (:pr:`2540`).

- |Enhancement| When using :func:`pairplot` with `corner=True` and `diag_kind=None`, the top left y axis label is no longer hidden (:pr:`2850`).
Expand Down
32 changes: 32 additions & 0 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,35 @@ def figure(self):
"""Access the :class:`matplotlib.figure.Figure` object underlying the grid."""
return self._figure

def apply(self, func, *args, **kwargs):
"""
Pass the grid to a user-supplied function and return self.

The `func` must accept an object of this type for its first
positional argument. Additional arguments are passed through.
The return value of `func` is ignored; this method returns self.
See the `pipe` method if you want the return value.

Added in v0.12.0.

"""
func(self, *args, **kwargs)
return self

def pipe(self, func, *args, **kwargs):
"""
Pass the grid to a user-supplied function and return its value.

The `func` must accept an object of this type for its first
positional argument. Additional arguments are passed through.
The return value of `func` becomes the return value of this method.
See the `apply` method if you want to return self instead.

Added in v0.12.0.

"""
return func(self, *args, **kwargs)

def savefig(self, *args, **kwargs):
"""
Save an image of the plot.
Expand Down Expand Up @@ -84,6 +113,7 @@ def tight_layout(self, *args, **kwargs):
if self._tight_layout_pad is not None:
kwargs.setdefault("pad", self._tight_layout_pad)
self._figure.tight_layout(*args, **kwargs)
return self

def add_legend(self, legend_data=None, title=None, label_order=None,
adjust_subtitles=False, **kwargs):
Expand Down Expand Up @@ -1006,6 +1036,8 @@ def refline(self, *, x=None, y=None, color='.5', linestyle='--', **line_kws):
if y is not None:
self.map(plt.axhline, y=y, **line_kws)

return self

# ------ Properties that are part of the public API and documented by Sphinx

@property
Expand Down
23 changes: 23 additions & 0 deletions tests/test_axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,29 @@ def test_refline(self):
assert g.axes[0, 0].lines[-1].get_color() == color
assert g.axes[0, 0].lines[-1].get_linestyle() == linestyle

def test_apply(self, long_df):

def f(grid, color):
grid.figure.set_facecolor(color)

color = (.1, .6, .3, .9)
g = ag.FacetGrid(long_df)
res = g.apply(f, color)
assert res is g
assert g.figure.get_facecolor() == color

def test_pipe(self, long_df):

def f(grid, color):
grid.figure.set_facecolor(color)
return color

color = (.1, .6, .3, .9)
g = ag.FacetGrid(long_df)
res = g.pipe(f, color)
assert res == color
assert g.figure.get_facecolor() == color


class TestPairGrid:

Expand Down