From ae9080d583032ab98a731e55753caf82b10c6905 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 8 Aug 2022 19:57:56 -0400 Subject: [PATCH] Work around Series.agg non-aggregation EstimateAggregator (#2946) --- seaborn/_statistics.py | 7 ++++++- tests/test_statistics.py | 9 +++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/seaborn/_statistics.py b/seaborn/_statistics.py index e745f3fb70..a19fdea7f1 100644 --- a/seaborn/_statistics.py +++ b/seaborn/_statistics.py @@ -478,7 +478,12 @@ def __init__(self, estimator, errorbar=None, **boot_kws): def __call__(self, data, var): """Aggregate over `var` column of `data` with estimate and error interval.""" vals = data[var] - estimate = vals.agg(self.estimator) + if callable(self.estimator): + # You would think we could pass to vals.agg, and yet: + # https://github.com/mwaskom/seaborn/issues/2943 + estimate = self.estimator(vals) + else: + estimate = vals.agg(self.estimator) # Options that produce no error bars if self.error_method is None: diff --git a/tests/test_statistics.py b/tests/test_statistics.py index a97f3d6877..e39127882a 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -499,6 +499,15 @@ def test_name_estimator(self, long_df): out = agg(long_df, "x") assert out["x"] == long_df["x"].mean() + def test_custom_func_estimator(self, long_df): + + def func(x): + return np.asarray(x).min() + + agg = EstimateAggregator(func) + out = agg(long_df, "x") + assert out["x"] == func(long_df["x"]) + def test_se_errorbars(self, long_df): agg = EstimateAggregator("mean", "se")