Skip to content

Commit

Permalink
Improvements for the secondary y axis functionality (#723)
Browse files Browse the repository at this point in the history
* ENH: Enhancing the secondary y axis functionality which will be a change in how we do things and might need to be considered for a v2.0

* ENH: Adding example for the secondary y-axis plotting

* ENH: formatting

* ENH: New plot for testing secondary_y

* ENH: New test for secondary_y

* ENH: pep8 update

* ENH: Bug fix

* ENH: Bug fixes for examples

* DOC: Pep8

* DOC: PEP8

* ENH: bug fix for test
  • Loading branch information
AdamTheisen authored Nov 29, 2023
1 parent 4bf0e1e commit 3cd0a2e
Show file tree
Hide file tree
Showing 17 changed files with 372 additions and 186 deletions.
2 changes: 1 addition & 1 deletion act/plotting/contourdisplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ContourDisplay(Display):
"""

def __init__(self, ds, subplot_shape=(1,), ds_name=None, **kwargs):
super().__init__(ds, subplot_shape, ds_name, **kwargs)
super().__init__(ds, subplot_shape, ds_name, secondary_y_allowed=False, **kwargs)

def create_contour(
self,
Expand Down
133 changes: 66 additions & 67 deletions act/plotting/distributiondisplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DistributionDisplay(Display):
"""

def __init__(self, ds, subplot_shape=(1,), ds_name=None, **kwargs):
super().__init__(ds, subplot_shape, ds_name, **kwargs)
super().__init__(ds, subplot_shape, ds_name, secondary_y_allowed=True, **kwargs)

def set_xrng(self, xrng, subplot_index=(0,)):
"""
Expand All @@ -55,7 +55,7 @@ def set_xrng(self, xrng, subplot_index=(0,)):
elif not hasattr(self, 'xrng') and len(self.axes.shape) == 1:
self.xrng = np.zeros((self.axes.shape[0], 2), dtype='datetime64[D]')

self.axes[subplot_index].set_xlim(xrng)
self.axes[subplot_index][0].set_xlim(xrng)
self.xrng[subplot_index, :] = np.array(xrng)

def set_yrng(self, yrng, subplot_index=(0,)):
Expand All @@ -81,7 +81,7 @@ def set_yrng(self, yrng, subplot_index=(0,)):
if yrng[0] == yrng[1]:
yrng[1] = yrng[1] + 1

self.axes[subplot_index].set_ylim(yrng)
self.axes[subplot_index][0].set_ylim(yrng)
self.yrng[subplot_index, :] = yrng

def _get_data(self, dsname, fields):
Expand Down Expand Up @@ -163,13 +163,13 @@ def plot_stacked_bar_graph(
# We will defaut the y direction to have the same # of bins as x
sortby_bins = np.linspace(ydata.values.min(), ydata.values.max(), len(bins))

# Get the current plotting axis, add day/night background and plot data
# Get the current plotting axis
if self.fig is None:
self.fig = plt.figure()

if self.axes is None:
self.axes = np.array([plt.axes()])
self.fig.add_axes(self.axes[0])
self.axes = np.array([[plt.axes(), plt.axes().twinx()]])
for a in self.axes[0]:
self.fig.add_axes(a)

if sortby_field is not None:
if 'units' in ydata.attrs:
Expand All @@ -189,26 +189,26 @@ def plot_stacked_bar_graph(
bins=[bins, sortby_bins],
**hist_kwargs)
x_inds = (x_bins[:-1] + x_bins[1:]) / 2.0
self.axes[subplot_index].bar(
self.axes[subplot_index][0].bar(
x_inds,
my_hist[:, 0].flatten(),
label=(str(y_bins[0]) + ' to ' + str(y_bins[1])),
**kwargs,
)
for i in range(1, len(y_bins) - 1):
self.axes[subplot_index].bar(
self.axes[subplot_index][0].bar(
x_inds,
my_hist[:, i].flatten(),
bottom=my_hist[:, i - 1],
label=(str(y_bins[i]) + ' to ' + str(y_bins[i + 1])),
**kwargs,
)
self.axes[subplot_index].legend()
self.axes[subplot_index][0].legend()
else:
my_hist, bins = np.histogram(xdata.values.flatten(), bins=bins,
density=density, **hist_kwargs)
x_inds = (bins[:-1] + bins[1:]) / 2.0
self.axes[subplot_index].bar(x_inds, my_hist)
self.axes[subplot_index][0].bar(x_inds, my_hist)

# Set Title
if set_title is None:
Expand All @@ -220,9 +220,9 @@ def plot_stacked_bar_graph(
dt_utils.numpy_to_arm_date(self._ds[dsname].time.values[0]),
]
)
self.axes[subplot_index].set_title(set_title)
self.axes[subplot_index].set_ylabel('count')
self.axes[subplot_index].set_xlabel(xtitle)
self.axes[subplot_index][0].set_title(set_title)
self.axes[subplot_index][0].set_ylabel('count')
self.axes[subplot_index][0].set_xlabel(xtitle)

return_dict = {}
return_dict['plot_handle'] = self.axes[subplot_index]
Expand Down Expand Up @@ -306,13 +306,13 @@ def plot_size_distribution(
+ 'length is equal to the field length!'
)

# Get the current plotting axis, add day/night background and plot data
# Get the current plotting axis
if self.fig is None:
self.fig = plt.figure()

if self.axes is None:
self.axes = np.array([plt.axes()])
self.fig.add_axes(self.axes[0])
self.axes = np.array([[plt.axes(), plt.axes().twinx()]])
for a in self.axes[0]:
self.fig.add_axes(a)

# Set Title
if set_title is None:
Expand All @@ -327,10 +327,10 @@ def plot_size_distribution(
if time is not None:
t = pd.Timestamp(time)
set_title += ''.join([' at ', ':'.join([str(t.hour), str(t.minute), str(t.second)])])
self.axes[subplot_index].set_title(set_title)
self.axes[subplot_index].step(bins.values, xdata.values, **kwargs)
self.axes[subplot_index].set_xlabel(xtitle)
self.axes[subplot_index].set_ylabel(ytitle)
self.axes[subplot_index][0].set_title(set_title)
self.axes[subplot_index][0].step(bins.values, xdata.values, **kwargs)
self.axes[subplot_index][0].set_xlabel(xtitle)
self.axes[subplot_index][0].set_ylabel(ytitle)

return self.axes[subplot_index]

Expand Down Expand Up @@ -412,8 +412,9 @@ def plot_stairstep_graph(
self.fig = plt.figure()

if self.axes is None:
self.axes = np.array([plt.axes()])
self.fig.add_axes(self.axes[0])
self.axes = np.array([[plt.axes(), plt.axes().twinx()]])
for a in self.axes[0]:
self.fig.add_axes(a)

if sortby_field is not None:
if 'units' in ydata.attrs:
Expand All @@ -433,26 +434,26 @@ def plot_stairstep_graph(
**hist_kwargs
)
x_inds = (x_bins[:-1] + x_bins[1:]) / 2.0
self.axes[subplot_index].step(
self.axes[subplot_index][0].step(
x_inds,
my_hist[:, 0].flatten(),
label=(str(y_bins[0]) + ' to ' + str(y_bins[1])),
**kwargs,
)
for i in range(1, len(y_bins) - 1):
self.axes[subplot_index].step(
self.axes[subplot_index][0].step(
x_inds,
my_hist[:, i].flatten(),
label=(str(y_bins[i]) + ' to ' + str(y_bins[i + 1])),
**kwargs,
)
self.axes[subplot_index].legend()
self.axes[subplot_index][0].legend()
else:
my_hist, bins = np.histogram(xdata.values.flatten(), bins=bins,
density=density, **hist_kwargs)

x_inds = (bins[:-1] + bins[1:]) / 2.0
self.axes[subplot_index].step(x_inds, my_hist, **kwargs)
self.axes[subplot_index][0].step(x_inds, my_hist, **kwargs)

# Set Title
if set_title is None:
Expand All @@ -464,9 +465,9 @@ def plot_stairstep_graph(
dt_utils.numpy_to_arm_date(self._ds[dsname].time.values[0]),
]
)
self.axes[subplot_index].set_title(set_title)
self.axes[subplot_index].set_ylabel('count')
self.axes[subplot_index].set_xlabel(xtitle)
self.axes[subplot_index][0].set_title(set_title)
self.axes[subplot_index][0].set_ylabel('count')
self.axes[subplot_index][0].set_xlabel(xtitle)

return_dict = {}
return_dict['plot_handle'] = self.axes[subplot_index]
Expand Down Expand Up @@ -568,10 +569,10 @@ def plot_heatmap(
# Get the current plotting axis, add day/night background and plot data
if self.fig is None:
self.fig = plt.figure()

if self.axes is None:
self.axes = np.array([plt.axes()])
self.fig.add_axes(self.axes[0])
self.axes = np.array([[plt.axes(), plt.axes().twinx()]])
for a in self.axes[0]:
self.fig.add_axes(a)

if 'units' in ydata.attrs:
ytitle = ''.join(['(', ydata.attrs['units'], ')'])
Expand All @@ -597,7 +598,7 @@ def plot_heatmap(
x_inds = (x_bins[:-1] + x_bins[1:]) / 2.0
y_inds = (y_bins[:-1] + y_bins[1:]) / 2.0
xi, yi = np.meshgrid(x_inds, y_inds, indexing='ij')
mesh = self.axes[subplot_index].pcolormesh(xi, yi, my_hist, shading=set_shading, **kwargs)
mesh = self.axes[subplot_index][0].pcolormesh(xi, yi, my_hist, shading=set_shading, **kwargs)

# Set Title
if set_title is None:
Expand All @@ -608,13 +609,13 @@ def plot_heatmap(
dt_utils.numpy_to_arm_date(self._ds[dsname].time.values[0]),
]
)
self.axes[subplot_index].set_title(set_title)
self.axes[subplot_index].set_ylabel(ytitle)
self.axes[subplot_index].set_xlabel(xtitle)
self.axes[subplot_index][0].set_title(set_title)
self.axes[subplot_index][0].set_ylabel(ytitle)
self.axes[subplot_index][0].set_xlabel(xtitle)
self.add_colorbar(mesh, title='count', subplot_index=subplot_index)

return_dict = {}
return_dict['plot_handle'] = self.axes[subplot_index]
return_dict['plot_handle'] = self.axes[subplot_index][0]
return_dict['x_bins'] = x_bins
return_dict['y_bins'] = y_bins
return_dict['histogram'] = my_hist
Expand All @@ -634,9 +635,9 @@ def set_ratio_line(self, subplot_index=(0, )):
if self.axes is None:
raise RuntimeError('set_ratio_line requires the plot to be displayed.')
# Define the xticks of the figure
xlims = self.axes[subplot_index].get_xticks()
xlims = self.axes[subplot_index][0].get_xticks()
ratio = np.linspace(xlims[0], xlims[-1])
self.axes[subplot_index].plot(ratio, ratio, 'k--')
self.axes[subplot_index][0].plot(ratio, ratio, 'k--')

def plot_scatter(self,
x_field,
Expand Down Expand Up @@ -713,15 +714,12 @@ def plot_scatter(self,

# Define the axes for the figure
if self.axes is None:
self.axes = np.array([plt.axes()])
self.fig.add_axes(self.axes[0])
self.axes = np.array([[plt.axes(), plt.axes().twinx()]])
for a in self.axes[0]:
self.fig.add_axes(a)

# Display the scatter plot, pass keyword args for unspecified attributes
scc = self.axes[subplot_index].scatter(xdata,
ydata,
c=mdata,
**kwargs
)
scc = self.axes[subplot_index][0].scatter(xdata, ydata, c=mdata, **kwargs)

# Set Title
if set_title is None:
Expand All @@ -748,9 +746,9 @@ def plot_scatter(self,
cbar.ax.set_ylabel(ztitle)

# Define the axe title, x-axis label, y-axis label
self.axes[subplot_index].set_title(set_title)
self.axes[subplot_index].set_ylabel(ytitle)
self.axes[subplot_index].set_xlabel(xtitle)
self.axes[subplot_index][0].set_title(set_title)
self.axes[subplot_index][0].set_ylabel(ytitle)
self.axes[subplot_index][0].set_xlabel(xtitle)

return self.axes[subplot_index]

Expand Down Expand Up @@ -818,8 +816,9 @@ def plot_violin(self,

# Define the axes for the figure
if self.axes is None:
self.axes = np.array([plt.axes()])
self.fig.add_axes(self.axes[0])
self.axes = np.array([[plt.axes(), plt.axes().twinx()]])
for a in self.axes[0]:
self.fig.add_axes(a)

# Define the axe label. If units are avaiable, plot.
if 'units' in ndata.attrs:
Expand All @@ -828,14 +827,14 @@ def plot_violin(self,
axtitle = field

# Display the scatter plot, pass keyword args for unspecified attributes
scc = self.axes[subplot_index].violinplot(ndata,
positions=positions,
vert=vert,
showmeans=showmeans,
showmedians=showmedians,
showextrema=showextrema,
**kwargs
)
scc = self.axes[subplot_index][0].violinplot(ndata,
positions=positions,
vert=vert,
showmeans=showmeans,
showmedians=showmedians,
showextrema=showextrema,
**kwargs
)
if showmeans is True:
scc['cmeans'].set_edgecolor('red')
scc['cmeans'].set_label('mean')
Expand All @@ -853,14 +852,14 @@ def plot_violin(self,
)

# Define the axe title, x-axis label, y-axis label
self.axes[subplot_index].set_title(set_title)
self.axes[subplot_index][0].set_title(set_title)
if vert is True:
self.axes[subplot_index].set_ylabel(axtitle)
self.axes[subplot_index][0].set_ylabel(axtitle)
if positions is None:
self.axes[subplot_index].set_xticks([])
self.axes[subplot_index][0].set_xticks([])
else:
self.axes[subplot_index].set_xlabel(axtitle)
self.axes[subplot_index][0].set_xlabel(axtitle)
if positions is None:
self.axes[subplot_index].set_yticks([])
self.axes[subplot_index][0].set_yticks([])

return self.axes[subplot_index]
return self.axes[subplot_index][0]
2 changes: 1 addition & 1 deletion act/plotting/geodisplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, ds, ds_name=None, **kwargs):
raise ImportError(
'Cartopy needs to be installed on your ' 'system to make geographic display plots.'
)
super().__init__(ds, ds_name, **kwargs)
super().__init__(ds, ds_name, secondary_y_allowed=False, **kwargs)
if self.fig is None:
self.fig = plt.figure(**kwargs)

Expand Down
Loading

0 comments on commit 3cd0a2e

Please sign in to comment.