diff --git a/seaborn/_stats/aggregation.py b/seaborn/_stats/aggregation.py index 3560977ab0..73da54bdad 100644 --- a/seaborn/_stats/aggregation.py +++ b/seaborn/_stats/aggregation.py @@ -1,14 +1,14 @@ from __future__ import annotations from dataclasses import dataclass from typing import ClassVar, Callable -from numbers import Number -import numpy as np import pandas as pd +from pandas import DataFrame +from seaborn._core.scales import Scale +from seaborn._core.groupby import GroupBy from seaborn._stats.base import Stat -from seaborn.algorithms import bootstrap -from seaborn.utils import _check_argument +from seaborn._statistics import EstimateAggregator from seaborn._core.typing import Vector @@ -20,7 +20,7 @@ class Agg(Stat): Parameters ---------- - func + func : str or callable Name of a :class:`pandas.Series` method or a vector -> scalar function. """ @@ -28,13 +28,14 @@ class Agg(Stat): group_by_orient: ClassVar[bool] = True - def __call__(self, data, groupby, orient, scales): + def __call__( + self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], + ) -> DataFrame: var = {"x": "y", "y": "x"}.get(orient) res = ( groupby .agg(data, {var: self.func}) - # TODO Could be an option not to drop NA? .dropna() .reset_index(drop=True) ) @@ -43,7 +44,23 @@ def __call__(self, data, groupby, orient, scales): @dataclass class Est(Stat): + """ + Calculate a point estimate and error bar interval. + + Parameters + ---------- + func : str or callable + Name of a :class:`numpy.ndarray` method or a vector -> scalar function. + errorbar : str, (str, float) tuple, or callable + Name of errorbar method (one of "ci", "pi", "se" or "sd"), or a tuple + with a method name ane a level parameter, or a function that maps from a + vector to a (min, max) interval. + n_boot : int + Number of bootstrap samples to draw for "ci" errorbars. + seed : int + Seed for the PRNG used to draw bootstrap samples. + """ func: str | Callable[[Vector], float] = "mean" errorbar: str | tuple[str, float] = ("ci", 95) n_boot: int = 1000 @@ -51,56 +68,31 @@ class Est(Stat): group_by_orient: ClassVar[bool] = True - def _process(self, data, var): - - vals = data[var] - - estimate = vals.agg(self.func) - - # Options that produce no error bars - if self.error_method is None: - err_min = err_max = np.nan - elif len(data) <= 1: - err_min = err_max = np.nan - - # Generic errorbars from user-supplied function - elif callable(self.error_method): - err_min, err_max = self.error_method(vals) - - # Parametric options - elif self.error_method == "sd": - half_interval = vals.std() * self.error_level - err_min, err_max = estimate - half_interval, estimate + half_interval - elif self.error_method == "se": - half_interval = vals.sem() * self.error_level - err_min, err_max = estimate - half_interval, estimate + half_interval - - # Nonparametric options - elif self.error_method == "pi": - err_min, err_max = _percentile_interval(vals, self.error_level) - elif self.error_method == "ci": - boot_kws = {"n_boot": self.n_boot, "seed": self.seed} - # units = data.get("units", None) # TODO change to unit - units = None - boots = bootstrap(vals, units=units, func=self.func, **boot_kws) - err_min, err_max = _percentile_interval(boots, self.error_level) - - res = {var: estimate, f"{var}min": err_min, f"{var}max": err_max} + def _process( + self, data: DataFrame, var: str, estimator: EstimateAggregator + ) -> DataFrame: + # Needed because GroupBy.apply assumes func is DataFrame -> DataFrame + # which we could probably make more general to allow Series return + res = estimator(data, var) return pd.DataFrame([res]) - def __call__(self, data, groupby, orient, scales): + def __call__( + self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], + ) -> DataFrame: - method, level = _validate_errorbar_arg(self.errorbar) - self.error_method = method - self.error_level = level + boot_kws = {"n_boot": self.n_boot, "seed": self.seed} + engine = EstimateAggregator(self.func, self.errorbar, **boot_kws) var = {"x": "y", "y": "x"}.get(orient) res = ( groupby - .apply(data, self._process, var) - .dropna() + .apply(data, self._process, var, engine) + .dropna(subset=["x", "y"]) .reset_index(drop=True) ) + + res = res.fillna({f"{var}min": res[var], f"{var}max": res[var]}) + return res @@ -110,41 +102,3 @@ class Rolling(Stat): def __call__(self, data, groupby, orient, scales): ... - - -def _percentile_interval(data, width): - """Return a percentile interval from data of a given width.""" - edge = (100 - width) / 2 - percentiles = edge, 100 - edge - return np.nanpercentile(data, percentiles) - - -def _validate_errorbar_arg(arg): - """Check type and value of errorbar argument and assign default level.""" - DEFAULT_LEVELS = { - "ci": 95, - "pi": 95, - "se": 1, - "sd": 1, - } - - usage = "`errorbar` must be a callable, string, or (string, number) tuple" - - if arg is None: - return None, None - elif callable(arg): - return arg, None - elif isinstance(arg, str): - method = arg - level = DEFAULT_LEVELS.get(method, None) - else: - try: - method, level = arg - except (ValueError, TypeError) as err: - raise err.__class__(usage) from err - - _check_argument("errorbar", list(DEFAULT_LEVELS), method) - if level is not None and not isinstance(level, Number): - raise TypeError(usage) - - return method, level diff --git a/seaborn/_stats/histograms.py b/seaborn/_stats/histograms.py index 8f069b64e5..85abed1036 100644 --- a/seaborn/_stats/histograms.py +++ b/seaborn/_stats/histograms.py @@ -31,6 +31,9 @@ class Hist(Stat): # Q: would Discrete() scale imply binwidth=1 or bins centered on integers? discrete: bool = False + # TODO Note that these methods are mostly copied from _statistics.Histogram, + # but it only computes univariate histograms. We should reconcile the code. + def _define_bin_edges(self, vals, weight, bins, binwidth, binrange, discrete): """Inner function that takes bin parameters as arguments.""" vals = vals.dropna() diff --git a/tests/_stats/test_aggregation.py b/tests/_stats/test_aggregation.py index ed5b7e4d03..c2a030886e 100644 --- a/tests/_stats/test_aggregation.py +++ b/tests/_stats/test_aggregation.py @@ -1,14 +1,15 @@ +import numpy as np import pandas as pd import pytest from pandas.testing import assert_frame_equal from seaborn._core.groupby import GroupBy -from seaborn._stats.aggregation import Agg +from seaborn._stats.aggregation import Agg, Est -class TestAgg: +class AggregationFixtures: @pytest.fixture def df(self, rng): @@ -27,6 +28,9 @@ def get_groupby(self, df, orient): cols = [c for c in df if c != other] return GroupBy(cols) + +class TestAgg(AggregationFixtures): + def test_default(self, df): ori = "x" @@ -69,3 +73,53 @@ def test_func(self, df, func): expected = df.groupby("x", as_index=False)["y"].agg(func) assert_frame_equal(res, expected) + + +class TestEst(AggregationFixtures): + + # Note: Most of the underlying code is exercised in tests/test_statistics + + @pytest.mark.parametrize("func", [np.mean, "mean"]) + def test_mean_sd(self, df, func): + + ori = "x" + df = df[["x", "y"]] + gb = self.get_groupby(df, ori) + res = Est(func, "sd")(df, gb, ori, {}) + + grouped = df.groupby("x", as_index=False)["y"] + est = grouped.mean() + err = grouped.std() + expected = est.assign(ymin=est["y"] - err["y"], ymax=est["y"] + err["y"]) + assert_frame_equal(res, expected) + + def test_sd_single_obs(self): + + y = 1.5 + ori = "x" + df = pd.DataFrame([{"x": "a", "y": y}]) + gb = self.get_groupby(df, ori) + res = Est("mean", "sd")(df, gb, ori, {}) + expected = df.assign(ymin=y, ymax=y) + assert_frame_equal(res, expected) + + def test_median_pi(self, df): + + ori = "x" + df = df[["x", "y"]] + gb = self.get_groupby(df, ori) + res = Est("median", ("pi", 100))(df, gb, ori, {}) + + grouped = df.groupby("x", as_index=False)["y"] + est = grouped.median() + expected = est.assign(ymin=grouped.min()["y"], ymax=grouped.max()["y"]) + assert_frame_equal(res, expected) + + def test_seed(self, df): + + ori = "x" + gb = self.get_groupby(df, ori) + args = df, gb, ori, {} + res1 = Est("mean", "ci", seed=99)(*args) + res2 = Est("mean", "ci", seed=99)(*args) + assert_frame_equal(res1, res2)