-
-
Notifications
You must be signed in to change notification settings - Fork 572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance refactor for Jax BDF Solver, fixes #4455 #4456
Performance refactor for Jax BDF Solver, fixes #4455 #4456
Conversation
…fixes for calculate_sensitivities, adds JAX vectorised example
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #4456 +/- ##
===========================================
- Coverage 99.46% 99.46% -0.01%
===========================================
Files 293 293
Lines 22386 22332 -54
===========================================
- Hits 22266 22212 -54
Misses 120 120 ☔ View full report in Codecov by Sentry. |
Benchmarking for the BDF method is provided below. This was ran on a MacBook Pro M3Pro laptop, with the DFN discretised with Unless otherwise stated, the results below are in seconds. This is the average across ten simulations of the example script (with the number of simulations reduced from 1000 to 100), including the JIT compilation time.
|
Great stuff, thanks @BradyPlanden. I'll have a closer look at the code in a tic, I'm really suprised how the timings have evolved, when I wrote this the jit compilation of the BDF solver was very slow compared with RK45, hence the RK45 default. But it looks like there have been many changes in JAX since then. Do you have timings including the solve by itself (without including JIT)? |
I was also a bit surprised, as I didn't go into this refactor looking for performance, mostly just to clean up the code and understand the underlying JAX methods. I suspect there are areas for further improvement. Here are the post-JIT timings per 100 solves, now in milliseconds.
The JIT compilation for the BDF has improved quite a bit, there are a few areas I think it's most likely to come from. I think reducing the memory copies and using newer JAX methods ( Edit: it would be interesting to compare these numbers to the recent idaklu parallelisation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great, thanks @BradyPlanden, glad to see the BDF jit compiling so much faster now :) Interested to see the runtime differences to idaklu.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @BradyPlanden. I don't think we need to store additional data to implement this property, see below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great, thanks @BradyPlanden :)
Ready for merge now :) |
Description
This PR refactors the JAX BDF solver for performance updates. It also updates the JaxSolver's default method to be
"BDF"
as performance is much higher than"RK45"
. This PR also bug fixes thecalculate_sensitivities
error described in #4455. Lastly, this PR adds JAX vectorised example script to showcase the JaxSolvers performance in highly vectorised usage.Fixes #4455
Type of change
Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.
Key checklist:
$ pre-commit run
(or$ nox -s pre-commit
) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)$ python run-tests.py --all
(or$ nox -s tests
)$ python run-tests.py --doctest
(or$ nox -s doctests
)You can run integration tests, unit tests, and doctests together at once, using
$ python run-tests.py --quick
(or$ nox -s quick
).Further checks: