From 6aa225f5dae9cc997e232c11a63072923c8c0238 Mon Sep 17 00:00:00 2001 From: Zac Hatfield Dodds Date: Fri, 19 Jan 2018 16:01:06 +1100 Subject: [PATCH] Normalisation for RGB imshow (#1819) * Normalisation for RGB imshow * Add test for error checking --- doc/plotting.rst | 2 ++ doc/whats-new.rst | 1 + xarray/plot/plot.py | 47 +++++++++++++++++++++++++++++++++++++-- xarray/plot/utils.py | 4 +++- xarray/tests/test_plot.py | 20 +++++++++++++++++ 5 files changed, 71 insertions(+), 3 deletions(-) diff --git a/doc/plotting.rst b/doc/plotting.rst index cd081811b99..2b816a24563 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -305,6 +305,8 @@ example, consider the original data in Kelvins rather than Celsius: The Celsius data contain 0, so a diverging color map was used. The Kelvins do not have 0, so the default color map was used. +.. _robust-plotting: + Robust ~~~~~~ diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7bfe5991b78..12d3b910ca6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -41,6 +41,7 @@ Enhancements - Support for using `Zarr`_ as storage layer for xarray. By `Ryan Abernathey `_. - :func:`xarray.plot.imshow` now handles RGB and RGBA images. + Saturation can be adjusted with ``vmin`` and ``vmax``, or with ``robust=True``. By `Zac Hatfield-Dodds `_. - Experimental support for parsing ENVI metadata to coordinates and attributes in :py:func:`xarray.open_rasterio`. diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 99ab4176714..d17ceb84e16 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -15,8 +15,8 @@ import pandas as pd from datetime import datetime -from .utils import (_determine_cmap_params, _infer_xy_labels, get_axis, - import_matplotlib_pyplot) +from .utils import (ROBUST_PERCENTILE, _determine_cmap_params, + _infer_xy_labels, get_axis, import_matplotlib_pyplot) from .facetgrid import FacetGrid from xarray.core.pycompat import basestring @@ -326,6 +326,39 @@ def line(self, *args, **kwargs): return line(self._da, *args, **kwargs) +def _rescale_imshow_rgb(darray, vmin, vmax, robust): + assert robust or vmin is not None or vmax is not None + # There's a cyclic dependency via DataArray, so we can't import from + # xarray.ufuncs in global scope. + from xarray.ufuncs import maximum, minimum + # Calculate vmin and vmax automatically for `robust=True` + if robust: + if vmax is None: + vmax = np.nanpercentile(darray, 100 - ROBUST_PERCENTILE) + if vmin is None: + vmin = np.nanpercentile(darray, ROBUST_PERCENTILE) + # If not robust and one bound is None, calculate the default other bound + # and check that an interval between them exists. + elif vmax is None: + vmax = 255 if np.issubdtype(darray.dtype, np.integer) else 1 + if vmax < vmin: + raise ValueError( + 'vmin=%r is less than the default vmax (%r) - you must supply ' + 'a vmax > vmin in this case.' % (vmin, vmax)) + elif vmin is None: + vmin = 0 + if vmin > vmax: + raise ValueError( + 'vmax=%r is less than the default vmin (0) - you must supply ' + 'a vmin < vmax in this case.' % vmax) + # Scale interval [vmin .. vmax] to [0 .. 1], with darray as 64-bit float + # to avoid precision loss, integer over/underflow, etc with extreme inputs. + # After scaling, downcast to 32-bit float. This substantially reduces + # memory usage after we hand `darray` off to matplotlib. + darray = ((darray.astype('f8') - vmin) / (vmax - vmin)).astype('f4') + return minimum(maximum(darray, 0), 1) + + def _plot2d(plotfunc): """ Decorator for common 2d plotting logic @@ -449,6 +482,11 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, if imshow_rgb: # Don't add a colorbar when showing an image with explicit colors add_colorbar = False + # Matplotlib does not support normalising RGB data, so do it here. + # See eg. https://github.com/matplotlib/matplotlib/pull/10220 + if robust or vmax is not None or vmin is not None: + darray = _rescale_imshow_rgb(darray, vmin, vmax, robust) + vmin, vmax, robust = None, None, False # Handle facetgrids first if row or col: @@ -625,6 +663,11 @@ def imshow(x, y, z, ax, **kwargs): dimension can be interpreted as RGB or RGBA color channels and allows this dimension to be specified via the kwarg ``rgb=``. + Unlike matplotlib, Xarray can apply ``vmin`` and ``vmax`` to RGB or RGBA + data, by applying a single scaling factor and offset to all bands. + Passing ``robust=True`` infers ``vmin`` and ``vmax`` + :ref:`in the usual way `. + .. note:: This function needs uniformly spaced coordinates to properly label the axes. Call DataArray.plot() to check. diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 58b4d55e0c5..c194b9dd8d8 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -11,6 +11,9 @@ from ..core.utils import is_scalar +ROBUST_PERCENTILE = 2.0 + + def _load_default_cmap(fname='default_colormap.csv'): """ Returns viridis color map @@ -165,7 +168,6 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, cmap_params : dict Use depends on the type of the plotting function """ - ROBUST_PERCENTILE = 2.0 import matplotlib as mpl calc_data = np.ravel(plot_data[~pd.isnull(plot_data)]) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 7306a36e6f5..1573577a092 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1126,6 +1126,26 @@ def test_rgb_errors_bad_dim_sizes(self): with pytest.raises(ValueError): arr.plot.imshow(rgb='band') + def test_normalize_rgb_imshow(self): + for kwds in ( + dict(vmin=-1), dict(vmax=2), + dict(vmin=-1, vmax=1), dict(vmin=0, vmax=0), + dict(vmin=0, robust=True), dict(vmax=-1, robust=True), + ): + da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) + arr = da.plot.imshow(**kwds).get_array() + assert 0 <= arr.min() <= arr.max() <= 1, kwds + + def test_normalize_rgb_one_arg_error(self): + da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) + # If passed one bound that implies all out of range, error: + for kwds in [dict(vmax=-1), dict(vmin=2)]: + with pytest.raises(ValueError): + da.plot.imshow(**kwds) + # If passed two that's just moving the range, *not* an error: + for kwds in [dict(vmax=-1, vmin=-1.2), dict(vmin=2, vmax=2.1)]: + da.plot.imshow(**kwds) + class TestFacetGrid(PlotTestCase): def setUp(self):