diff --git a/CHANGELOG.md b/CHANGELOG.md index b273435dff..10ecadbce6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # [Unreleased](https://github.com/pybamm-team/PyBaMM/) +## Features + +- Modified `quick_plot.plot` to accept a list of times and generate superimposed graphs for specified time points. ([#4529](https://github.com/pybamm-team/PyBaMM/pull/4529)) + # [v24.11.0](https://github.com/pybamm-team/PyBaMM/tree/v24.11.0) - 2024-11-20 ## Features diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index babfd2e761..e3ad40855e 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -1,6 +1,7 @@ # # Class for quick plotting of variables from models # +from __future__ import annotations import os import numpy as np import pybamm @@ -479,13 +480,13 @@ def reset_axis(self): ): # pragma: no cover raise ValueError(f"Axis limits cannot be NaN for variables '{key}'") - def plot(self, t, dynamic=False): + def plot(self, t: float | list[float], dynamic: bool = False): """Produces a quick plot with the internal states at time t. Parameters ---------- - t : float - Dimensional time (in 'time_units') at which to plot. + t : float or list of float + Dimensional time (in 'time_units') at which to plot. Can be a single time or a list of times. dynamic : bool, optional Determine whether to allocate space for a slider at the bottom of the plot when generating a dynamic plot. If True, creates a dynamic plot with a slider. @@ -493,10 +494,10 @@ def plot(self, t, dynamic=False): plt = import_optional_dependency("matplotlib.pyplot") gridspec = import_optional_dependency("matplotlib.gridspec") - cm = import_optional_dependency("matplotlib", "cm") - colors = import_optional_dependency("matplotlib", "colors") - t_in_seconds = t * self.time_scaling_factor + if not isinstance(t, list): + t = [t] + self.fig = plt.figure(figsize=self.figsize) self.gridspec = gridspec.GridSpec(self.n_rows, self.n_cols) @@ -508,6 +509,11 @@ def plot(self, t, dynamic=False): # initialize empty handles, to be created only if the appropriate plots are made solution_handles = [] + # Generate distinct colors for each time point + time_colors = plt.cm.coolwarm( + np.linspace(0, 1, len(t)) + ) # Use a colormap for distinct colors + for k, (key, variable_lists) in enumerate(self.variables.items()): ax = self.fig.add_subplot(self.gridspec[k]) self.axes.add(key, ax) @@ -518,19 +524,17 @@ def plot(self, t, dynamic=False): ax.xaxis.set_major_locator(plt.MaxNLocator(3)) self.plots[key] = defaultdict(dict) variable_handles = [] - # Set labels for the first subplot only (avoid repetition) + if variable_lists[0][0].dimensions == 0: - # 0D plot: plot as a function of time, indicating time t with a line + # 0D plot: plot as a function of time, indicating multiple times with lines ax.set_xlabel(f"Time [{self.time_unit}]") for i, variable_list in enumerate(variable_lists): for j, variable in enumerate(variable_list): - if len(variable_list) == 1: - # single variable -> use linestyle to differentiate model - linestyle = self.linestyles[i] - else: - # multiple variables -> use linestyle to differentiate - # variables (color differentiates models) - linestyle = self.linestyles[j] + linestyle = ( + self.linestyles[i] + if len(variable_list) == 1 + else self.linestyles[j] + ) full_t = self.ts_seconds[i] (self.plots[key][i][j],) = ax.plot( full_t / self.time_scaling_factor, @@ -542,128 +546,104 @@ def plot(self, t, dynamic=False): solution_handles.append(self.plots[key][i][0]) y_min, y_max = ax.get_ylim() ax.set_ylim(y_min, y_max) - (self.time_lines[key],) = ax.plot( - [ - t_in_seconds / self.time_scaling_factor, - t_in_seconds / self.time_scaling_factor, - ], - [y_min, y_max], - "k--", - lw=1.5, - ) + + # Add vertical lines for each time in the list, using different colors for each time + for idx, t_single in enumerate(t): + t_in_seconds = t_single * self.time_scaling_factor + (self.time_lines[key],) = ax.plot( + [ + t_in_seconds / self.time_scaling_factor, + t_in_seconds / self.time_scaling_factor, + ], + [y_min, y_max], + "--", # Dashed lines + lw=1.5, + color=time_colors[idx], # Different color for each time + label=f"t = {t_single:.2f} {self.time_unit}", + ) + ax.legend() + elif variable_lists[0][0].dimensions == 1: - # 1D plot: plot as a function of x at time t - # Read dictionary of spatial variables + # 1D plot: plot as a function of x at different times spatial_vars = self.spatial_variable_dict[key] spatial_var_name = next(iter(spatial_vars.keys())) - ax.set_xlabel( - f"{spatial_var_name} [{self.spatial_unit}]", - ) - for i, variable_list in enumerate(variable_lists): - for j, variable in enumerate(variable_list): - if len(variable_list) == 1: - # single variable -> use linestyle to differentiate model - linestyle = self.linestyles[i] - else: - # multiple variables -> use linestyle to differentiate - # variables (color differentiates models) - linestyle = self.linestyles[j] - (self.plots[key][i][j],) = ax.plot( - self.first_spatial_variable[key], - variable(t_in_seconds, **spatial_vars), - color=self.colors[i], - linestyle=linestyle, - zorder=10, - ) - variable_handles.append(self.plots[key][0][j]) - solution_handles.append(self.plots[key][i][0]) - # add lines for boundaries between subdomains - for boundary in variable_lists[0][0].internal_boundaries: - boundary_scaled = boundary * self.spatial_factor - ax.axvline(boundary_scaled, color="0.5", lw=1, zorder=0) + ax.set_xlabel(f"{spatial_var_name} [{self.spatial_unit}]") + + for idx, t_single in enumerate(t): + t_in_seconds = t_single * self.time_scaling_factor + + for i, variable_list in enumerate(variable_lists): + for j, variable in enumerate(variable_list): + linestyle = ( + self.linestyles[i] + if len(variable_list) == 1 + else self.linestyles[j] + ) + (self.plots[key][i][j],) = ax.plot( + self.first_spatial_variable[key], + variable(t_in_seconds, **spatial_vars), + color=time_colors[idx], # Different color for each time + linestyle=linestyle, + label=f"t = {t_single:.2f} {self.time_unit}", # Add time label + zorder=10, + ) + variable_handles.append(self.plots[key][0][j]) + solution_handles.append(self.plots[key][i][0]) + + # Add a legend to indicate which plot corresponds to which time + ax.legend() + elif variable_lists[0][0].dimensions == 2: - # Read dictionary of spatial variables + # 2D plot: superimpose plots at different times spatial_vars = self.spatial_variable_dict[key] - # there can only be one entry in the variable list variable = variable_lists[0][0] - # different order based on whether the domains are x-r, x-z or y-z, etc - if self.x_first_and_y_second[key] is False: - x_name = list(spatial_vars.keys())[1][0] - y_name = next(iter(spatial_vars.keys()))[0] - x = self.second_spatial_variable[key] - y = self.first_spatial_variable[key] - var = variable(t_in_seconds, **spatial_vars) - else: - x_name = next(iter(spatial_vars.keys()))[0] - y_name = list(spatial_vars.keys())[1][0] + + for t_single in t: + t_in_seconds = t_single * self.time_scaling_factor x = self.first_spatial_variable[key] y = self.second_spatial_variable[key] var = variable(t_in_seconds, **spatial_vars).T - ax.set_xlabel(f"{x_name} [{self.spatial_unit}]") - ax.set_ylabel(f"{y_name} [{self.spatial_unit}]") - vmin, vmax = self.variable_limits[key] - # store the plot and the var data (for testing) as cant access - # z data from QuadMesh or QuadContourSet object - if self.is_y_z[key] is True: - self.plots[key][0][0] = ax.pcolormesh( - x, - y, - var, - vmin=vmin, - vmax=vmax, - shading=self.shading, + + ax.set_xlabel( + f"{next(iter(spatial_vars.keys()))[0]} [{self.spatial_unit}]" ) - else: - self.plots[key][0][0] = ax.contourf( - x, y, var, levels=100, vmin=vmin, vmax=vmax + ax.set_ylabel( + f"{list(spatial_vars.keys())[1][0]} [{self.spatial_unit}]" ) - self.plots[key][0][1] = var - if vmin is None and vmax is None: - vmin = ax_min(var) - vmax = ax_max(var) - self.colorbars[key] = self.fig.colorbar( - cm.ScalarMappable(colors.Normalize(vmin=vmin, vmax=vmax)), - ax=ax, - ) - # Set either y label or legend entries - if len(key) == 1: - title = split_long_string(key[0]) - ax.set_title(title, fontsize="medium") - else: - ax.legend( - variable_handles, - [split_long_string(s, 6) for s in key], - bbox_to_anchor=(0.5, 1), - loc="lower center", - ) + vmin, vmax = self.variable_limits[key] + + # Use contourf and colorbars to represent the values + contour_plot = ax.contourf( + x, y, var, levels=100, vmin=vmin, vmax=vmax, cmap="coolwarm" + ) + self.plots[key][0][0] = contour_plot + self.colorbars[key] = self.fig.colorbar(contour_plot, ax=ax) - # Set global legend + self.plots[key][0][1] = var + + ax.set_title(f"t = {t_single:.2f} {self.time_unit}") + + # Set global legend if there are multiple models if len(self.labels) > 1: fig_legend = self.fig.legend( solution_handles, self.labels, loc="lower right" ) - # Get the position of the top of the legend in relative figure units - # There may be a better way ... - try: - legend_top_inches = fig_legend.get_window_extent( - renderer=self.fig.canvas.get_renderer() - ).get_points()[1, 1] - fig_height_inches = (self.fig.get_size_inches() * self.fig.dpi)[1] - legend_top = legend_top_inches / fig_height_inches - except AttributeError: # pragma: no cover - # When testing the examples we set the matplotlib backend to "Template" - # which means that the above code doesn't work. Since this is just for - # that particular test we can just skip it - legend_top = 0 else: - legend_top = 0 + fig_legend = None - # Fix layout + # Fix layout for sliders if dynamic if dynamic: slider_top = 0.05 else: slider_top = 0 - bottom = max(legend_top, slider_top) + bottom = max( + fig_legend.get_window_extent( + renderer=self.fig.canvas.get_renderer() + ).get_points()[1, 1] + if fig_legend + else 0, + slider_top, + ) self.gridspec.tight_layout(self.fig, rect=[0, bottom, 1, 1]) def dynamic_plot(self, show_plot=True, step=None): diff --git a/tests/unit/test_plotting/test_quick_plot.py b/tests/unit/test_plotting/test_quick_plot.py index 188a725680..092f3784f2 100644 --- a/tests/unit/test_plotting/test_quick_plot.py +++ b/tests/unit/test_plotting/test_quick_plot.py @@ -252,6 +252,41 @@ def test_simple_ode_model(self, solver): solution, ["a", "b broadcasted"], variable_limits="bad variable limits" ) + # Test with a list of times + # Test for a 0D variable + quick_plot = pybamm.QuickPlot(solution, ["a"]) + quick_plot.plot(t=[0, 1, 2]) + assert len(quick_plot.plots[("a",)]) == 1 + + # Test for a 1D variable + quick_plot = pybamm.QuickPlot(solution, ["c broadcasted"]) + + time_list = [0.5, 1.5, 2.5] + quick_plot.plot(time_list) + + ax = quick_plot.fig.axes[0] + lines = ax.get_lines() + + variable_key = ("c broadcasted",) + if variable_key in quick_plot.variables: + num_variables = len(quick_plot.variables[variable_key][0]) + else: + raise KeyError( + f"'{variable_key}' is not in quick_plot.variables. Available keys: " + + str(quick_plot.variables.keys()) + ) + + expected_lines = len(time_list) * num_variables + + assert ( + len(lines) == expected_lines + ), f"Expected {expected_lines} superimposed lines, but got {len(lines)}" + + # Test for a 2D variable + quick_plot = pybamm.QuickPlot(solution, ["2D variable"]) + quick_plot.plot(t=[0, 1, 2]) + assert len(quick_plot.plots[("2D variable",)]) == 1 + # Test errors with pytest.raises(ValueError, match="Mismatching variable domains"): pybamm.QuickPlot(solution, [["a", "b broadcasted"]])