Skip to content

Commit

Permalink
Revert Est to use EstimateAggregator and add (light) tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Jul 18, 2022
1 parent d9672c8 commit 01cf7e4
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 88 deletions.
126 changes: 40 additions & 86 deletions seaborn/_stats/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import ClassVar, Callable
from numbers import Number

import numpy as np
import pandas as pd
from pandas import DataFrame

from seaborn._core.scales import Scale
from seaborn._core.groupby import GroupBy
from seaborn._stats.base import Stat
from seaborn.algorithms import bootstrap
from seaborn.utils import _check_argument
from seaborn._statistics import EstimateAggregator

from seaborn._core.typing import Vector

Expand All @@ -20,21 +20,22 @@ class Agg(Stat):
Parameters
----------
func
func : str or callable
Name of a :class:`pandas.Series` method or a vector -> scalar function.
"""
func: str | Callable[[Vector], float] = "mean"

group_by_orient: ClassVar[bool] = True

def __call__(self, data, groupby, orient, scales):
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

var = {"x": "y", "y": "x"}.get(orient)
res = (
groupby
.agg(data, {var: self.func})
# TODO Could be an option not to drop NA?
.dropna()
.reset_index(drop=True)
)
Expand All @@ -43,64 +44,55 @@ def __call__(self, data, groupby, orient, scales):

@dataclass
class Est(Stat):
"""
Calculate a point estimate and error bar interval.
Parameters
----------
func : str or callable
Name of a :class:`numpy.ndarray` method or a vector -> scalar function.
errorbar : str, (str, float) tuple, or callable
Name of errorbar method (one of "ci", "pi", "se" or "sd"), or a tuple
with a method name ane a level parameter, or a function that maps from a
vector to a (min, max) interval.
n_boot : int
Number of bootstrap samples to draw for "ci" errorbars.
seed : int
Seed for the PRNG used to draw bootstrap samples.
"""
func: str | Callable[[Vector], float] = "mean"
errorbar: str | tuple[str, float] = ("ci", 95)
n_boot: int = 1000
seed: int | None = None

group_by_orient: ClassVar[bool] = True

def _process(self, data, var):

vals = data[var]

estimate = vals.agg(self.func)

# Options that produce no error bars
if self.error_method is None:
err_min = err_max = np.nan
elif len(data) <= 1:
err_min = err_max = np.nan

# Generic errorbars from user-supplied function
elif callable(self.error_method):
err_min, err_max = self.error_method(vals)

# Parametric options
elif self.error_method == "sd":
half_interval = vals.std() * self.error_level
err_min, err_max = estimate - half_interval, estimate + half_interval
elif self.error_method == "se":
half_interval = vals.sem() * self.error_level
err_min, err_max = estimate - half_interval, estimate + half_interval

# Nonparametric options
elif self.error_method == "pi":
err_min, err_max = _percentile_interval(vals, self.error_level)
elif self.error_method == "ci":
boot_kws = {"n_boot": self.n_boot, "seed": self.seed}
# units = data.get("units", None) # TODO change to unit
units = None
boots = bootstrap(vals, units=units, func=self.func, **boot_kws)
err_min, err_max = _percentile_interval(boots, self.error_level)

res = {var: estimate, f"{var}min": err_min, f"{var}max": err_max}
def _process(
self, data: DataFrame, var: str, estimator: EstimateAggregator
) -> DataFrame:
# Needed because GroupBy.apply assumes func is DataFrame -> DataFrame
# which we could probably make more general to allow Series return
res = estimator(data, var)
return pd.DataFrame([res])

def __call__(self, data, groupby, orient, scales):
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

method, level = _validate_errorbar_arg(self.errorbar)
self.error_method = method
self.error_level = level
boot_kws = {"n_boot": self.n_boot, "seed": self.seed}
engine = EstimateAggregator(self.func, self.errorbar, **boot_kws)

var = {"x": "y", "y": "x"}.get(orient)
res = (
groupby
.apply(data, self._process, var)
.dropna()
.apply(data, self._process, var, engine)
.dropna(subset=["x", "y"])
.reset_index(drop=True)
)

res = res.fillna({f"{var}min": res[var], f"{var}max": res[var]})

return res


Expand All @@ -110,41 +102,3 @@ class Rolling(Stat):

def __call__(self, data, groupby, orient, scales):
...


def _percentile_interval(data, width):
"""Return a percentile interval from data of a given width."""
edge = (100 - width) / 2
percentiles = edge, 100 - edge
return np.nanpercentile(data, percentiles)


def _validate_errorbar_arg(arg):
"""Check type and value of errorbar argument and assign default level."""
DEFAULT_LEVELS = {
"ci": 95,
"pi": 95,
"se": 1,
"sd": 1,
}

usage = "`errorbar` must be a callable, string, or (string, number) tuple"

if arg is None:
return None, None
elif callable(arg):
return arg, None
elif isinstance(arg, str):
method = arg
level = DEFAULT_LEVELS.get(method, None)
else:
try:
method, level = arg
except (ValueError, TypeError) as err:
raise err.__class__(usage) from err

_check_argument("errorbar", list(DEFAULT_LEVELS), method)
if level is not None and not isinstance(level, Number):
raise TypeError(usage)

return method, level
3 changes: 3 additions & 0 deletions seaborn/_stats/histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class Hist(Stat):
# Q: would Discrete() scale imply binwidth=1 or bins centered on integers?
discrete: bool = False

# TODO Note that these methods are mostly copied from _statistics.Histogram,
# but it only computes univariate histograms. We should reconcile the code.

def _define_bin_edges(self, vals, weight, bins, binwidth, binrange, discrete):
"""Inner function that takes bin parameters as arguments."""
vals = vals.dropna()
Expand Down
58 changes: 56 additions & 2 deletions tests/_stats/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@

import numpy as np
import pandas as pd

import pytest
from pandas.testing import assert_frame_equal

from seaborn._core.groupby import GroupBy
from seaborn._stats.aggregation import Agg
from seaborn._stats.aggregation import Agg, Est


class TestAgg:
class AggregationFixtures:

@pytest.fixture
def df(self, rng):
Expand All @@ -27,6 +28,9 @@ def get_groupby(self, df, orient):
cols = [c for c in df if c != other]
return GroupBy(cols)


class TestAgg(AggregationFixtures):

def test_default(self, df):

ori = "x"
Expand Down Expand Up @@ -69,3 +73,53 @@ def test_func(self, df, func):

expected = df.groupby("x", as_index=False)["y"].agg(func)
assert_frame_equal(res, expected)


class TestEst(AggregationFixtures):

# Note: Most of the underlying code is exercised in tests/test_statistics

@pytest.mark.parametrize("func", [np.mean, "mean"])
def test_mean_sd(self, df, func):

ori = "x"
df = df[["x", "y"]]
gb = self.get_groupby(df, ori)
res = Est(func, "sd")(df, gb, ori, {})

grouped = df.groupby("x", as_index=False)["y"]
est = grouped.mean()
err = grouped.std()
expected = est.assign(ymin=est["y"] - err["y"], ymax=est["y"] + err["y"])
assert_frame_equal(res, expected)

def test_sd_single_obs(self):

y = 1.5
ori = "x"
df = pd.DataFrame([{"x": "a", "y": y}])
gb = self.get_groupby(df, ori)
res = Est("mean", "sd")(df, gb, ori, {})
expected = df.assign(ymin=y, ymax=y)
assert_frame_equal(res, expected)

def test_median_pi(self, df):

ori = "x"
df = df[["x", "y"]]
gb = self.get_groupby(df, ori)
res = Est("median", ("pi", 100))(df, gb, ori, {})

grouped = df.groupby("x", as_index=False)["y"]
est = grouped.median()
expected = est.assign(ymin=grouped.min()["y"], ymax=grouped.max()["y"])
assert_frame_equal(res, expected)

def test_seed(self, df):

ori = "x"
gb = self.get_groupby(df, ori)
args = df, gb, ori, {}
res1 = Est("mean", "ci", seed=99)(*args)
res2 = Est("mean", "ci", seed=99)(*args)
assert_frame_equal(res1, res2)

0 comments on commit 01cf7e4

Please sign in to comment.