Skip to content

Commit

Permalink
Normalisation for RGB imshow
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD committed Jan 17, 2018
1 parent f3deb2f commit a1f475d
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 3 deletions.
2 changes: 2 additions & 0 deletions doc/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ example, consider the original data in Kelvins rather than Celsius:
The Celsius data contain 0, so a diverging color map was used. The
Kelvins do not have 0, so the default color map was used.

.. _robust-plotting:

Robust
~~~~~~

Expand Down
1 change: 1 addition & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Enhancements
- Support for using `Zarr`_ as storage layer for xarray.
By `Ryan Abernathey <https://github.com/rabernat>`_.
- :func:`xarray.plot.imshow` now handles RGB and RGBA images.
Saturation can be adjusted with ``vmin`` and ``vmax``, or with ``robust=True``.
By `Zac Hatfield-Dodds <https://github.com/Zac-HD>`_.
- Experimental support for parsing ENVI metadata to coordinates and attributes
in :py:func:`xarray.open_rasterio`.
Expand Down
38 changes: 36 additions & 2 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import pandas as pd
from datetime import datetime

from .utils import (_determine_cmap_params, _infer_xy_labels, get_axis,
import_matplotlib_pyplot)
from .utils import (ROBUST_PERCENTILE, _determine_cmap_params,
_infer_xy_labels, get_axis, import_matplotlib_pyplot)
from .facetgrid import FacetGrid
from xarray.core.pycompat import basestring

Expand Down Expand Up @@ -326,6 +326,30 @@ def line(self, *args, **kwargs):
return line(self._da, *args, **kwargs)


def _rescale_imshow_rgb(darray, vmin, vmax, robust):
assert robust or vmin is not None or vmax is not None
# There's a cyclic dependency via DataArray, so we can't import from
# xarray.ufuncs in global scope.
from xarray.ufuncs import maximum, minimum
# Calculate vmin and vmax automatically for `robust=True`
if robust:
if vmax is None:
vmax = np.nanpercentile(darray, 100 - ROBUST_PERCENTILE)
if vmin is None:
vmin = np.nanpercentile(darray, ROBUST_PERCENTILE)
# If not robust and one bound is None, calculate the default other bound
elif vmax is None:
vmax = 255 if np.issubdtype(darray.dtype, np.integer) else 1
elif vmin is None:
vmin = 0
# Scale interval [vmin .. vmax] to [0 .. 1], with darray as 64-bit float
# to avoid precision loss, integer over/underflow, etc with extreme inputs.
# After scaling, downcast to 32-bit float. This substantially reduces
# memory usage after we hand `darray` off to matplotlib.
darray = ((darray.astype('f8') - vmin) / (vmax - vmin)).astype('f4')
return minimum(maximum(darray, 0), 1)


def _plot2d(plotfunc):
"""
Decorator for common 2d plotting logic
Expand Down Expand Up @@ -449,6 +473,11 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
if imshow_rgb:
# Don't add a colorbar when showing an image with explicit colors
add_colorbar = False
# Matplotlib does not support normalising RGB data, so do it here.
# See eg. https://github.com/matplotlib/matplotlib/pull/10220
if robust or vmax is not None or vmin is not None:
darray = _rescale_imshow_rgb(darray, vmin, vmax, robust)
vmin, vmax, robust = None, None, False

# Handle facetgrids first
if row or col:
Expand Down Expand Up @@ -625,6 +654,11 @@ def imshow(x, y, z, ax, **kwargs):
dimension can be interpreted as RGB or RGBA color channels and
allows this dimension to be specified via the kwarg ``rgb=``.
Unlike matplotlib, Xarray can apply ``vmin`` and ``vmax`` to RGB or RGBA
data, by applying a single scaling factor and offset to all bands.
Passing ``robust=True`` infers ``vmin`` and ``vmax``
:ref:`in the usual way <robust-plotting>`.
.. note::
This function needs uniformly spaced coordinates to
properly label the axes. Call DataArray.plot() to check.
Expand Down
4 changes: 3 additions & 1 deletion xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from ..core.utils import is_scalar


ROBUST_PERCENTILE = 2.0


def _load_default_cmap(fname='default_colormap.csv'):
"""
Returns viridis color map
Expand Down Expand Up @@ -165,7 +168,6 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
cmap_params : dict
Use depends on the type of the plotting function
"""
ROBUST_PERCENTILE = 2.0
import matplotlib as mpl

calc_data = np.ravel(plot_data[~pd.isnull(plot_data)])
Expand Down
10 changes: 10 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,16 @@ def test_rgb_errors_bad_dim_sizes(self):
with pytest.raises(ValueError):
arr.plot.imshow(rgb='band')

def test_normalize_rgb_imshow(self):
for kwds in (
dict(vmin=-1), dict(vmax=2),
dict(vmin=-1, vmax=1), dict(vmin=0, vmax=0),
dict(vmin=0, robust=True), dict(vmax=-1, robust=True),
):
da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4))
arr = da.plot.imshow(**kwds).get_array()
assert 0 <= arr.min() <= arr.max() <= 1, kwds


class TestFacetGrid(PlotTestCase):
def setUp(self):
Expand Down

0 comments on commit a1f475d

Please sign in to comment.