-
-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problems with progress bar and jax.grad
#396
Comments
Thanks for the report! Just letting you know that I'm not ignoring this -- I'm hoping to have a fix lined up in the next week or so. |
This PR makes it so that both the `TextProgressMeter` and the `TqdmProgressMeter` both work through the progress bar manager. In particular this means that we will not get any printout on the backward pass for the `TextProgressMeter` (as its `meter_idx` is deleted during `close`, called at the end of the forward pass), whereas previously we would get a binomial printout as the backward pass worked through its recursive checkpointing. This means we now have consistency with `TqdmProgressMeter`, for which there is less of a clear way for any providing any possible out during the backward pass. (I'd be open to changing the above -- adding options for printout during backpropagation -- but this would probably be fairly tricky. I could see it requiring defining a new primitive, called at the start of integration, for closing the "backward bar", whilst still remaining JVP-compatible.) In addition, we now have a `try`/`except` around the bar lookup during `step`. This means we do not raise an error when the bar does not exist during backpropagation. Depends on patrick-kidger/equinox#697.
Alright, should be fixed in #398. (Which also requires a development version of Equinox.) So (a) WDYT, and (b) can you check that this works for your actual use-case? If it does, then I'll look to merge. In particular I'm looking to do a new release of Equinox in the next few days, so that change at least will become public. (Diffrax will probably be a few more weeks before the next release, as I'd still like to get #344 in for the next release.) |
Thanks for taking a look! |
Hello again, |
Thanks for flagging that up! I've just pushed another commit to the same branch, that I think I should fix things. I also figured out how to write some tests for tqdm. |
I just tried it on my MWE, it seems to fully work now. Thank you 🙏 |
Hello @patrick-kidger
I'm coming back concerning the progress bar feature. It seems to be broken when computing the gradient (thanks @gautierronan for the MWE):
As you can see, the errors are not the same depending on the type of progress bar we use.
TextProgressMeter
JAX tries to compute the gradient with respect to the
progress
variable (ref). It can be solved fairly easily by addingprogress = jax.lax.stop_gradient(progress)
.When fixed, we see a first progress bar from 0 to 100 for the backward pass. For the backward pass we have a second progress bar that goes from 100 to 0. It would be nice to have it go from 0 to 100 too.
TqdmProgressMeter
This one is more subtle: diffrax tries to use twice the same
tqdm
progress bar for the forward and the backward pass. The backward pass crashes because the progress bar was thrown away during the forward pass.Proposed fixes
For the text progress bar, we would need to know whether we are in the forward or backward pass. I don't know if this is doable as of today and I see 2 solutions :
progress
decreases in the progress barsstep
function. This feels a bit hacky but limits the footprint on the code.For the tqdm progress bar the solution is to instanciate a new progress bar before doing the backward pass. I'm blocking a bit here because I don't find were diffrax tells JAX how to compute the gradient. I would have expected
filter_custom_vjp
(ref) to fill this role but this part of the code is not used for the MWE. Could you point me to the part of the code were this is done please? Once I understand this I can probably quickly send a fix for both issues.The text was updated successfully, but these errors were encountered: