Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Support checkpointing multiple callbacks of the same type #6467

Closed
ananthsub opened this issue Mar 10, 2021 · 27 comments · Fixed by #7187
Closed

[RFC] Support checkpointing multiple callbacks of the same type #6467

ananthsub opened this issue Mar 10, 2021 · 27 comments · Fixed by #7187
Assignees
Labels
callback design Includes a design discussion feature Is an improvement or enhancement help wanted Open to be worked on
Milestone

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Mar 10, 2021

🚀 Feature

Currently when dumping the checkpoint dict, we overwrite callback states if there are multiple callbacks of the same type: https://github.com/PyTorchLightning/pytorch-lightning/blob/1c013b43e049d423184323313014774738827c90/pytorch_lightning/trainer/callback_hook.py#L209-L224

Motivation

This blocks #2908 and is an existing bug or at least unexpected behavior with checkpointing callback states today.

Pitch

[WIP]

  1. We add a state_identifier property to the base Callback class
@property
def state_identifier(self) -> str:
    return ""
  • TODO: should this raise NotImplementedError? Reasoning: as it stands, we'd call this only when saving/loading callback states. Could we be even more selective and call this only when we have multiple callback instances of the same type? Is that complexity warranted?
  1. Callback implementations define this property based on the business logic of the class in order to determine uniqueness, while also preserving flexibility for development.

For instance

class ModelCheckpoint(Callback):

...

@property
def state_identifier(self):
    return f"monitor={self.monitor}"  # handle special case where this is not available, include other params from here

TODO: provide better default implementation based on args here

  1. At save time, we still partition the callback state dicts by the type, but include the state_identifier to further disambiguate.
# we include the checkpoint dict in order to look at the lightning version or other metadata that can be used to better preserve backwards compatibility
def _compute_callback_state_key(self, checkpoint: Dict[str, Any], callback: Callback) -> str:
    return type(callback) + callback.state_identifier()
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
        """Called when saving a model checkpoint."""
        callback_states = {}
        for callback in self.callbacks:
            if self.__is_old_signature(callback.on_save_checkpoint):
                rank_zero_warn(
                    "`Callback.on_save_checkpoint` signature has changed in v1.3."
                    " A `checkpoint` parameter has been added."
                    " Support for the old signature will be removed in v1.5", DeprecationWarning
                )
                state = callback.on_save_checkpoint(self, self.lightning_module)  # noqa: parameter-unfilled
            else:
                state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
            if state:
                key = self._compute_callback_state_key(checkpoint, callback)
                callback_states[key] = state
        return callback_states
  1. At load time, we do something similar:
def on_load_checkpoint(self, checkpoint):
        """Called when loading a model checkpoint."""
        callback_states = checkpoint.get('callbacks')
        # Todo: the `callback_states` are dropped with TPUSpawn as they
        # can't be saved using `xm.save`
        # https://github.com/pytorch/xla/issues/2773
        if callback_states is not None:
            for callback in self.callbacks:
                key = self._compute_callback_state_key(checkpoint, callback)
                state = callback_states.get(key)
                if state:
                    state = deepcopy(state)
                    callback.on_load_checkpoint(state)

Alternatives

  1. Why don't Callback implementations simply hash the value of the constructor args?

from @awaelchli : The problematic part is around Trainer(resume_from_checkpoint=..., callbacks=[...]), which would restore the trainer state plus callback state. If the source code changes, what will happen? Well in the normal case you wouldn't modify the code to perfectly resume training, right? But I think it is still worth discussing what kind of flexibility we could allow, if any? For instance, if we hashed all the constructor args, a callback that changes a verbose flag from False to True would no longer be able to access its previously checkpointed state.

Additional context

@ananthsub ananthsub added feature Is an improvement or enhancement help wanted Open to be worked on design Includes a design discussion callback labels Mar 10, 2021
@awaelchli
Copy link
Contributor

My idea is to use an identifier for each callback instance and derive a unique key when saving. This would allow us to restore from a checkpoint even if the user has changed the callback list.
In ModelCheckpoint, the unique identifier could be derived from the monitor key automatically. Same for early stopping and other callbacks that safe state.

Rough sketch:

class Callback:

@property
def identifier()  # or name or whatever
    return self.__class__.__name__

class ModelCheckpoint(Callback):

@property
def identfier():
    return self.monitor  # handle special case where this is not available

When we init the trainer, we check that all identifiers are unique, otherwise warn user they need to set it (we could provide a setter too).

@carmocca
Copy link
Contributor

@awaelchli Continuing #2908 (comment) discussion here:

Yes it's true for loading but just saving the checkpoints with multiple instances should work today?

The last checkpoint in self.callbacks of the same type will be the one to get saved

In ModelCheckpoint, the unique identifier could be derived from the monitor key automatically. Same for early stopping and other callbacks that safe state.

But multiple checkpoints can share the same monitor, for example, you could choose to early stop after a threshold or if you notice a large regression in performance (both monitoring the same quantity)

@ananthsub
Copy link
Contributor Author

ananthsub commented Mar 10, 2021

I agree that the identifier should be on the callback interface.

For model checkpointing, this could be a hash of the arguments passed into the constructor.

We'd need to resolve backwards compatibility for existing checkpoints as well with this pitch

@awaelchli
Copy link
Contributor

But multiple checkpoints can share the same monitor, for example, you could choose to early stop after a threshold or if you notice a large regression in performance (both monitoring the same quantity)

In that case I would suggest to make the identifier also a setter. Then in the Trainer init, we can check if the identifier is unique and if not, we kindly ask the user to set the identifier to some meaningful, unique name. When we can't automatically determine a unique key, I think it would be reasonable to ask the user to disambiguate. What do you think? Essentially this is what the user can do today, they could subclass the Checkpoint class just to make it have a new name.

@ananthsub
Copy link
Contributor Author

@awaelchli @carmocca does this identifier need to be human readable? why couldn't it be an arbitrary string? in that case, wouldn't a hash be sufficient to avoid collisions?

@carmocca
Copy link
Contributor

carmocca commented Mar 10, 2021

I like the idea of generating a hash off the __init__ parameters. I don't see why anybody would need to customize it to something readable.

Although this couldn't live in the base Callback class, each subclass would need to override to generate it from its own init arguments. Not all arguments are hashable

We need to handle BC both in terms of (1) old checkpoint running on new code and (2) new checkpoint running on old code

@awaelchli
Copy link
Contributor

awaelchli commented Mar 10, 2021

I like the idea of generating a hash off the init parameters.

With this you make the assumption that the user does not change the arguments when resuming. Does this assumption hold for all callback types and use cases?

@carmocca
Copy link
Contributor

With this you make the assumption that the user does not change the arguments when resuming.

Are you saying if the user manually modifies the checkpoint? Before running load_from_checkpoint(checkpoint)?

@awaelchli
Copy link
Contributor

awaelchli commented Mar 10, 2021

No, when they manually modifiy their own source code they are resuming in.
The problematic part we are talking about here has nothing to do with load_from_checkpoint(checkpoint). This function does not restore Trainer state. This is about Trainer(resume_from_checkpoint=..., callbacks=[...]), which would restore the trainer state plus callbacks. If the source code changes, what will happen? Well in the normal case you wouldn't modify the code to perfectly resume training, right? But I think it is still worth discussing what kind of flexibility we could allow, if any?

@ananthsub
Copy link
Contributor Author

ananthsub commented Mar 10, 2021

  1. we'll need to use the identifier here to index into the callback states. As it stands, if the identifier of the newly instantiated callback is different from the one in the checkpoint, then we will skip loading the state.

I see, with the hash, it's not only if the user changes an argument, but even if the model checkpoint or any callback offered from Lightning directly changes the arguments, or default values, then we end up with a new hash. So preserving backwards compatibility from the framework side can be very tricky.

  1. Right now, people can configure multiple callbacks and we silently override the callback state. To me that's a bug. So the approach laid out here can be seen as an improvement / move towards strictness that was previously missing.

@awaelchli
Copy link
Contributor

awaelchli commented Mar 10, 2021

Well I think we just need to keep both the saving and loading in mind. The hashing as you suggested will certainly solve the saving bug you explain, about overwriting with the last state. However now if you go with the hashing approach, what will that mean for loading?
Say user trained with this callback: ModelCheckpoint(save_top_k=-1, verbose=True), and let's assume for simplicity that the hashing function maps to a key
save_top_k=-1__verbose=True

Now user will resume training but the source code has slightly changed, say verbose was set to False.
So when loading, we compare the hash of the current ModelCheckpoint save_top_k=-1__verbose=False with the one in the checkpoint, and it doesn't match. What do we do? We can't load the state, so it's as if the user dropped this callback.

So in my mind, there are some arguments that can change the hash value but others that should not matter to allow some flexibility and we will still be able to load. This is simply an argument against hashing all parameters.

@ananthsub
Copy link
Contributor Author

ananthsub commented Mar 10, 2021

So in my mind, there are some arguments that can change the hash value but others that should not matter to allow some flexibility and we will still be able to load. This is simply an argument against hashing all parameters.

+1

This still leaves open the question of how we can introduce, rename, or deprecate existing parameters in callbacks while preserving backwards compatibility. Taking #6146 as an example: this would certainly change the identifier, which means old checkpoints that had model checkpoint callback state would not be loaded. Is that reasonable? And maybe Lightning provides tools to help upgrade checkpoints?

@ananthsub
Copy link
Contributor Author

ananthsub commented Mar 10, 2021

@awaelchli Do you expect that the identifier would be used outside of the checkpoint save/load? Another option is that when we save/load checkpoint, we still partition by the type, and the identifier is postfixed like this here:

def _compute_key(callback: Callback) -> str:
    return type(callback) + callback.identifier

Saving:

state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) 
key = _compute_key(callback)
callback_states[key] = state 

Loading

for callback in self.callbacks:
    key = _compute_key(callback)
    state = callback_states.get(key)
    if state:
        state = deepcopy(state)
        callback.on_load_checkpoint(state)

then the identifier can focus on the specific business logic/internal state

@awaelchli
Copy link
Contributor

awaelchli commented Mar 10, 2021

@awaelchli Do you expect that the identifier would be used outside of the checkpoint save/load?

I don't think so and we could probably find a more suitable name (state_key, state_identifier, state_name, ...?).

postfixing sounds reasonable. However, keep in mind if the user renames their custom callback class (say a simple refactor) between the first run and then resuming from it, then they won't be able to load the state. I would say this is not a major concern, but just to mention this limitation.

For backward compatibility, with your approach we could do something like this:

def _compute_key(callback: Callback) -> str:
    key = type(callback)
    if saved_version >= "1.3": # the version of PL that saved the checkpoint
        # checkpoint uses new key format
        key += callback.identifier
    return key

We save the pl version into the checkpoint, so we will know which format was used to save the file.

@ananthsub ananthsub changed the title Support checkpointing multiple callbacks of the same type [RFC] Support checkpointing multiple callbacks of the same type Mar 11, 2021
@ananthsub ananthsub added this to the 1.3 milestone Mar 11, 2021
@ananthsub
Copy link
Contributor Author

@awaelchli @tchaton this identifier can also be used to disambiguate between callbacks passed to the trainer and callbacks configured from the lightning module

@ananthsub
Copy link
Contributor Author

ananthsub commented Mar 11, 2021

@awaelchli based on this setup, it looks like the callback implementation should be responsible for BC since it owns the state_identifier. As a result, the onus for preserving backwards compatibility falls to the callback implementors. Do you see this as significant issue with this approach? I don't see a way around it yet. Given callbacks are for non-essential code, and since that this hasn't been raised in a separate issue AFAICT, it might be very niche. But it's a fun thought exercise 😄

@Alexfinkelshtein
Copy link

@awaelchli @ananthsub , Hi I was facing a related issue, and just went through your discussion here and I want to suggest another angle.

I have a requirement to keep the saved checkpoints source code independent.
Currently with the current setup the callback type is used as a key for the state_dict['callbacks'], which makes the checkpoint, when using custom callbacks, source code dependent.
I think a solution for my issue might solve the problem described in this thread.

Namely somehow exposing the control over the key used for saving the callback under: callback_hook.py::TrainerCallbackHookMixin::on_save_checkpoint
Maybe by extending the identifier you described above to having an (identifier, callable, state) that user can output in a custom manner.
For example: by enabling variable number of outputs for callback.on_save_checkpoint such that if you output a single output it is the state (backward compatibility) but if you output 2\3 they might cover the initializer and identifier (somewhat similar to configure optimizer output handling)

I'll appreciate your thoughts

@awaelchli
Copy link
Contributor

I have encountered that problem too. As you say, the problem is that we dump type(callback), so the pickle depends on the source code of your custom callback and can't be loaded outside PL. I believe the hashing idea we talked about here can solve this problem. I plan to start experimenting with this soon.

For now, did you find a workaround for your case?

@Alexfinkelshtein
Copy link

Alexfinkelshtein commented Apr 7, 2021

@awaelchli , actually I didn't, I disabled the callback state saving for now.. which is to say the least sub optimal.
From what I understood the hashing idea was targeting the unique identifier problem, maybe I missed, but does it also somehow replace the type(callback)?

@awaelchli
Copy link
Contributor

awaelchli commented Apr 8, 2021

Yes, it would replace the type(callback) and thus the pickle dependency on the source code. However, if you have old checkpoints we will still have to load them using the old format, so it will solve your problem only for new checkpoints created from this point forward.

I started with drafting the PR #6886 here.
Suggestions of course welcome.

@awaelchli
Copy link
Contributor

awaelchli commented Apr 21, 2021

I thought a bit about backward compatibility. I propose to convert a checkpoint in Model.load_checkpoint to the current version by applying a set of migrations. Depending on the version the checkpoint was saved at, we apply a different migration to the next higher version.

Every time we introduce a backward incompatible change, we provide a migration targeting a specific version to upgrade. For example:

@Migration(target="1.2.0")
def upgrade_something(checkpoint: dict) -> dict:
    # Here take the checkpoint from version 1.2.0 and upgrade it to the next higher version 1.2.1
    # ...
    return checkpoint


@Migration(target="1.2.1")
def upgrade_something_else(checkpoint: dict) -> dict:
    # Here take the checkpoint from version 1.2.1 and upgrade it to the next higher version 1.2.2
    # ...
    return checkpoint

The Migration decorator registers the function in a dictionary.
For most versions, we don't have a change, so these upgrades will become an identity function.
It leads to a chain of migrations:

def upgrade_checkpoint(checkpoint: dict) -> dict:
    for migration in all_migrations.values():
        if migration is None:
            checkpoint = default_migration(checkpoint)  # the only thing this does is bump the version.
        checkpoint = migration(checkpoint)
    return checkpoint

In LightningModule.load_from_checkpoint() we call this upgrade function after we loaded the checkpoint.
We end up with a checkpoint that is standardized to the current version of the installed Lightning.

What do you think of this approach? @carmocca @ananthsub

@carmocca
Copy link
Contributor

Do you expect users to use these decorators themselves?

I like the idea but I would go with the simple yet effective:

def upgrade_checkpoint(checkpoint):
    if checkpoint["__version__"] == "...":
        ...
    if checkpoint["__version__"] == "...":
        ...
    ...
    return checkpoint

Instead of these decorators, given that it's not common at all that we break ckpt compatibility.
It is fine though if we expect users to register upgrade functions themselves.

Also, maybe instead of updating the original version field, we should add a upgraded_version so we can know which version actually created the checkpoint and to what version this checkpoint is adapted (if it has been).

@awaelchli
Copy link
Contributor

Do you expect users to use these decorators themselves?

No, just for internal use. We don't want users to interfere here. This is an improvement internally, because we have hardcoded backward compatible loading logic in the core. And for the support of multiple callbacks discussed here, I need a sane, standardized way to migrate.

it's not common at all that we break ckpt compatibility.

It's about not bloating the internals of lightning with logic that is just there to provide backward compatibility for loading an old checkpoint. I propose here to put this logic in one place and take care of it at the loading stage, so that the rest of lightning assumes the same format everywhere. 2nd, I propose to upgrade iteratively from one version to the other. This makes it easier to add on every time we make a change in a new version.

@carmocca
Copy link
Contributor

After a bit of offline chat, we have different opinions:

Adrian proposes always enforcing in-between version steps:

upgrades = {
"1.2.0": some_function,
"1.2.1": lambda x: x, 
"1.2.2": some_other_upgrade,
…
}

Where we have to update this every release.

On, the other hand, I propose avoiding the no-op version checks:

def should_update(version: str, target: str) -> bool:
    return compare_version(version, operator.lt, target)
    

def upgrade_checkpoint(checkpoint):
    # set the default compat version for old checkpoints
    checkpoint.setdefault("compat_version", checkpoint["version"])

    if should_update(checkpoint["compat_version"], "1.2.0"):
        update_to_1_2_0(checkpoint)
        assert checkpoint["compat_version"] == "1.2.0"
        
    if should_update(checkpoint["compat_version"], "1.2.7"):
        update_to_1_2_7(checkpoint)
        assert checkpoint["compat_version"] == "1.2.7"
    
    # No need to update compat_version higher than 1.2.7
    # if there are no breaking changes

    whatever_else()
    
    return checkpoint

And adapting the code if necessary in a future release

Both options imply incremental updates to the version. The difference is whether or not in-between no-op steps should be enforced.

@ananthsub
Copy link
Contributor Author

I prefer @carmocca's approach to avoid the boilerplate across versions. I think it's a lower initial investment, and if we find that it's lacking, we can then shift to @awaelchli 's proposal if we find we require stricter checkpoint upgrades

@awaelchli
Copy link
Contributor

I'm implementing your suggestion here in this PR #7161. However, I disagree with all your observations.

Carlos:

Adrian proposes always enforcing in-between version steps:

upgrades = {
"1.2.0": some_function,
"1.2.1": lambda x: x,
"1.2.2": some_other_upgrade,

}

The steps in between would of course be implied, not explicitly be defined by us.

Where we have to update this every release.

Only when a release is changing the checkpoint format => 0 effort for most releases.

Ananth:

I prefer @carmocca's approach to avoid the boilerplate across versions.

This has more boilerplate code. It is repeated and is more error prone. It is exactly what I have avoided in my proposal API.
Happy to iterate on either side.

Current plan:
[1 / 3] #6886, adds the string identifier for callbacks (fixes pickle issue), can be merged as part of 1.3 if desired.
[2 / 3] #7161, adds migration logic necessary to keep backward compat for step 3/3
[3 / 3] will add unique string identifiers for callbacks as discussed here.

@ananthsub
Copy link
Contributor Author

The steps in between would of course be implied, not explicitly be defined by us.

I misunderstood this point thinking that it would be explicitly defined. This sounds great @awaelchli !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
callback design Includes a design discussion feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants