Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize getting default color from color cycle #2449

Merged
merged 21 commits into from
Jan 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c2557c4
Clean up auto-gray code
mwaskom Jan 24, 2021
c21f737
Add initial versiono of default_color function
mwaskom Jan 24, 2021
8040425
Reenable swarmplot mandaatory kwarg warning deprecator
mwaskom Jan 24, 2021
3c5529c
Expand flexibility of supported scatter coloring kwarg
mwaskom Jan 25, 2021
33df6e3
Change logic of scatter-based coloring to post-proccess mpl object
mwaskom Jan 25, 2021
2b847e7
Test default and specified colors
mwaskom Jan 26, 2021
1490b23
Rework how shared strip/swarm tests work
mwaskom Jan 26, 2021
7e39045
Test user-supplied color array with swarm/strip plot
mwaskom Jan 26, 2021
9cfaf08
Address failures on pinned tests
mwaskom Jan 26, 2021
44f9c76
Use centralized default_color function in scatterplot
mwaskom Jan 26, 2021
bcfb979
Use default_color function in lineplot
mwaskom Jan 26, 2021
6ad20bf
Use general default color function for kdeplot, rugplot, and ecdfplot
mwaskom Jan 27, 2021
1eb0821
Use default_color function in histplot
mwaskom Jan 27, 2021
edbf06e
Get default color after attaching axes
mwaskom Jan 27, 2021
9160b5c
Mark kdeplot datetime autoscale test as xfail due to matplotlib bug
mwaskom Jan 27, 2021
877ad53
Fix logic of color tests
mwaskom Jan 27, 2021
4a30457
Workaround empty fill_between datetime autoscale bug
mwaskom Jan 28, 2021
950cf91
More backcompat and edge casing
mwaskom Jan 28, 2021
1028e86
Add color tests for distribution module
mwaskom Jan 28, 2021
35ae624
Fix bar hist legend artists
mwaskom Jan 28, 2021
6c88f98
Mute color with hue warning for now
mwaskom Jan 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/v0.12.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ v0.12.0 (Unreleased)

- In :func:`swarmplot`, the proportion of points that must overlap before issuing a warning can now be controlled with the `warn_thresh` parameter (:pr:`2447`).

- |Fix| In :func:`lineplot, allowed the `dashes` keyword to set the style of a line without mapping a `style` variable (:pr:`2449`).

- |Fix| |Enhancement| Improved integration with the matplotlib color cycle in most axes-level functions (:pr:`2449`).

- Made `scipy` an optional dependency and added `pip install seaborn[all]` as a method for ensuring the availability of compatible `scipy` and `statsmodels` libraries at install time. This has a few minor implications for existing code, which are explained in the Github pull request (:pr:`2398`).

Expand Down
120 changes: 51 additions & 69 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from textwrap import dedent
from numbers import Number
import warnings
import colorsys
from colorsys import rgb_to_hls
from functools import partial

import numpy as np
Expand All @@ -25,7 +25,7 @@
categorical_order,
)
from . import utils
from .utils import remove_na, _normal_quantile_func, _draw_figure
from .utils import remove_na, _normal_quantile_func, _draw_figure, _default_color
from .algorithms import bootstrap
from .palettes import color_palette, husl_palette, light_palette, dark_palette
from .axisgrid import FacetGrid, _facet_docs
Expand Down Expand Up @@ -144,17 +144,14 @@ def _hue_backcompat(self, color, palette, hue_order, force_hue=False):
def cat_axis(self):
return {"v": "x", "h": "y"}[self.orient]

def _get_gray(self, color="C0"):
def _get_gray(self, colors):
"""Get a grayscale value that looks good with color."""
if "hue" in self.variables:
rgb_colors = list(self._hue_map.lookup_table.values())
else:
rgb_colors = [mpl.colors.to_rgb(color)]

light_vals = [colorsys.rgb_to_hls(*mpl.colors.to_rgb(c))[1] for c in rgb_colors]
if not len(colors):
return None
unique_colors = np.unique(colors, axis=0)
light_vals = [rgb_to_hls(*rgb[:3])[1] for rgb in unique_colors]
lum = min(light_vals) * .6
gray = mpl.colors.rgb2hex((lum, lum, lum))
return gray
return (lum, lum, lum)

def _adjust_cat_axis(self, ax, axis):
"""Set ticks and limits for a categorical variable."""
Expand All @@ -180,13 +177,15 @@ def _adjust_cat_axis(self, ax, axis):
else:
order = categorical_order(data)

n = max(len(order), 1)

if axis == "x":
ax.xaxis.grid(False)
ax.set_xlim(-.5, len(order) - .5, auto=None)
ax.set_xlim(-.5, n - .5, auto=None)
else:
ax.yaxis.grid(False)
# Note limits that correspond to previously-inverted y axis
ax.set_ylim(len(order) - .5, -.5, auto=None)
ax.set_ylim(n - .5, -.5, auto=None)

@property
def _native_width(self):
Expand Down Expand Up @@ -221,10 +220,10 @@ def plot_strips(
jitter,
dodge,
color,
edgecolor,
plot_kws,
):

default_color = "C0" if color is None else color
width = .8 * self._native_width
offsets = self._nested_offsets(width, dodge)

Expand Down Expand Up @@ -256,17 +255,20 @@ def plot_strips(
adjusted_data = sub_data[self.cat_axis] + dodge_move + jitter_move
sub_data.loc[:, self.cat_axis] = adjusted_data

if "hue" in self.variables:
c = self._hue_map(sub_data["hue"])
else:
c = mpl.colors.to_hex(default_color)

for var in "xy":
if self._log_scaled(var):
sub_data[var] = np.power(10, sub_data[var])

ax = self._get_axes(sub_vars)
ax.scatter(sub_data["x"], sub_data["y"], c=c, **plot_kws)
points = ax.scatter(sub_data["x"], sub_data["y"], color=color, **plot_kws)

if "hue" in self.variables:
points.set_facecolors(self._hue_map(sub_data["hue"]))

if edgecolor == "gray": # XXX TODO change to "auto"
points.set_edgecolors(self._get_gray(points.get_facecolors()))
else:
points.set_edgecolors(edgecolor)

# TODO XXX fully impelement legend
show_legend = not self._redundant_hue and self.input_format != "wide"
Expand All @@ -280,11 +282,11 @@ def plot_swarms(
self,
dodge,
color,
edgecolor,
warn_thresh,
plot_kws,
):

default_color = "C0" if color is None else color
width = .8 * self._native_width
offsets = self._nested_offsets(width, dodge)

Expand All @@ -293,8 +295,7 @@ def plot_swarms(
iter_vars.append("hue")

ax = self.ax
centers = []
swarms = []
point_collections = {}
dodge_move = 0

for sub_vars, sub_data in self.iter_data(iter_vars,
Expand All @@ -307,27 +308,29 @@ def plot_swarms(
if not sub_data.empty:
sub_data.loc[:, self.cat_axis] = sub_data[self.cat_axis] + dodge_move

if "hue" in self.variables:
c = self._hue_map(sub_data["hue"])
else:
c = mpl.colors.to_hex(default_color)

for var in "xy":
if self._log_scaled(var):
sub_data[var] = np.power(10, sub_data[var])

ax = self._get_axes(sub_vars)
swarm = ax.scatter(sub_data["x"], sub_data["y"], c=c, **plot_kws)
points = ax.scatter(sub_data["x"], sub_data["y"], color=color, **plot_kws)

if "hue" in self.variables:
points.set_facecolors(self._hue_map(sub_data["hue"]))

if edgecolor == "gray": # XXX TODO change to "auto"
points.set_edgecolors(self._get_gray(points.get_facecolors()))
else:
points.set_edgecolors(edgecolor)

if not sub_data.empty:
centers.append(sub_data[self.cat_axis].iloc[0])
swarms.append(swarm)
point_collections[sub_data[self.cat_axis].iloc[0]] = points

beeswarm = Beeswarm(
width=width, orient=self.orient, warn_thresh=warn_thresh,
)
for center, swarm in zip(centers, swarms):
if swarm.get_offsets().shape[0] > 1:
for center, points in point_collections.items():
if points.get_offsets().shape[0] > 1:

def draw(points, renderer, *, center=center):

Expand All @@ -353,7 +356,7 @@ def draw(points, renderer, *, center=center):

super(points.__class__, points).draw(renderer)

swarm.draw = draw.__get__(swarm)
points.draw = draw.__get__(points)

_draw_figure(ax.figure)

Expand Down Expand Up @@ -658,7 +661,7 @@ def establish_colors(self, color, palette, saturation):
rgb_colors = color_palette(colors)

# Determine the gray color to use for the lines framing the plot
light_vals = [colorsys.rgb_to_hls(*c)[1] for c in rgb_colors]
light_vals = [rgb_to_hls(*c)[1] for c in rgb_colors]
lum = min(light_vals) * .6
gray = mpl.colors.rgb2hex((lum, lum, lum))

Expand Down Expand Up @@ -2775,25 +2778,16 @@ def stripplot(

p._attach(ax)

if not p.has_xy_data:
return ax

palette, hue_order = p._hue_backcompat(color, palette, hue_order)

color = _default_color(ax.scatter, hue, color, kwargs)

p.map_hue(palette=palette, order=hue_order, norm=hue_norm)

# XXX Copying possibly bad default decisions from original code for now
kwargs.setdefault("zorder", 3)
size = kwargs.get("s", size)

# XXX Here especially is tricky. Old code didn't follow the color cycle.
# If new code does, then we won't know the default non-mapped color out here.
# But also I think in general that logic should move to the outer functions.
# XXX Wait how does this work with a custom palette?
# XXX Regardless of implementation, I think we should change this default
# name to "auto" or something similar that doesn't overlap with a real color name
if edgecolor == "gray":
edgecolor = p._get_gray("C0" if color is None else color)

kwargs.update(dict(
s=size ** 2,
edgecolor=edgecolor,
Expand All @@ -2804,6 +2798,7 @@ def stripplot(
jitter=jitter,
dodge=dodge,
color=color,
edgecolor=edgecolor,
plot_kws=kwargs,
)

Expand Down Expand Up @@ -2877,7 +2872,7 @@ def stripplot(
""").format(**_categorical_docs)


# @_deprecate_positional_args
@_deprecate_positional_args
def swarmplot(
*,
x=None, y=None,
Expand Down Expand Up @@ -2910,6 +2905,9 @@ def swarmplot(
return ax

palette, hue_order = p._hue_backcompat(color, palette, hue_order)

color = _default_color(ax.scatter, hue, color, kwargs)

p.map_hue(palette=palette, order=hue_order, norm=hue_norm)

# XXX Copying possibly bad default decisions from original code for now
Expand All @@ -2919,24 +2917,15 @@ def swarmplot(
if linewidth is None:
linewidth = size / 10

# XXX Here especially is tricky. Old code didn't follow the color cycle.
# If new code does, then we won't know the default non-mapped color out here.
# But also I think in general that logic should move to the outer functions.
# XXX Wait how does this work with a custom palette?
# XXX Regardless of implementation, I think we should change this default
# name to "auto" or something similar that doesn't overlap with a real color name
if edgecolor == "gray":
edgecolor = p._get_gray("C0" if color is None else color)

kwargs.update(dict(
s=size ** 2,
edgecolor=edgecolor,
linewidth=linewidth,
))

p.plot_swarms(
dodge=dodge,
color=color,
edgecolor=edgecolor,
warn_thresh=warn_thresh,
plot_kws=kwargs,
)
Expand Down Expand Up @@ -3669,32 +3658,28 @@ def catplot(
# TODO get these defaults programatically?
jitter = kwargs.pop("jitter", True)
dodge = kwargs.pop("dodge", False)
edgecolor = kwargs.pop("edgecolor", "gray")
edgecolor = kwargs.pop("edgecolor", "gray") # XXX TODO default

plot_kws = kwargs.copy()

# XXX Copying possibly bad default decisions from original code for now
plot_kws.setdefault("zorder", 3)
plot_kws.setdefault("s", 25)

if edgecolor == "gray":
edgecolor = p._get_gray("C0" if color is None else color)
plot_kws["edgecolor"] = edgecolor

plot_kws.setdefault("linewidth", 0)

p.plot_strips(
jitter=jitter,
dodge=dodge,
color=color,
edgecolor=edgecolor,
plot_kws=plot_kws,
)

elif kind == "swarm":

# TODO get these defaults programatically?
dodge = kwargs.pop("dodge", False)
edgecolor = kwargs.pop("edgecolor", "gray")
edgecolor = kwargs.pop("edgecolor", "gray") # XXX TODO default
warn_thresh = kwargs.pop("warn_thresh", .05)

plot_kws = kwargs.copy()
Expand All @@ -3703,16 +3688,13 @@ def catplot(
plot_kws.setdefault("zorder", 3)
plot_kws.setdefault("s", 25)

if edgecolor == "gray":
edgecolor = p._get_gray("C0" if color is None else color)
plot_kws["edgecolor"] = edgecolor

if plot_kws.setdefault("linewidth", 0) is None:
plot_kws["linewidth"] = np.sqrt(plot_kws["s"]) / 10

p.plot_swarms(
dodge=dodge,
color=color,
edgecolor=edgecolor,
warn_thresh=warn_thresh,
plot_kws=plot_kws,
)
Expand Down
Loading