Skip to content

Commit

Permalink
New infer_intervals keyword for pcolormesh (#1079)
Browse files Browse the repository at this point in the history
* fixes #781

* typo

* rename keyword

* the kwargs should also be passed over now

* infer 2d coords, new chapter in docs

* revert to original cartopy test

* update docs

*  other version of _infer_intervals_breaks

* py2 division
  • Loading branch information
fmaussion authored and shoyer committed Nov 10, 2016
1 parent a7bb97c commit 8ded250
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 13 deletions.
53 changes: 53 additions & 0 deletions doc/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,56 @@ the values on the y axis are decreasing with -0.5 on the top. This is because
the pixels are centered over their coordinates, and the
axis labels and ranges correspond to the values of the
coordinates.

Multidimensional coordinates
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

See also: :ref:`examples.multidim`.

You can plot irregular grids defined by multidimensional coordinates with
xarray, but you'll have to tell the plot function to use these coordinates
instead of the default ones:

.. ipython:: python
lon, lat = np.meshgrid(np.linspace(-20, 20, 5), np.linspace(0, 30, 4))
lon += lat/10
lat += lon/10
da = xr.DataArray(np.arange(20).reshape(4, 5), dims=['y', 'x'],
coords = {'lat': (('y', 'x'), lat),
'lon': (('y', 'x'), lon)})
@savefig plotting_example_2d_irreg.png width=4in
da.plot.pcolormesh('lon', 'lat');
Note that in this case, xarray still follows the pixel centered convention.
This might be undesirable in some cases, for example when your data is defined
on a polar projection (:issue:`781`). This is why the default is to not follow
this convention when plotting on a map:

.. ipython:: python
import cartopy.crs as ccrs
ax = plt.subplot(projection=ccrs.PlateCarree());
da.plot.pcolormesh('lon', 'lat', ax=ax);
ax.scatter(lon, lat, transform=ccrs.PlateCarree());
@savefig plotting_example_2d_irreg_map.png width=4in
ax.coastlines(); ax.gridlines(draw_labels=True);
You can however decide to infer the cell boundaries and use the
``infer_intervals`` keyword:

.. ipython:: python
ax = plt.subplot(projection=ccrs.PlateCarree());
da.plot.pcolormesh('lon', 'lat', ax=ax, infer_intervals=True);
ax.scatter(lon, lat, transform=ccrs.PlateCarree());
@savefig plotting_example_2d_irreg_map_infer.png width=4in
ax.coastlines(); ax.gridlines(draw_labels=True);
.. note::
The data model of xarray does not support datasets with `cell boundaries`_
yet. If you want to use these coordinates, you'll have to make the plots
outside the xarray framework.

.. _cell boundaries: http://cfconventions.org/cf-conventions/v1.6.0/cf-conventions.html#cell-boundaries
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ Bug fixes
- ``Dataset.concat()`` now preserves variables order (:issue:`1027`).
By `Fabien Maussion <https://github.com/fmaussion>`_.

- Fixed an issue with pcolormesh (:issue:`781`). A new
``infer_intervals`` keyword gives control on whether the cell intervals
should be computed or not.
By `Fabien Maussion <https://github.com/fmaussion>`_.

.. _whats-new.0.8.2:

v0.8.2 (18 August 2016)
Expand Down
59 changes: 46 additions & 13 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,12 @@ def _plot2d(plotfunc):
provided, extend is inferred from vmin, vmax and the data limits.
levels : int or list-like object, optional
Split the colormap (cmap) into discrete color intervals.
infer_intervals : bool, optional
Only applies to pcolormesh. If True, the coordinate intervals are
passed to pcolormesh. If False, the original coordinates are used
(this can be useful for certain map projections). The default is to
always infer intervals, unless the mesh is irregular and plotted on
a map projection.
subplot_kws : dict, optional
Dictionary of keyword arguments for matplotlib subplots. Only applies
to FacetGrid plotting.
Expand All @@ -341,8 +347,9 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
col_wrap=None, xincrease=True, yincrease=True,
add_colorbar=None, add_labels=True, vmin=None, vmax=None,
cmap=None, center=None, robust=False, extend=None,
levels=None, colors=None, subplot_kws=None,
cbar_ax=None, cbar_kwargs=None, **kwargs):
levels=None, infer_intervals=None, colors=None,
subplot_kws=None, cbar_ax=None, cbar_kwargs=None,
**kwargs):
# All 2d plots in xarray share this function signature.
# Method signature below should be consistent.

Expand Down Expand Up @@ -416,6 +423,9 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
kwargs['extend'] = cmap_params['extend']
kwargs['levels'] = cmap_params['levels']

if 'pcolormesh' == plotfunc.__name__:
kwargs['infer_intervals'] = infer_intervals

# This allows the user to pass in a custom norm coming via kwargs
kwargs.setdefault('norm', cmap_params['norm'])

Expand Down Expand Up @@ -456,8 +466,8 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, ax=None, row=None,
col=None, col_wrap=None, xincrease=True, yincrease=True,
add_colorbar=None, add_labels=True, vmin=None, vmax=None,
cmap=None, colors=None, center=None, robust=False,
extend=None, levels=None, subplot_kws=None,
cbar_ax=None, cbar_kwargs=None, **kwargs):
extend=None, levels=None, infer_intervals=None,
subplot_kws=None, cbar_ax=None, cbar_kwargs=None, **kwargs):
"""
The method should have the same signature as the function.
Expand Down Expand Up @@ -542,29 +552,52 @@ def contourf(x, y, z, ax, **kwargs):
return ax, primitive


def _infer_interval_breaks(coord):
def _infer_interval_breaks(coord, axis=0):
"""
>>> _infer_interval_breaks(np.arange(5))
array([-0.5, 0.5, 1.5, 2.5, 3.5, 4.5])
>>> _infer_interval_breaks([[0, 1], [3, 4]], axis=1)
array([[-0.5, 0.5, 1.5],
[ 2.5, 3.5, 4.5]])
"""
coord = np.asarray(coord)
deltas = 0.5 * (coord[1:] - coord[:-1])
first = coord[0] - deltas[0]
last = coord[-1] + deltas[-1]
return np.r_[[first], coord[:-1] + deltas, [last]]
deltas = 0.5 * np.diff(coord, axis=axis)
first = np.take(coord, [0], axis=axis) - np.take(deltas, [0], axis=axis)
last = np.take(coord, [-1], axis=axis) + np.take(deltas, [-1], axis=axis)
trim_last = tuple(slice(None, -1) if n == axis else slice(None)
for n in range(coord.ndim))
return np.concatenate([first, coord[trim_last] + deltas, last], axis=axis)


@_plot2d
def pcolormesh(x, y, z, ax, **kwargs):
def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs):
"""
Pseudocolor plot of 2d DataArray
Wraps matplotlib.pyplot.pcolormesh
"""

if not hasattr(ax, 'projection'):
x = _infer_interval_breaks(x)
y = _infer_interval_breaks(y)
# decide on a default for infer_intervals (GH781)
x = np.asarray(x)
if infer_intervals is None:
if hasattr(ax, 'projection'):
if len(x.shape) == 1:
infer_intervals = True
else:
infer_intervals = False
else:
infer_intervals = True

if infer_intervals:
if len(x.shape) == 1:
x = _infer_interval_breaks(x)
y = _infer_interval_breaks(y)
else:
# we have to infer the intervals on both axes
x = _infer_interval_breaks(x, axis=1)
x = _infer_interval_breaks(x, axis=0)
y = _infer_interval_breaks(y, axis=1)
y = _infer_interval_breaks(y, axis=0)

primitive = ax.pcolormesh(x, y, z, **kwargs)

Expand Down
13 changes: 13 additions & 0 deletions xarray/test/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import division

import inspect

import numpy as np
Expand Down Expand Up @@ -112,6 +114,17 @@ def test__infer_interval_breaks(self):
self.assertArrayEqual(pd.date_range('20000101', periods=4) - np.timedelta64(12, 'h'),
_infer_interval_breaks(pd.date_range('20000101', periods=3)))

# make a bounded 2D array that we will center and re-infer
xref, yref = np.meshgrid(np.arange(6), np.arange(5))
cx = (xref[1:, 1:] + xref[:-1, :-1]) / 2
cy = (yref[1:, 1:] + yref[:-1, :-1]) / 2
x = _infer_interval_breaks(cx, axis=1)
x = _infer_interval_breaks(x, axis=0)
y = _infer_interval_breaks(cy, axis=1)
y = _infer_interval_breaks(y, axis=0)
np.testing.assert_allclose(xref, x)
np.testing.assert_allclose(yref, y)

def test_datetime_dimension(self):
nrow = 3
ncol = 4
Expand Down

0 comments on commit 8ded250

Please sign in to comment.