-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Perc stat for computing percentiles (#3063)
* Add Perc stat * Add Perc tests * Fix orientation test * Add Perc to API docs * Get Literal from typing_extensions when necessary * Make robust to missing data * Numpy backcompat * Add backcompat conditional in test too * Add API examples
- Loading branch information
Showing
7 changed files
with
300 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2d44a326-029b-47ff-b560-5f4b6a4bb73f", | ||
"metadata": { | ||
"tags": [ | ||
"hide" | ||
] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import seaborn.objects as so\n", | ||
"from seaborn import load_dataset\n", | ||
"diamonds = load_dataset(\"diamonds\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"id": "65e975a2-2559-4bf1-8851-8bbbf52bf22d", | ||
"metadata": {}, | ||
"source": [ | ||
"The default behavior computes the quartiles and min/max of the input data:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "36f927f5-3b64-4871-a355-adadc4da769b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"p = (\n", | ||
" so.Plot(diamonds, \"cut\", \"price\")\n", | ||
" .scale(y=\"log\")\n", | ||
")\n", | ||
"p.add(so.Dot(), so.Perc())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"id": "feba1b99-0f71-4b18-8e7e-bd5470cc2d0c", | ||
"metadata": {}, | ||
"source": [ | ||
"Passing an integer will compute that many evenly-spaced percentiles:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f030dd39-1223-475a-93e1-1759a8971a6c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"p.add(so.Dot(), so.Perc(20))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"id": "85bd754b-122e-4475-8727-2d584a90a38e", | ||
"metadata": {}, | ||
"source": [ | ||
"Passing a list will compute exactly those percentiles:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2fde7549-45b5-411a-afba-eb0da754d9e9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"p.add(so.Dot(), so.Perc([10, 25, 50, 75, 90]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"id": "7be16a13-dfc8-4595-a904-42f9be10f4f6", | ||
"metadata": {}, | ||
"source": [ | ||
"Combine with a range mark to show a percentile interval:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "05c561c6-0449-4a61-96d1-390611a1b694", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"(\n", | ||
" so.Plot(diamonds, \"price\", \"cut\")\n", | ||
" .add(so.Dots(pointsize=1, alpha=.2), so.Jitter(.3))\n", | ||
" .add(so.Range(color=\"k\"), so.Perc([25, 75]), so.Shift(y=.2))\n", | ||
" .scale(x=\"log\")\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d464157c-3187-49c1-9cd8-71f284ce4c50", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "py310", | ||
"language": "python", | ||
"name": "py310" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,6 +86,7 @@ Stat objects | |
Agg | ||
Est | ||
Hist | ||
Perc | ||
PolyFit | ||
|
||
Move objects | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
|
||
from __future__ import annotations | ||
from dataclasses import dataclass | ||
from typing import ClassVar, cast | ||
try: | ||
from typing import Literal | ||
except ImportError: | ||
from typing_extensions import Literal # type: ignore | ||
|
||
import numpy as np | ||
from pandas import DataFrame | ||
|
||
from seaborn._core.scales import Scale | ||
from seaborn._core.groupby import GroupBy | ||
from seaborn._stats.base import Stat | ||
from seaborn.external.version import Version | ||
|
||
|
||
# From https://github.com/numpy/numpy/blob/main/numpy/lib/function_base.pyi | ||
_MethodKind = Literal[ | ||
"inverted_cdf", | ||
"averaged_inverted_cdf", | ||
"closest_observation", | ||
"interpolated_inverted_cdf", | ||
"hazen", | ||
"weibull", | ||
"linear", | ||
"median_unbiased", | ||
"normal_unbiased", | ||
"lower", | ||
"higher", | ||
"midpoint", | ||
"nearest", | ||
] | ||
|
||
|
||
@dataclass | ||
class Perc(Stat): | ||
""" | ||
Replace observations with percentile values. | ||
Parameters | ||
---------- | ||
k : list of numbers or int | ||
If a list of numbers, this gives the percentiles (in [0, 100]) to compute. | ||
If an integer, compute `k` evenly-spaced percentiles between 0 and 100. | ||
For example, `k=5` computes the 0, 25, 50, 75, and 100th percentiles. | ||
method : str | ||
Method for interpolating percentiles between observed datapoints. | ||
See :func:`numpy.percentile` for valid options and more information. | ||
Examples | ||
-------- | ||
.. include:: ../docstrings/objects.Perc.rst | ||
""" | ||
k: int | list[float] = 5 | ||
method: str = "linear" | ||
|
||
group_by_orient: ClassVar[bool] = True | ||
|
||
def _percentile(self, data: DataFrame, var: str) -> DataFrame: | ||
|
||
k = list(np.linspace(0, 100, self.k)) if isinstance(self.k, int) else self.k | ||
method = cast(_MethodKind, self.method) | ||
values = data[var].dropna() | ||
if Version(np.__version__) < Version("1.22.0"): | ||
res = np.percentile(values, k, interpolation=method) # type: ignore | ||
else: | ||
res = np.percentile(data[var].dropna(), k, method=method) | ||
return DataFrame({var: res, "percentile": k}) | ||
|
||
def __call__( | ||
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], | ||
) -> DataFrame: | ||
|
||
var = {"x": "y", "y": "x"}[orient] | ||
return groupby.apply(data, self._percentile, var) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
import pytest | ||
from numpy.testing import assert_array_equal | ||
|
||
from seaborn._core.groupby import GroupBy | ||
from seaborn._stats.order import Perc | ||
from seaborn.external.version import Version | ||
|
||
|
||
class Fixtures: | ||
|
||
@pytest.fixture | ||
def df(self, rng): | ||
return pd.DataFrame(dict(x="", y=rng.normal(size=30))) | ||
|
||
def get_groupby(self, df, orient): | ||
# TODO note, copied from aggregation | ||
other = {"x": "y", "y": "x"}[orient] | ||
cols = [c for c in df if c != other] | ||
return GroupBy(cols) | ||
|
||
|
||
class TestPerc(Fixtures): | ||
|
||
def test_int_k(self, df): | ||
|
||
ori = "x" | ||
gb = self.get_groupby(df, ori) | ||
res = Perc(3)(df, gb, ori, {}) | ||
percentiles = [0, 50, 100] | ||
assert_array_equal(res["percentile"], percentiles) | ||
assert_array_equal(res["y"], np.percentile(df["y"], percentiles)) | ||
|
||
def test_list_k(self, df): | ||
|
||
ori = "x" | ||
gb = self.get_groupby(df, ori) | ||
percentiles = [0, 20, 100] | ||
res = Perc(k=percentiles)(df, gb, ori, {}) | ||
assert_array_equal(res["percentile"], percentiles) | ||
assert_array_equal(res["y"], np.percentile(df["y"], percentiles)) | ||
|
||
def test_orientation(self, df): | ||
|
||
df = df.rename(columns={"x": "y", "y": "x"}) | ||
ori = "y" | ||
gb = self.get_groupby(df, ori) | ||
res = Perc(k=3)(df, gb, ori, {}) | ||
assert_array_equal(res["x"], np.percentile(df["x"], [0, 50, 100])) | ||
|
||
def test_method(self, df): | ||
|
||
ori = "x" | ||
gb = self.get_groupby(df, ori) | ||
method = "nearest" | ||
res = Perc(k=5, method=method)(df, gb, ori, {}) | ||
percentiles = [0, 25, 50, 75, 100] | ||
if Version(np.__version__) < Version("1.22.0"): | ||
expected = np.percentile(df["y"], percentiles, interpolation=method) | ||
else: | ||
expected = np.percentile(df["y"], percentiles, method=method) | ||
assert_array_equal(res["y"], expected) | ||
|
||
def test_grouped(self, df, rng): | ||
|
||
ori = "x" | ||
df = df.assign(x=rng.choice(["a", "b", "c"], len(df))) | ||
gb = self.get_groupby(df, ori) | ||
k = [10, 90] | ||
res = Perc(k)(df, gb, ori, {}) | ||
for x, res_x in res.groupby("x"): | ||
assert_array_equal(res_x["percentile"], k) | ||
expected = np.percentile(df.loc[df["x"] == x, "y"], k) | ||
assert_array_equal(res_x["y"], expected) | ||
|
||
def test_with_na(self, df): | ||
|
||
ori = "x" | ||
df.loc[:5, "y"] = np.nan | ||
gb = self.get_groupby(df, ori) | ||
k = [10, 90] | ||
res = Perc(k)(df, gb, ori, {}) | ||
expected = np.percentile(df["y"].dropna(), k) | ||
assert_array_equal(res["y"], expected) |