Skip to content

Commit

Permalink
Add label parameter to pointplot
Browse files Browse the repository at this point in the history
Unbreaks FacetGrid + pointplot (fixes #3004) and is generally useful.
  • Loading branch information
mwaskom committed Sep 12, 2022
1 parent 1e6739f commit 46c3390
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
11 changes: 7 additions & 4 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,7 +1596,7 @@ class _PointPlotter(_CategoricalStatPlotter):
def __init__(self, x, y, hue, data, order, hue_order,
estimator, errorbar, n_boot, units, seed,
markers, linestyles, dodge, join, scale,
orient, color, palette, errwidth=None, capsize=None):
orient, color, palette, errwidth, capsize, label):
"""Initialize the plotter."""
self.establish_variables(x, y, hue, data, orient,
order, hue_order, units)
Expand Down Expand Up @@ -1631,6 +1631,7 @@ def __init__(self, x, y, hue, data, order, hue_order,
self.scale = scale
self.errwidth = errwidth
self.capsize = capsize
self.label = label

@property
def hue_offsets(self):
Expand Down Expand Up @@ -1678,7 +1679,7 @@ def draw_points(self, ax):
x, y = pointpos, self.statistic
ax.scatter(x, y,
linewidth=mew, marker=marker, s=markersize,
facecolor=colors, edgecolor=colors)
facecolor=colors, edgecolor=colors, label=self.label)

else:

Expand Down Expand Up @@ -2829,15 +2830,15 @@ def pointplot(
estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None,
markers="o", linestyles="-", dodge=False, join=True, scale=1,
orient=None, color=None, palette=None, errwidth=None, ci="deprecated",
capsize=None, ax=None,
capsize=None, label=None, ax=None,
):

errorbar = utils._deprecate_ci(errorbar, ci)

plotter = _PointPlotter(x, y, hue, data, order, hue_order,
estimator, errorbar, n_boot, units, seed,
markers, linestyles, dodge, join, scale,
orient, color, palette, errwidth, capsize)
orient, color, palette, errwidth, capsize, label)

if ax is None:
ax = plt.gca()
Expand Down Expand Up @@ -2893,6 +2894,8 @@ def pointplot(
{palette}
{errwidth}
{capsize}
label : string, optional
Label to represent the plot in a legend, only relevant when not using `hue`.
{ax_in}
Returns
Expand Down
12 changes: 12 additions & 0 deletions tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@

from seaborn.external.version import Version
from seaborn._oldcore import categorical_order
from seaborn.axisgrid import FacetGrid
from seaborn.categorical import (
_CategoricalPlotterNew,
Beeswarm,
catplot,
pointplot,
stripplot,
swarmplot,
)
Expand Down Expand Up @@ -2745,6 +2747,16 @@ def test_errorbar(self, long_df):
expected = mean - 2 * sd, mean + 2 * sd
assert_array_equal(line.get_ydata(), expected)

def test_on_facetgrid(self, long_df):

g = FacetGrid(long_df, hue="a")
g.map(pointplot, "a", "y")
g.add_legend()

order = categorical_order(long_df["a"])
legend_texts = [t.get_text() for t in g.legend.texts]
assert legend_texts == order


class TestCountPlot(CategoricalFixture):

Expand Down

0 comments on commit 46c3390

Please sign in to comment.