From 3db4d5ac7ff3381bdf56b9727a43bb94ab6607ce Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Thu, 14 May 2020 14:27:09 -0400 Subject: [PATCH] Add tight_layout method on FacetGrid and PairGrid (#2073) * Add Grid.tight_layout for legend-aware automatic layout * Use Grid.tight_layout internally --- doc/releases/v0.11.0.txt | 2 ++ doc/requirements.txt | 1 + seaborn/axisgrid.py | 21 ++++++++++++++++++--- seaborn/tests/test_axisgrid.py | 12 ++++++++++++ 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/doc/releases/v0.11.0.txt b/doc/releases/v0.11.0.txt index ea7a6780c3..3213da1313 100644 --- a/doc/releases/v0.11.0.txt +++ b/doc/releases/v0.11.0.txt @@ -6,6 +6,8 @@ v0.11.0 (Unreleased) - Standardized the parameter names for the oldest functions (:func:`distplot`, :func:`kdeplot`, and :func:`rugplot`) to be `x` and `y`, as in other functions. Using the old names will warn now and break in the future. +- Added a ``tight_layout`` method to :class:`FacetGrid` and :class:`PairGrid`, which runs the :func:`matplotlib.pyplot.tight_layout` algorithm without interference from the external legend. + - Added an explicit warning in :func:`swarmplot` when more than 2% of the points are overlap in the "gutters" of the swarm. - Added the ``axes_dict`` attribute to :class:`FacetGrid` for named access to the component axes. diff --git a/doc/requirements.txt b/doc/requirements.txt index 1b937fb21a..e61e56d8e8 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,4 +1,5 @@ sphinx==2.3.1 sphinx_bootstrap_theme==0.6.5 # Later versions mess up the css somehow +numpydoc nbconvert ipykernel diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index 07763effc0..a9c3f0316f 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -23,6 +23,10 @@ class Grid(object): _margin_titles = False _legend_out = True + def __init__(self): + + self._tight_layout_rect = [0, 0, 1, 1] + def set(self, **kwargs): """Set attributes on each subplot Axes.""" for ax in self.axes.flat: @@ -35,6 +39,12 @@ def savefig(self, *args, **kwargs): kwargs.setdefault("bbox_inches", "tight") self.fig.savefig(*args, **kwargs) + def tight_layout(self, *args, **kwargs): + """Call fig.tight_layout within rect that exclude the legend.""" + kwargs = kwargs.copy() + kwargs.setdefault("rect", self._tight_layout_rect) + self.fig.tight_layout(*args, **kwargs) + def add_legend(self, legend_data=None, title=None, label_order=None, **kwargs): """Draw a legend, maybe placing it outside axes and resizing the figure. @@ -125,6 +135,7 @@ def add_legend(self, legend_data=None, title=None, label_order=None, # Place the subplot axes to give space for the legend self.fig.subplots_adjust(right=right) + self._tight_layout_rect[2] = right else: # Draw a legend in the first axis @@ -240,6 +251,8 @@ def __init__( gridspec_kws=None, size=None ): + super(FacetGrid, self).__init__() + # Handle deprecations if size is not None: height = size @@ -385,7 +398,7 @@ def __init__( # --- Make the axes look good - fig.tight_layout() + self.tight_layout() if despine: self.despine() @@ -871,7 +884,7 @@ def _finalize_grid(self, axlabels): """Finalize the annotations and layout.""" self.set_axis_labels(*axlabels) self.set_titles() - self.fig.tight_layout() + self.tight_layout() def facet_axis(self, row_i, col_j): """Make the axis identified by these indices active and return it.""" @@ -1285,6 +1298,8 @@ def __init__( """ + super(PairGrid, self).__init__() + # Handle deprecations if size is not None: height = size @@ -1373,7 +1388,7 @@ def __init__( if despine: self._despine = True utils.despine(fig=fig) - fig.tight_layout(pad=layout_pad) + self.tight_layout(pad=layout_pad) def map(self, func, **kwargs): """Plot with the same function in every subplot. diff --git a/seaborn/tests/test_axisgrid.py b/seaborn/tests/test_axisgrid.py index 19fa2b6ed4..95a72fea2a 100644 --- a/seaborn/tests/test_axisgrid.py +++ b/seaborn/tests/test_axisgrid.py @@ -306,6 +306,18 @@ def test_legendout_with_colwrap(self): g.map(plt.plot, "x", "y", linewidth=3) g.add_legend() + def test_legend_tight_layout(self): + + g = ag.FacetGrid(self.df, hue='b') + g.map(plt.plot, "x", "y", linewidth=3) + g.add_legend() + g.tight_layout() + + axes_right_edge = g.ax.get_window_extent().xmax + legend_left_edge = g._legend.get_window_extent().xmin + + assert axes_right_edge < legend_left_edge + def test_subplot_kws(self): g = ag.FacetGrid(self.df, despine=False,