Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for checkpointing #149

Open
pevnak opened this issue Sep 1, 2023 · 12 comments
Open

Adding support for checkpointing #149

pevnak opened this issue Sep 1, 2023 · 12 comments

Comments

@pevnak
Copy link
Contributor

pevnak commented Sep 1, 2023

I am copying that here from my post on slack, such that it does not get lost.

I think it might be worth to add a rudimentary support for checkpointing as

struct Checkpointed{S} <: Transformers.Layers.AbstractTransformerBlock
	f::S
end

Base.show(io::IO, c::Checkpointed) = print(io, c.f)

(m::Checkpointed)(args...) = Zygote.checkpointed(m.f, args...)

and then wrappend blocks to Checkpointed as

decoder = Transformers.Layers.Chain(Transformer(map(Checkpointed, decoder.layers[1].blocks)), decoder.layers[2]) 

and while it is probably not the nicest representation, it seems to work.
The running times are approximately 50% longer, which I think is correct since the the forward pass is need to do twice.

I do not know, if this is something that is wanted. If yes, I might try to add this as a more proper solution and improve it. Ideally, one would like to have an option to download the model from HF and add checkpointing. I think that HF has this option.

@chengchingwen
Copy link
Owner

Sounds good to have!

HF handle it in the forward method of hf-models (equiv. Layers.Transformer). I'm not sure Checkedpointed as AbstractTransformerBlock is the best place to add the checkpoint functionality. Some alternative ideas I currently have in mind:

  1. Generalized Checkedpointed{S} <: LayerStruct and overload Checkpointed{<:Transformer} to add checkpoint per blocks.
  2. Modify Layers.applyblocks to allow hooks and use Zygote.checkpointed as the hook function.
  3. Similar to 2. but provide a HookedTransformerBlock <: AbstractTransformerBlock.

The wrapping function can be implemented with postwalk like the Layers.set_dropout.

@pevnak
Copy link
Contributor Author

pevnak commented Sep 1, 2023

I will look at your suggestions. Checkpointed as a AbstractTransformerBlock was quick and dirty trick. I like the postwalk trick.

@ToucheSir
Copy link

One thing you'll want to think about is stateful layers like Dropout and BatchNorm which would not behave the same in subsequent calls. For the former I think some mechanism to snapshot RNG state would be required, and for the latter maybe an explicit overload?

@chengchingwen
Copy link
Owner

It seems the problem is that we cannot know if a Dropout or BatchNorm is executed under checkpointed environment?

@pevnak
Copy link
Contributor Author

pevnak commented Sep 1, 2023

@ToucheSir I have not thought about this. Is there still switch to toggle train and test mode? That would effectively solve the problem.

@chengchingwen
Copy link
Owner

That doesn't sounds the same. One could always completely turn off all dropouts, but normally we would want the checkpoint computed with the same dropout state as the first forward call so that the gradient with or without checkpoint are the same.

@ToucheSir
Copy link

If the pullback is only called once, I believe BatchNorm and co should actually not require any special handling. Otherwise, the approach would be to traverse the model looking for these layers, saving their current train/test status, doing the checkpointing and then restoring the saved status.

As Peter notes, Dropout is trickier because you still need the RNG state around to create a mask. The most straightforward solution using struct Checkpointed would be to recurse through the model looking for Dropout layers and snapshotting their RNG state beforehand. Then that state can be restored whenever the checkpointing runs. I haven't quite thought about how this interacts with RNGs shared between layers (as is the default), but that should be solvable.

Medium-long term, we may want to consider a mechanism like https://github.com/vchuravy/ScopedValues.jl for exposing whether checkpointing is currently active in Flux itself. Then layers can query that info and change their behaviour accordingly without a wrapper.

@chengchingwen
Copy link
Owner

@ToucheSir I wonder if we could subtype the AContext in Zygote for a CheckpointedContext and overload the pullback behavior for dropout or so?

@ToucheSir
Copy link

Maybe, but generally we'd like to avoid coupling Flux to Zygote wherever possible (e.g. no custom pullbacks).

@chengchingwen
Copy link
Owner

I would say that only need to couple NNlib to Zygote since dropout is moved out from Flux.

@ToucheSir
Copy link

Yeah, NNlib has no dep (hard or weak) on Zygote right now and it'd be better to keep it that way. Porting Zygote.checkpoint to use the ChainRules API shouldn't be an issue, just need to decide if it lives in Flux or NNlib.

@pevnak
Copy link
Contributor Author

pevnak commented Sep 4, 2023

These are good points @ToucheSir. I will come to this in two weeks timeframe, I am a bit busy now with academic stuff.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants