Skip to content

Commit

Permalink
Modifies QuickPlot.plot to take a list of times and show superimpos…
Browse files Browse the repository at this point in the history
…ed plots in case of 1D variables. (pybamm-team#4529)

* Modified  plot method to take a list of times .

* fixed failing tests

* added different colours to different time plots

* fixed failing tests

* added tests for using list of times

* added changelog enrty

* made small changes

* Update CHANGELOG.md

Co-authored-by: Agriya Khetarpal <[email protected]>

---------

Co-authored-by: Arjun Verma <[email protected]>
Co-authored-by: Eric G. Kratz <[email protected]>
Co-authored-by: Agriya Khetarpal <[email protected]>
  • Loading branch information
4 people committed Nov 22, 2024
1 parent 103e94f commit b50d4a9
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 115 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)

## Features

- Modified `quick_plot.plot` to accept a list of times and generate superimposed graphs for specified time points. ([#4529](https://github.com/pybamm-team/PyBaMM/pull/4529))

# [v24.11.0](https://github.com/pybamm-team/PyBaMM/tree/v24.11.0) - 2024-11-20

## Features
Expand Down
210 changes: 95 additions & 115 deletions src/pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Class for quick plotting of variables from models
#
from __future__ import annotations
import os
import numpy as np
import pybamm
Expand Down Expand Up @@ -479,24 +480,24 @@ def reset_axis(self):
): # pragma: no cover
raise ValueError(f"Axis limits cannot be NaN for variables '{key}'")

def plot(self, t, dynamic=False):
def plot(self, t: float | list[float], dynamic: bool = False):
"""Produces a quick plot with the internal states at time t.
Parameters
----------
t : float
Dimensional time (in 'time_units') at which to plot.
t : float or list of float
Dimensional time (in 'time_units') at which to plot. Can be a single time or a list of times.
dynamic : bool, optional
Determine whether to allocate space for a slider at the bottom of the plot when generating a dynamic plot.
If True, creates a dynamic plot with a slider.
"""

plt = import_optional_dependency("matplotlib.pyplot")
gridspec = import_optional_dependency("matplotlib.gridspec")
cm = import_optional_dependency("matplotlib", "cm")
colors = import_optional_dependency("matplotlib", "colors")

t_in_seconds = t * self.time_scaling_factor
if not isinstance(t, list):
t = [t]

self.fig = plt.figure(figsize=self.figsize)

self.gridspec = gridspec.GridSpec(self.n_rows, self.n_cols)
Expand All @@ -508,6 +509,11 @@ def plot(self, t, dynamic=False):
# initialize empty handles, to be created only if the appropriate plots are made
solution_handles = []

# Generate distinct colors for each time point
time_colors = plt.cm.coolwarm(
np.linspace(0, 1, len(t))
) # Use a colormap for distinct colors

for k, (key, variable_lists) in enumerate(self.variables.items()):
ax = self.fig.add_subplot(self.gridspec[k])
self.axes.add(key, ax)
Expand All @@ -518,19 +524,17 @@ def plot(self, t, dynamic=False):
ax.xaxis.set_major_locator(plt.MaxNLocator(3))
self.plots[key] = defaultdict(dict)
variable_handles = []
# Set labels for the first subplot only (avoid repetition)

if variable_lists[0][0].dimensions == 0:
# 0D plot: plot as a function of time, indicating time t with a line
# 0D plot: plot as a function of time, indicating multiple times with lines
ax.set_xlabel(f"Time [{self.time_unit}]")
for i, variable_list in enumerate(variable_lists):
for j, variable in enumerate(variable_list):
if len(variable_list) == 1:
# single variable -> use linestyle to differentiate model
linestyle = self.linestyles[i]
else:
# multiple variables -> use linestyle to differentiate
# variables (color differentiates models)
linestyle = self.linestyles[j]
linestyle = (
self.linestyles[i]
if len(variable_list) == 1
else self.linestyles[j]
)
full_t = self.ts_seconds[i]
(self.plots[key][i][j],) = ax.plot(
full_t / self.time_scaling_factor,
Expand All @@ -542,128 +546,104 @@ def plot(self, t, dynamic=False):
solution_handles.append(self.plots[key][i][0])
y_min, y_max = ax.get_ylim()
ax.set_ylim(y_min, y_max)
(self.time_lines[key],) = ax.plot(
[
t_in_seconds / self.time_scaling_factor,
t_in_seconds / self.time_scaling_factor,
],
[y_min, y_max],
"k--",
lw=1.5,
)

# Add vertical lines for each time in the list, using different colors for each time
for idx, t_single in enumerate(t):
t_in_seconds = t_single * self.time_scaling_factor
(self.time_lines[key],) = ax.plot(
[
t_in_seconds / self.time_scaling_factor,
t_in_seconds / self.time_scaling_factor,
],
[y_min, y_max],
"--", # Dashed lines
lw=1.5,
color=time_colors[idx], # Different color for each time
label=f"t = {t_single:.2f} {self.time_unit}",
)
ax.legend()

elif variable_lists[0][0].dimensions == 1:
# 1D plot: plot as a function of x at time t
# Read dictionary of spatial variables
# 1D plot: plot as a function of x at different times
spatial_vars = self.spatial_variable_dict[key]
spatial_var_name = next(iter(spatial_vars.keys()))
ax.set_xlabel(
f"{spatial_var_name} [{self.spatial_unit}]",
)
for i, variable_list in enumerate(variable_lists):
for j, variable in enumerate(variable_list):
if len(variable_list) == 1:
# single variable -> use linestyle to differentiate model
linestyle = self.linestyles[i]
else:
# multiple variables -> use linestyle to differentiate
# variables (color differentiates models)
linestyle = self.linestyles[j]
(self.plots[key][i][j],) = ax.plot(
self.first_spatial_variable[key],
variable(t_in_seconds, **spatial_vars),
color=self.colors[i],
linestyle=linestyle,
zorder=10,
)
variable_handles.append(self.plots[key][0][j])
solution_handles.append(self.plots[key][i][0])
# add lines for boundaries between subdomains
for boundary in variable_lists[0][0].internal_boundaries:
boundary_scaled = boundary * self.spatial_factor
ax.axvline(boundary_scaled, color="0.5", lw=1, zorder=0)
ax.set_xlabel(f"{spatial_var_name} [{self.spatial_unit}]")

for idx, t_single in enumerate(t):
t_in_seconds = t_single * self.time_scaling_factor

for i, variable_list in enumerate(variable_lists):
for j, variable in enumerate(variable_list):
linestyle = (
self.linestyles[i]
if len(variable_list) == 1
else self.linestyles[j]
)
(self.plots[key][i][j],) = ax.plot(
self.first_spatial_variable[key],
variable(t_in_seconds, **spatial_vars),
color=time_colors[idx], # Different color for each time
linestyle=linestyle,
label=f"t = {t_single:.2f} {self.time_unit}", # Add time label
zorder=10,
)
variable_handles.append(self.plots[key][0][j])
solution_handles.append(self.plots[key][i][0])

# Add a legend to indicate which plot corresponds to which time
ax.legend()

elif variable_lists[0][0].dimensions == 2:
# Read dictionary of spatial variables
# 2D plot: superimpose plots at different times
spatial_vars = self.spatial_variable_dict[key]
# there can only be one entry in the variable list
variable = variable_lists[0][0]
# different order based on whether the domains are x-r, x-z or y-z, etc
if self.x_first_and_y_second[key] is False:
x_name = list(spatial_vars.keys())[1][0]
y_name = next(iter(spatial_vars.keys()))[0]
x = self.second_spatial_variable[key]
y = self.first_spatial_variable[key]
var = variable(t_in_seconds, **spatial_vars)
else:
x_name = next(iter(spatial_vars.keys()))[0]
y_name = list(spatial_vars.keys())[1][0]

for t_single in t:
t_in_seconds = t_single * self.time_scaling_factor
x = self.first_spatial_variable[key]
y = self.second_spatial_variable[key]
var = variable(t_in_seconds, **spatial_vars).T
ax.set_xlabel(f"{x_name} [{self.spatial_unit}]")
ax.set_ylabel(f"{y_name} [{self.spatial_unit}]")
vmin, vmax = self.variable_limits[key]
# store the plot and the var data (for testing) as cant access
# z data from QuadMesh or QuadContourSet object
if self.is_y_z[key] is True:
self.plots[key][0][0] = ax.pcolormesh(
x,
y,
var,
vmin=vmin,
vmax=vmax,
shading=self.shading,

ax.set_xlabel(
f"{next(iter(spatial_vars.keys()))[0]} [{self.spatial_unit}]"
)
else:
self.plots[key][0][0] = ax.contourf(
x, y, var, levels=100, vmin=vmin, vmax=vmax
ax.set_ylabel(
f"{list(spatial_vars.keys())[1][0]} [{self.spatial_unit}]"
)
self.plots[key][0][1] = var
if vmin is None and vmax is None:
vmin = ax_min(var)
vmax = ax_max(var)
self.colorbars[key] = self.fig.colorbar(
cm.ScalarMappable(colors.Normalize(vmin=vmin, vmax=vmax)),
ax=ax,
)
# Set either y label or legend entries
if len(key) == 1:
title = split_long_string(key[0])
ax.set_title(title, fontsize="medium")
else:
ax.legend(
variable_handles,
[split_long_string(s, 6) for s in key],
bbox_to_anchor=(0.5, 1),
loc="lower center",
)
vmin, vmax = self.variable_limits[key]

# Use contourf and colorbars to represent the values
contour_plot = ax.contourf(
x, y, var, levels=100, vmin=vmin, vmax=vmax, cmap="coolwarm"
)
self.plots[key][0][0] = contour_plot
self.colorbars[key] = self.fig.colorbar(contour_plot, ax=ax)

# Set global legend
self.plots[key][0][1] = var

ax.set_title(f"t = {t_single:.2f} {self.time_unit}")

# Set global legend if there are multiple models
if len(self.labels) > 1:
fig_legend = self.fig.legend(
solution_handles, self.labels, loc="lower right"
)
# Get the position of the top of the legend in relative figure units
# There may be a better way ...
try:
legend_top_inches = fig_legend.get_window_extent(
renderer=self.fig.canvas.get_renderer()
).get_points()[1, 1]
fig_height_inches = (self.fig.get_size_inches() * self.fig.dpi)[1]
legend_top = legend_top_inches / fig_height_inches
except AttributeError: # pragma: no cover
# When testing the examples we set the matplotlib backend to "Template"
# which means that the above code doesn't work. Since this is just for
# that particular test we can just skip it
legend_top = 0
else:
legend_top = 0
fig_legend = None

# Fix layout
# Fix layout for sliders if dynamic
if dynamic:
slider_top = 0.05
else:
slider_top = 0
bottom = max(legend_top, slider_top)
bottom = max(
fig_legend.get_window_extent(
renderer=self.fig.canvas.get_renderer()
).get_points()[1, 1]
if fig_legend
else 0,
slider_top,
)
self.gridspec.tight_layout(self.fig, rect=[0, bottom, 1, 1])

def dynamic_plot(self, show_plot=True, step=None):
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/test_plotting/test_quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,41 @@ def test_simple_ode_model(self, solver):
solution, ["a", "b broadcasted"], variable_limits="bad variable limits"
)

# Test with a list of times
# Test for a 0D variable
quick_plot = pybamm.QuickPlot(solution, ["a"])
quick_plot.plot(t=[0, 1, 2])
assert len(quick_plot.plots[("a",)]) == 1

# Test for a 1D variable
quick_plot = pybamm.QuickPlot(solution, ["c broadcasted"])

time_list = [0.5, 1.5, 2.5]
quick_plot.plot(time_list)

ax = quick_plot.fig.axes[0]
lines = ax.get_lines()

variable_key = ("c broadcasted",)
if variable_key in quick_plot.variables:
num_variables = len(quick_plot.variables[variable_key][0])
else:
raise KeyError(
f"'{variable_key}' is not in quick_plot.variables. Available keys: "
+ str(quick_plot.variables.keys())
)

expected_lines = len(time_list) * num_variables

assert (
len(lines) == expected_lines
), f"Expected {expected_lines} superimposed lines, but got {len(lines)}"

# Test for a 2D variable
quick_plot = pybamm.QuickPlot(solution, ["2D variable"])
quick_plot.plot(t=[0, 1, 2])
assert len(quick_plot.plots[("2D variable",)]) == 1

# Test errors
with pytest.raises(ValueError, match="Mismatching variable domains"):
pybamm.QuickPlot(solution, [["a", "b broadcasted"]])
Expand Down

0 comments on commit b50d4a9

Please sign in to comment.