Skip to content

Commit

Permalink
Merge pull request #77 from csiro-coasts/polycollection
Browse files Browse the repository at this point in the history
Use PolyCollection over PatchCollection in matplotlib plots
  • Loading branch information
mx-moth authored Apr 27, 2023
2 parents 91a1bd4 + 75dd7df commit 59c2cc6
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 78 deletions.
4 changes: 4 additions & 0 deletions docs/releases/development.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ Next release (in development)

* Fix an issue with negative coordinates in :func:`~emsarray.cli.utils.bounds_argument` (:pr:`74`).
* Add a new ``emsarray plot`` subcommand to the ``emsarray`` command line interface (:pr:`76`).
* Use :class:`matplotlib.collections.PolyCollection`
rather than :class:`~matplotlib.collections.PatchCollection`
for significant speed improvements
(:pr:`77`).
42 changes: 27 additions & 15 deletions src/emsarray/conventions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dataclasses
import enum
import logging
import warnings
from functools import cached_property
from typing import (
TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Generic, Hashable, List,
Expand All @@ -19,8 +20,7 @@
from emsarray.compat.shapely import SpatialIndex
from emsarray.operations import depth
from emsarray.plot import (
_requires_plot, animate_on_figure, plot_on_figure,
polygons_to_patch_collection
_requires_plot, animate_on_figure, plot_on_figure, polygons_to_collection
)
from emsarray.state import State
from emsarray.types import Pathish
Expand All @@ -30,7 +30,7 @@
from cartopy.crs import CRS
from matplotlib.animation import FuncAnimation
from matplotlib.axes import Axes
from matplotlib.collections import PatchCollection
from matplotlib.collections import PolyCollection
from matplotlib.figure import Figure
from matplotlib.quiver import Quiver

Expand Down Expand Up @@ -552,7 +552,7 @@ def data_crs(self) -> CRS:
"""
The coordinate reference system that coordinates in this dataset are
defined in.
Used by :meth:`.make_patch_collection` and :meth:`.make_quiver`.
Used by :meth:`.make_poly_collection` and :meth:`.make_quiver`.
Defaults to :class:`cartopy.crs.PlateCarree`.
"""
# Lazily imported here as cartopy is an optional dependency
Expand Down Expand Up @@ -746,35 +746,35 @@ def animate_on_figure(
return animate_on_figure(figure, self, coordinate=coordinate, **kwargs)

@_requires_plot
def make_patch_collection(
def make_poly_collection(
self,
data_array: Optional[DataArrayOrName] = None,
**kwargs: Any,
) -> PatchCollection:
) -> PolyCollection:
"""
Make a :class:`~matplotlib.collections.PatchCollection`
Make a :class:`~matplotlib.collections.PolyCollection`
from the geometry of this :class:`~xarray.Dataset`.
This can be used to make custom matplotlib plots from your data.
If a :class:`~xarray.DataArray` is passed in,
the values of that are assigned to the PatchCollection `array` parameter.
the values of that are assigned to the PolyCollection `array` parameter.
Parameters
----------
data_array : Hashable or :class:`xarray.DataArray`, optional
A data array, or the name of a data variable in this dataset. Optional.
If given, the data array is :meth:`linearised <.make_linear>`
and passed to :meth:`PatchCollection.set_array() <matplotlib.cm.ScalarMappable.set_array>`.
and passed to :meth:`PolyCollection.set_array() <matplotlib.cm.ScalarMappable.set_array>`.
The data is used to colour the patches.
Refer to the matplotlib documentation for more information on styling.
**kwargs
Any keyword arguments are passed to the
:class:`~matplotlib.collections.PatchCollection` constructor.
:class:`~matplotlib.collections.PolyCollection` constructor.
Returns
-------
:class:`~matplotlib.collections.PatchCollection`
A PatchCollection constructed using the geometry of this dataset.
:class:`~matplotlib.collections.PolyCollection`
A PolyCollection constructed using the geometry of this dataset.
Example
-------
Expand All @@ -791,7 +791,7 @@ def make_patch_collection(
ds = emsarray.open_dataset("./tests/datasets/ugrid_mesh2d.nc")
ds = ds.isel(record=0, Mesh2_layers=-1)
patches = ds.ems.make_patch_collection('temp')
patches = ds.ems.make_poly_collection('temp')
axes.add_collection(patches)
figure.colorbar(patches, ax=axes, location='right', label='meters')
Expand All @@ -802,7 +802,7 @@ def make_patch_collection(
if data_array is not None:
if 'array' in kwargs:
raise TypeError(
"Can not pass both `data_array` and `array` to make_patch_collection"
"Can not pass both `data_array` and `array` to make_poly_collection"
)

data_array = self._get_data_array(data_array)
Expand All @@ -821,7 +821,19 @@ def make_patch_collection(
if 'transform' not in kwargs:
kwargs['transform'] = self.data_crs

return polygons_to_patch_collection(self.polygons[self.mask], **kwargs)
return polygons_to_collection(self.polygons[self.mask], **kwargs)

def make_patch_collection(
self,
data_array: Optional[DataArrayOrName] = None,
**kwargs: Any,
) -> PolyCollection:
warnings.warn(
"Convention.make_patch_collection has been renamed to "
"Convention.make_poly_collection, and now returns a PolyCollection",
category=DeprecationWarning,
)
return self.make_poly_collection(data_array, **kwargs)

@_requires_plot
def make_quiver(
Expand Down
74 changes: 35 additions & 39 deletions src/emsarray/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import cartopy.crs
from cartopy.feature import GSHHSFeature
from cartopy.mpl import gridliner
from matplotlib import animation, patches
from matplotlib import animation
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.collections import PatchCollection
from matplotlib.collections import PolyCollection
from matplotlib.figure import Figure
from shapely.geometry import Polygon
CAN_PLOT = True
Expand All @@ -30,7 +30,7 @@
IMPORT_EXCEPTION = exc


__all___ = ['CAN_PLOT', 'plot_on_figure', 'polygon_to_patch']
__all___ = ['CAN_PLOT', 'plot_on_figure', 'polygons_to_collection']


_requires_plot = requires_extra(extra='plot', import_error=IMPORT_EXCEPTION)
Expand Down Expand Up @@ -81,7 +81,7 @@ def bounds_to_extent(bounds: Tuple[float, float, float, float]) -> List[float]:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
from emsarray.plot import bounds_to_extent, polygon_to_patch
from emsarray.plot import bounds_to_extent
from shapely.geometry import Polygon
polygon = Polygon([
Expand All @@ -91,44 +91,40 @@ def bounds_to_extent(bounds: Tuple[float, float, float, float]) -> List[float]:
figure = plt.figure(figsize=(10, 8), dpi=100)
axes = plt.subplot(projection=ccrs.PlateCarree())
axes.set_extent(bounds_to_extent(polygon.buffer(0.1).bounds))
axes.add_patch(polygon_to_patch(polygon))
figure.show()
"""
minx, miny, maxx, maxy = bounds
return [minx, maxx, miny, maxy]


@_requires_plot
def polygon_to_patch(polygon: Polygon, **kwargs: Any) -> patches.Polygon:
"""
Convert a :class:`shapely.geometry.Polygon <Polygon>` to a
:class:`matplotlib.patches.Polygon`.
"""
return patches.Polygon(np.transpose(polygon.exterior.xy), **kwargs)


@_requires_plot
def polygons_to_patch_collection(
def polygons_to_collection(
polygons: Iterable[Polygon],
**kwargs: Any,
) -> PatchCollection:
) -> PolyCollection:
"""
Convert a list of Shapely :class:`Polygons <Polygon>`
to a matplotlib :class:`~matplotlib.collections.PatchCollection`.
to a matplotlib :class:`~matplotlib.collections.PolyCollection`.
Parameters
----------
polygons : iterable of `Polygon`
The polygons for the patch collection
polygons : iterable of Shapely :class:`Polygons <Polygon>`
The polygons for the poly collection
**kwargs : Any
Keyword arguments to pass to the PatchCollection constructor.
Keyword arguments to pass to the PolyCollection constructor.
Returns
-------
:class:`matplotlib.collections.PatchCollection`
The PatchCollection made up of the polygons passed in.
:class:`matplotlib.collections.PolyCollection`
A PolyCollection made up of the polygons passed in.
"""
return PatchCollection(map(polygon_to_patch, polygons), **kwargs)
return PolyCollection(
verts=[
np.asarray(polygon.exterior.coords)
for polygon in polygons
],
closed=False,
**kwargs
)


@_requires_plot
Expand All @@ -154,7 +150,7 @@ def plot_on_figure(
This is used to build the polygons and vector quivers.
scalar : :class:`xarray.DataArray`, optional
The data to plot as an :class:`xarray.DataArray`.
This will be passed to :meth:`.Convention.make_patch_collection`.
This will be passed to :meth:`.Convention.make_poly_collection`.
vector : tuple of :class:`numpy.ndarray`, optional
The *u* and *v* components of a vector field
as a tuple of :class:`xarray.DataArray`.
Expand All @@ -175,18 +171,18 @@ def plot_on_figure(

if scalar is None and vector is None:
# Plot the polygon shapes for want of anything else to draw
patches = convention.make_patch_collection()
axes.add_collection(patches)
collection = convention.make_poly_collection()
axes.add_collection(collection)
if title is None:
title = 'Geometry'

if scalar is not None:
# Plot a scalar variable on the polygons using a colour map
patches = convention.make_patch_collection(
collection = convention.make_poly_collection(
scalar, cmap='jet', edgecolor='face')
axes.add_collection(patches)
axes.add_collection(collection)
units = scalar.attrs.get('units')
figure.colorbar(patches, ax=axes, location='right', label=units)
figure.colorbar(collection, ax=axes, location='right', label=units)

if vector is not None:
# Plot a vector variable using a quiver
Expand Down Expand Up @@ -230,7 +226,7 @@ def animate_on_figure(
The coordinate values to vary across frames in the animation.
scalar : :class:`xarray.DataArray`, optional
The data to plot as an :class:`xarray.DataArray`.
This will be passed to :meth:`.Convention.make_patch_collection`.
This will be passed to :meth:`.Convention.make_poly_collection`.
It should have horizontal dimensions appropriate for this convention,
and a dimension matching the ``coordinate`` parameter.
vector : tuple of :class:`numpy.ndarray`, optional
Expand Down Expand Up @@ -273,17 +269,17 @@ def animate_on_figure(
axes.set_aspect(aspect='equal', adjustable='datalim')
axes.title.set_animated(True)

patches = None
collection = None
if scalar is not None:
# Plot a scalar variable on the polygons using a colour map
scalar_values = convention.make_linear(scalar).values[:, convention.mask]
patches = convention.make_patch_collection(
collection = convention.make_poly_collection(
cmap='jet', edgecolor='face',
clim=(np.nanmin(scalar_values), np.nanmax(scalar_values)))
axes.add_collection(patches)
patches.set_animated(True)
axes.add_collection(collection)
collection.set_animated(True)
units = scalar.attrs.get('units')
figure.colorbar(patches, ax=axes, location='right', label=units)
figure.colorbar(collection, ax=axes, location='right', label=units)

quiver = None
if vector is not None:
Expand Down Expand Up @@ -333,9 +329,9 @@ def animate(index: int) -> Iterable[Artist]:
changes.extend(gridlines.xline_artists)
changes.extend(gridlines.yline_artists)

if patches is not None:
patches.set_array(scalar_values[index])
changes.append(patches)
if collection is not None:
collection.set_array(scalar_values[index])
changes.append(collection)

if quiver is not None:
quiver.set_UVC(vector_u_values[index], vector_v_values[index])
Expand Down
24 changes: 12 additions & 12 deletions tests/conventions/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,36 +325,36 @@ def test_face_centres():


@pytest.mark.matplotlib
def test_make_patch_collection():
def test_make_poly_collection():
dataset = xr.Dataset({
'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))),
'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10),
})
convention = SimpleConvention(dataset)

patches = convention.make_patch_collection(cmap='plasma', edgecolor='black')
patches = convention.make_poly_collection(cmap='plasma', edgecolor='black')
assert len(patches.get_paths()) == len(convention.polygons[convention.mask])
assert patches.get_cmap().name == 'plasma'
# Colours get transformed in to RGBA arrays
np.testing.assert_equal(patches.get_edgecolor(), [[0., 0., 0., 1.0]])


def test_make_patch_collection_data_array():
def test_make_poly_collection_data_array():
dataset = xr.Dataset({
'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))),
'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10),
})
convention = SimpleConvention(dataset)

patches = convention.make_patch_collection(data_array='botz')
patches = convention.make_poly_collection(data_array='botz')
assert len(patches.get_paths()) == len(convention.polygons[convention.mask])

values = convention.make_linear(dataset.data_vars['botz'])[convention.mask]
np.testing.assert_equal(patches.get_array(), values)
assert patches.get_clim() == (np.nanmin(values), np.nanmax(values))


def test_make_patch_collection_data_array_and_array():
def test_make_poly_collection_data_array_and_array():
dataset = xr.Dataset({
'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))),
'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10),
Expand All @@ -365,22 +365,22 @@ def test_make_patch_collection_data_array_and_array():

with pytest.raises(TypeError):
# Passing both array and data_array is a TypeError
convention.make_patch_collection(data_array='botz', array=array)
convention.make_poly_collection(data_array='botz', array=array)


def test_make_patch_collection_data_array_and_clim():
def test_make_poly_collection_data_array_and_clim():
dataset = xr.Dataset({
'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))),
'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10),
})
convention = SimpleConvention(dataset)

# You can override the default clim if you want
patches = convention.make_patch_collection(data_array='botz', clim=(-12, -8))
patches = convention.make_poly_collection(data_array='botz', clim=(-12, -8))
assert patches.get_clim() == (-12, -8)


def test_make_patch_collection_data_array_dimensions():
def test_make_poly_collection_data_array_dimensions():
dataset = xr.Dataset({
'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))),
'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10),
Expand All @@ -389,12 +389,12 @@ def test_make_patch_collection_data_array_dimensions():

with pytest.raises(ValueError):
# temp needs subsetting first, so this should raise an error
convention.make_patch_collection(data_array='temp')
convention.make_poly_collection(data_array='temp')

# One way to avoid this is to isel the data array
convention.make_patch_collection(data_array=dataset.data_vars['temp'].isel(z=0, t=0))
convention.make_poly_collection(data_array=dataset.data_vars['temp'].isel(z=0, t=0))

# Another way to avoid this is to isel the dataset
dataset_0 = dataset.isel(z=0, t=0)
convention = SimpleConvention(dataset_0)
convention.make_patch_collection(data_array='temp')
convention.make_poly_collection(data_array='temp')
Loading

0 comments on commit 59c2cc6

Please sign in to comment.