Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Est stat and Interval mark to show error bars #2912

Merged
merged 9 commits into from
Jul 24, 2022
9 changes: 5 additions & 4 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,10 +1354,11 @@ def split_generator(keep_na=False) -> Generator:
# Matplotlib (usually?) masks nan data, so this should "work".
# Downstream code can also drop these rows, at some speed cost.
present = axes_df.notna().all(axis=1)
axes_df = axes_df.assign(
x=axes_df["x"].where(present),
y=axes_df["y"].where(present),
)
nulled = {}
for axis in "xy":
if axis in axes_df:
nulled[axis] = axes_df[axis].where(present)
axes_df = axes_df.assign(**nulled)
else:
axes_df = axes_df.dropna()

Expand Down
5 changes: 4 additions & 1 deletion seaborn/_core/scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,10 @@ def normalize(x):
]

def spacer(x):
return np.min(np.diff(np.sort(x.dropna().unique())))
x = x.dropna().unique()
if len(x) < 2:
return np.nan
return np.min(np.diff(np.sort(x)))
new._spacer = spacer

# TODO How to allow disabling of legend for all uses of property?
Expand Down
6 changes: 2 additions & 4 deletions seaborn/_marks/bars.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,8 @@ def _plot(self, split_gen, scales, orient):
# Workaround for matplotlib autoscaling bug
# https://github.com/matplotlib/matplotlib/issues/11898
# https://github.com/matplotlib/matplotlib/issues/23129
xy = np.vstack([path.vertices for path in col.get_paths()])
ax.dataLim.update_from_data_xy(
xy, ax.ignore_existing_data_limits, updatex=True, updatey=True
)
xys = np.vstack([path.vertices for path in col.get_paths()])
ax.update_datalim(xys)

if "edgewidth" not in scales and isinstance(self.edgewidth, Mappable):

Expand Down
94 changes: 80 additions & 14 deletions seaborn/_marks/lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def _plot(self, split_gen, scales, orient):
if self._sort:
data = data.sort_values(orient)

artist_kws = self.artist_kws.copy()
self._handle_capstyle(artist_kws, vals)

line = mpl.lines.Line2D(
data["x"].to_numpy(),
data["y"].to_numpy(),
Expand All @@ -61,7 +64,7 @@ def _plot(self, split_gen, scales, orient):
markerfacecolor=vals["fillcolor"],
markeredgecolor=vals["edgecolor"],
markeredgewidth=vals["edgewidth"],
**self.artist_kws,
**artist_kws,
)
ax.add_line(line)

Expand All @@ -77,6 +80,9 @@ def _legend_artist(self, variables, value, scales):
if Version(mpl.__version__) < Version("3.3.0"):
vals["marker"] = vals["marker"]._marker

artist_kws = self.artist_kws.copy()
self._handle_capstyle(artist_kws, vals)

return mpl.lines.Line2D(
[], [],
color=vals["color"],
Expand All @@ -87,9 +93,17 @@ def _legend_artist(self, variables, value, scales):
markerfacecolor=vals["fillcolor"],
markeredgecolor=vals["edgecolor"],
markeredgewidth=vals["edgewidth"],
**self.artist_kws,
**artist_kws,
)

def _handle_capstyle(self, kws, vals):

# Work around for this matplotlib issue:
# https://github.com/matplotlib/matplotlib/issues/23437
if vals["linestyle"][1] is None:
capstyle = kws.get("solid_capstyle", mpl.rcParams["lines.solid_capstyle"])
kws["dash_capstyle"] = capstyle


@dataclass
class Line(Path):
Expand All @@ -111,7 +125,15 @@ class Paths(Mark):

_sort: ClassVar[bool] = False

def _plot(self, split_gen, scales, orient):
def __post_init__(self):

# LineCollection artists have a capstyle property but don't source its value
# from the rc, so we do that manually here. Unfortunately, because we add
# only one LineCollection, we have the use the same capstyle for all lines
# even when they are dashed. It's a slight inconsistency, but looks fine IMO.
self.artist_kws.setdefault("capstyle", mpl.rcParams["lines.solid_capstyle"])

def _setup_lines(self, split_gen, scales, orient):

line_data = {}

Expand All @@ -131,36 +153,42 @@ def _plot(self, split_gen, scales, orient):
if self._sort:
data = data.sort_values(orient)

# TODO comment about block consolidation
# Column stack to avoid block consolidation
xy = np.column_stack([data["x"], data["y"]])
line_data[ax]["segments"].append(xy)
line_data[ax]["colors"].append(vals["color"])
line_data[ax]["linewidths"].append(vals["linewidth"])
line_data[ax]["linestyles"].append(vals["linestyle"])

return line_data

def _plot(self, split_gen, scales, orient):

line_data = self._setup_lines(split_gen, scales, orient)

for ax, ax_data in line_data.items():
lines = mpl.collections.LineCollection(
**ax_data,
**self.artist_kws,
)
ax.add_collection(lines, autolim=False)
lines = mpl.collections.LineCollection(**ax_data, **self.artist_kws)
# Handle datalim update manually
# https://github.com/matplotlib/matplotlib/issues/23129
# TODO get paths from lines object?
ax.add_collection(lines, autolim=False)
xy = np.concatenate(ax_data["segments"])
ax.dataLim.update_from_data_xy(
xy, ax.ignore_existing_data_limits, updatex=True, updatey=True
)
ax.update_datalim(xy)

def _legend_artist(self, variables, value, scales):

key = resolve_properties(self, {v: value for v in variables}, scales)

artist_kws = self.artist_kws.copy()
capstyle = artist_kws.pop("capstyle")
artist_kws["solid_capstyle"] = capstyle
artist_kws["dash_capstyle"] = capstyle

return mpl.lines.Line2D(
[], [],
color=key["color"],
linewidth=key["linewidth"],
linestyle=key["linestyle"],
**self.artist_kws,
**artist_kws,
)


Expand All @@ -170,3 +198,41 @@ class Lines(Paths):
A faster but less-flexible mark for drawing many lines.
"""
_sort: ClassVar[bool] = True


@dataclass
class Interval(Paths):
"""
An oriented line mark drawn between min/max values.
"""
def _setup_lines(self, split_gen, scales, orient):

line_data = {}

other = {"x": "y", "y": "x"}[orient]

for keys, data, ax in split_gen(keep_na=not self._sort):

if ax not in line_data:
line_data[ax] = {
"segments": [],
"colors": [],
"linewidths": [],
"linestyles": [],
}

vals = resolve_properties(self, keys, scales)
vals["color"] = resolve_color(self, keys, scales=scales)

cols = [orient, f"{other}min", f"{other}max"]
data = data[cols].melt(orient, value_name=other)[["x", "y"]]
segments = [d.to_numpy() for _, d in data.groupby(orient)]

line_data[ax]["segments"].extend(segments)

n = len(segments)
line_data[ax]["colors"].extend([vals["color"]] * n)
line_data[ax]["linewidths"].extend([vals["linewidth"]] * n)
line_data[ax]["linestyles"].extend([vals["linestyle"]] * n)

return line_data
76 changes: 57 additions & 19 deletions seaborn/_stats/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import ClassVar
from typing import ClassVar, Callable

import pandas as pd
from pandas import DataFrame

from seaborn._core.scales import Scale
from seaborn._core.groupby import GroupBy
from seaborn._stats.base import Stat
from seaborn._statistics import EstimateAggregator

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable
from numbers import Number
from seaborn._core.typing import Vector
from seaborn._core.typing import Vector


@dataclass
Expand All @@ -18,23 +20,22 @@ class Agg(Stat):

Parameters
----------
func
Name of a method understood by Pandas or an arbitrary vector -> scalar function.
func : str or callable
Name of a :class:`pandas.Series` method or a vector -> scalar function.

"""
# TODO In current practice we will always have a numeric x/y variable,
# but they may represent non-numeric values. Needs clear documentation.
func: str | Callable[[Vector], Number] = "mean"
func: str | Callable[[Vector], float] = "mean"

group_by_orient: ClassVar[bool] = True

def __call__(self, data, groupby, orient, scales):
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

var = {"x": "y", "y": "x"}.get(orient)
res = (
groupby
.agg(data, {var: self.func})
# TODO Could be an option not to drop NA?
.dropna()
.reset_index(drop=True)
)
Expand All @@ -43,19 +44,56 @@ def __call__(self, data, groupby, orient, scales):

@dataclass
class Est(Stat):
"""
Calculate a point estimate and error bar interval.

# TODO a string here must be a numpy ufunc?
func: str | Callable[[Vector], Number] = "mean"
Parameters
----------
func : str or callable
Name of a :class:`numpy.ndarray` method or a vector -> scalar function.
errorbar : str, (str, float) tuple, or callable
Name of errorbar method (one of "ci", "pi", "se" or "sd"), or a tuple
with a method name ane a level parameter, or a function that maps from a
vector to a (min, max) interval.
n_boot : int
Number of bootstrap samples to draw for "ci" errorbars.
seed : int
Seed for the PRNG used to draw bootstrap samples.

# TODO type errorbar options with literal?
"""
func: str | Callable[[Vector], float] = "mean"
errorbar: str | tuple[str, float] = ("ci", 95)
n_boot: int = 1000
seed: int | None = None

group_by_orient: ClassVar[bool] = True

def __call__(self, data, groupby, orient, scales):
def _process(
self, data: DataFrame, var: str, estimator: EstimateAggregator
) -> DataFrame:
# Needed because GroupBy.apply assumes func is DataFrame -> DataFrame
# which we could probably make more general to allow Series return
res = estimator(data, var)
return pd.DataFrame([res])

# TODO port code over from _statistics
...
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

boot_kws = {"n_boot": self.n_boot, "seed": self.seed}
engine = EstimateAggregator(self.func, self.errorbar, **boot_kws)

var = {"x": "y", "y": "x"}.get(orient)
res = (
groupby
.apply(data, self._process, var, engine)
.dropna(subset=["x", "y"])
.reset_index(drop=True)
)

res = res.fillna({f"{var}min": res[var], f"{var}max": res[var]})

return res


@dataclass
Expand Down
3 changes: 3 additions & 0 deletions seaborn/_stats/histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class Hist(Stat):
# Q: would Discrete() scale imply binwidth=1 or bins centered on integers?
discrete: bool = False

# TODO Note that these methods are mostly copied from _statistics.Histogram,
# but it only computes univariate histograms. We should reconcile the code.

def _define_bin_edges(self, vals, weight, bins, binwidth, binrange, discrete):
"""Inner function that takes bin parameters as arguments."""
vals = vals.dropna()
Expand Down
31 changes: 28 additions & 3 deletions seaborn/objects.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,41 @@
"""
TODO Give this module a useful docstring
A declarative, object-oriented interface for creating statistical graphics.

The seaborn.objects namespace contains a number of classes that can be composed
together to build a customized visualization.

The main object is :class:`Plot`, which is the starting point for all figures.
Pass :class:`Plot` a dataset and specify assignments from its variables to
roles in the plot. Build up the visualization by calling its methods.

There are four other general types of objects in this interface:

- :class:`Mark` subclasses, which create matplotlib artists for visualization
- :class:`Stat` subclasses, which apply statistical transforms before plotting
- :class:`Move` subclasses, which make further adjustments to reduce overplotting

These classes are passed to :meth:`Plot.add` to define a layer in the plot.
Each layer has a :class:`Mark` and optional :class:`Stat` and/or :class:`Move`.
Plots can have multiple layers.

The other general type of object is a :class:`Scale` subclass, which provide an
interface for controlling the mappings between data values and visual properties.
Pass :class:`Scale` objects to :meth:`Plot.scale`.

See the documentation for other :class:`Plot` methods to learn about the many
ways that a plot can be enhanced and customized.

"""
from seaborn._core.plot import Plot # noqa: F401

from seaborn._marks.base import Mark # noqa: F401
from seaborn._marks.area import Area, Ribbon # noqa: F401
from seaborn._marks.bars import Bar, Bars # noqa: F401
from seaborn._marks.lines import Line, Lines, Path, Paths # noqa: F401
from seaborn._marks.lines import Line, Lines, Path, Paths, Interval # noqa: F401
from seaborn._marks.scatter import Dot, Scatter # noqa: F401

from seaborn._stats.base import Stat # noqa: F401
from seaborn._stats.aggregation import Agg # noqa: F401
from seaborn._stats.aggregation import Agg, Est # noqa: F401
from seaborn._stats.regression import OLSFit, PolyFit # noqa: F401
from seaborn._stats.histograms import Hist # noqa: F401

Expand Down
Loading