Skip to content

Commit

Permalink
Add width parameter to barplot (#2860)
Browse files Browse the repository at this point in the history
* Add width parameter to barplot

* Add width to countplot as well

* Update barplot test fixture
  • Loading branch information
mwaskom authored Jun 16, 2022
1 parent a48dc8f commit 1e24db4
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 19 deletions.
2 changes: 2 additions & 0 deletions doc/releases/v0.12.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ Other updates

- |Enhancement| Example datasets are now stored in an OS-specific cache location (as determined by `appdirs`) rather than in the user's home directory. Users should feel free to remove `~/seaborn-data` if desired (:pr:`2773`).

- |Enhancement| Added a `width` parameter to :func:`barplot` (:pr:`2860`).

- |Enhancement| Error bars in :func:`regplot` now inherit the alpha value of the points they correspond to (:pr:`2540`).

- |Enhancement| When using :func:`pairplot` with `corner=True` and `diag_kind=None`, the top left y axis label is no longer hidden (:pr:2850`).
Expand Down
14 changes: 8 additions & 6 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,15 +1576,16 @@ class _BarPlotter(_CategoricalStatPlotter):

def __init__(self, x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units, seed,
orient, color, palette, saturation, errcolor,
errwidth, capsize, dodge):
orient, color, palette, saturation, width,
errcolor, errwidth, capsize, dodge):
"""Initialize the plotter."""
self.establish_variables(x, y, hue, data, orient,
order, hue_order, units)
self.establish_colors(color, palette, saturation)
self.estimate_statistic(estimator, ci, n_boot, seed)

self.dodge = dodge
self.width = width

self.errcolor = errcolor
self.errwidth = errwidth
Expand Down Expand Up @@ -2743,7 +2744,7 @@ def swarmplot(
def barplot(
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
estimator=np.mean, ci=95, n_boot=1000, units=None, seed=None,
orient=None, color=None, palette=None, saturation=.75,
orient=None, color=None, palette=None, saturation=.75, width=.8,
errcolor=".26", errwidth=None, capsize=None, dodge=True,
ax=None,
**kwargs,
Expand All @@ -2752,7 +2753,7 @@ def barplot(
plotter = _BarPlotter(x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units, seed,
orient, color, palette, saturation,
errcolor, errwidth, capsize, dodge)
width, errcolor, errwidth, capsize, dodge)

if ax is None:
ax = plt.gca()
Expand Down Expand Up @@ -2793,6 +2794,7 @@ def barplot(
{color}
{palette}
{saturation}
{width}
errcolor : matplotlib color
Color used for the error bar lines.
{errwidth}
Expand Down Expand Up @@ -2910,7 +2912,7 @@ def pointplot(

def countplot(
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
orient=None, color=None, palette=None, saturation=.75,
orient=None, color=None, palette=None, saturation=.75, width=.8,
dodge=True, ax=None, **kwargs
):

Expand All @@ -2936,7 +2938,7 @@ def countplot(
x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units, seed,
orient, color, palette, saturation,
errcolor, errwidth, capsize, dodge
width, errcolor, errwidth, capsize, dodge
)

plotter.value_label = "count"
Expand Down
27 changes: 14 additions & 13 deletions tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2203,26 +2203,27 @@ class TestBarPlotter(CategoricalFixture):
estimator=np.mean, ci=95, n_boot=100, units=None, seed=None,
order=None, hue_order=None,
orient=None, color=None, palette=None,
saturation=.75, errcolor=".26", errwidth=None,
saturation=.75, width=0.8,
errcolor=".26", errwidth=None,
capsize=None, dodge=True
)

def test_nested_width(self):

kws = self.default_kws.copy()
ax = cat.barplot(data=self.df, x="g", y="y", hue="h")
for bar in ax.patches:
assert bar.get_width() == pytest.approx(.8 / 2)
ax.clear()

p = cat._BarPlotter(**kws)
p.establish_variables("g", "y", hue="h", data=self.df)
assert p.nested_width == .8 / 2

p = cat._BarPlotter(**kws)
p.establish_variables("h", "y", "g", data=self.df)
assert p.nested_width == .8 / 3
ax = cat.barplot(data=self.df, x="g", y="y", hue="g", width=.5)
for bar in ax.patches:
assert bar.get_width() == pytest.approx(.5 / 3)
ax.clear()

kws["dodge"] = False
p = cat._BarPlotter(**kws)
p.establish_variables("h", "y", "g", data=self.df)
assert p.nested_width == .8
ax = cat.barplot(data=self.df, x="g", y="y", hue="g", dodge=False)
for bar in ax.patches:
assert bar.get_width() == pytest.approx(.8)
ax.clear()

def test_draw_vertical_bars(self):

Expand Down

0 comments on commit 1e24db4

Please sign in to comment.