diff --git a/doc/_docstrings/objects.Est.ipynb b/doc/_docstrings/objects.Est.ipynb index 3dcac462e5..94aacfa902 100644 --- a/doc/_docstrings/objects.Est.ipynb +++ b/doc/_docstrings/objects.Est.ipynb @@ -109,12 +109,30 @@ "p.add(so.Range(), so.Est(seed=0))" ] }, + { + "cell_type": "markdown", + "id": "df807ef8-b5fb-4eac-b539-1bd4e797ddc2", + "metadata": {}, + "source": [ + "To compute a weighted estimate (and confidence interval), assign a `weight` variable in the layer where you use the stat:" + ] + }, { "cell_type": "code", "execution_count": null, "id": "5e4a0594-e1ee-4f72-971e-3763dd626e8b", "metadata": {}, "outputs": [], + "source": [ + "p.add(so.Range(), so.Est(), weight=\"price\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d0c34d7-fb76-44cf-9079-3ec7f45741d0", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/seaborn/_statistics.py b/seaborn/_statistics.py index c2f01ce7b5..ca0364de37 100644 --- a/seaborn/_statistics.py +++ b/seaborn/_statistics.py @@ -518,6 +518,62 @@ def __call__(self, data, var): return pd.Series({var: estimate, f"{var}min": err_min, f"{var}max": err_max}) +class WeightedEstimateAggregator: + + def __init__(self, estimator, errorbar=None, **boot_kws): + """ + Data aggregator that produces a weighted estimate and error bar interval. + + Parameters + ---------- + estimator : string + Function (or method name) that maps a vector to a scalar. Currently + supports only "mean". + errorbar : string or (string, number) tuple + Name of errorbar method or a tuple with a method name and a level parameter. + Currently the only supported method is "ci". + boot_kws + Additional keywords are passed to bootstrap when error_method is "ci". + + """ + if estimator != "mean": + # Note that, while other weighted estimators may make sense (e.g. median), + # I'm not aware of an implementation in our dependencies. We can add one + # in seaborn later, if there is sufficient interest. For now, limit to mean. + raise ValueError(f"Weighted estimator must be 'mean', not {estimator!r}.") + self.estimator = estimator + + method, level = _validate_errorbar_arg(errorbar) + if method is not None and method != "ci": + # As with the estimator, weighted 'sd' or 'pi' error bars may make sense. + # But we'll keep things simple for now and limit to (bootstrap) CI. + raise ValueError(f"Error bar method must be 'ci', not {method!r}.") + self.error_method = method + self.error_level = level + + self.boot_kws = boot_kws + + def __call__(self, data, var): + """Aggregate over `var` column of `data` with estimate and error interval.""" + vals = data[var] + weights = data["weight"] + + estimate = np.average(vals, weights=weights) + + if self.error_method == "ci" and len(data) > 1: + + def error_func(x, w): + return np.average(x, weights=w) + + boots = bootstrap(vals, weights, func=error_func, **self.boot_kws) + err_min, err_max = _percentile_interval(boots, self.error_level) + + else: + err_min = err_max = np.nan + + return pd.Series({var: estimate, f"{var}min": err_min, f"{var}max": err_max}) + + class LetterValues: def __init__(self, k_depth, outlier_prop, trust_alpha): diff --git a/seaborn/_stats/aggregation.py b/seaborn/_stats/aggregation.py index d175273e78..aa7677b142 100644 --- a/seaborn/_stats/aggregation.py +++ b/seaborn/_stats/aggregation.py @@ -8,7 +8,10 @@ from seaborn._core.scales import Scale from seaborn._core.groupby import GroupBy from seaborn._stats.base import Stat -from seaborn._statistics import EstimateAggregator +from seaborn._statistics import ( + EstimateAggregator, + WeightedEstimateAggregator, +) from seaborn._core.typing import Vector @@ -54,8 +57,14 @@ class Est(Stat): """ Calculate a point estimate and error bar interval. - For additional information about the various `errorbar` choices, see - the :doc:`errorbar tutorial `. + For more information about the various `errorbar` choices, see the + :doc:`errorbar tutorial `. + + Additional variables: + + - **weight**: When passed to a layer that uses this stat, a weighted estimate + will be computed. Note that use of weights currently limits the choice of + function and error bar method to `"mean"` and `"ci"`, respectively. Parameters ---------- @@ -95,7 +104,10 @@ def __call__( ) -> DataFrame: boot_kws = {"n_boot": self.n_boot, "seed": self.seed} - engine = EstimateAggregator(self.func, self.errorbar, **boot_kws) + if "weight" in data: + engine = WeightedEstimateAggregator(self.func, self.errorbar, **boot_kws) + else: + engine = EstimateAggregator(self.func, self.errorbar, **boot_kws) var = {"x": "y", "y": "x"}[orient] res = ( diff --git a/tests/_stats/test_aggregation.py b/tests/_stats/test_aggregation.py index 08291d449b..b3a5d58aab 100644 --- a/tests/_stats/test_aggregation.py +++ b/tests/_stats/test_aggregation.py @@ -115,6 +115,17 @@ def test_median_pi(self, df): expected = est.assign(ymin=grouped.min()["y"], ymax=grouped.max()["y"]) assert_frame_equal(res, expected) + def test_weighted_mean(self, df, rng): + + weights = rng.uniform(0, 5, len(df)) + gb = self.get_groupby(df[["x", "y"]], "x") + df = df.assign(weight=weights) + res = Est("mean")(df, gb, "x", {}) + for _, res_row in res.iterrows(): + rows = df[df["x"] == res_row["x"]] + expected = np.average(rows["y"], weights=rows["weight"]) + assert res_row["y"] == expected + def test_seed(self, df): ori = "x" diff --git a/tests/test_statistics.py b/tests/test_statistics.py index c0d4e83cf0..fb36e0e922 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -15,6 +15,7 @@ ECDF, EstimateAggregator, LetterValues, + WeightedEstimateAggregator, _validate_errorbar_arg, _no_scipy, ) @@ -632,6 +633,39 @@ def test_errorbar_validation(self): _validate_errorbar_arg(arg) +class TestWeightedEstimateAggregator: + + def test_weighted_mean(self, long_df): + + long_df["weight"] = long_df["x"] + est = WeightedEstimateAggregator("mean") + out = est(long_df, "y") + expected = np.average(long_df["y"], weights=long_df["weight"]) + assert_array_equal(out["y"], expected) + assert_array_equal(out["ymin"], np.nan) + assert_array_equal(out["ymax"], np.nan) + + def test_weighted_ci(self, long_df): + + long_df["weight"] = long_df["x"] + est = WeightedEstimateAggregator("mean", "ci") + out = est(long_df, "y") + expected = np.average(long_df["y"], weights=long_df["weight"]) + assert_array_equal(out["y"], expected) + assert (out["ymin"] <= out["y"]).all() + assert (out["ymax"] >= out["y"]).all() + + def test_limited_estimator(self): + + with pytest.raises(ValueError, match="Weighted estimator must be 'mean'"): + WeightedEstimateAggregator("median") + + def test_limited_ci(self): + + with pytest.raises(ValueError, match="Error bar method must be 'ci'"): + WeightedEstimateAggregator("mean", "sd") + + class TestLetterValues: @pytest.fixture