Skip to content

Commit

Permalink
Make Gridliner into an Artist
Browse files Browse the repository at this point in the history
  • Loading branch information
rcomer committed Sep 29, 2023
1 parent f2bb81d commit 44bedc1
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 39 deletions.
21 changes: 10 additions & 11 deletions lib/cartopy/mpl/geoaxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ def __init__(self, *args, **kwargs):
self.projection = projection

super().__init__(*args, **kwargs)
self._gridliners = []
self.img_factories = []
self._done_img_factory = False

Expand Down Expand Up @@ -487,16 +486,14 @@ def _draw_preprocess(self, renderer):
self.patch._adjust_location()

self.apply_aspect()
for gl in self._gridliners:
gl._draw_gridliner(renderer=renderer)

def get_tightbbox(self, renderer, *args, **kwargs):
"""
Extend the standard behaviour of
:func:`matplotlib.axes.Axes.get_tightbbox`.
Adjust the axes aspect ratio, background patch location, and add
gridliners before calculating the tight bounding box.
Adjust the axes aspect ratio and background patch location before
calculating the tight bounding box.
"""
# Shared processing steps
self._draw_preprocess(renderer)
Expand All @@ -508,9 +505,8 @@ def draw(self, renderer=None, **kwargs):
"""
Extend the standard behaviour of :func:`matplotlib.axes.Axes.draw`.
Draw grid lines and image factory results before invoking standard
Matplotlib drawing. A global range is used if no limits have yet
been set.
Draw image factory results before invoking standard Matplotlib drawing.
A global range is used if no limits have yet been set.
"""
# Shared processing steps
self._draw_preprocess(renderer)
Expand All @@ -532,15 +528,18 @@ def draw(self, renderer=None, **kwargs):

def _update_title_position(self, renderer):
super()._update_title_position(renderer)
if not self._gridliners:

from cartopy.mpl.gridliner import Gridliner
gridliners = [a for a in self.artists if isinstance(a, Gridliner)]
if not gridliners:
return

if self._autotitlepos is not None and not self._autotitlepos:
return

# Get the max ymax of all top labels
top = -1
for gl in self._gridliners:
for gl in gridliners:
if gl.has_labels():
for label in (gl.top_label_artists +
gl.left_label_artists +
Expand Down Expand Up @@ -1512,7 +1511,7 @@ def gridlines(self, crs=None, draw_labels=False,
labels_bbox_style=labels_bbox_style,
xpadding=xpadding, ypadding=ypadding, offset_angle=offset_angle,
auto_update=auto_update, formatter_kwargs=formatter_kwargs)
self._gridliners.append(gl)
self.add_artist(gl)
return gl

def _gen_axes_patch(self):
Expand Down
64 changes: 39 additions & 25 deletions lib/cartopy/mpl/gridliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import warnings

import matplotlib
import matplotlib.artist
import matplotlib.collections as mcollections
import matplotlib.text
import matplotlib.ticker as mticker
import matplotlib.transforms as mtrans
import numpy as np
Expand Down Expand Up @@ -101,11 +103,7 @@ def _north_south_formatted(latitude, num_format='g'):
_north_south_formatted(v))


class Gridliner:
# NOTE: In future, one of these objects will be add-able to a GeoAxes (and
# maybe even a plain old mpl axes) and it will call the "_draw_gridliner"
# method on draw. This will enable automatic gridline resolution
# determination on zoom/pan.
class Gridliner(matplotlib.artist.Artist):
def __init__(self, axes, crs, draw_labels=False, xlocator=None,
ylocator=None, collection_kwargs=None,
xformatter=None, yformatter=None, dms=False,
Expand All @@ -115,7 +113,7 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None,
xpadding=5, ypadding=5, offset_angle=25,
auto_update=False, formatter_kwargs=None):
"""
Object used by :meth:`cartopy.mpl.geoaxes.GeoAxes.gridlines`
Artist used by :meth:`cartopy.mpl.geoaxes.GeoAxes.gridlines`
to add gridlines and tick labels to a map.
Parameters
Expand Down Expand Up @@ -234,7 +232,13 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None,
used for the map, meridians and parallels can cross both the X axis and
the Y axis.
"""
self.axes = axes
super().__init__()

# We do not want the labels clipped to axes.
self.set_clip_on(False)
# Backcompat: the LineCollection was previously added directly to the
# axes, having a default zorder of 2.
self.set_zorder(2)

#: The :class:`~matplotlib.ticker.Locator` to use for the x
#: gridlines and labels.
Expand Down Expand Up @@ -332,10 +336,10 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None,
raise ValueError(f"Invalid draw_labels argument: {value}")

if auto_inline:
if isinstance(self.axes.projection, _X_INLINE_PROJS):
if isinstance(axes.projection, _X_INLINE_PROJS):
self.x_inline = True
self.y_inline = False
elif isinstance(self.axes.projection, _POLAR_PROJS):
elif isinstance(axes.projection, _POLAR_PROJS):
self.x_inline = False
self.y_inline = True
else:
Expand Down Expand Up @@ -399,7 +403,7 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None,
#: Control the rotation of labels.
if rotate_labels is None:
rotate_labels = (
self.axes.projection.__class__ in _ROTATE_LABEL_PROJS)
axes.projection.__class__ in _ROTATE_LABEL_PROJS)
if not isinstance(rotate_labels, (bool, float, int)):
raise ValueError("Invalid rotate_labels argument")
self.rotate_labels = rotate_labels
Expand Down Expand Up @@ -436,10 +440,6 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None,
self._drawn = False
self._auto_update = auto_update

# Check visibility of labels at each draw event
# (or once drawn, only at resize event ?)
self.axes.figure.canvas.mpl_connect('draw_event', self._draw_event)

@property
def xlabels_top(self):
warnings.warn('The .xlabels_top attribute is deprecated. Please '
Expand Down Expand Up @@ -488,9 +488,6 @@ def ylabels_right(self, value):
'use .right_labels to toggle visibility instead.')
self.right_labels = value

def _draw_event(self, event):
self._draw_gridliner(renderer=event.renderer)

def has_labels(self):
return len(self._labels) != 0

Expand Down Expand Up @@ -629,13 +626,9 @@ def _draw_gridliner(self, nx=None, ny=None, renderer=None):
return
self._drawn = True

# Clear lists of artists
for lines in [*self.xline_artists, *self.yline_artists]:
lines.remove()
# Clear lists of child artists
self.xline_artists.clear()
self.yline_artists.clear()
for label in self._labels:
label.artist.remove()
self._labels.clear()

# Inits
Expand Down Expand Up @@ -673,6 +666,7 @@ def _draw_gridliner(self, nx=None, ny=None, renderer=None):
if not any(x in collection_kwargs for x in ['lw', 'linewidth']):
collection_kwargs.setdefault('linewidth',
matplotlib.rcParams['grid.linewidth'])
collection_kwargs.setdefault('clip_path', self.axes.patch)

# Meridians
lat_min, lat_max = lat_lim
Expand All @@ -696,7 +690,6 @@ def _draw_gridliner(self, nx=None, ny=None, renderer=None):
lon_lc = mcollections.LineCollection(lon_lines,
**collection_kwargs)
self.xline_artists.append(lon_lc)
self.axes.add_collection(lon_lc, autolim=False)

# Parallels
lon_min, lon_max = lon_lim
Expand All @@ -711,7 +704,6 @@ def _draw_gridliner(self, nx=None, ny=None, renderer=None):
lat_lc = mcollections.LineCollection(lat_lines,
**collection_kwargs)
self.yline_artists.append(lat_lc)
self.axes.add_collection(lat_lc, autolim=False)

#################
# Label drawing #
Expand Down Expand Up @@ -925,7 +917,9 @@ def update_artist(artist, renderer):

# Add text to the plot
text = formatter(tick_value)
artist = self.axes.text(x, y, text, **kw)
artist = matplotlib.text.Text(x, y, text, **kw)
artist.set_figure(self.axes.figure)
artist.axes = self.axes

# Update loc from spine overlapping now that we have a bbox
# of the label.
Expand Down Expand Up @@ -1239,6 +1233,26 @@ def _axes_domain(self, nx=None, ny=None):

return lon_range, lat_range

def get_visible_children(self):
r"""Return a list of the visible child `.Artist`\s."""
all_children = (self.xline_artists + self.yline_artists
+ self.label_artists)
return [c for c in all_children if c.get_visible()]

def get_tightbbox(self, renderer=None):
self._draw_gridliner(renderer=renderer)
bboxes = [c.get_tightbbox(renderer=renderer)
for c in self.get_visible_children()]
if bboxes:
return mtrans.Bbox.union(bboxes)
else:
return mtrans.Bbox.null()

def draw(self, renderer=None):
self._draw_gridliner(renderer=renderer)
for c in self.get_visible_children():
c.draw(renderer=renderer)


class Label:
"""Helper class to manage the attributes for a single label"""
Expand Down
47 changes: 44 additions & 3 deletions lib/cartopy/tests/mpl/test_gridliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# See COPYING and COPYING.LESSER in the root of the repository for full
# licensing details.

import io
from unittest import mock

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
Expand All @@ -13,7 +16,8 @@
import cartopy.crs as ccrs
from cartopy.mpl.geoaxes import GeoAxes
from cartopy.mpl.gridliner import (LATITUDE_FORMATTER, LONGITUDE_FORMATTER,
classic_formatter, classic_locator)
Gridliner, classic_formatter,
classic_locator)
from cartopy.mpl.ticker import LongitudeFormatter, LongitudeLocator


Expand Down Expand Up @@ -242,8 +246,9 @@ def test_grid_labels_tight():

# Ensure gridliners were drawn
for ax in fig.axes:
for gl in ax._gridliners:
assert hasattr(gl, '_drawn') and gl._drawn
for artist in ax.artists:
if isinstance(artist, Gridliner):
assert hasattr(artist, '_drawn') and artist._drawn

return fig

Expand Down Expand Up @@ -432,3 +437,39 @@ def test_gridliner_formatter_kwargs():
fig.canvas.draw()
labels = [a.get_text() for a in gl.bottom_label_artists if a.get_visible()]
assert labels == ['75°O', '70°O', '65°O', '60°O', '55°O', '50°O', '45°O']


def test_gridliner_count_draws():
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
ax.set_global()
gl = ax.gridlines()

with mock.patch.object(gl, '_draw_gridliner', return_value=None) as mocked:
ax.get_tightbbox(renderer=None)
mocked.assert_called_once()

with mock.patch.object(gl, '_draw_gridliner', return_value=None) as mocked:
fig.draw_without_rendering()
mocked.assert_called_once()


def test_gridliner_remove():
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
ax.set_global()
gl = ax.gridlines(draw_labels=True)
fig.draw_without_rendering() # Generate child artists
gl.remove()

assert not ax.artists
assert not ax.collections


def test_gridliner_save_tight_bbox():
# Smoke test for save with auto_update=True and bbox_inches=Tight (gh2246).
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
ax.set_global()
ax.gridlines(draw_labels=True, auto_update=True)
fig.savefig(io.BytesIO(), bbox_inches='tight')

0 comments on commit 44bedc1

Please sign in to comment.