From 75dd7df404a1a922e1d7f2cf4a9c101ed5204033 Mon Sep 17 00:00:00 2001 From: Tim Heap Date: Thu, 27 Apr 2023 13:30:26 +1000 Subject: [PATCH] Use PolyCollection over PatchCollection in matplotlib plots --- docs/releases/development.rst | 4 ++ src/emsarray/conventions/_base.py | 42 +++++++++++------- src/emsarray/plot.py | 74 +++++++++++++++---------------- tests/conventions/test_base.py | 24 +++++----- tests/test_plot.py | 15 ++----- 5 files changed, 81 insertions(+), 78 deletions(-) diff --git a/docs/releases/development.rst b/docs/releases/development.rst index 6163b2b..42ab24e 100644 --- a/docs/releases/development.rst +++ b/docs/releases/development.rst @@ -4,3 +4,7 @@ Next release (in development) * Fix an issue with negative coordinates in :func:`~emsarray.cli.utils.bounds_argument` (:pr:`74`). * Add a new ``emsarray plot`` subcommand to the ``emsarray`` command line interface (:pr:`76`). +* Use :class:`matplotlib.collections.PolyCollection` + rather than :class:`~matplotlib.collections.PatchCollection` + for significant speed improvements + (:pr:`77`). diff --git a/src/emsarray/conventions/_base.py b/src/emsarray/conventions/_base.py index e405d56..0429055 100644 --- a/src/emsarray/conventions/_base.py +++ b/src/emsarray/conventions/_base.py @@ -4,6 +4,7 @@ import dataclasses import enum import logging +import warnings from functools import cached_property from typing import ( TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Generic, Hashable, List, @@ -19,8 +20,7 @@ from emsarray.compat.shapely import SpatialIndex from emsarray.operations import depth from emsarray.plot import ( - _requires_plot, animate_on_figure, plot_on_figure, - polygons_to_patch_collection + _requires_plot, animate_on_figure, plot_on_figure, polygons_to_collection ) from emsarray.state import State from emsarray.types import Pathish @@ -30,7 +30,7 @@ from cartopy.crs import CRS from matplotlib.animation import FuncAnimation from matplotlib.axes import Axes - from matplotlib.collections import PatchCollection + from matplotlib.collections import PolyCollection from matplotlib.figure import Figure from matplotlib.quiver import Quiver @@ -552,7 +552,7 @@ def data_crs(self) -> CRS: """ The coordinate reference system that coordinates in this dataset are defined in. - Used by :meth:`.make_patch_collection` and :meth:`.make_quiver`. + Used by :meth:`.make_poly_collection` and :meth:`.make_quiver`. Defaults to :class:`cartopy.crs.PlateCarree`. """ # Lazily imported here as cartopy is an optional dependency @@ -746,35 +746,35 @@ def animate_on_figure( return animate_on_figure(figure, self, coordinate=coordinate, **kwargs) @_requires_plot - def make_patch_collection( + def make_poly_collection( self, data_array: Optional[DataArrayOrName] = None, **kwargs: Any, - ) -> PatchCollection: + ) -> PolyCollection: """ - Make a :class:`~matplotlib.collections.PatchCollection` + Make a :class:`~matplotlib.collections.PolyCollection` from the geometry of this :class:`~xarray.Dataset`. This can be used to make custom matplotlib plots from your data. If a :class:`~xarray.DataArray` is passed in, - the values of that are assigned to the PatchCollection `array` parameter. + the values of that are assigned to the PolyCollection `array` parameter. Parameters ---------- data_array : Hashable or :class:`xarray.DataArray`, optional A data array, or the name of a data variable in this dataset. Optional. If given, the data array is :meth:`linearised <.make_linear>` - and passed to :meth:`PatchCollection.set_array() `. + and passed to :meth:`PolyCollection.set_array() `. The data is used to colour the patches. Refer to the matplotlib documentation for more information on styling. **kwargs Any keyword arguments are passed to the - :class:`~matplotlib.collections.PatchCollection` constructor. + :class:`~matplotlib.collections.PolyCollection` constructor. Returns ------- - :class:`~matplotlib.collections.PatchCollection` - A PatchCollection constructed using the geometry of this dataset. + :class:`~matplotlib.collections.PolyCollection` + A PolyCollection constructed using the geometry of this dataset. Example ------- @@ -791,7 +791,7 @@ def make_patch_collection( ds = emsarray.open_dataset("./tests/datasets/ugrid_mesh2d.nc") ds = ds.isel(record=0, Mesh2_layers=-1) - patches = ds.ems.make_patch_collection('temp') + patches = ds.ems.make_poly_collection('temp') axes.add_collection(patches) figure.colorbar(patches, ax=axes, location='right', label='meters') @@ -802,7 +802,7 @@ def make_patch_collection( if data_array is not None: if 'array' in kwargs: raise TypeError( - "Can not pass both `data_array` and `array` to make_patch_collection" + "Can not pass both `data_array` and `array` to make_poly_collection" ) data_array = self._get_data_array(data_array) @@ -821,7 +821,19 @@ def make_patch_collection( if 'transform' not in kwargs: kwargs['transform'] = self.data_crs - return polygons_to_patch_collection(self.polygons[self.mask], **kwargs) + return polygons_to_collection(self.polygons[self.mask], **kwargs) + + def make_patch_collection( + self, + data_array: Optional[DataArrayOrName] = None, + **kwargs: Any, + ) -> PolyCollection: + warnings.warn( + "Convention.make_patch_collection has been renamed to " + "Convention.make_poly_collection, and now returns a PolyCollection", + category=DeprecationWarning, + ) + return self.make_poly_collection(data_array, **kwargs) @_requires_plot def make_quiver( diff --git a/src/emsarray/plot.py b/src/emsarray/plot.py index 20619ea..ed3f091 100644 --- a/src/emsarray/plot.py +++ b/src/emsarray/plot.py @@ -17,10 +17,10 @@ import cartopy.crs from cartopy.feature import GSHHSFeature from cartopy.mpl import gridliner - from matplotlib import animation, patches + from matplotlib import animation from matplotlib.artist import Artist from matplotlib.axes import Axes - from matplotlib.collections import PatchCollection + from matplotlib.collections import PolyCollection from matplotlib.figure import Figure from shapely.geometry import Polygon CAN_PLOT = True @@ -30,7 +30,7 @@ IMPORT_EXCEPTION = exc -__all___ = ['CAN_PLOT', 'plot_on_figure', 'polygon_to_patch'] +__all___ = ['CAN_PLOT', 'plot_on_figure', 'polygons_to_collection'] _requires_plot = requires_extra(extra='plot', import_error=IMPORT_EXCEPTION) @@ -81,7 +81,7 @@ def bounds_to_extent(bounds: Tuple[float, float, float, float]) -> List[float]: import cartopy.crs as ccrs import matplotlib.pyplot as plt - from emsarray.plot import bounds_to_extent, polygon_to_patch + from emsarray.plot import bounds_to_extent from shapely.geometry import Polygon polygon = Polygon([ @@ -91,44 +91,40 @@ def bounds_to_extent(bounds: Tuple[float, float, float, float]) -> List[float]: figure = plt.figure(figsize=(10, 8), dpi=100) axes = plt.subplot(projection=ccrs.PlateCarree()) axes.set_extent(bounds_to_extent(polygon.buffer(0.1).bounds)) - axes.add_patch(polygon_to_patch(polygon)) - figure.show() """ minx, miny, maxx, maxy = bounds return [minx, maxx, miny, maxy] @_requires_plot -def polygon_to_patch(polygon: Polygon, **kwargs: Any) -> patches.Polygon: - """ - Convert a :class:`shapely.geometry.Polygon ` to a - :class:`matplotlib.patches.Polygon`. - """ - return patches.Polygon(np.transpose(polygon.exterior.xy), **kwargs) - - -@_requires_plot -def polygons_to_patch_collection( +def polygons_to_collection( polygons: Iterable[Polygon], **kwargs: Any, -) -> PatchCollection: +) -> PolyCollection: """ Convert a list of Shapely :class:`Polygons ` - to a matplotlib :class:`~matplotlib.collections.PatchCollection`. + to a matplotlib :class:`~matplotlib.collections.PolyCollection`. Parameters ---------- - polygons : iterable of `Polygon` - The polygons for the patch collection + polygons : iterable of Shapely :class:`Polygons ` + The polygons for the poly collection **kwargs : Any - Keyword arguments to pass to the PatchCollection constructor. + Keyword arguments to pass to the PolyCollection constructor. Returns ------- - :class:`matplotlib.collections.PatchCollection` - The PatchCollection made up of the polygons passed in. + :class:`matplotlib.collections.PolyCollection` + A PolyCollection made up of the polygons passed in. """ - return PatchCollection(map(polygon_to_patch, polygons), **kwargs) + return PolyCollection( + verts=[ + np.asarray(polygon.exterior.coords) + for polygon in polygons + ], + closed=False, + **kwargs + ) @_requires_plot @@ -154,7 +150,7 @@ def plot_on_figure( This is used to build the polygons and vector quivers. scalar : :class:`xarray.DataArray`, optional The data to plot as an :class:`xarray.DataArray`. - This will be passed to :meth:`.Convention.make_patch_collection`. + This will be passed to :meth:`.Convention.make_poly_collection`. vector : tuple of :class:`numpy.ndarray`, optional The *u* and *v* components of a vector field as a tuple of :class:`xarray.DataArray`. @@ -175,18 +171,18 @@ def plot_on_figure( if scalar is None and vector is None: # Plot the polygon shapes for want of anything else to draw - patches = convention.make_patch_collection() - axes.add_collection(patches) + collection = convention.make_poly_collection() + axes.add_collection(collection) if title is None: title = 'Geometry' if scalar is not None: # Plot a scalar variable on the polygons using a colour map - patches = convention.make_patch_collection( + collection = convention.make_poly_collection( scalar, cmap='jet', edgecolor='face') - axes.add_collection(patches) + axes.add_collection(collection) units = scalar.attrs.get('units') - figure.colorbar(patches, ax=axes, location='right', label=units) + figure.colorbar(collection, ax=axes, location='right', label=units) if vector is not None: # Plot a vector variable using a quiver @@ -230,7 +226,7 @@ def animate_on_figure( The coordinate values to vary across frames in the animation. scalar : :class:`xarray.DataArray`, optional The data to plot as an :class:`xarray.DataArray`. - This will be passed to :meth:`.Convention.make_patch_collection`. + This will be passed to :meth:`.Convention.make_poly_collection`. It should have horizontal dimensions appropriate for this convention, and a dimension matching the ``coordinate`` parameter. vector : tuple of :class:`numpy.ndarray`, optional @@ -273,17 +269,17 @@ def animate_on_figure( axes.set_aspect(aspect='equal', adjustable='datalim') axes.title.set_animated(True) - patches = None + collection = None if scalar is not None: # Plot a scalar variable on the polygons using a colour map scalar_values = convention.make_linear(scalar).values[:, convention.mask] - patches = convention.make_patch_collection( + collection = convention.make_poly_collection( cmap='jet', edgecolor='face', clim=(np.nanmin(scalar_values), np.nanmax(scalar_values))) - axes.add_collection(patches) - patches.set_animated(True) + axes.add_collection(collection) + collection.set_animated(True) units = scalar.attrs.get('units') - figure.colorbar(patches, ax=axes, location='right', label=units) + figure.colorbar(collection, ax=axes, location='right', label=units) quiver = None if vector is not None: @@ -333,9 +329,9 @@ def animate(index: int) -> Iterable[Artist]: changes.extend(gridlines.xline_artists) changes.extend(gridlines.yline_artists) - if patches is not None: - patches.set_array(scalar_values[index]) - changes.append(patches) + if collection is not None: + collection.set_array(scalar_values[index]) + changes.append(collection) if quiver is not None: quiver.set_UVC(vector_u_values[index], vector_v_values[index]) diff --git a/tests/conventions/test_base.py b/tests/conventions/test_base.py index 85f8118..4c109bd 100644 --- a/tests/conventions/test_base.py +++ b/tests/conventions/test_base.py @@ -325,28 +325,28 @@ def test_face_centres(): @pytest.mark.matplotlib -def test_make_patch_collection(): +def test_make_poly_collection(): dataset = xr.Dataset({ 'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))), 'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10), }) convention = SimpleConvention(dataset) - patches = convention.make_patch_collection(cmap='plasma', edgecolor='black') + patches = convention.make_poly_collection(cmap='plasma', edgecolor='black') assert len(patches.get_paths()) == len(convention.polygons[convention.mask]) assert patches.get_cmap().name == 'plasma' # Colours get transformed in to RGBA arrays np.testing.assert_equal(patches.get_edgecolor(), [[0., 0., 0., 1.0]]) -def test_make_patch_collection_data_array(): +def test_make_poly_collection_data_array(): dataset = xr.Dataset({ 'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))), 'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10), }) convention = SimpleConvention(dataset) - patches = convention.make_patch_collection(data_array='botz') + patches = convention.make_poly_collection(data_array='botz') assert len(patches.get_paths()) == len(convention.polygons[convention.mask]) values = convention.make_linear(dataset.data_vars['botz'])[convention.mask] @@ -354,7 +354,7 @@ def test_make_patch_collection_data_array(): assert patches.get_clim() == (np.nanmin(values), np.nanmax(values)) -def test_make_patch_collection_data_array_and_array(): +def test_make_poly_collection_data_array_and_array(): dataset = xr.Dataset({ 'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))), 'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10), @@ -365,10 +365,10 @@ def test_make_patch_collection_data_array_and_array(): with pytest.raises(TypeError): # Passing both array and data_array is a TypeError - convention.make_patch_collection(data_array='botz', array=array) + convention.make_poly_collection(data_array='botz', array=array) -def test_make_patch_collection_data_array_and_clim(): +def test_make_poly_collection_data_array_and_clim(): dataset = xr.Dataset({ 'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))), 'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10), @@ -376,11 +376,11 @@ def test_make_patch_collection_data_array_and_clim(): convention = SimpleConvention(dataset) # You can override the default clim if you want - patches = convention.make_patch_collection(data_array='botz', clim=(-12, -8)) + patches = convention.make_poly_collection(data_array='botz', clim=(-12, -8)) assert patches.get_clim() == (-12, -8) -def test_make_patch_collection_data_array_dimensions(): +def test_make_poly_collection_data_array_dimensions(): dataset = xr.Dataset({ 'temp': (['t', 'z', 'y', 'x'], np.random.standard_normal((5, 5, 10, 20))), 'botz': (['y', 'x'], np.random.standard_normal((10, 20)) - 10), @@ -389,12 +389,12 @@ def test_make_patch_collection_data_array_dimensions(): with pytest.raises(ValueError): # temp needs subsetting first, so this should raise an error - convention.make_patch_collection(data_array='temp') + convention.make_poly_collection(data_array='temp') # One way to avoid this is to isel the data array - convention.make_patch_collection(data_array=dataset.data_vars['temp'].isel(z=0, t=0)) + convention.make_poly_collection(data_array=dataset.data_vars['temp'].isel(z=0, t=0)) # Another way to avoid this is to isel the dataset dataset_0 = dataset.isel(z=0, t=0) convention = SimpleConvention(dataset_0) - convention.make_patch_collection(data_array='temp') + convention.make_poly_collection(data_array='temp') diff --git a/tests/test_plot.py b/tests/test_plot.py index ff5640e..19f193b 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -2,26 +2,17 @@ import pytest from shapely.geometry import Polygon -from emsarray.plot import polygon_to_patch, polygons_to_patch_collection +from emsarray.plot import polygons_to_collection @pytest.mark.matplotlib -def test_polygon_to_patch(): - polygon = Polygon([(0, 0), (1, 0), (2, 2,), (0, 1), (0, 0)]) - patch = polygon_to_patch(polygon) - for index, poly_coords in enumerate(polygon.exterior.coords): - patch_coords = patch.get_xy()[index] - assert tuple(poly_coords) == tuple(patch_coords) - - -@pytest.mark.matplotlib -def test_polygons_to_patch_collection(): +def test_polygons_to_collection(): polygons = [ Polygon([(i, 0), (i + 1, 0), (i + 1, 1), (i, 1), (i, 0)]) for i in range(10) ] data = np.random.random(10) * 10 - patch_collection = polygons_to_patch_collection( + patch_collection = polygons_to_collection( polygons, array=data, cmap='autumn', clim=(0, 10)) # Check that the polygons came through