diff --git a/doc/releases/v0.11.0.txt b/doc/releases/v0.11.0.txt index c24fbb9392..5a0739bc96 100644 --- a/doc/releases/v0.11.0.txt +++ b/doc/releases/v0.11.0.txt @@ -21,3 +21,5 @@ v0.11.0 (Unreleased) - Added the ``axes_dict`` attribute to :class:`FacetGrid` for named access to the component axes. GH2046 - Made :meth:`FacetGrid.set_axis_labels` clear labels from "interior" axes. GH2046 + +- Improved :meth:`FacetGrid.set_titles` with `margin titles=True`, such that texts representing the original row titles are removed before adding new ones. GH2083 diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index a9c3f0316f..d9e34cdec0 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -385,6 +385,7 @@ def __init__( self._col_var = col self._margin_titles = margin_titles + self._margin_titles_texts = [] self._col_wrap = col_wrap self._hue_var = hue_var self._colors = colors @@ -1005,16 +1006,24 @@ def set_titles(self, template=None, row_template=None, col_template=None, template = utils.to_utf8(template) if self._margin_titles: + + # Remove any existing title texts + for text in self._margin_titles_texts: + text.remove() + self._margin_titles_texts = [] + if self.row_names is not None: # Draw the row titles on the right edge of the grid for i, row_name in enumerate(self.row_names): ax = self.axes[i, -1] args.update(dict(row_name=row_name)) title = row_template.format(**args) - bgcolor = self.fig.get_facecolor() - ax.annotate(title, xy=(1.02, .5), xycoords="axes fraction", - rotation=270, ha="left", va="center", - backgroundcolor=bgcolor, **kwargs) + text = ax.annotate( + title, xy=(1.02, .5), xycoords="axes fraction", + rotation=270, ha="left", va="center", + **kwargs + ) + self._margin_titles_texts.append(text) if self.col_names is not None: # Draw the column titles as normal titles diff --git a/seaborn/tests/test_axisgrid.py b/seaborn/tests/test_axisgrid.py index 95a72fea2a..9af96e77e6 100644 --- a/seaborn/tests/test_axisgrid.py +++ b/seaborn/tests/test_axisgrid.py @@ -473,19 +473,23 @@ def test_set_titles_margin_titles(self): g.map(plt.plot, "x", "y") # Test the default titles - nt.assert_equal(g.axes[0, 0].get_title(), "b = m") - nt.assert_equal(g.axes[0, 1].get_title(), "b = n") - nt.assert_equal(g.axes[1, 0].get_title(), "") + assert g.axes[0, 0].get_title() == "b = m" + assert g.axes[0, 1].get_title() == "b = n" + assert g.axes[1, 0].get_title() == "" # Test the row "titles" - nt.assert_equal(g.axes[0, 1].texts[0].get_text(), "a = a") - nt.assert_equal(g.axes[1, 1].texts[0].get_text(), "a = b") - - # Test a provided title - g.set_titles(col_template="{col_var} == {col_name}") - nt.assert_equal(g.axes[0, 0].get_title(), "b == m") - nt.assert_equal(g.axes[0, 1].get_title(), "b == n") - nt.assert_equal(g.axes[1, 0].get_title(), "") + assert g.axes[0, 1].texts[0].get_text() == "a = a" + assert g.axes[1, 1].texts[0].get_text() == "a = b" + assert g.axes[0, 1].texts[0] is g._margin_titles_texts[0] + + # Test provided titles + g.set_titles(col_template="{col_name}", row_template="{row_name}") + assert g.axes[0, 0].get_title() == "m" + assert g.axes[0, 1].get_title() == "n" + assert g.axes[1, 0].get_title() == "" + + assert len(g.axes[1, 1].texts) == 1 + assert g.axes[1, 1].texts[0].get_text() == "b" def test_set_ticklabels(self):