Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bars for more efficient bar plots and improve Bar as well #2893

Merged
merged 9 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 152 additions & 69 deletions seaborn/_marks/bars.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass

import numpy as np
Expand All @@ -23,22 +24,52 @@
from seaborn._core.scales import Scale


@dataclass
class Bar(Mark):
"""
An interval mark drawn between baseline and data values with a width.
"""
color: MappableColor = Mappable("C0", )
alpha: MappableFloat = Mappable(.7, )
fill: MappableBool = Mappable(True, )
edgecolor: MappableColor = Mappable(depend="color", )
edgealpha: MappableFloat = Mappable(1, )
edgewidth: MappableFloat = Mappable(rc="patch.linewidth")
edgestyle: MappableStyle = Mappable("-", )
# pattern: MappableString = Mappable(None, ) # TODO no Property yet
class BarBase(Mark):

width: MappableFloat = Mappable(.8, grouping=False)
baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable?
def _make_patches(self, data, scales, orient):

kws = self._resolve_properties(data, scales)
if orient == "x":
kws["x"] = (data["x"] - data["width"] / 2).to_numpy()
kws["y"] = data["baseline"].to_numpy()
kws["w"] = data["width"].to_numpy()
kws["h"] = (data["y"] - data["baseline"]).to_numpy()
else:
kws["x"] = data["baseline"].to_numpy()
kws["y"] = (data["y"] - data["width"] / 2).to_numpy()
kws["w"] = (data["x"] - data["baseline"]).to_numpy()
kws["h"] = data["width"].to_numpy()

kws.pop("width", None)
kws.pop("baseline", None)

val_dim = {"x": "h", "y": "w"}[orient]
bars, vals = [], []

for i in range(len(data)):

row = {k: v[i] for k, v in kws.items()}

# Skip bars with no value. It's possible we'll want to make this
# an option (i.e so you have an artist for animating or annotating),
# but let's keep things simple for now.
if not np.nan_to_num(row[val_dim]):
continue

bar = mpl.patches.Rectangle(
xy=(row["x"], row["y"]),
width=row["w"],
height=row["h"],
facecolor=row["facecolor"],
edgecolor=row["edgecolor"],
linestyle=row["edgestyle"],
linewidth=row["edgewidth"],
**self.artist_kws,
)
bars.append(bar)
vals.append(row[val_dim])

return bars, vals

def _resolve_properties(self, data, scales):

Expand All @@ -56,58 +87,57 @@ def _resolve_properties(self, data, scales):

return resolved

def _plot(self, split_gen, scales, orient):
def _legend_artist(
self, variables: list[str], value: Any, scales: dict[str, Scale],
) -> Artist:
# TODO return some sensible default?
key = {v: value for v in variables}
key = self._resolve_properties(key, scales)
artist = mpl.patches.Patch(
facecolor=key["facecolor"],
edgecolor=key["edgecolor"],
linewidth=key["edgewidth"],
linestyle=key["edgestyle"],
)
return artist

def coords_to_geometry(x, y, w, b):
# TODO possible too slow with lots of bars (e.g. dense hist)
# Why not just use BarCollection?
if orient == "x":
w, h = w, y - b
xy = x - w / 2, b
else:
w, h = x - b, w
xy = b, y - h / 2
return xy, w, h

val_idx = ["y", "x"].index(orient)
@dataclass
class Bar(BarBase):
"""
An rectangular mark drawn between baseline and data values.
"""
color: MappableColor = Mappable("C0", grouping=False)
alpha: MappableFloat = Mappable(.7, grouping=False)
fill: MappableBool = Mappable(True, grouping=False)
edgecolor: MappableColor = Mappable(depend="color", grouping=False)
edgealpha: MappableFloat = Mappable(1, grouping=False)
edgewidth: MappableFloat = Mappable(rc="patch.linewidth", grouping=False)
edgestyle: MappableStyle = Mappable("-", grouping=False)
# pattern: MappableString = Mappable(None) # TODO no Property yet

for _, data, ax in split_gen():
width: MappableFloat = Mappable(.8, grouping=False)
baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable?

xys = data[["x", "y"]].to_numpy()
data = self._resolve_properties(data, scales)
def _plot(self, split_gen, scales, orient):

bars, vals = [], []
for i, (x, y) in enumerate(xys):
val_idx = ["y", "x"].index(orient)

baseline = data["baseline"][i]
width = data["width"][i]
xy, w, h = coords_to_geometry(x, y, width, baseline)
for _, data, ax in split_gen():

bars, vals = self._make_patches(data, scales, orient)

# Skip bars with no value. It's possible we'll want to make this
# an option (i.e so you have an artist for animating or annotating),
# but let's keep things simple for now.
if not np.nan_to_num(h):
continue
for bar in bars:

# TODO Because we are clipping the artist (see below), the edges end up
# Because we are clipping the artist (see below), the edges end up
# looking half as wide as they actually are. I don't love this clumsy
# workaround, which is going to cause surprises if you work with the
# artists directly. We may need to revisit after feedback.
linewidth = data["edgewidth"][i] * 2
linestyle = data["edgestyle"][i]
bar.set_linewidth(bar.get_linewidth() * 2)
linestyle = bar.get_linestyle()
if linestyle[1]:
linestyle = (linestyle[0], tuple(x / 2 for x in linestyle[1]))

bar = mpl.patches.Rectangle(
xy=xy,
width=w,
height=h,
facecolor=data["facecolor"][i],
edgecolor=data["edgecolor"][i],
linestyle=linestyle,
linewidth=linewidth,
**self.artist_kws,
)
bar.set_linestyle(linestyle)

# This is a bit of a hack to handle the fact that the edge lines are
# centered on the actual extents of the bar, and overlap when bars are
Expand All @@ -121,8 +151,6 @@ def coords_to_geometry(x, y, w, b):
bar.set_clip_box(ax.bbox)
bar.sticky_edges[val_idx][:] = (0, np.inf)
ax.add_patch(bar)
bars.append(bar)
vals.append(h)

# Add a container which is useful for, e.g. Axes.bar_label
if Version(mpl.__version__) >= Version("3.4.0"):
Expand All @@ -133,16 +161,71 @@ def coords_to_geometry(x, y, w, b):
container = mpl.container.BarContainer(bars, **container_kws)
ax.add_container(container)

def _legend_artist(
self, variables: list[str], value: Any, scales: dict[str, Scale],
) -> Artist:
# TODO return some sensible default?
key = {v: value for v in variables}
key = self._resolve_properties(key, scales)
artist = mpl.patches.Patch(
facecolor=key["facecolor"],
edgecolor=key["edgecolor"],
linewidth=key["edgewidth"],
linestyle=key["edgestyle"],
)
return artist

@dataclass
class Bars(BarBase):
"""
A faster Bar mark with defaults that are more suitable for histograms.
"""
color: MappableColor = Mappable("C0", grouping=False)
alpha: MappableFloat = Mappable(.7, grouping=False)
fill: MappableBool = Mappable(True, grouping=False)
edgecolor: MappableColor = Mappable(rc="patch.edgecolor", grouping=False)
edgealpha: MappableFloat = Mappable(1, grouping=False)
edgewidth: MappableFloat = Mappable(auto=True, grouping=False)
edgestyle: MappableStyle = Mappable("-", grouping=False)
# pattern: MappableString = Mappable(None) # TODO no Property yet

width: MappableFloat = Mappable(1, grouping=False)
baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable?

def _plot(self, split_gen, scales, orient):

ori_idx = ["x", "y"].index(orient)
val_idx = ["y", "x"].index(orient)

patches = defaultdict(list)
for _, data, ax in split_gen():
bars, _ = self._make_patches(data, scales, orient)
patches[ax].extend(bars)

collections = {}
for ax, ax_patches in patches.items():

col = mpl.collections.PatchCollection(ax_patches, match_original=True)
col.sticky_edges[val_idx][:] = (0, np.inf)
ax.add_collection(col, autolim=False)
collections[ax] = col

# Workaround for matplotlib autoscaling bug
# https://github.com/matplotlib/matplotlib/issues/11898
# https://github.com/matplotlib/matplotlib/issues/23129
xy = np.vstack([path.vertices for path in col.get_paths()])
ax.dataLim.update_from_data_xy(
xy, ax.ignore_existing_data_limits, updatex=True, updatey=True
)

if "edgewidth" not in scales and isinstance(self.edgewidth, Mappable):

for ax in collections:
ax.autoscale_view()

def get_dimensions(collection):
edges, widths = [], []
for verts in (path.vertices for path in collection.get_paths()):
edges.append(min(verts[:, ori_idx]))
widths.append(np.ptp(verts[:, ori_idx]))
return np.array(edges), np.array(widths)

min_width = np.inf
for ax, col in collections.items():
edges, widths = get_dimensions(col)
points = 72 / ax.figure.dpi * abs(
ax.transData.transform([edges + widths] * 2)
- ax.transData.transform([edges] * 2)
)
min_width = min(min_width, min(points[:, ori_idx]))

linewidth = min(.1 * min_width, mpl.rcParams["patch.linewidth"])
for _, col in collections.items():
col.set_linewidth(linewidth)
6 changes: 6 additions & 0 deletions seaborn/_marks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
val: Any = None,
depend: str | None = None,
rc: str | None = None,
auto: bool = False,
grouping: bool = True,
):
"""
Expand All @@ -40,6 +41,8 @@ def __init__(
Use the value of this feature as the default.
rc : str
Use the value of this rcParam as the default.
auto : bool
The default value will depend on other parameters at compile time.
grouping : bool
If True, use the mapped variable to define groups.

Expand All @@ -52,6 +55,7 @@ def __init__(
self._val = val
self._rc = rc
self._depend = depend
self._auto = auto
self._grouping = grouping

def __repr__(self):
Expand All @@ -62,6 +66,8 @@ def __repr__(self):
s = f"<depend:{self._depend}>"
elif self._rc is not None:
s = f"<rc:{self._rc}>"
elif self._auto:
s = "<auto>"
else:
s = "<undefined>"
return s
Expand Down
2 changes: 1 addition & 1 deletion seaborn/_marks/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _plot(self, split_gen, scales, orient):
# (That should be solved upstream by defaulting to "" for unset x/y?)
# (Be mindful of xmin/xmax, etc!)

for keys, data, ax in split_gen():
for _, data, ax in split_gen():

offsets = np.column_stack([data["x"], data["y"]])
data = self._resolve_properties(data, scales)
Expand Down
2 changes: 1 addition & 1 deletion seaborn/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from seaborn._marks.base import Mark # noqa: F401
from seaborn._marks.area import Area, Ribbon # noqa: F401
from seaborn._marks.bars import Bar # noqa: F401
from seaborn._marks.bars import Bar, Bars # noqa: F401
from seaborn._marks.lines import Line, Lines, Path, Paths # noqa: F401
from seaborn._marks.scatter import Dot, Scatter # noqa: F401

Expand Down
Loading