Skip to content

Commit

Permalink
reorganize discrete color maps
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Aug 4, 2015
1 parent 119c1c8 commit 2b8b80f
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 141 deletions.
146 changes: 67 additions & 79 deletions xray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _update_axes_limits(ax, xincrease, yincrease):

def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
center=None, robust=False, extend=None,
levels=None):
levels=None, filled=True, cnorm=None):
"""
Use some heuristics to set good defaults for colorbar and range.
Expand All @@ -218,13 +218,10 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
import matplotlib as mpl

calc_data = plot_data[~pd.isnull(plot_data)]
bounds_set = True
if vmin is None:
vmin = np.percentile(calc_data, 2) if robust else calc_data.min()
bounds_set = False
if vmax is None:
vmax = np.percentile(calc_data, 98) if robust else calc_data.max()
bounds_set = False

# Simple heuristics for whether these data should have a divergent map
divergent = ((vmin < 0) and (vmax > 0)) or center is not None
Expand All @@ -249,43 +246,62 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
else:
cmap = "viridis"

# Allow viridis before matplotlib 1.5
if cmap == "viridis":
cmap = _load_default_cmap()

# Handle discrete levels
if levels is not None:
if isinstance(levels, int):
ticker = mpl.ticker.MaxNLocator(levels)
levels = ticker.tick_values(vmin, vmax)
vmin, vmax = levels[0], levels[-1]

if extend is None:
extend_set = False
extend_min = calc_data.min() < vmin
extend_max = calc_data.max() > vmax
if extend_min and extend_max:
extend = 'both'
elif extend_min:
extend = 'min'
elif extend_max:
extend = 'max'
else:
extend = 'neither'
else:
extend_set = True
extend = _determine_extend(calc_data, vmin, vmax)

if levels is not None:
cmap, cnorm, extend = _determine_discrete_cmap_params(cmap, levels,
vmin, vmax,
extend,
bounds_set,
extend_set)
cmap, cnorm = _build_discrete_cmap(cmap, levels, extend, filled)

return vmin, vmax, cmap, extend, levels, cnorm


def _determine_extend(calc_data, vmin, vmax):
extend_min = calc_data.min() < vmin
extend_max = calc_data.max() > vmax
if extend_min and extend_max:
extend = 'both'
elif extend_min:
extend = 'min'
elif extend_max:
extend = 'max'
else:
cnorm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
extend = 'neither'
return extend

return vmin, vmax, cmap, extend, cnorm

def _color_palette(cmap, n_colors):
import matplotlib.pyplot as plt
try:
from seaborn.apionly import color_palette
pal = color_palette(cmap, n_colors=n_colors)
except (TypeError, ImportError):
# TypeError is raised when LinearSegmentedColormap (viridis) is used
# Import Error is raised when seaborn is not installed
# Use homegrown solution if you don't have seaborn or are using viridis
if isinstance(cmap, basestring):
cmap = plt.get_cmap(cmap)

def _determine_discrete_cmap_params(cmap, levels, vmin, vmax, extend,
bounds_set, extend_set):
colors_i = np.linspace(0, 1., n_colors)
pal = cmap(colors_i)
return pal


def _build_discrete_cmap(cmap, levels, extend, filled):
"""
Build a discrete colormap and normalization of the data.
"""
import matplotlib as mpl
import matplotlib.pyplot as plt

def extension_colors(extend):
if extend == 'both':
Expand All @@ -296,54 +312,20 @@ def extension_colors(extend):
ext_n = 0
return ext_n

if isinstance(levels, int):
vmax += 10 * np.finfo(float).eps # Add small epison to include vmax
if not bounds_set:
# if there were not user provided bounds, use MaxNLocator to pick
# a nice set of ticks
ticker = mpl.ticker.MaxNLocator(levels)
cticks = ticker.tick_values(vmin, vmax)
print(cticks)
else:
# otherwise, use the user provided vmin/vmax
cticks = np.linspace(vmin, vmax, num=levels + 1, endpoint=True)
ext_n = extension_colors(extend)
n_colors = len(cticks) + ext_n - 1
else:
try:
cticks = np.asarray(levels)
if not extend_set:
extend_min = cticks[0] > vmin
extend_max = cticks[-1] < vmax
if extend_min and extend_max:
extend = 'both'
elif extend_min:
extend = 'min'
elif extend_max:
extend = 'max'
else:
extend = 'neither'
ext_n = extension_colors(extend)
n_colors = len(levels) + ext_n - 1
except TypeError as e:
print('Unexpected type (%s) given for levels' % type(levels))
raise e
try:
from seaborn.apionly import color_palette
pal = color_palette(cmap, n_colors=n_colors)
except (TypeError, ImportError):
# TypeError is raised when LinearSegmentedColormap (viridis) is used
# Import Error is raised when seaborn is not installed
# Use homegrown solution if you don't have seaborn or are using viridis
if isinstance(cmap, basestring):
cmap = plt.get_cmap(cmap)
if not filled:
# non-filled contour plots
extend = 'neither'

colors_i = np.linspace(0, 1., n_colors)
pal = cmap(colors_i)
ext_n = extension_colors(extend)
n_colors = len(levels) + ext_n - 1
pal = _color_palette(cmap, n_colors)

cmap, cnorm = mpl.colors.from_levels_and_colors(cticks, pal, extend=extend)
new_cmap, cnorm = mpl.colors.from_levels_and_colors(
levels, pal, extend=extend)
# copy the old cmap name, for easier testing
new_cmap.name = getattr(cmap, 'name', cmap)

return cmap, cnorm, extend
return new_cmap, cnorm


# MUST run before any 2d plotting functions are defined since
Expand Down Expand Up @@ -393,7 +375,8 @@ def _plot2d(plotfunc):
vmin, vmax : floats, optional
Values to anchor the colormap, otherwise they are inferred from the
data and other keyword arguments. When a diverging dataset is inferred,
one of these values may be ignored.
one of these values may be ignored. If discrete levels are provided as
an explicit list, both of these values are ignored.
cmap : matplotlib colormap name or object, optional
The mapping from data values to color space. If not provided, this
will be either be ``viridis`` (if the function infers a sequential
Expand Down Expand Up @@ -449,20 +432,25 @@ def newplotfunc(darray, ax=None, xincrease=None, yincrease=None,

_ensure_plottable(x, y)

vmin, vmax, cmap, extend, cnorm = _determine_cmap_params(
z.data, vmin, vmax, cmap, center, robust, extend, levels)
if 'contour' in plotfunc.__name__ and levels is None:
levels = 7 # this is the matplotlib default
filled = plotfunc.__name__ != 'contour'

vmin, vmax, cmap, extend, levels, cnorm = _determine_cmap_params(
z.data, vmin, vmax, cmap, center, robust, extend, levels, filled)

if 'contour' in plotfunc.__name__:
# extend is a keyword argument only for contour and contourf, but
# passing it to the colorbar is sufficient for imshow and
# pcolormesh
kwargs['extend'] = extend
kwargs['levels'] = levels

if 'norm' not in kwargs:
# This allows the user to pass in a custom norm coming via kwargs
kwargs['norm'] = cnorm
# This allows the user to pass in a custom norm coming via kwargs
kwargs.setdefault('norm', cnorm)

ax, primitive = plotfunc(x, y, z, ax=ax, cmap=cmap, **kwargs)
ax, primitive = plotfunc(x, y, z, ax=ax, cmap=cmap, vmin=vmin,
vmax=vmax, **kwargs)

ax.set_xlabel(xlab)
ax.set_ylabel(ylab)
Expand Down
102 changes: 40 additions & 62 deletions xray/test/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import xray.plot as xplt
from xray.plot.plot import (_infer_interval_breaks,
_determine_cmap_params,
_determine_discrete_cmap_params)
_build_discrete_cmap)

from . import TestCase, requires_matplotlib

Expand Down Expand Up @@ -166,79 +166,57 @@ def test_plot_nans(self):
class TestDetermineCmapParams(TestCase):
def test_robust(self):
data = np.random.RandomState(1).rand(100)
vmin, vmax, cmap, extend, cnorm = _determine_cmap_params(data,
robust=True)
vmin, vmax, cmap, extend, levels, cnorm = _determine_cmap_params(
data, robust=True)
self.assertEqual(vmin, np.percentile(data, 2))
self.assertEqual(vmax, np.percentile(data, 98))
self.assertEqual(cmap.name, 'viridis')
self.assertEqual(extend, 'both')
self.assertIsNone(levels)
self.assertIsNone(cnorm)

def test_center(self):
data = np.random.RandomState(2).rand(100)
vmin, vmax, cmap, extend, cnorm = _determine_cmap_params(data,
center=0.5)
vmin, vmax, cmap, extend, levels, cnorm = _determine_cmap_params(
data, center=0.5)
self.assertEqual(vmax - 0.5, 0.5 - vmin)
self.assertEqual(cmap, 'RdBu_r')
self.assertEqual(extend, 'neither')
self.assertIsNone(levels)
self.assertIsNone(cnorm)


@requires_matplotlib
class TestDetermineDiscreteCmapParams(TestCase):
def test_integer_levels(self):
levels = 8
vmin = -5
vmax = 5
cmap, cnorm, extend = _determine_discrete_cmap_params(
'Spectral', levels, vmin, vmax, 'neither', True, True)
self.assertEqual(cmap.N, levels)
self.assertEqual(cnorm.N, levels + 1)
self.assertEqual(cnorm.vmin, vmin)
self.assertEqual(cnorm.vmax, vmax + 10 * np.finfo(float).eps)

cmap, cnorm, extend = _determine_discrete_cmap_params(
'Blues', levels, vmin, vmax, 'both', True, True)
# extension colors are not included here
self.assertEqual(cmap.N, levels)
self.assertEqual(cnorm.N, levels + 1)
self.assertEqual(cnorm.vmin, vmin)
self.assertEqual(cnorm.vmax, vmax + 10 * np.finfo(float).eps)

# heuristics for picking nice ticks
cmap, cnorm, extend = _determine_discrete_cmap_params(
'Spectral', levels, vmin, vmax, 'neither', False, True)
self.assertGreaterEqual(cnorm.vmax, vmax)
self.assertLessEqual(cnorm.vmin, vmin)
data = 1 + np.random.RandomState(3).rand(100)
vmin, vmax, cmap, extend, levels, cnorm = _determine_cmap_params(
data, levels=5, vmin=0, vmax=5, cmap='Blues')
self.assertEqual(vmin, levels[0])
self.assertEqual(vmax, levels[-1])
self.assertEqual(cmap.name, 'Blues')
self.assertEqual(extend, 'neither')
self.assertEqual(cmap.N, 5)
self.assertEqual(cnorm.N, 6)

vmin, vmax, cmap, extend, levels, cnorm = _determine_cmap_params(
data, levels=5, vmin=0.5, vmax=1.5)
self.assertEqual(cmap.name, 'viridis')
self.assertEqual(extend, 'max')

def test_list_levels(self):
levels = [-4, -2, 0, 2, 4]
vmin = -5
vmax = 5

cmap, cnorm, extend = _determine_discrete_cmap_params(
'Spectral', levels, vmin, vmax, 'neither', True, True)
self.assertEqual(cmap.N, len(levels) - 1)
self.assertEqual(cnorm.N, len(levels))
self.assertEqual(cnorm.vmin, min(levels))
self.assertEqual(cnorm.vmax, max(levels))

cmap, cnorm, extend = _determine_discrete_cmap_params(
'Greens_r', levels, vmin, vmax, 'both', True, True)
self.assertEqual(cmap.N, len(levels) - 1)
self.assertEqual(cnorm.N, len(levels))
self.assertEqual(cnorm.vmin, min(levels))
self.assertEqual(cnorm.vmax, max(levels))

# levels as an array
cmap, cnorm, extend = _determine_discrete_cmap_params(
'Greens_r', np.array(levels), vmin, vmax, 'both', True, True)
# levels as a DataArray
cmap, cnorm, extend = _determine_discrete_cmap_params(
'Greens_r', DataArray(levels), vmin, vmax, 'both', True, True)

# heuristics for picking extend when using list of levels
cmap, cnorm, extend = _determine_discrete_cmap_params(
'Greens_r', DataArray(levels), -5, 3, 'both', True, False)
self.assertEqual(extend, 'min')
data = 1 + np.random.RandomState(3).rand(100)

orig_levels = [0, 1, 2, 3, 4, 5]
# vmin and vmax should be ignored if levels are explicitly provided
vmin, vmax, cmap, extend, levels, cnorm = _determine_cmap_params(
data, levels=orig_levels, vmin=0, vmax=3)
self.assertEqual(vmin, 0)
self.assertEqual(vmax, 5)
self.assertEqual(cmap.N, 5)
self.assertEqual(cnorm.N, 6)

for wrap_levels in [list, np.array, pd.Index, DataArray]:
vmin, vmax, cmap, extend, levels, cnorm = _determine_cmap_params(
data, levels=wrap_levels(orig_levels))
self.assertArrayEqual(levels, orig_levels)


class Common2dMixin:
Expand Down Expand Up @@ -311,8 +289,8 @@ def test_default_cmap(self):
self.assertEqual('viridis', cmap_name)

def test_can_change_default_cmap(self):
cmap_name = self.plotmethod(cmap='jet').get_cmap().name
self.assertEqual('jet', cmap_name)
cmap_name = self.plotmethod(cmap='Blues').get_cmap().name
self.assertEqual('Blues', cmap_name)

def test_diverging_color_limits(self):
artist = self.plotmethod()
Expand Down

0 comments on commit 2b8b80f

Please sign in to comment.