Skip to content

Commit

Permalink
Merge pull request #1656 from sdeastham/fix/get_arr_masked
Browse files Browse the repository at this point in the history
Combine wrapped collection into GeoMesh data when get_array is called
  • Loading branch information
greglucas authored Oct 11, 2020
2 parents 304b6d4 + 8c2367c commit d12c86c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
8 changes: 8 additions & 0 deletions lib/cartopy/mpl/geocollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ class GeoQuadMesh(QuadMesh):
# come from GeoAxes.pcolormesh. These methods morph a QuadMesh by
# fiddling with instance.__class__.

def get_array(self):
# Retrieve the array - use copy to avoid any chance of overwrite
A = super(QuadMesh, self).get_array().copy()
# If the input array has a mask, retrieve the associated data
if hasattr(self, '_wrapped_mask'):
A[self._wrapped_mask] = self._wrapped_collection_fix.get_array()
return A

def set_array(self, A):
# raise right away if A is 2-dimensional.
if A.ndim > 1:
Expand Down
43 changes: 41 additions & 2 deletions lib/cartopy/tests/mpl/test_mpl_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,47 @@ def test_pcolormesh_global_with_wrap1():
ax.set_global() # make sure everything is visible


def test_pcolormesh_get_array_with_mask():
# make up some realistic data with bounds (such as data from the UM)
nx, ny = 36, 18
xbnds = np.linspace(0, 360, nx, endpoint=True)
ybnds = np.linspace(-90, 90, ny, endpoint=True)

x, y = np.meshgrid(xbnds, ybnds)
data = np.exp(np.sin(np.deg2rad(x)) + np.cos(np.deg2rad(y)))
data = data[:-1, :-1]

ax = plt.subplot(211, projection=ccrs.PlateCarree())
c = plt.pcolormesh(xbnds, ybnds, data, transform=ccrs.PlateCarree())
assert c._wrapped_collection_fix is not None, \
'No pcolormesh wrapping was done when it should have been.'

assert np.array_equal(data.ravel(), c.get_array()), \
'Data supplied does not match data retrieved in wrapped case'

ax.coastlines()
ax.set_global() # make sure everything is visible

# Case without wrapping
nx, ny = 36, 18
xbnds = np.linspace(-60, 60, nx, endpoint=True)
ybnds = np.linspace(-80, 80, ny, endpoint=True)

x, y = np.meshgrid(xbnds, ybnds)
data = np.exp(np.sin(np.deg2rad(x)) + np.cos(np.deg2rad(y)))
data2 = data[:-1, :-1]

ax = plt.subplot(212, projection=ccrs.PlateCarree())
c = plt.pcolormesh(xbnds, ybnds, data2, transform=ccrs.PlateCarree())
ax.coastlines()
ax.set_global() # make sure everything is visible

assert getattr(c, "_wrapped_collection_fix", None) is None, \
'pcolormesh wrapping was done when it should not have been.'

assert np.array_equal(data2.ravel(), c.get_array()), \
'Data supplied does not match data retrieved in unwrapped case'

tolerance = 1.61
if (5, 0, 0) <= ccrs.PROJ4_VERSION < (5, 1, 0):
tolerance += 0.8
Expand Down Expand Up @@ -564,8 +605,6 @@ def test_pcolormesh_diagonal_wrap():
ax = plt.axes(projection=ccrs.PlateCarree())
mesh = ax.pcolormesh(xs, ys, zs)

# Check that the quadmesh is masked
assert np.ma.is_masked(mesh.get_array())
# And that the wrapped_collection is added
assert hasattr(mesh, "_wrapped_collection_fix")

Expand Down

0 comments on commit d12c86c

Please sign in to comment.