-
Notifications
You must be signed in to change notification settings - Fork 75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for checkpointing #149
Comments
Sounds good to have! HF handle it in the forward method of hf-models (equiv.
The wrapping function can be implemented with |
I will look at your suggestions. Checkpointed as a |
One thing you'll want to think about is stateful layers like Dropout and BatchNorm which would not behave the same in subsequent calls. For the former I think some mechanism to snapshot RNG state would be required, and for the latter maybe an explicit overload? |
It seems the problem is that we cannot know if a Dropout or BatchNorm is executed under checkpointed environment? |
@ToucheSir I have not thought about this. Is there still switch to toggle train and test mode? That would effectively solve the problem. |
That doesn't sounds the same. One could always completely turn off all dropouts, but normally we would want the checkpoint computed with the same dropout state as the first forward call so that the gradient with or without checkpoint are the same. |
If the pullback is only called once, I believe BatchNorm and co should actually not require any special handling. Otherwise, the approach would be to traverse the model looking for these layers, saving their current train/test status, doing the checkpointing and then restoring the saved status. As Peter notes, Dropout is trickier because you still need the RNG state around to create a mask. The most straightforward solution using Medium-long term, we may want to consider a mechanism like https://github.com/vchuravy/ScopedValues.jl for exposing whether checkpointing is currently active in Flux itself. Then layers can query that info and change their behaviour accordingly without a wrapper. |
@ToucheSir I wonder if we could subtype the AContext in Zygote for a CheckpointedContext and overload the pullback behavior for dropout or so? |
Maybe, but generally we'd like to avoid coupling Flux to Zygote wherever possible (e.g. no custom pullbacks). |
I would say that only need to couple NNlib to Zygote since dropout is moved out from Flux. |
Yeah, NNlib has no dep (hard or weak) on Zygote right now and it'd be better to keep it that way. Porting |
These are good points @ToucheSir. I will come to this in two weeks timeframe, I am a bit busy now with academic stuff. |
I am copying that here from my post on slack, such that it does not get lost.
I think it might be worth to add a rudimentary support for checkpointing as
and then wrappend blocks to Checkpointed as
and while it is probably not the nicest representation, it seems to work.
The running times are approximately 50% longer, which I think is correct since the the forward pass is need to do twice.
I do not know, if this is something that is wanted. If yes, I might try to add this as a more proper solution and improve it. Ideally, one would like to have an option to download the model from HF and add checkpointing. I think that HF has this option.
The text was updated successfully, but these errors were encountered: