From 8ded25088cbb92bdd588f3e2dcbb0491d280c623 Mon Sep 17 00:00:00 2001 From: Fabien Maussion Date: Thu, 10 Nov 2016 23:55:03 +0100 Subject: [PATCH] New infer_intervals keyword for pcolormesh (#1079) * fixes https://github.com/pydata/xarray/issues/781 * typo * rename keyword * the kwargs should also be passed over now * infer 2d coords, new chapter in docs * revert to original cartopy test * update docs * other version of _infer_intervals_breaks * py2 division --- doc/plotting.rst | 53 ++++++++++++++++++++++++++++++++++++ doc/whats-new.rst | 5 ++++ xarray/plot/plot.py | 59 +++++++++++++++++++++++++++++++--------- xarray/test/test_plot.py | 13 +++++++++ 4 files changed, 117 insertions(+), 13 deletions(-) diff --git a/doc/plotting.rst b/doc/plotting.rst index f2acce8acb2..04c3917f51b 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -525,3 +525,56 @@ the values on the y axis are decreasing with -0.5 on the top. This is because the pixels are centered over their coordinates, and the axis labels and ranges correspond to the values of the coordinates. + +Multidimensional coordinates +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +See also: :ref:`examples.multidim`. + +You can plot irregular grids defined by multidimensional coordinates with +xarray, but you'll have to tell the plot function to use these coordinates +instead of the default ones: + +.. ipython:: python + + lon, lat = np.meshgrid(np.linspace(-20, 20, 5), np.linspace(0, 30, 4)) + lon += lat/10 + lat += lon/10 + da = xr.DataArray(np.arange(20).reshape(4, 5), dims=['y', 'x'], + coords = {'lat': (('y', 'x'), lat), + 'lon': (('y', 'x'), lon)}) + + @savefig plotting_example_2d_irreg.png width=4in + da.plot.pcolormesh('lon', 'lat'); + +Note that in this case, xarray still follows the pixel centered convention. +This might be undesirable in some cases, for example when your data is defined +on a polar projection (:issue:`781`). This is why the default is to not follow +this convention when plotting on a map: + +.. ipython:: python + + import cartopy.crs as ccrs + ax = plt.subplot(projection=ccrs.PlateCarree()); + da.plot.pcolormesh('lon', 'lat', ax=ax); + ax.scatter(lon, lat, transform=ccrs.PlateCarree()); + @savefig plotting_example_2d_irreg_map.png width=4in + ax.coastlines(); ax.gridlines(draw_labels=True); + +You can however decide to infer the cell boundaries and use the +``infer_intervals`` keyword: + +.. ipython:: python + + ax = plt.subplot(projection=ccrs.PlateCarree()); + da.plot.pcolormesh('lon', 'lat', ax=ax, infer_intervals=True); + ax.scatter(lon, lat, transform=ccrs.PlateCarree()); + @savefig plotting_example_2d_irreg_map_infer.png width=4in + ax.coastlines(); ax.gridlines(draw_labels=True); + +.. note:: + The data model of xarray does not support datasets with `cell boundaries`_ + yet. If you want to use these coordinates, you'll have to make the plots + outside the xarray framework. + +.. _cell boundaries: http://cfconventions.org/cf-conventions/v1.6.0/cf-conventions.html#cell-boundaries diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d30925c8f28..f001c5044ce 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -113,6 +113,11 @@ Bug fixes - ``Dataset.concat()`` now preserves variables order (:issue:`1027`). By `Fabien Maussion `_. +- Fixed an issue with pcolormesh (:issue:`781`). A new + ``infer_intervals`` keyword gives control on whether the cell intervals + should be computed or not. + By `Fabien Maussion `_. + .. _whats-new.0.8.2: v0.8.2 (18 August 2016) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 11a32e9f85e..1e4abb148a8 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -316,6 +316,12 @@ def _plot2d(plotfunc): provided, extend is inferred from vmin, vmax and the data limits. levels : int or list-like object, optional Split the colormap (cmap) into discrete color intervals. + infer_intervals : bool, optional + Only applies to pcolormesh. If True, the coordinate intervals are + passed to pcolormesh. If False, the original coordinates are used + (this can be useful for certain map projections). The default is to + always infer intervals, unless the mesh is irregular and plotted on + a map projection. subplot_kws : dict, optional Dictionary of keyword arguments for matplotlib subplots. Only applies to FacetGrid plotting. @@ -341,8 +347,9 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None, col_wrap=None, xincrease=True, yincrease=True, add_colorbar=None, add_labels=True, vmin=None, vmax=None, cmap=None, center=None, robust=False, extend=None, - levels=None, colors=None, subplot_kws=None, - cbar_ax=None, cbar_kwargs=None, **kwargs): + levels=None, infer_intervals=None, colors=None, + subplot_kws=None, cbar_ax=None, cbar_kwargs=None, + **kwargs): # All 2d plots in xarray share this function signature. # Method signature below should be consistent. @@ -416,6 +423,9 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None, kwargs['extend'] = cmap_params['extend'] kwargs['levels'] = cmap_params['levels'] + if 'pcolormesh' == plotfunc.__name__: + kwargs['infer_intervals'] = infer_intervals + # This allows the user to pass in a custom norm coming via kwargs kwargs.setdefault('norm', cmap_params['norm']) @@ -456,8 +466,8 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, ax=None, row=None, col=None, col_wrap=None, xincrease=True, yincrease=True, add_colorbar=None, add_labels=True, vmin=None, vmax=None, cmap=None, colors=None, center=None, robust=False, - extend=None, levels=None, subplot_kws=None, - cbar_ax=None, cbar_kwargs=None, **kwargs): + extend=None, levels=None, infer_intervals=None, + subplot_kws=None, cbar_ax=None, cbar_kwargs=None, **kwargs): """ The method should have the same signature as the function. @@ -542,29 +552,52 @@ def contourf(x, y, z, ax, **kwargs): return ax, primitive -def _infer_interval_breaks(coord): +def _infer_interval_breaks(coord, axis=0): """ >>> _infer_interval_breaks(np.arange(5)) array([-0.5, 0.5, 1.5, 2.5, 3.5, 4.5]) + >>> _infer_interval_breaks([[0, 1], [3, 4]], axis=1) + array([[-0.5, 0.5, 1.5], + [ 2.5, 3.5, 4.5]]) """ coord = np.asarray(coord) - deltas = 0.5 * (coord[1:] - coord[:-1]) - first = coord[0] - deltas[0] - last = coord[-1] + deltas[-1] - return np.r_[[first], coord[:-1] + deltas, [last]] + deltas = 0.5 * np.diff(coord, axis=axis) + first = np.take(coord, [0], axis=axis) - np.take(deltas, [0], axis=axis) + last = np.take(coord, [-1], axis=axis) + np.take(deltas, [-1], axis=axis) + trim_last = tuple(slice(None, -1) if n == axis else slice(None) + for n in range(coord.ndim)) + return np.concatenate([first, coord[trim_last] + deltas, last], axis=axis) @_plot2d -def pcolormesh(x, y, z, ax, **kwargs): +def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): """ Pseudocolor plot of 2d DataArray Wraps matplotlib.pyplot.pcolormesh """ - if not hasattr(ax, 'projection'): - x = _infer_interval_breaks(x) - y = _infer_interval_breaks(y) + # decide on a default for infer_intervals (GH781) + x = np.asarray(x) + if infer_intervals is None: + if hasattr(ax, 'projection'): + if len(x.shape) == 1: + infer_intervals = True + else: + infer_intervals = False + else: + infer_intervals = True + + if infer_intervals: + if len(x.shape) == 1: + x = _infer_interval_breaks(x) + y = _infer_interval_breaks(y) + else: + # we have to infer the intervals on both axes + x = _infer_interval_breaks(x, axis=1) + x = _infer_interval_breaks(x, axis=0) + y = _infer_interval_breaks(y, axis=1) + y = _infer_interval_breaks(y, axis=0) primitive = ax.pcolormesh(x, y, z, **kwargs) diff --git a/xarray/test/test_plot.py b/xarray/test/test_plot.py index 4e812d2b602..e529f9dab83 100644 --- a/xarray/test/test_plot.py +++ b/xarray/test/test_plot.py @@ -1,3 +1,5 @@ +from __future__ import division + import inspect import numpy as np @@ -112,6 +114,17 @@ def test__infer_interval_breaks(self): self.assertArrayEqual(pd.date_range('20000101', periods=4) - np.timedelta64(12, 'h'), _infer_interval_breaks(pd.date_range('20000101', periods=3))) + # make a bounded 2D array that we will center and re-infer + xref, yref = np.meshgrid(np.arange(6), np.arange(5)) + cx = (xref[1:, 1:] + xref[:-1, :-1]) / 2 + cy = (yref[1:, 1:] + yref[:-1, :-1]) / 2 + x = _infer_interval_breaks(cx, axis=1) + x = _infer_interval_breaks(x, axis=0) + y = _infer_interval_breaks(cy, axis=1) + y = _infer_interval_breaks(y, axis=0) + np.testing.assert_allclose(xref, x) + np.testing.assert_allclose(yref, y) + def test_datetime_dimension(self): nrow = 3 ncol = 4