Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems with progress bar and jax.grad #396

Closed
abocquet opened this issue Apr 3, 2024 · 6 comments
Closed

Problems with progress bar and jax.grad #396

abocquet opened this issue Apr 3, 2024 · 6 comments
Labels
bug Something isn't working

Comments

@abocquet
Copy link
Contributor

abocquet commented Apr 3, 2024

Hello @patrick-kidger

I'm coming back concerning the progress bar feature. It seems to be broken when computing the gradient (thanks @gautierronan for the MWE):

import diffrax as dx
import jax
import jax.numpy as jnp

def todiff(p):
    f = lambda t, y, args: -y
    term = dx.ODETerm(f)
    solver = dx.Dopri5()
    y0 = p * jnp.array([2.0, 3.0])

    # Using TextProgressMeter
    # NotImplementedError: Differentiation rule for 'unvmap_max' not implemented
    #progress_meter = dx.TextProgressMeter()

    # Using TqdmProgressMeter
    # XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error:
    # KeyError: 8
    progress_meter = dx.TqdmProgressMeter()

    solution = dx.diffeqsolve(
        term, solver, t0=0, t1=1, dt0=0.1, y0=y0, progress_meter=progress_meter
    )
    return solution.ys[-1, 0]

todiff(jnp.array(1.0)) # OK
jax.grad(todiff)(jnp.array(1.0)) # breaks

As you can see, the errors are not the same depending on the type of progress bar we use.

TextProgressMeter

JAX tries to compute the gradient with respect to the progress variable (ref). It can be solved fairly easily by adding progress = jax.lax.stop_gradient(progress).

When fixed, we see a first progress bar from 0 to 100 for the backward pass. For the backward pass we have a second progress bar that goes from 100 to 0. It would be nice to have it go from 0 to 100 too.

TqdmProgressMeter

This one is more subtle: diffrax tries to use twice the same tqdm progress bar for the forward and the backward pass. The backward pass crashes because the progress bar was thrown away during the forward pass.

Proposed fixes

For the text progress bar, we would need to know whether we are in the forward or backward pass. I don't know if this is doable as of today and I see 2 solutions :

  • give a flag to the loop function to know whether we are in a backward pass
  • check if progress decreases in the progress bars step function. This feels a bit hacky but limits the footprint on the code.

For the tqdm progress bar the solution is to instanciate a new progress bar before doing the backward pass. I'm blocking a bit here because I don't find were diffrax tells JAX how to compute the gradient. I would have expected filter_custom_vjp (ref) to fill this role but this part of the code is not used for the MWE. Could you point me to the part of the code were this is done please? Once I understand this I can probably quickly send a fix for both issues.

@patrick-kidger
Copy link
Owner

Thanks for the report! Just letting you know that I'm not ignoring this -- I'm hoping to have a fix lined up in the next week or so.
I think (but have not yet checked) that this is something best handled in Equinox -- I think we can arrange for equinox.internal.while_loop to be slightly more careful about where it propagates JVP tracers. I'll keep you posted!

@patrick-kidger patrick-kidger added the bug Something isn't working label Apr 9, 2024
patrick-kidger added a commit that referenced this issue Apr 9, 2024
This PR makes it so that both the `TextProgressMeter` and the `TqdmProgressMeter` both work through the progress bar manager. In particular this means that we will not get any printout on the backward pass for the `TextProgressMeter` (as its `meter_idx` is deleted during `close`, called at the end of the forward pass), whereas previously we would get a binomial printout as the backward pass worked through its recursive checkpointing. This means we now have consistency with `TqdmProgressMeter`, for which there is less of a clear way for any providing any possible out during the backward pass.

(I'd be open to changing the above -- adding options for printout during backpropagation -- but this would probably be fairly tricky. I could see it requiring defining a new primitive, called at the start of integration, for closing the "backward bar", whilst still remaining JVP-compatible.)

In addition, we now have a `try`/`except` around the bar lookup during `step`. This means we do not raise an error when the bar does not exist during backpropagation.

Depends on patrick-kidger/equinox#697.
@patrick-kidger
Copy link
Owner

Alright, should be fixed in #398. (Which also requires a development version of Equinox.)

So (a) WDYT, and (b) can you check that this works for your actual use-case?

If it does, then I'll look to merge. In particular I'm looking to do a new release of Equinox in the next few days, so that change at least will become public. (Diffrax will probably be a few more weeks before the next release, as I'd still like to get #344 in for the next release.)

@abocquet
Copy link
Contributor Author

Thanks for taking a look!
nonbatchable and unvmap_* is still a bit of dark magic for me but otherwise everything looks good to me.
The integration in dynamiqs works well too

@abocquet
Copy link
Contributor Author

Hello again,
Before we can close this issue, I tested the code a bit more. When using vmap, I see that the progress bar does not update with the tqdm option but it works fine with the text option. I'll try to investigate it and I'll keep you updated

@patrick-kidger
Copy link
Owner

Thanks for flagging that up! I've just pushed another commit to the same branch, that I think I should fix things. I also figured out how to write some tests for tqdm.
So take 2 -- let me know if this works for you!

@abocquet
Copy link
Contributor Author

I just tried it on my MWE, it seems to fully work now. Thank you 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants