-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Support checkpointing multiple callbacks of the same type #6467
Comments
My idea is to use an identifier for each callback instance and derive a unique key when saving. This would allow us to restore from a checkpoint even if the user has changed the callback list. Rough sketch: class Callback:
@property
def identifier() # or name or whatever
return self.__class__.__name__
class ModelCheckpoint(Callback):
@property
def identfier():
return self.monitor # handle special case where this is not available When we init the trainer, we check that all identifiers are unique, otherwise warn user they need to set it (we could provide a setter too). |
@awaelchli Continuing #2908 (comment) discussion here:
The last checkpoint in
But multiple checkpoints can share the same monitor, for example, you could choose to early stop after a threshold or if you notice a large regression in performance (both monitoring the same quantity) |
I agree that the identifier should be on the callback interface. For model checkpointing, this could be a hash of the arguments passed into the constructor. We'd need to resolve backwards compatibility for existing checkpoints as well with this pitch |
In that case I would suggest to make the identifier also a setter. Then in the Trainer init, we can check if the identifier is unique and if not, we kindly ask the user to set the identifier to some meaningful, unique name. When we can't automatically determine a unique key, I think it would be reasonable to ask the user to disambiguate. What do you think? Essentially this is what the user can do today, they could subclass the Checkpoint class just to make it have a new name. |
@awaelchli @carmocca does this identifier need to be human readable? why couldn't it be an arbitrary string? in that case, wouldn't a hash be sufficient to avoid collisions? |
I like the idea of generating a hash off the Although this couldn't live in the base We need to handle BC both in terms of (1) old checkpoint running on new code and (2) new checkpoint running on old code |
With this you make the assumption that the user does not change the arguments when resuming. Does this assumption hold for all callback types and use cases? |
Are you saying if the user manually modifies the checkpoint? Before running |
No, when they manually modifiy their own source code they are resuming in. |
I see, with the hash, it's not only if the user changes an argument, but even if the model checkpoint or any callback offered from Lightning directly changes the arguments, or default values, then we end up with a new hash. So preserving backwards compatibility from the framework side can be very tricky.
|
Well I think we just need to keep both the saving and loading in mind. The hashing as you suggested will certainly solve the saving bug you explain, about overwriting with the last state. However now if you go with the hashing approach, what will that mean for loading? Now user will resume training but the source code has slightly changed, say verbose was set to False. So in my mind, there are some arguments that can change the hash value but others that should not matter to allow some flexibility and we will still be able to load. This is simply an argument against hashing all parameters. |
+1 This still leaves open the question of how we can introduce, rename, or deprecate existing parameters in callbacks while preserving backwards compatibility. Taking #6146 as an example: this would certainly change the identifier, which means old checkpoints that had model checkpoint callback state would not be loaded. Is that reasonable? And maybe Lightning provides tools to help upgrade checkpoints? |
@awaelchli Do you expect that the
Saving:
Loading
then the identifier can focus on the specific business logic/internal state |
I don't think so and we could probably find a more suitable name (state_key, state_identifier, state_name, ...?). postfixing sounds reasonable. However, keep in mind if the user renames their custom callback class (say a simple refactor) between the first run and then resuming from it, then they won't be able to load the state. I would say this is not a major concern, but just to mention this limitation. For backward compatibility, with your approach we could do something like this: def _compute_key(callback: Callback) -> str:
key = type(callback)
if saved_version >= "1.3": # the version of PL that saved the checkpoint
# checkpoint uses new key format
key += callback.identifier
return key We save the pl version into the checkpoint, so we will know which format was used to save the file. |
@awaelchli @tchaton this identifier can also be used to disambiguate between callbacks passed to the trainer and callbacks configured from the lightning module |
@awaelchli based on this setup, it looks like the callback implementation should be responsible for BC since it owns the |
@awaelchli @ananthsub , Hi I was facing a related issue, and just went through your discussion here and I want to suggest another angle. I have a requirement to keep the saved checkpoints source code independent. Namely somehow exposing the control over the key used for saving the callback under: callback_hook.py::TrainerCallbackHookMixin::on_save_checkpoint I'll appreciate your thoughts |
I have encountered that problem too. As you say, the problem is that we dump For now, did you find a workaround for your case? |
@awaelchli , actually I didn't, I disabled the callback state saving for now.. which is to say the least sub optimal. |
Yes, it would replace the type(callback) and thus the pickle dependency on the source code. However, if you have old checkpoints we will still have to load them using the old format, so it will solve your problem only for new checkpoints created from this point forward. I started with drafting the PR #6886 here. |
I thought a bit about backward compatibility. I propose to convert a checkpoint in Every time we introduce a backward incompatible change, we provide a migration targeting a specific version to upgrade. For example: @Migration(target="1.2.0")
def upgrade_something(checkpoint: dict) -> dict:
# Here take the checkpoint from version 1.2.0 and upgrade it to the next higher version 1.2.1
# ...
return checkpoint
@Migration(target="1.2.1")
def upgrade_something_else(checkpoint: dict) -> dict:
# Here take the checkpoint from version 1.2.1 and upgrade it to the next higher version 1.2.2
# ...
return checkpoint The def upgrade_checkpoint(checkpoint: dict) -> dict:
for migration in all_migrations.values():
if migration is None:
checkpoint = default_migration(checkpoint) # the only thing this does is bump the version.
checkpoint = migration(checkpoint)
return checkpoint In What do you think of this approach? @carmocca @ananthsub |
Do you expect users to use these decorators themselves? I like the idea but I would go with the simple yet effective: def upgrade_checkpoint(checkpoint):
if checkpoint["__version__"] == "...":
...
if checkpoint["__version__"] == "...":
...
...
return checkpoint Instead of these decorators, given that it's not common at all that we break ckpt compatibility. Also, maybe instead of updating the original version field, we should add a |
No, just for internal use. We don't want users to interfere here. This is an improvement internally, because we have hardcoded backward compatible loading logic in the core. And for the support of multiple callbacks discussed here, I need a sane, standardized way to migrate.
It's about not bloating the internals of lightning with logic that is just there to provide backward compatibility for loading an old checkpoint. I propose here to put this logic in one place and take care of it at the loading stage, so that the rest of lightning assumes the same format everywhere. 2nd, I propose to upgrade iteratively from one version to the other. This makes it easier to add on every time we make a change in a new version. |
After a bit of offline chat, we have different opinions: Adrian proposes always enforcing in-between version steps: upgrades = {
"1.2.0": some_function,
"1.2.1": lambda x: x,
"1.2.2": some_other_upgrade,
…
} Where we have to update this every release. On, the other hand, I propose avoiding the no-op version checks: def should_update(version: str, target: str) -> bool:
return compare_version(version, operator.lt, target)
def upgrade_checkpoint(checkpoint):
# set the default compat version for old checkpoints
checkpoint.setdefault("compat_version", checkpoint["version"])
if should_update(checkpoint["compat_version"], "1.2.0"):
update_to_1_2_0(checkpoint)
assert checkpoint["compat_version"] == "1.2.0"
if should_update(checkpoint["compat_version"], "1.2.7"):
update_to_1_2_7(checkpoint)
assert checkpoint["compat_version"] == "1.2.7"
# No need to update compat_version higher than 1.2.7
# if there are no breaking changes
whatever_else()
return checkpoint And adapting the code if necessary in a future release Both options imply incremental updates to the version. The difference is whether or not in-between no-op steps should be enforced. |
I prefer @carmocca's approach to avoid the boilerplate across versions. I think it's a lower initial investment, and if we find that it's lacking, we can then shift to @awaelchli 's proposal if we find we require stricter checkpoint upgrades |
I'm implementing your suggestion here in this PR #7161. However, I disagree with all your observations. Carlos:
The steps in between would of course be implied, not explicitly be defined by us.
Only when a release is changing the checkpoint format => 0 effort for most releases. Ananth:
This has more boilerplate code. It is repeated and is more error prone. It is exactly what I have avoided in my proposal API. Current plan: |
I misunderstood this point thinking that it would be explicitly defined. This sounds great @awaelchli ! |
🚀 Feature
Currently when dumping the checkpoint dict, we overwrite callback states if there are multiple callbacks of the same type: https://github.com/PyTorchLightning/pytorch-lightning/blob/1c013b43e049d423184323313014774738827c90/pytorch_lightning/trainer/callback_hook.py#L209-L224
Motivation
This blocks #2908 and is an existing bug or at least unexpected behavior with checkpointing callback states today.
Pitch
[WIP]
state_identifier
property to the baseCallback
classNotImplementedError
? Reasoning: as it stands, we'd call this only when saving/loading callback states. Could we be even more selective and call this only when we have multiple callback instances of the same type? Is that complexity warranted?For instance
TODO: provide better default implementation based on args here
state_identifier
to further disambiguate.Alternatives
from @awaelchli : The problematic part is around
Trainer(resume_from_checkpoint=..., callbacks=[...])
, which would restore the trainer state plus callback state. If the source code changes, what will happen? Well in the normal case you wouldn't modify the code to perfectly resume training, right? But I think it is still worth discussing what kind of flexibility we could allow, if any? For instance, if we hashed all the constructor args, a callback that changes averbose
flag fromFalse
toTrue
would no longer be able to access its previously checkpointed state.Additional context
The text was updated successfully, but these errors were encountered: