Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add required states for resumed ModelCheckpoint GC #10995

Merged
merged 15 commits into from
Dec 20, 2021
Merged

Add required states for resumed ModelCheckpoint GC #10995

merged 15 commits into from
Dec 20, 2021

Conversation

ORippler
Copy link
Contributor

@ORippler ORippler commented Dec 8, 2021

What does this PR do?

Fixes #4911
Related: #5090

Currently, when resuming training the internal states required for continued ModelCheckpointing are neither saved nor restored. This leads to the fact that k new checkpoints are always generated due to this check. These new checkpoints are properly gced/compared to against, but the old ones are not.

Note that this PR does not handle overrides of monitor, dirpath or mode, as also referred to in #4911

Does your PR introduce any breaking changes? If yes, please list them.

It might be that resuming training fails now if it did not fail before, if the paths were changed in the mean time (refer also #4911). I did not check/test for this, but confirmed that resumption of GC now works properly

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃
cc @carmocca @awaelchli @ninginthecloud @jjenniferdai

Copy link
Member

@justusschock justusschock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ORippler thanks for this fix. To avoid regression again, we need a test for this.

Do you think the following does reflect this issue sufficiently (If so, feel free to take it and commit it directly to your branch):

def test_model_checkpoint_attributes(tmpdir):
    seed_everything()
    model = LogInTwoMethods()

    epochs = 2
    checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1, save_last=True)
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[checkpoint_callback],
        limit_train_batches=10,
        limit_val_batches=10,
        max_epochs=epochs,
        logger=False,
    )

    trainer.fit(model)

    checkpoint = torch.load(os.path.join(tmpdir, 'last.ckpt'))['callbacks'][checkpoint_callback.state_key]
    for k in ("best_models, kth_best_model_path", "kth_value", "last_model_path"):
        assert checkpoint[k] == getattr(checkpoint_callback, k)

@ORippler
Copy link
Contributor Author

ORippler commented Dec 8, 2021

@ORippler thanks for this fix. To avoid regression again, we need a test for this.

Do you think the following does reflect this issue sufficiently (If so, feel free to take it and commit it directly to your branch):

def test_model_checkpoint_attributes(tmpdir):
    seed_everything()
    model = LogInTwoMethods()

    epochs = 2
    checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1, save_last=True)
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[checkpoint_callback],
        limit_train_batches=10,
        limit_val_batches=10,
        max_epochs=epochs,
        logger=False,
    )

    trainer.fit(model)

    checkpoint = torch.load(os.path.join(tmpdir, 'last.ckpt'))['callbacks'][checkpoint_callback.state_key]
    for k in ("best_models, kth_best_model_path", "kth_value", "last_model_path"):
        assert checkpoint[k] == getattr(checkpoint_callback, k)

How would this integrate with the different functionality tests for ModelCheckpoint the overall testing framework ? I see many tests checking whether attributes are written to the ckpt properly, but not whether they are loaded. For example here, we check for current_score being written to disk/to the ckpt, but the current on_load_checkpoint never sets this attribute again.

Is this not something we want to test also? Am a bit confused here.

@justusschock
Copy link
Member

@ORippler that's true. I suppose for end-to-end testing you would have to extend that test by resuming it at a freshly created trainer and examining the properties there.

@carmocca do we want to test for different parametrizations of the callback here?

Note that we do not yet check for proper loading/reinstantiation of
ModelCheckpooint based on the ckpt written to disk
@ORippler
Copy link
Contributor Author

ORippler commented Dec 9, 2021

@ORippler that's true. I suppose for end-to-end testing you would have to extend that test by resuming it at a freshly created trainer and examining the properties there.

@carmocca do we want to test for different parametrizations of the callback here?

I added your test and expanded it to check whether a freshly instantiated ModelCheckpoint also loads the properties.
This is still not a fully functional test though.

Off-Note: Do we have a test that compares for equivalence of results generated by one continuous training run and an interrupted one that is resumed by passing the checkpoint to trainer.fit ? This would be nice to have imo

@justusschock
Copy link
Member

justusschock commented Dec 9, 2021

@ORippler we have https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/trainer/test_trainer.py#L399 which doesn't check the results. However, this is on purpose as the checkpoint does not include any random state and thus continuing from the checkpoint doesn't have to yield the exact same results (different random states when using the global rng for example) until now (this is currently in development).

cc @tchaton to add a similar test once fault tolerance is ready

@tchaton
Copy link
Contributor

tchaton commented Dec 9, 2021

Hey @ORippler,

Yes, we have multiple tests checking the weights are the same before and after for Fault Tolerance. Here they are: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/utilities/test_auto_restart.py

@justusschock justusschock added the bug Something isn't working label Dec 10, 2021
ORippler and others added 2 commits December 15, 2021 18:46
`ModelCheckpoint` is configured to save after every epoch,
but `trainer.fit` is called with `max_steps = 1`

Note there may be a better way of doing this, where `ModelCheckpoint`
is called after `training_step`
@codecov
Copy link

codecov bot commented Dec 16, 2021

Codecov Report

Merging #10995 (3d7994a) into master (e19d93f) will decrease coverage by 4%.
The diff coverage is 100%.

@@           Coverage Diff            @@
##           master   #10995    +/-   ##
========================================
- Coverage      92%      88%    -4%     
========================================
  Files         177      177            
  Lines       16502    16560    +58     
========================================
- Hits        15173    14604   -569     
- Misses       1329     1956   +627     

@mergify mergify bot added the ready PRs ready to be merged label Dec 17, 2021
* First save, then load ckpt.
* Instantiate ModelCheckpoint twice.
Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

tests/checkpointing/test_model_checkpoint.py Show resolved Hide resolved
@justusschock justusschock merged commit 86a3c5e into Lightning-AI:master Dec 20, 2021
awaelchli added a commit that referenced this pull request Dec 21, 2021
Co-authored-by: Justus Schock <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
lexierule pushed a commit that referenced this pull request Dec 21, 2021
Co-authored-by: Justus Schock <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
@rohitgr7 rohitgr7 mentioned this pull request Feb 7, 2022
12 tasks
Comment on lines +347 to +350
"best_k_models": self.best_k_models,
"kth_best_model_path": self.kth_best_model_path,
"kth_value": self.kth_value,
"last_model_path": self.last_model_path,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed an issue with doing this.

Since we save each "ModelCheckpoint" mode sequentally, these attributes will not be correct depending on the order if more than 1 mode triggers a save for the same global step:

https://github.com/PyTorchLightning/pytorch-lightning/blob/fe940e195dceb18eb9f3bd512cea56ae3405d464/pytorch_lightning/callbacks/model_checkpoint.py#L366-L373

Currently, a "top-k" checkpoint will not include the last_model_path path even if it's saved right after for this global step.

I'm not sure what would be the best solution here. I think we should start recommending multiple ModelCheckpoint instances as a best practice because these interactions between flags can be unintuitive.

cc @awaelchli @ananthsub @jjenniferdai
Related to #4335 and #11805 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working callback: model checkpoint ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ModelCheckpoint Callback save and restore extension
6 participants