Skip to content

Commit

Permalink
ENH: Updates for stripe plot
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamTheisen committed Oct 29, 2024
1 parent 7bf846b commit 9b777e3
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions act/plotting/timeseriesdisplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,7 +1858,9 @@ def plot_stripes(
subplot_index=(0,),
set_title=None,
reference_period=None,
cmap='coolwarm',
cmap='bwr',
cbar_label=None,
colorbar=True,
**kwargs,
):
"""
Expand All @@ -1882,7 +1884,11 @@ def plot_stripes(
If this is set, the plot will subtract the mean of the reference period from the
field to create an anomaly calculation.
cmap : string
Colormap to use for plotting. Defaults to coolwarm
Colormap to use for plotting. Defaults to bwr
cbar_label : str
Option to overwrite default colorbar label.
colorbar : boolean
Option to not plot the colorbar. Default is to plot it
**kwargs : keyword arguments
The keyword arguments for :func:`plt.plot` (1D timeseries) or
:func:`plt.pcolormesh` (2D timeseries).
Expand All @@ -1907,15 +1913,11 @@ def plot_stripes(
dim = list(self._ds[dsname][field].dims)
xdata = self._ds[dsname][dim[0]]

if 'units' in data.attrs:
ytitle = ''.join(['(', data.attrs['units'], ')'])
else:
ytitle = field

delta = 1
start = int(mdates.date2num(xdata.values[0]))
end = int(mdates.date2num(xdata.values[-1]))
delta = stats.mode(xdata.diff('time').values)[0] / np.timedelta64(1, 'D')

# Calculate mean for reference period and subtract from the data
if reference_period is not None:
reference = data.sel(time=slice(reference_period[0], reference_period[1])).mean('time')
data.values = data.values - reference.values
Expand All @@ -1931,8 +1933,10 @@ def plot_stripes(
# Set ax to appropriate axis
ax = self.axes[subplot_index]

col = PatchCollection([Rectangle((y, 0), 1, 1) for y in np.arange(start, end + 1, delta)])

# Plot up data using rectangles
col = PatchCollection(
[Rectangle((y, 0), delta, 1) for y in np.arange(start, end + 1, delta)]
)
col.set_array(data)
col.set_cmap(cmap)
col.set_clim(np.nanmin(data), np.nanmax(data))
Expand All @@ -1947,9 +1951,6 @@ def plot_stripes(
ax.set_yticks([])
ax.set_xlim(start, end + 1)

# Set YTitle
ax.set_ylabel(ytitle)

# Set Title
if set_title is None:
set_title = ' '.join(
Expand All @@ -1961,5 +1962,18 @@ def plot_stripes(
]
)
ax.set_title(set_title)

# Set Colorbar
if colorbar:
if 'units' in data.attrs:
ytitle = ''.join(['(', data.attrs['units'], ')'])
else:
ytitle = field
if cbar_label is None:
cbar_title = ytitle
else:
cbar_title = ''.join(['(', cbar_label, ')'])
self.add_colorbar(col, title=cbar_title, subplot_index=subplot_index)

self.axes[subplot_index] = ax
return self.axes[subplot_index]

0 comments on commit 9b777e3

Please sign in to comment.