Skip to content

Commit

Permalink
perf: refactor and speed-ups for Jax BDF Solver (#4456)
Browse files Browse the repository at this point in the history
* Performance refactor for Jax BDF, "BDF" as default for JaxSolver, bugfixes for calculate_sensitivities, adds JAX vectorised example

* update docstring, add changelog entry

* feat: adds property for explicit sensitivity attribute, suggested changes from review

* examples: adds JIT compiled comparison

* Apply suggestions from code review

Co-authored-by: Martin Robinson <[email protected]>

* fix: post suggestions property alignment

* tests: adds calculate_sensitivities check for JaxSolver

* tests: move calculate_senstivities unit test for coverage

---------

Co-authored-by: Martin Robinson <[email protected]>
  • Loading branch information
BradyPlanden and martinjrobins authored Sep 26, 2024
1 parent 1390ea3 commit 62a7ee8
Show file tree
Hide file tree
Showing 9 changed files with 262 additions and 232 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

## Optimizations

- Performance refactor of JAX BDF Solver with default Jax method set to `"BDF"`. ([#4456](https://github.com/pybamm-team/PyBaMM/pull/4456))
- Improved performance of initialization and reinitialization of ODEs in the (`IDAKLUSolver`). ([#4453](https://github.com/pybamm-team/PyBaMM/pull/4453))
- Removed the `start_step_offset` setting and disabled minimum `dt` warnings for drive cycles with the (`IDAKLUSolver`). ([#4416](https://github.com/pybamm-team/PyBaMM/pull/4416))

Expand Down
57 changes: 57 additions & 0 deletions examples/scripts/multiprocess_jax_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pybamm
import time
import numpy as np


# This script provides an example for massively vectorised
# model solves using the JAX BDF solver. First,
# we set up the model and process parameters
model = pybamm.lithium_ion.SPM()
model.convert_to_format = "jax"
model.events = [] # remove events (not supported in jax)
geometry = model.default_geometry
param = pybamm.ParameterValues("Chen2020")
param.update({"Current function [A]": "[input]"})
param.process_geometry(geometry)
param.process_model(model)

# Discretise and setup solver
mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)
t_eval = np.linspace(0, 3600, 100)
solver = pybamm.JaxSolver(atol=1e-6, rtol=1e-6, method="BDF")

# Set number of vectorised solves
values = np.linspace(0.01, 1.0, 1000)
inputs = [{"Current function [A]": value} for value in values]

# Run solve for all inputs, with a just-in-time compilation
# occurring on the first solve. All sequential solves will
# use the compiled code, with a large performance improvement.
start_time = time.time()
sol = solver.solve(model, t_eval, inputs=inputs)
print(f"Time taken: {time.time() - start_time}") # 1.3s

# Rerun the vectorised solve, showing performance improvement
start_time = time.time()
compiled_sol = solver.solve(model, t_eval, inputs=inputs)
print(f"Compiled time taken: {time.time() - start_time}") # 0.42s

# Plot one of the solves
plot = pybamm.QuickPlot(
sol[5],
[
"Negative particle concentration [mol.m-3]",
"Electrolyte concentration [mol.m-3]",
"Positive particle concentration [mol.m-3]",
"Current [A]",
"Negative electrode potential [V]",
"Electrolyte potential [V]",
"Positive electrode potential [V]",
"Voltage [V]",
],
time_unit="seconds",
spatial_unit="um",
)
plot.dynamic_plot()
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def run_scripts(session):
# https://bitbucket.org/pybtex-devs/pybtex/issues/169/replace-pkg_resources-with
# is fixed
session.install("setuptools", silent=False)
session.install("-e", ".[all,dev]", silent=False)
session.install("-e", ".[all,dev,jax]", silent=False)
session.run("python", "-m", "pytest", "-m", "scripts")


Expand Down
12 changes: 6 additions & 6 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def root_method(self):
def supports_parallel_solve(self):
return False

@property
def requires_explicit_sensitivities(self):
return True

@root_method.setter
def root_method(self, method):
if method == "casadi":
Expand Down Expand Up @@ -141,7 +145,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):

# see if we need to form the explicit sensitivity equations
calculate_sensitivities_explicit = (
model.calculate_sensitivities and not isinstance(self, pybamm.IDAKLUSolver)
model.calculate_sensitivities and self.requires_explicit_sensitivities
)

self._set_up_model_sensitivities_inplace(
Expand Down Expand Up @@ -494,11 +498,7 @@ def _set_up_model_sensitivities_inplace(
# if we have a mass matrix, we need to extend it
def extend_mass_matrix(M):
M_extend = [M.entries] * (num_parameters + 1)
M_extend_pybamm = pybamm.Matrix(block_diag(M_extend, format="csr"))
return M_extend_pybamm

model.mass_matrix = extend_mass_matrix(model.mass_matrix)
model.mass_matrix = extend_mass_matrix(model.mass_matrix)
return pybamm.Matrix(block_diag(M_extend, format="csr"))

model.mass_matrix = extend_mass_matrix(model.mass_matrix)

Expand Down
4 changes: 4 additions & 0 deletions src/pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,10 @@ def _demote_64_to_32(self, x: pybamm.EvaluatorJax):
def supports_parallel_solve(self):
return True

@property
def requires_explicit_sensitivities(self):
return False

def _integrate(self, model, t_eval, inputs_list=None, t_interp=None):
"""
Solve a DAE model defined by residuals with initial conditions y0.
Expand Down
Loading

0 comments on commit 62a7ee8

Please sign in to comment.