From 53c3e7490ccc3523c449380345ca064f248d9da4 Mon Sep 17 00:00:00 2001 From: "Eric G. Kratz" Date: Wed, 27 Nov 2024 11:18:00 -0500 Subject: [PATCH] Revert "Modifies `QuickPlot.plot` ..." (#4622) * Revert "Modifies `QuickPlot.plot` to take a list of times and show superimposed plots in case of 1D variables. (#4529)" This reverts commit b50d4a95c2639a4e10644332145d33d14e7b7594. * style: pre-commit fixes * Fix changelog with PR number --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 4 + src/pybamm/plotting/quick_plot.py | 210 +++++++++++--------- tests/unit/test_plotting/test_quick_plot.py | 35 ---- 3 files changed, 119 insertions(+), 130 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index efed2fdaba..07905b7209 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # [Unreleased](https://github.com/pybamm-team/PyBaMM/) +## Bug fixes + +- Reverted modifications to quickplot from [#4529](https://github.com/pybamm-team/PyBaMM/pull/4529) which caused issues with the plots displaying correct variable names. ([#4622](https://github.com/pybamm-team/PyBaMM/pull/4622)) + # [v24.11.1](https://github.com/pybamm-team/PyBaMM/tree/v24.11.1) - 2024-11-22 ## Features diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index e3ad40855e..babfd2e761 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -1,7 +1,6 @@ # # Class for quick plotting of variables from models # -from __future__ import annotations import os import numpy as np import pybamm @@ -480,13 +479,13 @@ def reset_axis(self): ): # pragma: no cover raise ValueError(f"Axis limits cannot be NaN for variables '{key}'") - def plot(self, t: float | list[float], dynamic: bool = False): + def plot(self, t, dynamic=False): """Produces a quick plot with the internal states at time t. Parameters ---------- - t : float or list of float - Dimensional time (in 'time_units') at which to plot. Can be a single time or a list of times. + t : float + Dimensional time (in 'time_units') at which to plot. dynamic : bool, optional Determine whether to allocate space for a slider at the bottom of the plot when generating a dynamic plot. If True, creates a dynamic plot with a slider. @@ -494,10 +493,10 @@ def plot(self, t: float | list[float], dynamic: bool = False): plt = import_optional_dependency("matplotlib.pyplot") gridspec = import_optional_dependency("matplotlib.gridspec") + cm = import_optional_dependency("matplotlib", "cm") + colors = import_optional_dependency("matplotlib", "colors") - if not isinstance(t, list): - t = [t] - + t_in_seconds = t * self.time_scaling_factor self.fig = plt.figure(figsize=self.figsize) self.gridspec = gridspec.GridSpec(self.n_rows, self.n_cols) @@ -509,11 +508,6 @@ def plot(self, t: float | list[float], dynamic: bool = False): # initialize empty handles, to be created only if the appropriate plots are made solution_handles = [] - # Generate distinct colors for each time point - time_colors = plt.cm.coolwarm( - np.linspace(0, 1, len(t)) - ) # Use a colormap for distinct colors - for k, (key, variable_lists) in enumerate(self.variables.items()): ax = self.fig.add_subplot(self.gridspec[k]) self.axes.add(key, ax) @@ -524,17 +518,19 @@ def plot(self, t: float | list[float], dynamic: bool = False): ax.xaxis.set_major_locator(plt.MaxNLocator(3)) self.plots[key] = defaultdict(dict) variable_handles = [] - + # Set labels for the first subplot only (avoid repetition) if variable_lists[0][0].dimensions == 0: - # 0D plot: plot as a function of time, indicating multiple times with lines + # 0D plot: plot as a function of time, indicating time t with a line ax.set_xlabel(f"Time [{self.time_unit}]") for i, variable_list in enumerate(variable_lists): for j, variable in enumerate(variable_list): - linestyle = ( - self.linestyles[i] - if len(variable_list) == 1 - else self.linestyles[j] - ) + if len(variable_list) == 1: + # single variable -> use linestyle to differentiate model + linestyle = self.linestyles[i] + else: + # multiple variables -> use linestyle to differentiate + # variables (color differentiates models) + linestyle = self.linestyles[j] full_t = self.ts_seconds[i] (self.plots[key][i][j],) = ax.plot( full_t / self.time_scaling_factor, @@ -546,104 +542,128 @@ def plot(self, t: float | list[float], dynamic: bool = False): solution_handles.append(self.plots[key][i][0]) y_min, y_max = ax.get_ylim() ax.set_ylim(y_min, y_max) - - # Add vertical lines for each time in the list, using different colors for each time - for idx, t_single in enumerate(t): - t_in_seconds = t_single * self.time_scaling_factor - (self.time_lines[key],) = ax.plot( - [ - t_in_seconds / self.time_scaling_factor, - t_in_seconds / self.time_scaling_factor, - ], - [y_min, y_max], - "--", # Dashed lines - lw=1.5, - color=time_colors[idx], # Different color for each time - label=f"t = {t_single:.2f} {self.time_unit}", - ) - ax.legend() - + (self.time_lines[key],) = ax.plot( + [ + t_in_seconds / self.time_scaling_factor, + t_in_seconds / self.time_scaling_factor, + ], + [y_min, y_max], + "k--", + lw=1.5, + ) elif variable_lists[0][0].dimensions == 1: - # 1D plot: plot as a function of x at different times + # 1D plot: plot as a function of x at time t + # Read dictionary of spatial variables spatial_vars = self.spatial_variable_dict[key] spatial_var_name = next(iter(spatial_vars.keys())) - ax.set_xlabel(f"{spatial_var_name} [{self.spatial_unit}]") - - for idx, t_single in enumerate(t): - t_in_seconds = t_single * self.time_scaling_factor - - for i, variable_list in enumerate(variable_lists): - for j, variable in enumerate(variable_list): - linestyle = ( - self.linestyles[i] - if len(variable_list) == 1 - else self.linestyles[j] - ) - (self.plots[key][i][j],) = ax.plot( - self.first_spatial_variable[key], - variable(t_in_seconds, **spatial_vars), - color=time_colors[idx], # Different color for each time - linestyle=linestyle, - label=f"t = {t_single:.2f} {self.time_unit}", # Add time label - zorder=10, - ) - variable_handles.append(self.plots[key][0][j]) - solution_handles.append(self.plots[key][i][0]) - - # Add a legend to indicate which plot corresponds to which time - ax.legend() - + ax.set_xlabel( + f"{spatial_var_name} [{self.spatial_unit}]", + ) + for i, variable_list in enumerate(variable_lists): + for j, variable in enumerate(variable_list): + if len(variable_list) == 1: + # single variable -> use linestyle to differentiate model + linestyle = self.linestyles[i] + else: + # multiple variables -> use linestyle to differentiate + # variables (color differentiates models) + linestyle = self.linestyles[j] + (self.plots[key][i][j],) = ax.plot( + self.first_spatial_variable[key], + variable(t_in_seconds, **spatial_vars), + color=self.colors[i], + linestyle=linestyle, + zorder=10, + ) + variable_handles.append(self.plots[key][0][j]) + solution_handles.append(self.plots[key][i][0]) + # add lines for boundaries between subdomains + for boundary in variable_lists[0][0].internal_boundaries: + boundary_scaled = boundary * self.spatial_factor + ax.axvline(boundary_scaled, color="0.5", lw=1, zorder=0) elif variable_lists[0][0].dimensions == 2: - # 2D plot: superimpose plots at different times + # Read dictionary of spatial variables spatial_vars = self.spatial_variable_dict[key] + # there can only be one entry in the variable list variable = variable_lists[0][0] - - for t_single in t: - t_in_seconds = t_single * self.time_scaling_factor + # different order based on whether the domains are x-r, x-z or y-z, etc + if self.x_first_and_y_second[key] is False: + x_name = list(spatial_vars.keys())[1][0] + y_name = next(iter(spatial_vars.keys()))[0] + x = self.second_spatial_variable[key] + y = self.first_spatial_variable[key] + var = variable(t_in_seconds, **spatial_vars) + else: + x_name = next(iter(spatial_vars.keys()))[0] + y_name = list(spatial_vars.keys())[1][0] x = self.first_spatial_variable[key] y = self.second_spatial_variable[key] var = variable(t_in_seconds, **spatial_vars).T - - ax.set_xlabel( - f"{next(iter(spatial_vars.keys()))[0]} [{self.spatial_unit}]" - ) - ax.set_ylabel( - f"{list(spatial_vars.keys())[1][0]} [{self.spatial_unit}]" + ax.set_xlabel(f"{x_name} [{self.spatial_unit}]") + ax.set_ylabel(f"{y_name} [{self.spatial_unit}]") + vmin, vmax = self.variable_limits[key] + # store the plot and the var data (for testing) as cant access + # z data from QuadMesh or QuadContourSet object + if self.is_y_z[key] is True: + self.plots[key][0][0] = ax.pcolormesh( + x, + y, + var, + vmin=vmin, + vmax=vmax, + shading=self.shading, ) - vmin, vmax = self.variable_limits[key] - - # Use contourf and colorbars to represent the values - contour_plot = ax.contourf( - x, y, var, levels=100, vmin=vmin, vmax=vmax, cmap="coolwarm" + else: + self.plots[key][0][0] = ax.contourf( + x, y, var, levels=100, vmin=vmin, vmax=vmax ) - self.plots[key][0][0] = contour_plot - self.colorbars[key] = self.fig.colorbar(contour_plot, ax=ax) - - self.plots[key][0][1] = var - - ax.set_title(f"t = {t_single:.2f} {self.time_unit}") + self.plots[key][0][1] = var + if vmin is None and vmax is None: + vmin = ax_min(var) + vmax = ax_max(var) + self.colorbars[key] = self.fig.colorbar( + cm.ScalarMappable(colors.Normalize(vmin=vmin, vmax=vmax)), + ax=ax, + ) + # Set either y label or legend entries + if len(key) == 1: + title = split_long_string(key[0]) + ax.set_title(title, fontsize="medium") + else: + ax.legend( + variable_handles, + [split_long_string(s, 6) for s in key], + bbox_to_anchor=(0.5, 1), + loc="lower center", + ) - # Set global legend if there are multiple models + # Set global legend if len(self.labels) > 1: fig_legend = self.fig.legend( solution_handles, self.labels, loc="lower right" ) + # Get the position of the top of the legend in relative figure units + # There may be a better way ... + try: + legend_top_inches = fig_legend.get_window_extent( + renderer=self.fig.canvas.get_renderer() + ).get_points()[1, 1] + fig_height_inches = (self.fig.get_size_inches() * self.fig.dpi)[1] + legend_top = legend_top_inches / fig_height_inches + except AttributeError: # pragma: no cover + # When testing the examples we set the matplotlib backend to "Template" + # which means that the above code doesn't work. Since this is just for + # that particular test we can just skip it + legend_top = 0 else: - fig_legend = None + legend_top = 0 - # Fix layout for sliders if dynamic + # Fix layout if dynamic: slider_top = 0.05 else: slider_top = 0 - bottom = max( - fig_legend.get_window_extent( - renderer=self.fig.canvas.get_renderer() - ).get_points()[1, 1] - if fig_legend - else 0, - slider_top, - ) + bottom = max(legend_top, slider_top) self.gridspec.tight_layout(self.fig, rect=[0, bottom, 1, 1]) def dynamic_plot(self, show_plot=True, step=None): diff --git a/tests/unit/test_plotting/test_quick_plot.py b/tests/unit/test_plotting/test_quick_plot.py index 092f3784f2..188a725680 100644 --- a/tests/unit/test_plotting/test_quick_plot.py +++ b/tests/unit/test_plotting/test_quick_plot.py @@ -252,41 +252,6 @@ def test_simple_ode_model(self, solver): solution, ["a", "b broadcasted"], variable_limits="bad variable limits" ) - # Test with a list of times - # Test for a 0D variable - quick_plot = pybamm.QuickPlot(solution, ["a"]) - quick_plot.plot(t=[0, 1, 2]) - assert len(quick_plot.plots[("a",)]) == 1 - - # Test for a 1D variable - quick_plot = pybamm.QuickPlot(solution, ["c broadcasted"]) - - time_list = [0.5, 1.5, 2.5] - quick_plot.plot(time_list) - - ax = quick_plot.fig.axes[0] - lines = ax.get_lines() - - variable_key = ("c broadcasted",) - if variable_key in quick_plot.variables: - num_variables = len(quick_plot.variables[variable_key][0]) - else: - raise KeyError( - f"'{variable_key}' is not in quick_plot.variables. Available keys: " - + str(quick_plot.variables.keys()) - ) - - expected_lines = len(time_list) * num_variables - - assert ( - len(lines) == expected_lines - ), f"Expected {expected_lines} superimposed lines, but got {len(lines)}" - - # Test for a 2D variable - quick_plot = pybamm.QuickPlot(solution, ["2D variable"]) - quick_plot.plot(t=[0, 1, 2]) - assert len(quick_plot.plots[("2D variable",)]) == 1 - # Test errors with pytest.raises(ValueError, match="Mismatching variable domains"): pybamm.QuickPlot(solution, [["a", "b broadcasted"]])