Skip to content

Commit

Permalink
Merge pull request #114 from csiro-coasts/transect-attributes
Browse files Browse the repository at this point in the history
Fix transect plot title and labels
  • Loading branch information
mx-moth authored Sep 26, 2023
2 parents f7fffc2 + c44242b commit 4853266
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 34 deletions.
4 changes: 3 additions & 1 deletion docs/releases/development.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
Next release (in development)
=============================

* ...
* Fix transect plot title and units.
All attributes were being dropped accidentally in `prepare_data_array_for_transect()`.
(:pr:`114`).
23 changes: 3 additions & 20 deletions src/emsarray/conventions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from emsarray.exceptions import NoSuchCoordinateError
from emsarray.operations import depth
from emsarray.plot import (
_requires_plot, animate_on_figure, plot_on_figure, polygons_to_collection
_requires_plot, animate_on_figure, make_plot_title, plot_on_figure,
polygons_to_collection
)
from emsarray.state import State
from emsarray.types import Bounds, Pathish
Expand Down Expand Up @@ -938,25 +939,7 @@ def plot_on_figure(
#
# Users can supply their own titles
# if this automatic behaviour is insufficient
title_bits: list[str] = []
long_name = kwargs['scalar'].attrs.get('long_name')
if long_name is not None:
title_bits.append(str(long_name))
try:
time_coordinate = self.dataset.variables[self.get_time_name()]
except KeyError:
pass
else:
# Add a time stamp when the time coordinate has a single value.
# This happens when you `.sel()` a single time slice to plot -
# as long as the time coordinate is a proper coordinate with
# matching dimension name, not an auxiliary coordinate.
if time_coordinate.size == 1:
time = time_coordinate.values[0]
title_bits.append(str(time))

if title_bits:
kwargs['title'] = '\n'.join(title_bits)
kwargs['title'] = make_plot_title(self.dataset, kwargs['scalar'])

plot_on_figure(figure, self, **kwargs)

Expand Down
43 changes: 43 additions & 0 deletions src/emsarray/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy
import xarray

from emsarray.exceptions import NoSuchCoordinateError
from emsarray.types import Landmark
from emsarray.utils import requires_extra

Expand Down Expand Up @@ -217,6 +218,48 @@ def polygons_to_collection(
)


def make_plot_title(
dataset: xarray.Dataset,
data_array: xarray.DataArray,
) -> Optional[str]:
"""
Make a suitable plot title for a variable.
This will attempt to find a name for the variable by looking through the attributes.
If the variable has a time coordinate,
and the time coordinate has a single value,
the time step is appended after the title.
"""
if 'long_name' in data_array.attrs:
title = str(data_array.attrs['long_name'])
elif 'standard_name' in data_array.attrs:
title = str(data_array.attrs['standard_name'])
elif data_array.name is not None:
title = str(data_array.name)
else:
return None

# Check if this variable has a time coordinate
try:
time_coordinate = dataset.ems.time_coordinate
except NoSuchCoordinateError:
return title
if time_coordinate.name not in data_array.coords:
return title
# Fetch the coordinate from the data array itself,
# in case someone did `data_array = dataset['temp'].isel(time=0)`
time_coordinate = data_array.coords[time_coordinate.name]

if len(time_coordinate.dims) == 0:
time_value = time_coordinate.values
elif time_coordinate.size == 1:
time_value = time_coordinate.values[0]
else:
return title

time_string = numpy.datetime_as_string(time_value, unit='auto')
return f'{title}\n{time_string}'


@_requires_plot
def plot_on_figure(
figure: Figure,
Expand Down
17 changes: 12 additions & 5 deletions src/emsarray/transect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import shapely
import xarray
from cartopy import crs
from matplotlib import animation, cm, pyplot
from matplotlib import animation, colormaps, pyplot
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.collections import PolyCollection
Expand All @@ -20,7 +20,7 @@
from matplotlib.ticker import EngFormatter, Formatter

from emsarray.conventions import Convention, Index
from emsarray.plot import _requires_plot
from emsarray.plot import _requires_plot, make_plot_title
from emsarray.types import Landmark
from emsarray.utils import move_dimensions_to_end

Expand Down Expand Up @@ -62,7 +62,7 @@ def plot(
figure = pyplot.figure(layout="constrained", figsize=figsize)
transect = Transect(dataset, line)
transect.plot_on_figure(figure, data_array, **kwargs)
figure.show()
pyplot.show()
return figure


Expand Down Expand Up @@ -493,6 +493,10 @@ def prepare_data_array_for_transect(self, data_array: xarray.DataArray) -> xarra
The input data array transformed to have the correct shape
for plotting on the transect.
"""
# Some of the following operations drop attrs,
# so keep a reference to the original ones
attrs = data_array.attrs

data_array = self.convention.ravel(data_array)

depth_dimension = self.transect_dataset.coords['depth'].dims[0]
Expand All @@ -502,6 +506,9 @@ def prepare_data_array_for_transect(self, data_array: xarray.DataArray) -> xarra
linear_indices = self.transect_dataset['linear_index'].values
data_array = data_array.isel({index_dimension: linear_indices})

# Restore attrs after reformatting
data_array.attrs.update(attrs)

return data_array

def _find_depth_bounds(self, data_array: xarray.DataArray) -> Tuple[int, int]:
Expand Down Expand Up @@ -749,11 +756,11 @@ def _plot_on_figure(
axes.set_ylim(depth_limit_deep, depth_limit_shallow)

if title is None:
title = data_array.attrs.get('long_name')
title = make_plot_title(self.dataset, data_array)
if title is not None:
axes.set_title(title)

cmap = cm.get_cmap(cmap).copy()
cmap = colormaps[cmap].copy()
cmap.set_bad(ocean_floor_colour)
collection = self.make_poly_collection(
cmap=cmap, clim=(numpy.nanmin(data_array), numpy.nanmax(data_array)))
Expand Down
1 change: 1 addition & 0 deletions tests/conventions/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def test_plot():
'botz': (['y', 'x'], numpy.random.standard_normal((10, 20)) - 10),
})
convention = SimpleConvention(dataset)
convention.bind()

# Naming a simple variable should work fine
convention.plot('botz')
Expand Down
12 changes: 4 additions & 8 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,15 @@ def test_plot(
datasets: pathlib.Path,
tmp_path: pathlib.Path,
):
"""
Test plotting a variable with no long_name attribute works.
Regression test for https://github.com/csiro-coasts/emsarray/issues/105
"""
dataset = emsarray.tutorial.open_dataset('gbr4')
dataset = emsarray.tutorial.open_dataset('fraser')
temp = dataset['temp'].copy()
temp = temp.isel(time=0, k=-1)

dataset.ems.plot(temp)

figure = matplotlib.pyplot.gcf()
axes = figure.axes[0]
assert axes.get_title() == 'Temperature\n2022-05-11T14:00:00.000000000'
assert axes.get_title() == 'Temperature\n2022-05-11T14:00'

matplotlib.pyplot.savefig(tmp_path / 'plot.png')
logger.info("Saved plot to %r", tmp_path / 'plot.png')
Expand All @@ -86,7 +82,7 @@ def test_plot_no_long_name(
Test plotting a variable with no long_name attribute works.
Regression test for https://github.com/csiro-coasts/emsarray/issues/105
"""
dataset = emsarray.tutorial.open_dataset('gbr4')
dataset = emsarray.tutorial.open_dataset('fraser')
temp = dataset['temp'].copy()
temp = temp.isel(time=0, k=-1)
del temp.attrs['long_name']
Expand All @@ -95,7 +91,7 @@ def test_plot_no_long_name(

figure = matplotlib.pyplot.gcf()
axes = figure.axes[0]
assert axes.get_title() == '2022-05-11T14:00:00.000000000'
assert axes.get_title() == 'temp\n2022-05-11T14:00'

matplotlib.pyplot.savefig(tmp_path / 'plot.png')
logger.info("Saved plot to %r", tmp_path / 'plot.png')
57 changes: 57 additions & 0 deletions tests/test_transect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
import pathlib

import matplotlib
import pytest
import shapely

import emsarray.transect

logger = logging.getLogger(__name__)


@pytest.mark.matplotlib(mock_coast=True)
@pytest.mark.tutorial
def test_plot(
datasets: pathlib.Path,
tmp_path: pathlib.Path,
):
dataset = emsarray.tutorial.open_dataset('gbr4')
temp = dataset['temp'].copy()
temp = temp.isel(time=-1)

line = shapely.LineString([
[152.9768944, -25.4827962],
[152.9701996, -25.4420345],
[152.9727745, -25.3967620],
[152.9623032, -25.3517828],
[152.9401588, -25.3103560],
[152.9173279, -25.2538563],
[152.8962135, -25.1942238],
[152.8692627, -25.0706729],
[152.8623962, -24.9698750],
[152.8472900, -24.8415806],
[152.8308105, -24.6470172],
[152.7607727, -24.3521012],
[152.6392365, -24.1906056],
[152.4792480, -24.0615124],
])
emsarray.transect.plot(
dataset, line, temp,
bathymetry=dataset['botz'])

figure = matplotlib.pyplot.gcf()
axes = figure.axes[0]
# This is assembled from the variable long_name and the time coordinate
assert axes.get_title() == 'Temperature\n2022-05-11T14:00'
# This is the long_name of the depth coordinate
assert axes.get_ylabel() == 'Z coordinate'
# This is made up
assert axes.get_xlabel() == 'Distance along transect'

colorbar = figure.axes[-1]
# This is the variable units
assert colorbar.get_ylabel() == 'degrees C'

matplotlib.pyplot.savefig(tmp_path / 'plot.png')
logger.info("Saved plot to %r", tmp_path / 'plot.png')

0 comments on commit 4853266

Please sign in to comment.