Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Langevin PR #453

Merged
merged 15 commits into from
Sep 1, 2024
Merged

Langevin PR #453

merged 15 commits into from
Sep 1, 2024

Conversation

andyElking
Copy link
Contributor

@andyElking andyElking commented Jul 1, 2024

Hi Patrick,

Thanks for the heads up, here's the reuploaded PR (with another quick fix).

This PR contains all the new Langevin solvers. All of these inherit from AbstractLangevinSRK in langevin_srk.py.
Another important addition is LangevinTerm in _term.py. I explained why it is needed in a comment bellow.

I haven't added the new solvers to the docs and autocite yet, because 1) the relevant paper is not on arxiv yet, but might be in a month or two and 2) I expect you might suggest several changes, so might as well write the docs once the rest is stationary. Still, I think the docstrings and comments are quite comprehensive.

I'm making this PR now so you have ample time to have a look at it, but I will be away for the next few weeks, so there is aboslutely no hurry.

Best,
Andraž

diffrax/_term.py Outdated Show resolved Hide resolved
@andyElking andyElking force-pushed the langevin_pr branch 2 times, most recently from 2d5dcad to 21a5730 Compare July 11, 2024 15:07
@andyElking
Copy link
Contributor Author

Hi @patrick-kidger, just bumping this in case you didn't notice that the tests passed. No hurry though :)

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, sorry, it's taken me a while to get to this!

Here's a first review. Many of my comments on align apply to the others too, I think.

diffrax/_integrate.py Outdated Show resolved Hide resolved
diffrax/_solver/align.py Outdated Show resolved Hide resolved
diffrax/_solver/align.py Outdated Show resolved Hide resolved
diffrax/_solver/align.py Show resolved Hide resolved
diffrax/_solver/align.py Outdated Show resolved Hide resolved
diffrax/_solver/langevin_srk.py Outdated Show resolved Hide resolved
diffrax/_term.py Outdated Show resolved Hide resolved
diffrax/_term.py Outdated Show resolved Hide resolved
diffrax/_term.py Outdated Show resolved Hide resolved
diffrax/_term.py Outdated Show resolved Hide resolved
@andyElking
Copy link
Contributor Author

andyElking commented Jul 21, 2024

Hi Patrick,

Thanks so much for your review! I addressed almost all of your comments. Two of them I will address in a later commit (it's getting late today haha).

I think you didn't quite understand the reason why the LangevinTerm and the changes in _integrate.py are needed, so I tried to give an explanation, but if that is unclear we can have a call at some point. Let me know :)

Also you might want to read my reply here, but it's a bit hidden way up the conversation: #453 (comment)

@andyElking
Copy link
Contributor Author

Quick heads up: I now made all the edits you suggested and the tests all passed :)

@andyElking
Copy link
Contributor Author

@patrick-kidger a note about interpolation for ALIGN:
to maintain a 2nd order of convergence at interpolated points, the interpolated value cannot depend just on t0, t, t1, y0, y1, but must also depend on W0, H0, W1, H1. Do you think it would be feasible to modify the interpolation code in order to allow for that? Maybe I could store W and H as part of the solution? Do you have any ideas how to do this in a way that fits into Diffrax naturally?

@patrick-kidger
Copy link
Owner

to maintain a 2nd order of convergence at interpolated points, the interpolated value cannot depend just on t0, t, t1, y0, y1, but must also depend on W0, H0, W1, H1. Do you think it would be feasible to modify the interpolation code in order to allow for that? Maybe I could store W and H as part of the solution? Do you have any ideas how to do this in a way that fits into Diffrax naturally?

I think should be totally fine -- it can go into the dense_info output of a solver step. Take a look at how Runge--Kutta methods output the intermediate stages, for example.

@andyElking
Copy link
Contributor Author

I think should be totally fine -- it can go into the dense_info output of a solver step. Take a look at how Runge--Kutta methods output the intermediate stages, for example.

That's good to know, thanks! I'll still make it a separate PR - I have a few other tasks to complete beforehand. Otherwise I think the only comment that remains unresolved is #453 (comment). Let me know if there is anything else you'd like me to improve.

@andyElking
Copy link
Contributor Author

@patrick-kidger I now made a temporary fix, but there are two things that remain to be solved:

  1. For some reason eqx.filter_eval_shape(term.vf, 0.0, y0, args) doesn't work when term is a LangevinTerm (the stacktrace is in a comment above). I genuinely do not undertand what is going wrong, so please help.
  2. I think we are still not on the same page as to why I introduced LangevinTerm in the first place. Please let's have a call at some point to discuss it. And then hopefully it will make sense why using MultiTerm[LangevinDriftTerm, LangevinDiffusionTerm] could lead to incorrect results.

@andyElking andyElking force-pushed the langevin_pr branch 2 times, most recently from 2719cc9 to 67011f0 Compare August 9, 2024 21:03
@andyElking
Copy link
Contributor Author

Great news @patrick-kidger @lockwo: what we discussed worked! Thanks for your advice!

Patrick, if there is nothing else about the code itself that you'd like me to change, then I'll add all of the new things to the docs. Do you think I should add a short example of how to use the langevin solvers as well? Maybe a simple Langevin Monte Carlo example?

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I really like this approach -- I think this is now looking really elegant.

I've gone through and left comments but they're mostly nits and tidy ups. I think we've now cracked the structure of this problem! :)

diffrax/_term.py Outdated Show resolved Hide resolved
diffrax/_term.py Outdated Show resolved Hide resolved
diffrax/_term.py Outdated Show resolved Hide resolved
diffrax/_term.py Outdated Show resolved Hide resolved
test/helpers.py Outdated Show resolved Hide resolved
diffrax/_solver/langevin_srk.py Outdated Show resolved Hide resolved
diffrax/_solver/langevin_srk.py Outdated Show resolved Hide resolved
diffrax/_solver/langevin_srk.py Outdated Show resolved Hide resolved
diffrax/_solver/langevin_srk.py Outdated Show resolved Hide resolved
diffrax/_solver/langevin_srk.py Outdated Show resolved Hide resolved
@patrick-kidger
Copy link
Owner

Do you think I should add a short example of how to use the langevin solvers as well? Maybe a simple Langevin Monte Carlo example?

I like the sound of that! An example would be great.

I'll emphasise 'short' -- I really try to keep the examples pedagogical.

@andyElking
Copy link
Contributor Author

Thanks so much for the review, Patrick! And sorry for all my dummy mistakes 😅. I made all the smaller edits already and tomorrow I'll write a short example, a test for the backward solve and put everything into the docs.

diffrax/_term.py Outdated Show resolved Hide resolved
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, this basically LGTM!

I've gone through and left what I think is a final round of nits, but I think they should all be super minor.

diffrax/_solver/foster_langevin_srk.py Outdated Show resolved Hide resolved
diffrax/_solver/foster_langevin_srk.py Outdated Show resolved Hide resolved
diffrax/_solver/foster_langevin_srk.py Outdated Show resolved Hide resolved
diffrax/_solver/foster_langevin_srk.py Outdated Show resolved Hide resolved
diffrax/_solver/foster_langevin_srk.py Outdated Show resolved Hide resolved
diffrax/_term.py Outdated Show resolved Hide resolved
diffrax/_term.py Outdated Show resolved Hide resolved
docs/api/solvers/sde_solvers.md Show resolved Hide resolved
docs/devdocs/srk_example.ipynb Outdated Show resolved Hide resolved
examples/underdamped_langevin_example.ipynb Outdated Show resolved Hide resolved
@andyElking
Copy link
Contributor Author

Quick question @patrick-kidger:

Should I make AbstractFosterLangevinSRK public and add it under Abstract Solvers in the docs? I haven't done it so far, because, unlike RK and SRK, where the user just has to specify a tableau, making a custom child of AbstractFosterLangevinSRK is much more involved, so I doubt users who are just using a packaged version of Diffrax would do that. WDYT?

@patrick-kidger
Copy link
Owner

I think it probably should be public + in the abstract solvers. I agree that writing your own here is incredibly niche, but I think it's useful for inquisitive users to be able to poke at such things.

diffrax/_term.py Outdated Show resolved Hide resolved
@andyElking
Copy link
Contributor Author

So I added the check that drift and diffusion have the same arguments in AbstractFosterLangevinSRK.init and also added a short test for this in test_underdamped_langevin.py.

I will now do the scan trick.

@andyElking
Copy link
Contributor Author

I think I now addressed everything you suggested, including the scan trick in both QUICSORT and ShOULD. I'll do another quick check and then I pass it back to you.

@andyElking andyElking force-pushed the langevin_pr branch 4 times, most recently from 914ef13 to 5af3e06 Compare September 1, 2024 16:00
@andyElking
Copy link
Contributor Author

andyElking commented Sep 1, 2024

I went through the code and the docs again and now I think I addressed everything you mentioned. Also please take a look at this comment here and let me know if I should revert it to how it was before. Other than that there are no major changes.

Also please take a look at the conversations I left unresolved, namely this and this. Thanks!

@andyElking
Copy link
Contributor Author

Sorry, I left a tiny problem in the test I added, I fixed it now.

diffrax/_solver/foster_langevin_srk.py Outdated Show resolved Hide resolved
docs/api/terms.md Outdated Show resolved Hide resolved
@patrick-kidger patrick-kidger changed the base branch from main to dev September 1, 2024 19:28
@patrick-kidger patrick-kidger merged commit fa7417f into patrick-kidger:dev Sep 1, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

patrick-kidger commented Sep 1, 2024

Aaaaand... merged! 🎉
Great job getting this one done, I'm really happy to have it in! :)

@andyElking
Copy link
Contributor Author

Thanks for bearing with me and taking your time Patrick!! I really appreciate it!

@andyElking andyElking deleted the langevin_pr branch September 26, 2024 10:21
patrick-kidger pushed a commit that referenced this pull request Nov 29, 2024
* Langevin PR

* Minor fixes

* removed the SORT solver (superseded by QUICSORT)

* made LangevinTerm.term a static field

* temporary fix for _term_compatible and LangevinTerm

* Fixed LangevinTerm YAAAYYYYY

* Nits

* Added Langevin docs, a Langevin example and backwards in time test

* Fixed Patrick's comments

* langevin -> underdamped_langevin

* round of small fixes

* check langevin drift term and diffusion term have same args

* added scan_trick in QUICSORT and ShOULD

* using RuntimeError for when ULD args have wrong structure

* small fixes
patrick-kidger added a commit that referenced this pull request Dec 6, 2024
* Langevin PR (#453)

* Langevin PR

* Minor fixes

* removed the SORT solver (superseded by QUICSORT)

* made LangevinTerm.term a static field

* temporary fix for _term_compatible and LangevinTerm

* Fixed LangevinTerm YAAAYYYYY

* Nits

* Added Langevin docs, a Langevin example and backwards in time test

* Fixed Patrick's comments

* langevin -> underdamped_langevin

* round of small fixes

* check langevin drift term and diffusion term have same args

* added scan_trick in QUICSORT and ShOULD

* using RuntimeError for when ULD args have wrong structure

* small fixes

* tidy-ups

* Split SDE tests in half, to try and avoid GitHub runner issues?

* Added effects_barrier to fix test issue with JAX 0.4.33+

* small fix of docs in all three and a return type in quicsort

* bump doc building pipeline

* Compatibility with JAX 0.4.36, which removes ConcreteArray

* using a fori_loop to save states in edge case t0==t1

* added case for saving t0 data, which was also not getting updated.

Added a test

* using while_loop, ran into issues with reverse-mode diff using the fori_loop

* bug fix for cases when t0=True

* simplified logic for saving, no loop necessary

* added vmap test

* using a fori_loop to save states in edge case t0==t1

* added case for saving t0 data, which was also not getting updated.

Added a test

* using while_loop, ran into issues with reverse-mode diff using the fori_loop

* bug fix for cases when t0=True

* simplified logic for saving, no loop necessary

* added vmap test

* fix t1 out of bounds issue

* fix for steps: don't want to update those values if t0==t1 since we didn't take any steps.

Added test

---------

Co-authored-by: Andraž Jelinčič <[email protected]>
Co-authored-by: Patrick Kidger <[email protected]>
Co-authored-by: andyElking <[email protected]>
patrick-kidger pushed a commit that referenced this pull request Dec 9, 2024
* Langevin PR

* Minor fixes

* removed the SORT solver (superseded by QUICSORT)

* made LangevinTerm.term a static field

* temporary fix for _term_compatible and LangevinTerm

* Fixed LangevinTerm YAAAYYYYY

* Nits

* Added Langevin docs, a Langevin example and backwards in time test

* Fixed Patrick's comments

* langevin -> underdamped_langevin

* round of small fixes

* check langevin drift term and diffusion term have same args

* added scan_trick in QUICSORT and ShOULD

* using RuntimeError for when ULD args have wrong structure

* small fixes
patrick-kidger added a commit that referenced this pull request Dec 9, 2024
* Langevin PR (#453)

* Langevin PR

* Minor fixes

* removed the SORT solver (superseded by QUICSORT)

* made LangevinTerm.term a static field

* temporary fix for _term_compatible and LangevinTerm

* Fixed LangevinTerm YAAAYYYYY

* Nits

* Added Langevin docs, a Langevin example and backwards in time test

* Fixed Patrick's comments

* langevin -> underdamped_langevin

* round of small fixes

* check langevin drift term and diffusion term have same args

* added scan_trick in QUICSORT and ShOULD

* using RuntimeError for when ULD args have wrong structure

* small fixes

* tidy-ups

* Split SDE tests in half, to try and avoid GitHub runner issues?

* Added effects_barrier to fix test issue with JAX 0.4.33+

* small fix of docs in all three and a return type in quicsort

* bump doc building pipeline

* Compatibility with JAX 0.4.36, which removes ConcreteArray

* using a fori_loop to save states in edge case t0==t1

* added case for saving t0 data, which was also not getting updated.

Added a test

* using while_loop, ran into issues with reverse-mode diff using the fori_loop

* bug fix for cases when t0=True

* simplified logic for saving, no loop necessary

* added vmap test

* using a fori_loop to save states in edge case t0==t1

* added case for saving t0 data, which was also not getting updated.

Added a test

* using while_loop, ran into issues with reverse-mode diff using the fori_loop

* bug fix for cases when t0=True

* simplified logic for saving, no loop necessary

* added vmap test

* fix t1 out of bounds issue

* fix for steps: don't want to update those values if t0==t1 since we didn't take any steps.

Added test

---------

Co-authored-by: Andraž Jelinčič <[email protected]>
Co-authored-by: Patrick Kidger <[email protected]>
Co-authored-by: andyElking <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants