From 1af53688c738d4244f403857dda5822715e300dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 15 Sep 2021 13:20:46 +0200 Subject: [PATCH] update test and fix restarting logic --- .../loops/optimization/optimizer_loop.py | 3 ++- tests/loops/test_loops.py | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index b76e935f5edd6..aaad5c50ebf94 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -204,7 +204,7 @@ def connect(self, **kwargs: "Loop") -> None: raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") def reset(self) -> None: - if not self.restarting or self.done: + if not self.restarting: self.optim_progress.optimizer_position = 0 self.outputs = [[] for _ in range(len(self.trainer.optimizers))] @@ -226,6 +226,7 @@ def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignor self.optim_progress.optimizer_position += 1 def on_run_end(self) -> _OUTPUTS_TYPE: + self.optim_progress.optimizer_position = 0 outputs, self.outputs = self.outputs, [] # free memory return outputs diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 1a308fec983b4..efae7a783c87e 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -348,11 +348,11 @@ def _assert_optimizer_sequence(method_mock, expected): assert sequence == expected num_optimizers_incomplete = stop_epoch * n_batches * n_optimizers + stop_batch * n_optimizers + stop_optimizer - opt_idx_sequence_full = list(range(n_optimizers)) * n_epochs * n_batches # [0, 1, 2, 0, 1, 2, 0, 1, ...] + opt_idx_sequence_complete = list(range(n_optimizers)) * n_epochs * n_batches # [0, 1, 2, 0, 1, 2, 0, 1, ...] # +1 because we fail inside the closure inside optimizer_step() - opt_idx_sequence_incomplete = opt_idx_sequence_full[: (num_optimizers_incomplete + 1)] - opt_idx_sequence_resumed = opt_idx_sequence_full[num_optimizers_incomplete:] + opt_idx_sequence_incomplete = opt_idx_sequence_complete[: (num_optimizers_incomplete + 1)] + opt_idx_sequence_resumed = opt_idx_sequence_complete[num_optimizers_incomplete:] class MultipleOptimizerModel(BoringModel): def __init__(self): @@ -385,8 +385,8 @@ def configure_optimizers(self): logger=False, ) trainer.fit(model) - weights_success = model.parameters() - _assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_full) + weights_complete = model.parameters() + _assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_complete) # simulate a failure fail = True @@ -424,7 +424,7 @@ def configure_optimizers(self): trainer.fit(model) weights_resumed = model.parameters() - for w0, w1 in zip(weights_success, weights_resumed): + for w0, w1 in zip(weights_complete, weights_resumed): assert torch.allclose(w0, w1) _assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_resumed)