-
-
Notifications
You must be signed in to change notification settings - Fork 572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue 1031 jax #1038
Issue 1031 jax #1038
Conversation
…or jax compilation
Codecov Report
@@ Coverage Diff @@
## develop #1038 +/- ##
===========================================
+ Coverage 97.76% 97.81% +0.04%
===========================================
Files 243 245 +2
Lines 12667 13177 +510
===========================================
+ Hits 12384 12889 +505
- Misses 283 288 +5
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @martinjrobins , I don't understand most of the evaluate/BDF code but since tests pass I'm sure it's all fine - and thanks for documenting well in case anyone wants to check it. Looking forward to seeing how this works and hopefully they add sparse matrices soon.
How much work is it to adapt BDF to work for DAEs?
Also, do you think it's possible to write this code in a way that we can generate a C function from the solver?
Thanks Tino. Yea, the BDF code was a pain to write as you can't use any flow control or classes, it does make it difficult to read! I'm not sure what changes need to be made to solve a DAE as I've not looked into that yet. I've not done the solution of daes before, but I would suspect you would have to code up a separate algorithm. You can't generate C code from Jax as it never uses C. JAX traces your python code and produces its own expression tree. It uses that to build a computational graph using XLA. The input to XLA is an intermediate representation language called HLO IR, so rather than C, the language that you could emit would be HLO, this could then be compiled and run on any machine using XLA. I'm not sure exactly how to get the HLO using JAX, but in theory it should be a matter of taking any jax function (e.g. the jax_bdf_integrate function in this case), compiling it with JAX and then telling it to give you the generated HLO. |
Thanks for the info. It looks like it's possible to adapt BDF to solve DAEs: https://www.cs.usask.ca/~spiteri/M314/notes/AP/chap10.pdf |
Nice, thanks for the link. I'll have a look.
Cheers,
Martin
…On Tue, 7 Jul 2020, 20:20 Valentin Sulzer, ***@***.***> wrote:
Thanks for the info. It looks like it's possible to adapt BDF to solve
ODEs: https://www.cs.usask.ca/~spiteri/M314/notes/AP/chap10.pdf
—
You are receiving this because you modified the open/close state.
Reply to this email directly, view it on GitHub
<#1038 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAIYL5HEKLOJO6FNQWCY7T3R2NYQZANCNFSM4NRTZF3A>
.
|
Description
Adds support for evaluating expression trees using JAX, and adds a new solver
JaxSolver
with two methods: A RK4(5) method usingjax.experimental.odeint
, and a custom BDF method written using jax. This solver is currently only useful for ode models with no termination events, and since JAX does not support sparse matrices they are all converted to dense arrays in the model. The latter limitation should hopefully be addressed over time by JAX (see jax-ml/jax#765).This solver would be useful for running on a GPU/TPU, and, for very small state vectors (due to the restriction to dense arrays), is also very quick on a CPU. The next step after this PR would be to implement the adjoint sensitivities for the new BDF solver (already there for the RK45) and expose the raw jax solvers so a user can calculate sensitivities for parameter estimation.
UPDATE: just saw that jax is not supported under windows :(, so this would be linux/mac only....
Fixes #1031
Type of change
Key checklist:
$ flake8
$ python run-tests.py --unit
$ cd docs
and then$ make clean; make html
You can run all three at once, using
$ python run-tests.py --quick
.Further checks: