Skip to content

Commit

Permalink
Generalize getting default color from color cycle (#2449)
Browse files Browse the repository at this point in the history
* Clean up auto-gray code

* Add initial versiono of default_color function

* Reenable swarmplot mandaatory kwarg warning deprecator

* Expand flexibility of supported scatter coloring kwarg

* Change logic of scatter-based coloring to post-proccess mpl object

* Test default and specified colors

* Rework how shared strip/swarm tests work

* Test user-supplied color array with swarm/strip plot

* Address failures on pinned tests

* Use centralized default_color function in scatterplot

* Use default_color function in lineplot

* Use general default color function for kdeplot, rugplot, and ecdfplot

* Use default_color function in histplot

* Get default color after attaching axes

* Mark kdeplot datetime autoscale test as xfail due to matplotlib bug

* Fix logic of color tests

* Workaround empty fill_between datetime autoscale bug

Fixes #2133 on newer matplotlibs

* More backcompat and edge casing

* Add color tests for distribution module

* Fix bar hist legend artists

* Mute color with hue warning for now
  • Loading branch information
mwaskom authored Jan 30, 2021
1 parent 15db683 commit 10aa7a8
Show file tree
Hide file tree
Showing 8 changed files with 558 additions and 314 deletions.
3 changes: 3 additions & 0 deletions doc/releases/v0.12.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ v0.12.0 (Unreleased)

- In :func:`swarmplot`, the proportion of points that must overlap before issuing a warning can now be controlled with the `warn_thresh` parameter (:pr:`2447`).

- |Fix| In :func:`lineplot, allowed the `dashes` keyword to set the style of a line without mapping a `style` variable (:pr:`2449`).

- |Fix| |Enhancement| Improved integration with the matplotlib color cycle in most axes-level functions (:pr:`2449`).

- Made `scipy` an optional dependency and added `pip install seaborn[all]` as a method for ensuring the availability of compatible `scipy` and `statsmodels` libraries at install time. This has a few minor implications for existing code, which are explained in the Github pull request (:pr:`2398`).

Expand Down
120 changes: 51 additions & 69 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from textwrap import dedent
from numbers import Number
import warnings
import colorsys
from colorsys import rgb_to_hls
from functools import partial

import numpy as np
Expand All @@ -25,7 +25,7 @@
categorical_order,
)
from . import utils
from .utils import remove_na, _normal_quantile_func, _draw_figure
from .utils import remove_na, _normal_quantile_func, _draw_figure, _default_color
from .algorithms import bootstrap
from .palettes import color_palette, husl_palette, light_palette, dark_palette
from .axisgrid import FacetGrid, _facet_docs
Expand Down Expand Up @@ -144,17 +144,14 @@ def _hue_backcompat(self, color, palette, hue_order, force_hue=False):
def cat_axis(self):
return {"v": "x", "h": "y"}[self.orient]

def _get_gray(self, color="C0"):
def _get_gray(self, colors):
"""Get a grayscale value that looks good with color."""
if "hue" in self.variables:
rgb_colors = list(self._hue_map.lookup_table.values())
else:
rgb_colors = [mpl.colors.to_rgb(color)]

light_vals = [colorsys.rgb_to_hls(*mpl.colors.to_rgb(c))[1] for c in rgb_colors]
if not len(colors):
return None
unique_colors = np.unique(colors, axis=0)
light_vals = [rgb_to_hls(*rgb[:3])[1] for rgb in unique_colors]
lum = min(light_vals) * .6
gray = mpl.colors.rgb2hex((lum, lum, lum))
return gray
return (lum, lum, lum)

def _adjust_cat_axis(self, ax, axis):
"""Set ticks and limits for a categorical variable."""
Expand All @@ -180,13 +177,15 @@ def _adjust_cat_axis(self, ax, axis):
else:
order = categorical_order(data)

n = max(len(order), 1)

if axis == "x":
ax.xaxis.grid(False)
ax.set_xlim(-.5, len(order) - .5, auto=None)
ax.set_xlim(-.5, n - .5, auto=None)
else:
ax.yaxis.grid(False)
# Note limits that correspond to previously-inverted y axis
ax.set_ylim(len(order) - .5, -.5, auto=None)
ax.set_ylim(n - .5, -.5, auto=None)

@property
def _native_width(self):
Expand Down Expand Up @@ -221,10 +220,10 @@ def plot_strips(
jitter,
dodge,
color,
edgecolor,
plot_kws,
):

default_color = "C0" if color is None else color
width = .8 * self._native_width
offsets = self._nested_offsets(width, dodge)

Expand Down Expand Up @@ -256,17 +255,20 @@ def plot_strips(
adjusted_data = sub_data[self.cat_axis] + dodge_move + jitter_move
sub_data.loc[:, self.cat_axis] = adjusted_data

if "hue" in self.variables:
c = self._hue_map(sub_data["hue"])
else:
c = mpl.colors.to_hex(default_color)

for var in "xy":
if self._log_scaled(var):
sub_data[var] = np.power(10, sub_data[var])

ax = self._get_axes(sub_vars)
ax.scatter(sub_data["x"], sub_data["y"], c=c, **plot_kws)
points = ax.scatter(sub_data["x"], sub_data["y"], color=color, **plot_kws)

if "hue" in self.variables:
points.set_facecolors(self._hue_map(sub_data["hue"]))

if edgecolor == "gray": # XXX TODO change to "auto"
points.set_edgecolors(self._get_gray(points.get_facecolors()))
else:
points.set_edgecolors(edgecolor)

# TODO XXX fully impelement legend
show_legend = not self._redundant_hue and self.input_format != "wide"
Expand All @@ -280,11 +282,11 @@ def plot_swarms(
self,
dodge,
color,
edgecolor,
warn_thresh,
plot_kws,
):

default_color = "C0" if color is None else color
width = .8 * self._native_width
offsets = self._nested_offsets(width, dodge)

Expand All @@ -293,8 +295,7 @@ def plot_swarms(
iter_vars.append("hue")

ax = self.ax
centers = []
swarms = []
point_collections = {}
dodge_move = 0

for sub_vars, sub_data in self.iter_data(iter_vars,
Expand All @@ -307,27 +308,29 @@ def plot_swarms(
if not sub_data.empty:
sub_data.loc[:, self.cat_axis] = sub_data[self.cat_axis] + dodge_move

if "hue" in self.variables:
c = self._hue_map(sub_data["hue"])
else:
c = mpl.colors.to_hex(default_color)

for var in "xy":
if self._log_scaled(var):
sub_data[var] = np.power(10, sub_data[var])

ax = self._get_axes(sub_vars)
swarm = ax.scatter(sub_data["x"], sub_data["y"], c=c, **plot_kws)
points = ax.scatter(sub_data["x"], sub_data["y"], color=color, **plot_kws)

if "hue" in self.variables:
points.set_facecolors(self._hue_map(sub_data["hue"]))

if edgecolor == "gray": # XXX TODO change to "auto"
points.set_edgecolors(self._get_gray(points.get_facecolors()))
else:
points.set_edgecolors(edgecolor)

if not sub_data.empty:
centers.append(sub_data[self.cat_axis].iloc[0])
swarms.append(swarm)
point_collections[sub_data[self.cat_axis].iloc[0]] = points

beeswarm = Beeswarm(
width=width, orient=self.orient, warn_thresh=warn_thresh,
)
for center, swarm in zip(centers, swarms):
if swarm.get_offsets().shape[0] > 1:
for center, points in point_collections.items():
if points.get_offsets().shape[0] > 1:

def draw(points, renderer, *, center=center):

Expand All @@ -353,7 +356,7 @@ def draw(points, renderer, *, center=center):

super(points.__class__, points).draw(renderer)

swarm.draw = draw.__get__(swarm)
points.draw = draw.__get__(points)

_draw_figure(ax.figure)

Expand Down Expand Up @@ -658,7 +661,7 @@ def establish_colors(self, color, palette, saturation):
rgb_colors = color_palette(colors)

# Determine the gray color to use for the lines framing the plot
light_vals = [colorsys.rgb_to_hls(*c)[1] for c in rgb_colors]
light_vals = [rgb_to_hls(*c)[1] for c in rgb_colors]
lum = min(light_vals) * .6
gray = mpl.colors.rgb2hex((lum, lum, lum))

Expand Down Expand Up @@ -2775,25 +2778,16 @@ def stripplot(

p._attach(ax)

if not p.has_xy_data:
return ax

palette, hue_order = p._hue_backcompat(color, palette, hue_order)

color = _default_color(ax.scatter, hue, color, kwargs)

p.map_hue(palette=palette, order=hue_order, norm=hue_norm)

# XXX Copying possibly bad default decisions from original code for now
kwargs.setdefault("zorder", 3)
size = kwargs.get("s", size)

# XXX Here especially is tricky. Old code didn't follow the color cycle.
# If new code does, then we won't know the default non-mapped color out here.
# But also I think in general that logic should move to the outer functions.
# XXX Wait how does this work with a custom palette?
# XXX Regardless of implementation, I think we should change this default
# name to "auto" or something similar that doesn't overlap with a real color name
if edgecolor == "gray":
edgecolor = p._get_gray("C0" if color is None else color)

kwargs.update(dict(
s=size ** 2,
edgecolor=edgecolor,
Expand All @@ -2804,6 +2798,7 @@ def stripplot(
jitter=jitter,
dodge=dodge,
color=color,
edgecolor=edgecolor,
plot_kws=kwargs,
)

Expand Down Expand Up @@ -2877,7 +2872,7 @@ def stripplot(
""").format(**_categorical_docs)


# @_deprecate_positional_args
@_deprecate_positional_args
def swarmplot(
*,
x=None, y=None,
Expand Down Expand Up @@ -2910,6 +2905,9 @@ def swarmplot(
return ax

palette, hue_order = p._hue_backcompat(color, palette, hue_order)

color = _default_color(ax.scatter, hue, color, kwargs)

p.map_hue(palette=palette, order=hue_order, norm=hue_norm)

# XXX Copying possibly bad default decisions from original code for now
Expand All @@ -2919,24 +2917,15 @@ def swarmplot(
if linewidth is None:
linewidth = size / 10

# XXX Here especially is tricky. Old code didn't follow the color cycle.
# If new code does, then we won't know the default non-mapped color out here.
# But also I think in general that logic should move to the outer functions.
# XXX Wait how does this work with a custom palette?
# XXX Regardless of implementation, I think we should change this default
# name to "auto" or something similar that doesn't overlap with a real color name
if edgecolor == "gray":
edgecolor = p._get_gray("C0" if color is None else color)

kwargs.update(dict(
s=size ** 2,
edgecolor=edgecolor,
linewidth=linewidth,
))

p.plot_swarms(
dodge=dodge,
color=color,
edgecolor=edgecolor,
warn_thresh=warn_thresh,
plot_kws=kwargs,
)
Expand Down Expand Up @@ -3669,32 +3658,28 @@ def catplot(
# TODO get these defaults programatically?
jitter = kwargs.pop("jitter", True)
dodge = kwargs.pop("dodge", False)
edgecolor = kwargs.pop("edgecolor", "gray")
edgecolor = kwargs.pop("edgecolor", "gray") # XXX TODO default

plot_kws = kwargs.copy()

# XXX Copying possibly bad default decisions from original code for now
plot_kws.setdefault("zorder", 3)
plot_kws.setdefault("s", 25)

if edgecolor == "gray":
edgecolor = p._get_gray("C0" if color is None else color)
plot_kws["edgecolor"] = edgecolor

plot_kws.setdefault("linewidth", 0)

p.plot_strips(
jitter=jitter,
dodge=dodge,
color=color,
edgecolor=edgecolor,
plot_kws=plot_kws,
)

elif kind == "swarm":

# TODO get these defaults programatically?
dodge = kwargs.pop("dodge", False)
edgecolor = kwargs.pop("edgecolor", "gray")
edgecolor = kwargs.pop("edgecolor", "gray") # XXX TODO default
warn_thresh = kwargs.pop("warn_thresh", .05)

plot_kws = kwargs.copy()
Expand All @@ -3703,16 +3688,13 @@ def catplot(
plot_kws.setdefault("zorder", 3)
plot_kws.setdefault("s", 25)

if edgecolor == "gray":
edgecolor = p._get_gray("C0" if color is None else color)
plot_kws["edgecolor"] = edgecolor

if plot_kws.setdefault("linewidth", 0) is None:
plot_kws["linewidth"] = np.sqrt(plot_kws["s"]) / 10

p.plot_swarms(
dodge=dodge,
color=color,
edgecolor=edgecolor,
warn_thresh=warn_thresh,
plot_kws=plot_kws,
)
Expand Down
Loading

0 comments on commit 10aa7a8

Please sign in to comment.