diff --git a/CHANGELOG.md b/CHANGELOG.md index 84f01d3797..a03ec39715 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Features +- Improved `QuickPlot` accuracy for simulations with Hermite interpolation. ([#4483](https://github.com/pybamm-team/PyBaMM/pull/4483)) - Added Hermite interpolation to the (`IDAKLUSolver`) that improves the accuracy and performance of post-processing variables. ([#4464](https://github.com/pybamm-team/PyBaMM/pull/4464)) - Added `BasicDFN` model for sodium-ion batteries ([#4451](https://github.com/pybamm-team/PyBaMM/pull/4451)) - Added sensitivity calculation support for `pybamm.Simulation` and `pybamm.Experiment` ([#4415](https://github.com/pybamm-team/PyBaMM/pull/4415)) diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index cddce58d77..babfd2e761 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -84,6 +84,9 @@ class QuickPlot: variable_limits : str or dict of str, optional How to set the axis limits (for 0D or 1D variables) or colorbar limits (for 2D variables). Options are: + n_t_linear: int, optional + The number of linearly spaced time points added to the t axis for each sub-solution. + Note: this is only used if the solution has hermite interpolation enabled. - "fixed" (default): keep all axes fixes so that all data is visible - "tight": make axes tight to plot at each time @@ -105,6 +108,7 @@ def __init__( time_unit=None, spatial_unit="um", variable_limits="fixed", + n_t_linear=100, ): solutions = self.preprocess_solutions(solutions) @@ -169,6 +173,24 @@ def __init__( min_t = np.min([t[0] for t in self.ts_seconds]) max_t = np.max([t[-1] for t in self.ts_seconds]) + hermite_interp = all(sol.hermite_interpolation for sol in solutions) + + def t_sample(sol): + if hermite_interp and n_t_linear > 2: + # Linearly spaced time points + t_linspace = np.linspace(sol.t[0], sol.t[-1], n_t_linear + 2)[1:-1] + t_plot = np.union1d(sol.t, t_linspace) + else: + t_plot = sol.t + return t_plot + + ts_seconds = [] + for sol in solutions: + # Sample time points for each sub-solution + t_sol = [t_sample(sub_sol) for sub_sol in sol.sub_solutions] + ts_seconds.append(np.concatenate(t_sol)) + self.ts_seconds = ts_seconds + # Set timescale if time_unit is None: # defaults depend on how long the simulation is diff --git a/tests/unit/test_plotting/test_quick_plot.py b/tests/unit/test_plotting/test_quick_plot.py index d5d994117d..188a725680 100644 --- a/tests/unit/test_plotting/test_quick_plot.py +++ b/tests/unit/test_plotting/test_quick_plot.py @@ -7,7 +7,12 @@ class TestQuickPlot: - def test_simple_ode_model(self): + _solver_args = [pybamm.CasadiSolver()] + if pybamm.has_idaklu(): + _solver_args.append(pybamm.IDAKLUSolver()) + + @pytest.mark.parametrize("solver", _solver_args) + def test_simple_ode_model(self, solver): model = pybamm.lithium_ion.BaseModel(name="Simple ODE Model") whole_cell = ["negative electrode", "separator", "positive electrode"] @@ -48,9 +53,6 @@ def test_simple_ode_model(self): "NaN variable": pybamm.Scalar(np.nan), } - # ODEs only (don't use Jacobian) - model.use_jacobian = False - # Process and solve geometry = model.default_geometry param = model.default_parameter_values @@ -59,7 +61,6 @@ def test_simple_ode_model(self): mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts) disc = pybamm.Discretisation(mesh, model.default_spatial_methods) disc.process_model(model) - solver = model.default_solver t_eval = np.linspace(0, 2, 100) solution = solver.solve(model, t_eval) quick_plot = pybamm.QuickPlot( @@ -142,6 +143,13 @@ def test_simple_ode_model(self): assert quick_plot.n_rows == 2 assert quick_plot.n_cols == 1 + if solution.hermite_interpolation: + t_plot = np.union1d( + solution.t, np.linspace(solution.t[0], solution.t[-1], 100 + 2)[1:-1] + ) + else: + t_plot = t_eval + # Test different time units quick_plot = pybamm.QuickPlot(solution, ["a"]) assert quick_plot.time_scaling_factor == 1 @@ -149,28 +157,28 @@ def test_simple_ode_model(self): quick_plot.plot(0) assert quick_plot.time_scaling_factor == 1 np.testing.assert_array_almost_equal( - quick_plot.plots[("a",)][0][0].get_xdata(), t_eval + quick_plot.plots[("a",)][0][0].get_xdata(), t_plot ) np.testing.assert_array_almost_equal( - quick_plot.plots[("a",)][0][0].get_ydata(), 0.2 * t_eval + quick_plot.plots[("a",)][0][0].get_ydata(), 0.2 * t_plot ) quick_plot = pybamm.QuickPlot(solution, ["a"], time_unit="minutes") quick_plot.plot(0) assert quick_plot.time_scaling_factor == 60 np.testing.assert_array_almost_equal( - quick_plot.plots[("a",)][0][0].get_xdata(), t_eval / 60 + quick_plot.plots[("a",)][0][0].get_xdata(), t_plot / 60 ) np.testing.assert_array_almost_equal( - quick_plot.plots[("a",)][0][0].get_ydata(), 0.2 * t_eval + quick_plot.plots[("a",)][0][0].get_ydata(), 0.2 * t_plot ) quick_plot = pybamm.QuickPlot(solution, ["a"], time_unit="hours") quick_plot.plot(0) assert quick_plot.time_scaling_factor == 3600 np.testing.assert_array_almost_equal( - quick_plot.plots[("a",)][0][0].get_xdata(), t_eval / 3600 + quick_plot.plots[("a",)][0][0].get_xdata(), t_plot / 3600 ) np.testing.assert_array_almost_equal( - quick_plot.plots[("a",)][0][0].get_ydata(), 0.2 * t_eval + quick_plot.plots[("a",)][0][0].get_ydata(), 0.2 * t_plot ) with pytest.raises(ValueError, match="time unit"): pybamm.QuickPlot(solution, ["a"], time_unit="bad unit")