Skip to content

Commit

Permalink
More accurate QuickPlots with Hermite interpolation (pybamm-team#4483)
Browse files Browse the repository at this point in the history
* Update CHANGELOG.md

accurate quickplots

* evenly sample sub-solutions

* lowercase variable

* move `_solver_args` inside class

---------

Co-authored-by: Eric G. Kratz <[email protected]>
  • Loading branch information
2 people authored and Pritam Kalbhor committed Nov 15, 2024
1 parent 8a30099 commit afb9e80
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features

- Improved `QuickPlot` accuracy for simulations with Hermite interpolation. ([#4483](https://github.com/pybamm-team/PyBaMM/pull/4483))
- Added Hermite interpolation to the (`IDAKLUSolver`) that improves the accuracy and performance of post-processing variables. ([#4464](https://github.com/pybamm-team/PyBaMM/pull/4464))
- Added `BasicDFN` model for sodium-ion batteries ([#4451](https://github.com/pybamm-team/PyBaMM/pull/4451))
- Added sensitivity calculation support for `pybamm.Simulation` and `pybamm.Experiment` ([#4415](https://github.com/pybamm-team/PyBaMM/pull/4415))
Expand Down
22 changes: 22 additions & 0 deletions src/pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class QuickPlot:
variable_limits : str or dict of str, optional
How to set the axis limits (for 0D or 1D variables) or colorbar limits (for 2D
variables). Options are:
n_t_linear: int, optional
The number of linearly spaced time points added to the t axis for each sub-solution.
Note: this is only used if the solution has hermite interpolation enabled.
- "fixed" (default): keep all axes fixes so that all data is visible
- "tight": make axes tight to plot at each time
Expand All @@ -105,6 +108,7 @@ def __init__(
time_unit=None,
spatial_unit="um",
variable_limits="fixed",
n_t_linear=100,
):
solutions = self.preprocess_solutions(solutions)

Expand Down Expand Up @@ -169,6 +173,24 @@ def __init__(
min_t = np.min([t[0] for t in self.ts_seconds])
max_t = np.max([t[-1] for t in self.ts_seconds])

hermite_interp = all(sol.hermite_interpolation for sol in solutions)

def t_sample(sol):
if hermite_interp and n_t_linear > 2:
# Linearly spaced time points
t_linspace = np.linspace(sol.t[0], sol.t[-1], n_t_linear + 2)[1:-1]
t_plot = np.union1d(sol.t, t_linspace)
else:
t_plot = sol.t
return t_plot

ts_seconds = []
for sol in solutions:
# Sample time points for each sub-solution
t_sol = [t_sample(sub_sol) for sub_sol in sol.sub_solutions]
ts_seconds.append(np.concatenate(t_sol))
self.ts_seconds = ts_seconds

# Set timescale
if time_unit is None:
# defaults depend on how long the simulation is
Expand Down
30 changes: 19 additions & 11 deletions tests/unit/test_plotting/test_quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@


class TestQuickPlot:
def test_simple_ode_model(self):
_solver_args = [pybamm.CasadiSolver()]
if pybamm.has_idaklu():
_solver_args.append(pybamm.IDAKLUSolver())

@pytest.mark.parametrize("solver", _solver_args)
def test_simple_ode_model(self, solver):
model = pybamm.lithium_ion.BaseModel(name="Simple ODE Model")

whole_cell = ["negative electrode", "separator", "positive electrode"]
Expand Down Expand Up @@ -48,9 +53,6 @@ def test_simple_ode_model(self):
"NaN variable": pybamm.Scalar(np.nan),
}

# ODEs only (don't use Jacobian)
model.use_jacobian = False

# Process and solve
geometry = model.default_geometry
param = model.default_parameter_values
Expand All @@ -59,7 +61,6 @@ def test_simple_ode_model(self):
mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)
solver = model.default_solver
t_eval = np.linspace(0, 2, 100)
solution = solver.solve(model, t_eval)
quick_plot = pybamm.QuickPlot(
Expand Down Expand Up @@ -142,35 +143,42 @@ def test_simple_ode_model(self):
assert quick_plot.n_rows == 2
assert quick_plot.n_cols == 1

if solution.hermite_interpolation:
t_plot = np.union1d(
solution.t, np.linspace(solution.t[0], solution.t[-1], 100 + 2)[1:-1]
)
else:
t_plot = t_eval

# Test different time units
quick_plot = pybamm.QuickPlot(solution, ["a"])
assert quick_plot.time_scaling_factor == 1
quick_plot = pybamm.QuickPlot(solution, ["a"], time_unit="seconds")
quick_plot.plot(0)
assert quick_plot.time_scaling_factor == 1
np.testing.assert_array_almost_equal(
quick_plot.plots[("a",)][0][0].get_xdata(), t_eval
quick_plot.plots[("a",)][0][0].get_xdata(), t_plot
)
np.testing.assert_array_almost_equal(
quick_plot.plots[("a",)][0][0].get_ydata(), 0.2 * t_eval
quick_plot.plots[("a",)][0][0].get_ydata(), 0.2 * t_plot
)
quick_plot = pybamm.QuickPlot(solution, ["a"], time_unit="minutes")
quick_plot.plot(0)
assert quick_plot.time_scaling_factor == 60
np.testing.assert_array_almost_equal(
quick_plot.plots[("a",)][0][0].get_xdata(), t_eval / 60
quick_plot.plots[("a",)][0][0].get_xdata(), t_plot / 60
)
np.testing.assert_array_almost_equal(
quick_plot.plots[("a",)][0][0].get_ydata(), 0.2 * t_eval
quick_plot.plots[("a",)][0][0].get_ydata(), 0.2 * t_plot
)
quick_plot = pybamm.QuickPlot(solution, ["a"], time_unit="hours")
quick_plot.plot(0)
assert quick_plot.time_scaling_factor == 3600
np.testing.assert_array_almost_equal(
quick_plot.plots[("a",)][0][0].get_xdata(), t_eval / 3600
quick_plot.plots[("a",)][0][0].get_xdata(), t_plot / 3600
)
np.testing.assert_array_almost_equal(
quick_plot.plots[("a",)][0][0].get_ydata(), 0.2 * t_eval
quick_plot.plots[("a",)][0][0].get_ydata(), 0.2 * t_plot
)
with pytest.raises(ValueError, match="time unit"):
pybamm.QuickPlot(solution, ["a"], time_unit="bad unit")
Expand Down

0 comments on commit afb9e80

Please sign in to comment.