From 2b8b80f4edcede4131fb2cb48c1667f78906378b Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 4 Aug 2015 00:20:13 -0700 Subject: [PATCH] reorganize discrete color maps --- xray/plot/plot.py | 146 +++++++++++++++++++---------------------- xray/test/test_plot.py | 102 +++++++++++----------------- 2 files changed, 107 insertions(+), 141 deletions(-) diff --git a/xray/plot/plot.py b/xray/plot/plot.py index 2df0e711170..982779e06a2 100644 --- a/xray/plot/plot.py +++ b/xray/plot/plot.py @@ -208,7 +208,7 @@ def _update_axes_limits(ax, xincrease, yincrease): def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, center=None, robust=False, extend=None, - levels=None): + levels=None, filled=True, cnorm=None): """ Use some heuristics to set good defaults for colorbar and range. @@ -218,13 +218,10 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, import matplotlib as mpl calc_data = plot_data[~pd.isnull(plot_data)] - bounds_set = True if vmin is None: vmin = np.percentile(calc_data, 2) if robust else calc_data.min() - bounds_set = False if vmax is None: vmax = np.percentile(calc_data, 98) if robust else calc_data.max() - bounds_set = False # Simple heuristics for whether these data should have a divergent map divergent = ((vmin < 0) and (vmax > 0)) or center is not None @@ -249,43 +246,62 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, else: cmap = "viridis" + # Allow viridis before matplotlib 1.5 if cmap == "viridis": cmap = _load_default_cmap() + # Handle discrete levels + if levels is not None: + if isinstance(levels, int): + ticker = mpl.ticker.MaxNLocator(levels) + levels = ticker.tick_values(vmin, vmax) + vmin, vmax = levels[0], levels[-1] + if extend is None: - extend_set = False - extend_min = calc_data.min() < vmin - extend_max = calc_data.max() > vmax - if extend_min and extend_max: - extend = 'both' - elif extend_min: - extend = 'min' - elif extend_max: - extend = 'max' - else: - extend = 'neither' - else: - extend_set = True + extend = _determine_extend(calc_data, vmin, vmax) if levels is not None: - cmap, cnorm, extend = _determine_discrete_cmap_params(cmap, levels, - vmin, vmax, - extend, - bounds_set, - extend_set) + cmap, cnorm = _build_discrete_cmap(cmap, levels, extend, filled) + + return vmin, vmax, cmap, extend, levels, cnorm + + +def _determine_extend(calc_data, vmin, vmax): + extend_min = calc_data.min() < vmin + extend_max = calc_data.max() > vmax + if extend_min and extend_max: + extend = 'both' + elif extend_min: + extend = 'min' + elif extend_max: + extend = 'max' else: - cnorm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) + extend = 'neither' + return extend - return vmin, vmax, cmap, extend, cnorm +def _color_palette(cmap, n_colors): + import matplotlib.pyplot as plt + try: + from seaborn.apionly import color_palette + pal = color_palette(cmap, n_colors=n_colors) + except (TypeError, ImportError): + # TypeError is raised when LinearSegmentedColormap (viridis) is used + # Import Error is raised when seaborn is not installed + # Use homegrown solution if you don't have seaborn or are using viridis + if isinstance(cmap, basestring): + cmap = plt.get_cmap(cmap) -def _determine_discrete_cmap_params(cmap, levels, vmin, vmax, extend, - bounds_set, extend_set): + colors_i = np.linspace(0, 1., n_colors) + pal = cmap(colors_i) + return pal + + +def _build_discrete_cmap(cmap, levels, extend, filled): """ Build a discrete colormap and normalization of the data. """ import matplotlib as mpl - import matplotlib.pyplot as plt def extension_colors(extend): if extend == 'both': @@ -296,54 +312,20 @@ def extension_colors(extend): ext_n = 0 return ext_n - if isinstance(levels, int): - vmax += 10 * np.finfo(float).eps # Add small epison to include vmax - if not bounds_set: - # if there were not user provided bounds, use MaxNLocator to pick - # a nice set of ticks - ticker = mpl.ticker.MaxNLocator(levels) - cticks = ticker.tick_values(vmin, vmax) - print(cticks) - else: - # otherwise, use the user provided vmin/vmax - cticks = np.linspace(vmin, vmax, num=levels + 1, endpoint=True) - ext_n = extension_colors(extend) - n_colors = len(cticks) + ext_n - 1 - else: - try: - cticks = np.asarray(levels) - if not extend_set: - extend_min = cticks[0] > vmin - extend_max = cticks[-1] < vmax - if extend_min and extend_max: - extend = 'both' - elif extend_min: - extend = 'min' - elif extend_max: - extend = 'max' - else: - extend = 'neither' - ext_n = extension_colors(extend) - n_colors = len(levels) + ext_n - 1 - except TypeError as e: - print('Unexpected type (%s) given for levels' % type(levels)) - raise e - try: - from seaborn.apionly import color_palette - pal = color_palette(cmap, n_colors=n_colors) - except (TypeError, ImportError): - # TypeError is raised when LinearSegmentedColormap (viridis) is used - # Import Error is raised when seaborn is not installed - # Use homegrown solution if you don't have seaborn or are using viridis - if isinstance(cmap, basestring): - cmap = plt.get_cmap(cmap) + if not filled: + # non-filled contour plots + extend = 'neither' - colors_i = np.linspace(0, 1., n_colors) - pal = cmap(colors_i) + ext_n = extension_colors(extend) + n_colors = len(levels) + ext_n - 1 + pal = _color_palette(cmap, n_colors) - cmap, cnorm = mpl.colors.from_levels_and_colors(cticks, pal, extend=extend) + new_cmap, cnorm = mpl.colors.from_levels_and_colors( + levels, pal, extend=extend) + # copy the old cmap name, for easier testing + new_cmap.name = getattr(cmap, 'name', cmap) - return cmap, cnorm, extend + return new_cmap, cnorm # MUST run before any 2d plotting functions are defined since @@ -393,7 +375,8 @@ def _plot2d(plotfunc): vmin, vmax : floats, optional Values to anchor the colormap, otherwise they are inferred from the data and other keyword arguments. When a diverging dataset is inferred, - one of these values may be ignored. + one of these values may be ignored. If discrete levels are provided as + an explicit list, both of these values are ignored. cmap : matplotlib colormap name or object, optional The mapping from data values to color space. If not provided, this will be either be ``viridis`` (if the function infers a sequential @@ -449,20 +432,25 @@ def newplotfunc(darray, ax=None, xincrease=None, yincrease=None, _ensure_plottable(x, y) - vmin, vmax, cmap, extend, cnorm = _determine_cmap_params( - z.data, vmin, vmax, cmap, center, robust, extend, levels) + if 'contour' in plotfunc.__name__ and levels is None: + levels = 7 # this is the matplotlib default + filled = plotfunc.__name__ != 'contour' + + vmin, vmax, cmap, extend, levels, cnorm = _determine_cmap_params( + z.data, vmin, vmax, cmap, center, robust, extend, levels, filled) if 'contour' in plotfunc.__name__: # extend is a keyword argument only for contour and contourf, but # passing it to the colorbar is sufficient for imshow and # pcolormesh kwargs['extend'] = extend + kwargs['levels'] = levels - if 'norm' not in kwargs: - # This allows the user to pass in a custom norm coming via kwargs - kwargs['norm'] = cnorm + # This allows the user to pass in a custom norm coming via kwargs + kwargs.setdefault('norm', cnorm) - ax, primitive = plotfunc(x, y, z, ax=ax, cmap=cmap, **kwargs) + ax, primitive = plotfunc(x, y, z, ax=ax, cmap=cmap, vmin=vmin, + vmax=vmax, **kwargs) ax.set_xlabel(xlab) ax.set_ylabel(ylab) diff --git a/xray/test/test_plot.py b/xray/test/test_plot.py index 0d8d9117aae..a2abafb19d7 100644 --- a/xray/test/test_plot.py +++ b/xray/test/test_plot.py @@ -6,7 +6,7 @@ import xray.plot as xplt from xray.plot.plot import (_infer_interval_breaks, _determine_cmap_params, - _determine_discrete_cmap_params) + _build_discrete_cmap) from . import TestCase, requires_matplotlib @@ -166,79 +166,57 @@ def test_plot_nans(self): class TestDetermineCmapParams(TestCase): def test_robust(self): data = np.random.RandomState(1).rand(100) - vmin, vmax, cmap, extend, cnorm = _determine_cmap_params(data, - robust=True) + vmin, vmax, cmap, extend, levels, cnorm = _determine_cmap_params( + data, robust=True) self.assertEqual(vmin, np.percentile(data, 2)) self.assertEqual(vmax, np.percentile(data, 98)) self.assertEqual(cmap.name, 'viridis') self.assertEqual(extend, 'both') + self.assertIsNone(levels) + self.assertIsNone(cnorm) def test_center(self): data = np.random.RandomState(2).rand(100) - vmin, vmax, cmap, extend, cnorm = _determine_cmap_params(data, - center=0.5) + vmin, vmax, cmap, extend, levels, cnorm = _determine_cmap_params( + data, center=0.5) self.assertEqual(vmax - 0.5, 0.5 - vmin) self.assertEqual(cmap, 'RdBu_r') self.assertEqual(extend, 'neither') + self.assertIsNone(levels) + self.assertIsNone(cnorm) - -@requires_matplotlib -class TestDetermineDiscreteCmapParams(TestCase): def test_integer_levels(self): - levels = 8 - vmin = -5 - vmax = 5 - cmap, cnorm, extend = _determine_discrete_cmap_params( - 'Spectral', levels, vmin, vmax, 'neither', True, True) - self.assertEqual(cmap.N, levels) - self.assertEqual(cnorm.N, levels + 1) - self.assertEqual(cnorm.vmin, vmin) - self.assertEqual(cnorm.vmax, vmax + 10 * np.finfo(float).eps) - - cmap, cnorm, extend = _determine_discrete_cmap_params( - 'Blues', levels, vmin, vmax, 'both', True, True) - # extension colors are not included here - self.assertEqual(cmap.N, levels) - self.assertEqual(cnorm.N, levels + 1) - self.assertEqual(cnorm.vmin, vmin) - self.assertEqual(cnorm.vmax, vmax + 10 * np.finfo(float).eps) - - # heuristics for picking nice ticks - cmap, cnorm, extend = _determine_discrete_cmap_params( - 'Spectral', levels, vmin, vmax, 'neither', False, True) - self.assertGreaterEqual(cnorm.vmax, vmax) - self.assertLessEqual(cnorm.vmin, vmin) + data = 1 + np.random.RandomState(3).rand(100) + vmin, vmax, cmap, extend, levels, cnorm = _determine_cmap_params( + data, levels=5, vmin=0, vmax=5, cmap='Blues') + self.assertEqual(vmin, levels[0]) + self.assertEqual(vmax, levels[-1]) + self.assertEqual(cmap.name, 'Blues') + self.assertEqual(extend, 'neither') + self.assertEqual(cmap.N, 5) + self.assertEqual(cnorm.N, 6) + + vmin, vmax, cmap, extend, levels, cnorm = _determine_cmap_params( + data, levels=5, vmin=0.5, vmax=1.5) + self.assertEqual(cmap.name, 'viridis') + self.assertEqual(extend, 'max') def test_list_levels(self): - levels = [-4, -2, 0, 2, 4] - vmin = -5 - vmax = 5 - - cmap, cnorm, extend = _determine_discrete_cmap_params( - 'Spectral', levels, vmin, vmax, 'neither', True, True) - self.assertEqual(cmap.N, len(levels) - 1) - self.assertEqual(cnorm.N, len(levels)) - self.assertEqual(cnorm.vmin, min(levels)) - self.assertEqual(cnorm.vmax, max(levels)) - - cmap, cnorm, extend = _determine_discrete_cmap_params( - 'Greens_r', levels, vmin, vmax, 'both', True, True) - self.assertEqual(cmap.N, len(levels) - 1) - self.assertEqual(cnorm.N, len(levels)) - self.assertEqual(cnorm.vmin, min(levels)) - self.assertEqual(cnorm.vmax, max(levels)) - - # levels as an array - cmap, cnorm, extend = _determine_discrete_cmap_params( - 'Greens_r', np.array(levels), vmin, vmax, 'both', True, True) - # levels as a DataArray - cmap, cnorm, extend = _determine_discrete_cmap_params( - 'Greens_r', DataArray(levels), vmin, vmax, 'both', True, True) - - # heuristics for picking extend when using list of levels - cmap, cnorm, extend = _determine_discrete_cmap_params( - 'Greens_r', DataArray(levels), -5, 3, 'both', True, False) - self.assertEqual(extend, 'min') + data = 1 + np.random.RandomState(3).rand(100) + + orig_levels = [0, 1, 2, 3, 4, 5] + # vmin and vmax should be ignored if levels are explicitly provided + vmin, vmax, cmap, extend, levels, cnorm = _determine_cmap_params( + data, levels=orig_levels, vmin=0, vmax=3) + self.assertEqual(vmin, 0) + self.assertEqual(vmax, 5) + self.assertEqual(cmap.N, 5) + self.assertEqual(cnorm.N, 6) + + for wrap_levels in [list, np.array, pd.Index, DataArray]: + vmin, vmax, cmap, extend, levels, cnorm = _determine_cmap_params( + data, levels=wrap_levels(orig_levels)) + self.assertArrayEqual(levels, orig_levels) class Common2dMixin: @@ -311,8 +289,8 @@ def test_default_cmap(self): self.assertEqual('viridis', cmap_name) def test_can_change_default_cmap(self): - cmap_name = self.plotmethod(cmap='jet').get_cmap().name - self.assertEqual('jet', cmap_name) + cmap_name = self.plotmethod(cmap='Blues').get_cmap().name + self.assertEqual('Blues', cmap_name) def test_diverging_color_limits(self): artist = self.plotmethod()