Skip to content

Commit

Permalink
Work around Series.agg non-aggregation EstimateAggregator (#2946)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom authored Aug 8, 2022
1 parent a1ede5e commit ae9080d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
7 changes: 6 additions & 1 deletion seaborn/_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,12 @@ def __init__(self, estimator, errorbar=None, **boot_kws):
def __call__(self, data, var):
"""Aggregate over `var` column of `data` with estimate and error interval."""
vals = data[var]
estimate = vals.agg(self.estimator)
if callable(self.estimator):
# You would think we could pass to vals.agg, and yet:
# https://github.com/mwaskom/seaborn/issues/2943
estimate = self.estimator(vals)
else:
estimate = vals.agg(self.estimator)

# Options that produce no error bars
if self.error_method is None:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,15 @@ def test_name_estimator(self, long_df):
out = agg(long_df, "x")
assert out["x"] == long_df["x"].mean()

def test_custom_func_estimator(self, long_df):

def func(x):
return np.asarray(x).min()

agg = EstimateAggregator(func)
out = agg(long_df, "x")
assert out["x"] == func(long_df["x"])

def test_se_errorbars(self, long_df):

agg = EstimateAggregator("mean", "se")
Expand Down

0 comments on commit ae9080d

Please sign in to comment.