From ed3e6f3eaa86cada4b38d0fa34881db4bf880a46 Mon Sep 17 00:00:00 2001 From: Fabio Kepler Date: Tue, 12 May 2020 14:54:51 +0100 Subject: [PATCH 01/13] Add an additional attribute to ModelCheckpoint to keep track of the best model's path Currently, only the best metric value is directly tracked. This new attribute will help in uses cases where the trained model needs to be used or tracked right after training. --- pytorch_lightning/callbacks/model_checkpoint.py | 4 +++- pytorch_lightning/trainer/training_io.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a65855e49cda5..497fe1aff01db 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -112,6 +112,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve # {filename: monitor} self.kth_best_model = '' self.best = 0 + self.best_model = '' self.save_function = None torch_inf = torch.tensor(np.Inf) @@ -265,7 +266,8 @@ def _do_check_save(self, filepath, current, epoch): self.kth_value = self.best_k_models[self.kth_best_model] _op = min if self.mode == 'min' else max - self.best = _op(self.best_k_models.values()) + self.best_model = _op(self.best_k_models, key=self.best_k_models.get) + self.best = self.best_k_models[self.best_model] if self.verbose > 0: log.info( diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 11771b21961ed..ac65a75afbfdc 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -315,6 +315,7 @@ def dump_checkpoint(self, weights_only: bool = False): if not weights_only: if self.checkpoint_callback: checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best + checkpoint['checkpoint_callback_best_model'] = self.checkpoint_callback.best_model if self.early_stop_callback: checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait @@ -398,10 +399,11 @@ def restore_training_state(self, checkpoint): ' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.' ) - if self.checkpoint_callback is not None and self.checkpoint_callback is not False: + if self.checkpoint_callback: self.checkpoint_callback.best = checkpoint['checkpoint_callback_best'] + self.checkpoint_callback.best_model = checkpoint['checkpoint_callback_best_model'] - if self.early_stop_callback is not None and self.early_stop_callback is not False: + if self.early_stop_callback: self.early_stop_callback.wait = checkpoint['early_stop_callback_wait'] self.early_stop_callback.patience = checkpoint['early_stop_callback_patience'] From 33ea2759467d5426f544a4f1c15d3a698013573f Mon Sep 17 00:00:00 2001 From: Fabio Natanael Kepler Date: Wed, 27 May 2020 11:56:16 +0100 Subject: [PATCH 02/13] Add small description and usage example to docs --- pytorch_lightning/callbacks/model_checkpoint.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 497fe1aff01db..716c7f88c3d03 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -20,7 +20,10 @@ class ModelCheckpoint(Callback): r""" - Save the model after every epoch. + Save the model after every epoch if it improves. + + After training finishes, use :attr:`best_model` to retrieve the path to the + best checkpoint file. Args: filepath: path to save the model file. @@ -79,6 +82,13 @@ class ModelCheckpoint(Callback): >>> checkpoint_callback = ModelCheckpoint( ... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}' ... ) + + # retrieve the best checkpoint after training + >>> checkpoint_callback = ModelCheckpoint(filepath='my/path/') + >>> trainer = Trainer(checkpoint_callback=checkpoint_callback) + >>> model = ... + >>> trainer.fit(model) + >>> print(checkpoint_callback.best_model) """ From 5fe9a582c4b9b9d30491a5f5c66a0220f26f13df Mon Sep 17 00:00:00 2001 From: Fabio Natanael Kepler Date: Wed, 27 May 2020 12:44:11 +0100 Subject: [PATCH 03/13] Fix PEP8 issues --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 716c7f88c3d03..dcb98f10ec72f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -21,7 +21,7 @@ class ModelCheckpoint(Callback): r""" Save the model after every epoch if it improves. - + After training finishes, use :attr:`best_model` to retrieve the path to the best checkpoint file. @@ -82,7 +82,7 @@ class ModelCheckpoint(Callback): >>> checkpoint_callback = ModelCheckpoint( ... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}' ... ) - + # retrieve the best checkpoint after training >>> checkpoint_callback = ModelCheckpoint(filepath='my/path/') >>> trainer = Trainer(checkpoint_callback=checkpoint_callback) From 7a92c0038f3b883cfb29607bd4f262338851ec87 Mon Sep 17 00:00:00 2001 From: Fabio Natanael Kepler Date: Wed, 27 May 2020 13:03:15 +0100 Subject: [PATCH 04/13] Fix doctest example --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index dcb98f10ec72f..981a35decbdf4 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -86,8 +86,8 @@ class ModelCheckpoint(Callback): # retrieve the best checkpoint after training >>> checkpoint_callback = ModelCheckpoint(filepath='my/path/') >>> trainer = Trainer(checkpoint_callback=checkpoint_callback) - >>> model = ... - >>> trainer.fit(model) + >>> # model = ... + >>> # trainer.fit(model) >>> print(checkpoint_callback.best_model) """ From 1434249596a1c73631cdf4c839835fb2454b325b Mon Sep 17 00:00:00 2001 From: Fabio Kepler Date: Thu, 28 May 2020 09:10:34 +0100 Subject: [PATCH 05/13] Fix expected output in doctest --- pytorch_lightning/callbacks/model_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 981a35decbdf4..ca54a32d1fb0d 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -89,6 +89,7 @@ class ModelCheckpoint(Callback): >>> # model = ... >>> # trainer.fit(model) >>> print(checkpoint_callback.best_model) + """ From 648c152aea63f1f8f88cda2d761ca93f051ae042 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 28 May 2020 10:18:07 +0200 Subject: [PATCH 06/13] Apply suggestions from code review --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index ca54a32d1fb0d..c1a1c5b1c2103 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -88,8 +88,8 @@ class ModelCheckpoint(Callback): >>> trainer = Trainer(checkpoint_callback=checkpoint_callback) >>> # model = ... >>> # trainer.fit(model) - >>> print(checkpoint_callback.best_model) - + >>> checkpoint_callback.best_model + '' """ From 8c594803456653b458002a2afab468e895c5dff4 Mon Sep 17 00:00:00 2001 From: Fabio Natanael Kepler Date: Thu, 28 May 2020 10:02:50 +0100 Subject: [PATCH 07/13] Show example as code block instead of doctest --- pytorch_lightning/callbacks/model_checkpoint.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index c1a1c5b1c2103..6df6335a8cc62 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -84,12 +84,11 @@ class ModelCheckpoint(Callback): ... ) # retrieve the best checkpoint after training - >>> checkpoint_callback = ModelCheckpoint(filepath='my/path/') - >>> trainer = Trainer(checkpoint_callback=checkpoint_callback) - >>> # model = ... - >>> # trainer.fit(model) - >>> checkpoint_callback.best_model - '' + checkpoint_callback = ModelCheckpoint(filepath='my/path/') + trainer = Trainer(checkpoint_callback=checkpoint_callback) + # model = ... + # trainer.fit(model) + checkpoint_callback.best_model """ From 0ee63a9baf77320ee961944e1aacf51115a85e31 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 28 May 2020 16:04:38 +0200 Subject: [PATCH 08/13] Apply suggestions from code review --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 6df6335a8cc62..4182f0339b02f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -86,8 +86,8 @@ class ModelCheckpoint(Callback): # retrieve the best checkpoint after training checkpoint_callback = ModelCheckpoint(filepath='my/path/') trainer = Trainer(checkpoint_callback=checkpoint_callback) - # model = ... - # trainer.fit(model) + model = ... + trainer.fit(model) checkpoint_callback.best_model """ From 24c27a077b4780ef656f399032925ecd5e839738 Mon Sep 17 00:00:00 2001 From: Fabio Kepler Date: Thu, 28 May 2020 15:30:11 +0100 Subject: [PATCH 09/13] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ad8752763794..689d488f33126 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)). +- Attribute `best_model` to `ModelCheckpoint` for storing and later retrieving the path to the best saved model file ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799)) + ### Changed - Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729)) From 3f4901dc00a48a098c6d2d543654918e40cb18fe Mon Sep 17 00:00:00 2001 From: Fabio Kepler Date: Fri, 29 May 2020 17:45:58 +0100 Subject: [PATCH 10/13] Rename `ModelCheckpoint.best` to `ModelCheckpoint.best_model_score` Also rename `ModelCheckpoint.best_model` (added in this PR) to `ModelCheckpoint.best_model_path`, for consistency, and `kth_best_model` to `kth_best_model_path`. --- CHANGELOG.md | 4 +- .../callbacks/model_checkpoint.py | 42 ++++++++++++------- pytorch_lightning/trainer/training_io.py | 11 +++-- 3 files changed, 37 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 689d488f33126..5f5d2c8a0f243 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)). -- Attribute `best_model` to `ModelCheckpoint` for storing and later retrieving the path to the best saved model file ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799)) +- Attribute `best_model_path` to `ModelCheckpoint` for storing and later retrieving the path to the best saved model file ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799)) ### Changed @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed non-finite values from loss in `LRFinder` ([#1862](https://github.com/PyTorchLightning/pytorch-lightning/pull/1862)) +- Renamed `ModelCheckpoint`'s attributes `best` to `best_model_score` and `kth_best_model` to `kth_best_model_path` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799)) + ### Deprecated ### Removed diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4182f0339b02f..adb19ead6e02f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -22,8 +22,8 @@ class ModelCheckpoint(Callback): r""" Save the model after every epoch if it improves. - After training finishes, use :attr:`best_model` to retrieve the path to the - best checkpoint file. + After training finishes, use :attr:`best_model_path` to retrieve the path to the + best checkpoint file and :attr:`best_model_score` to retrieve its score. Args: filepath: path to save the model file. @@ -88,7 +88,7 @@ class ModelCheckpoint(Callback): trainer = Trainer(checkpoint_callback=checkpoint_callback) model = ... trainer.fit(model) - checkpoint_callback.best_model + checkpoint_callback.best_model_path """ @@ -120,9 +120,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve self.prefix = prefix self.best_k_models = {} # {filename: monitor} - self.kth_best_model = '' - self.best = 0 - self.best_model = '' + self.kth_best_model_path = '' + self.best_model_score = 0 + self.best_model_path = '' self.save_function = None torch_inf = torch.tensor(np.Inf) @@ -140,6 +140,18 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve self.kth_value, self.mode = mode_dict[mode] + @property + def best(self): + rank_zero_warn("attribute `best` has been renamed to `best_model_score` since v0.8.0" + " and will be removed in v1.0.0", DeprecationWarning) + return self.best_model_score + + @property + def kth_best_model(self): + rank_zero_warn("attribute `kth_best_model` has been renamed to `kth_best_model_path` since v0.8.0" + " and will be removed in v1.0.0", DeprecationWarning) + return self.kth_best_model_path + def _del_model(self, filepath): if os.path.isfile(filepath): os.remove(filepath) @@ -171,7 +183,7 @@ def check_monitor_top_k(self, current): "max": torch.gt, }[self.mode] - return monitor_op(current, self.best_k_models[self.kth_best_model]) + return monitor_op(current, self.best_k_models[self.kth_best_model_path]) def format_checkpoint_name(self, epoch, metrics, ver=None): """Generate a filename according to the defined template. @@ -263,26 +275,26 @@ def _do_check_save(self, filepath, current, epoch): del_list = [] if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0: - delpath = self.kth_best_model - self.best_k_models.pop(self.kth_best_model) + delpath = self.kth_best_model_path + self.best_k_models.pop(self.kth_best_model_path) del_list.append(delpath) self.best_k_models[filepath] = current if len(self.best_k_models) == self.save_top_k: # monitor dict has reached k elements _op = max if self.mode == 'min' else min - self.kth_best_model = _op(self.best_k_models, - key=self.best_k_models.get) - self.kth_value = self.best_k_models[self.kth_best_model] + self.kth_best_model_path = _op(self.best_k_models, + key=self.best_k_models.get) + self.kth_value = self.best_k_models[self.kth_best_model_path] _op = min if self.mode == 'min' else max - self.best_model = _op(self.best_k_models, key=self.best_k_models.get) - self.best = self.best_k_models[self.best_model] + self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) + self.best_model_score = self.best_k_models[self.best_model_path] if self.verbose > 0: log.info( f'\nEpoch {epoch:05d}: {self.monitor} reached' - f' {current:0.5f} (best {self.best:0.5f}), saving model to' + f' {current:0.5f} (best {self.best_model_score:0.5f}), saving model to' f' {filepath} as top {self.save_top_k}') self._save_model(filepath) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index ac65a75afbfdc..fbb435770de08 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -314,8 +314,8 @@ def dump_checkpoint(self, weights_only: bool = False): if not weights_only: if self.checkpoint_callback: - checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best - checkpoint['checkpoint_callback_best_model'] = self.checkpoint_callback.best_model + checkpoint['checkpoint_callback_best_model_score'] = self.checkpoint_callback.best_model_score + checkpoint['checkpoint_callback_best_model_path'] = self.checkpoint_callback.best_model_path if self.early_stop_callback: checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait @@ -400,8 +400,11 @@ def restore_training_state(self, checkpoint): ) if self.checkpoint_callback: - self.checkpoint_callback.best = checkpoint['checkpoint_callback_best'] - self.checkpoint_callback.best_model = checkpoint['checkpoint_callback_best_model'] + if 'checkpoint_callback_best' in checkpoint: + self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best'] + else: + self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best_model_score'] + self.checkpoint_callback.best_model_path = checkpoint['checkpoint_callback_best_model_path'] if self.early_stop_callback: self.early_stop_callback.wait = checkpoint['early_stop_callback_wait'] From f09372f8cf47d351540b450caad07bcceeb0443e Mon Sep 17 00:00:00 2001 From: Fabio Natanael Kepler Date: Fri, 29 May 2020 18:58:08 +0100 Subject: [PATCH 11/13] Update pytorch_lightning/trainer/training_io.py Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/training_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 4cb576ce87a06..f694233b8a269 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -407,7 +407,7 @@ def restore_training_state(self, checkpoint): self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best'] else: self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best_model_score'] - self.checkpoint_callback.best_model_path = checkpoint['checkpoint_callback_best_model_path'] + self.checkpoint_callback.best_model_path = checkpoint.get('checkpoint_callback_best_model_path') if self.early_stop_callback: self.early_stop_callback.wait = checkpoint['early_stop_callback_wait'] From bc2baad43be8294a980a567e9148a2cc98f320b5 Mon Sep 17 00:00:00 2001 From: Fabio Natanael Kepler Date: Fri, 29 May 2020 18:59:22 +0100 Subject: [PATCH 12/13] Apply suggestions from code review Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 ++ pytorch_lightning/callbacks/model_checkpoint.py | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b7a294e6ca62..d6c59a1a09114 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `ModelCheckpoint`'s attributes `best` and `kth_best_model` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799)) + - Dropped official support/testing for older PyTorch versions <1.3 ([#1917](https://github.com/PyTorchLightning/pytorch-lightning/pull/1917)) ### Removed diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e0ab8e2c2ca57..9336fe309889a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -144,14 +144,14 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve @property def best(self): - rank_zero_warn("attribute `best` has been renamed to `best_model_score` since v0.8.0" - " and will be removed in v1.0.0", DeprecationWarning) + rank_zero_warn("Attribute `best` has been renamed to `best_model_score` since v0.8.0" + " and will be removed in v0.10.0", DeprecationWarning) return self.best_model_score @property def kth_best_model(self): - rank_zero_warn("attribute `kth_best_model` has been renamed to `kth_best_model_path` since v0.8.0" - " and will be removed in v1.0.0", DeprecationWarning) + rank_zero_warn("Attribute `kth_best_model` has been renamed to `kth_best_model_path` since v0.8.0" + " and will be removed in v0.10.0", DeprecationWarning) return self.kth_best_model_path def _del_model(self, filepath): From ac36651d25b2aa7a963873e3a6469852275cd56d Mon Sep 17 00:00:00 2001 From: Fabio Kepler Date: Sun, 31 May 2020 09:20:22 +0100 Subject: [PATCH 13/13] Add warning when loading checkpoint from an old version --- pytorch_lightning/trainer/training_io.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index f694233b8a269..fd0385cde4b34 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -403,11 +403,16 @@ def restore_training_state(self, checkpoint): ) if self.checkpoint_callback: - if 'checkpoint_callback_best' in checkpoint: - self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best'] - else: + if 'checkpoint_callback_best_model_score' in checkpoint: self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best_model_score'] - self.checkpoint_callback.best_model_path = checkpoint.get('checkpoint_callback_best_model_path') + else: + # Old naming until version 0.7.6 + rank_zero_warn( + 'Loading a checkpoint created with an old version of Lightning; ' + 'this will not be supported in the future.' + ) + self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best'] + self.checkpoint_callback.best_model_path = checkpoint['checkpoint_callback_best_model_path'] if self.early_stop_callback: self.early_stop_callback.wait = checkpoint['early_stop_callback_wait']