Skip to content

Commit

Permalink
update test and fix restarting logic
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Sep 15, 2021
1 parent 4674f83 commit 1af5368
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
3 changes: 2 additions & 1 deletion pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def reset(self) -> None:
if not self.restarting or self.done:
if not self.restarting:
self.optim_progress.optimizer_position = 0
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]

Expand All @@ -226,6 +226,7 @@ def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignor
self.optim_progress.optimizer_position += 1

def on_run_end(self) -> _OUTPUTS_TYPE:
self.optim_progress.optimizer_position = 0
outputs, self.outputs = self.outputs, [] # free memory
return outputs

Expand Down
12 changes: 6 additions & 6 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,11 @@ def _assert_optimizer_sequence(method_mock, expected):
assert sequence == expected

num_optimizers_incomplete = stop_epoch * n_batches * n_optimizers + stop_batch * n_optimizers + stop_optimizer
opt_idx_sequence_full = list(range(n_optimizers)) * n_epochs * n_batches # [0, 1, 2, 0, 1, 2, 0, 1, ...]

opt_idx_sequence_complete = list(range(n_optimizers)) * n_epochs * n_batches # [0, 1, 2, 0, 1, 2, 0, 1, ...]
# +1 because we fail inside the closure inside optimizer_step()
opt_idx_sequence_incomplete = opt_idx_sequence_full[: (num_optimizers_incomplete + 1)]
opt_idx_sequence_resumed = opt_idx_sequence_full[num_optimizers_incomplete:]
opt_idx_sequence_incomplete = opt_idx_sequence_complete[: (num_optimizers_incomplete + 1)]
opt_idx_sequence_resumed = opt_idx_sequence_complete[num_optimizers_incomplete:]

class MultipleOptimizerModel(BoringModel):
def __init__(self):
Expand Down Expand Up @@ -385,8 +385,8 @@ def configure_optimizers(self):
logger=False,
)
trainer.fit(model)
weights_success = model.parameters()
_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_full)
weights_complete = model.parameters()
_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_complete)

# simulate a failure
fail = True
Expand Down Expand Up @@ -424,7 +424,7 @@ def configure_optimizers(self):
trainer.fit(model)
weights_resumed = model.parameters()

for w0, w1 in zip(weights_success, weights_resumed):
for w0, w1 in zip(weights_complete, weights_resumed):
assert torch.allclose(w0, w1)

_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_resumed)
Expand Down

0 comments on commit 1af5368

Please sign in to comment.