Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multiple optimizer restart with fault-tolerant training #9537

Merged
merged 21 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953))
* Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))
* Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401))
* Added support for restarting an optimizer loop (multiple optimizers) ([#9537](https://github.com/PyTorchLightning/pytorch-lightning/pull/9537))


- Checkpoint saving & loading extensibility:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def reset(self) -> None:
if not self.restarting or self.done:
if not self.restarting:
self.optim_progress.optimizer_position = 0
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]

Expand All @@ -226,6 +226,7 @@ def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignor
self.optim_progress.optimizer_position += 1

def on_run_end(self) -> _OUTPUTS_TYPE:
self.optim_progress.optimizer_position = 0
carmocca marked this conversation as resolved.
Show resolved Hide resolved
outputs, self.outputs = self.outputs, [] # free memory
return outputs

Expand Down
139 changes: 125 additions & 14 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from dataclasses import dataclass
from typing import Any, Dict, Iterator
from unittest import mock
from unittest.mock import ANY
from unittest.mock import ANY, Mock

import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loops import Loop, TrainingBatchLoop
from pytorch_lightning.trainer.progress import BaseProgress
Expand Down Expand Up @@ -251,11 +251,11 @@ def on_load_checkpoint(self, state_dict: Dict) -> None:
assert state_dict == {"state_dict": {"a": 1}, "progress": {"increment": 1}}


@RunIf(min_torch="1.7.0")
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize("stop_epoch", (1, 2))
@pytest.mark.parametrize("stop_batch", (1, 2))
@pytest.mark.parametrize("n_dataloaders,stop_dataloader", [(2, 0), (2, 1), (3, 2)])
@RunIf(min_torch="1.7.0")
def test_loop_restart_progress_multiple_dataloaders(tmpdir, n_dataloaders, stop_dataloader, stop_epoch, stop_batch):
n_batches = 5
n_epochs = 3
Expand Down Expand Up @@ -284,10 +284,8 @@ def val_dataloader(self):
)

# simulate a failure
try:
with pytest.raises(CustomException):
trainer.fit(model)
except CustomException:
pass

ckpt_path = str(tmpdir / ".pl_auto_save.ckpt")
checkpoint = torch.load(ckpt_path)["loops"]["fit_loop"]
Expand Down Expand Up @@ -334,13 +332,112 @@ def val_dataloader(self):
assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected


@RunIf(min_torch="1.7.0")
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize("stop_epoch", (0, 1))
@pytest.mark.parametrize("stop_batch", (0, 1, 2))
@pytest.mark.parametrize("n_optimizers,stop_optimizer", [(2, 0), (2, 1), (3, 2)])
def test_loop_restart_progress_multiple_optimizers(tmpdir, n_optimizers, stop_optimizer, stop_epoch, stop_batch):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
n_batches = 3
n_epochs = 2
fail = False

def _assert_optimizer_sequence(method_mock, expected):
positional_args = [c[0] for c in method_mock.call_args_list]
sequence = [arg[3] for arg in positional_args]
assert sequence == expected

num_optimizers_incomplete = stop_epoch * n_batches * n_optimizers + stop_batch * n_optimizers + stop_optimizer

opt_idx_sequence_complete = list(range(n_optimizers)) * n_epochs * n_batches # [0, 1, 2, 0, 1, 2, 0, 1, ...]
# +1 because we fail inside the closure inside optimizer_step()
opt_idx_sequence_incomplete = opt_idx_sequence_complete[: (num_optimizers_incomplete + 1)]
opt_idx_sequence_resumed = opt_idx_sequence_complete[num_optimizers_incomplete:]

class MultipleOptimizerModel(BoringModel):
def __init__(self):
super().__init__()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def training_step(self, batch, batch_idx, optimizer_idx):
if (
fail
and self.current_epoch == stop_epoch
and batch_idx == stop_batch
and optimizer_idx == stop_optimizer
):
raise CustomException
return super().training_step(batch, batch_idx)

def configure_optimizers(self):
return [torch.optim.SGD(self.parameters(), lr=0.1) for _ in range(n_optimizers)]

# run without a failure, collect weights
fail = False
carmocca marked this conversation as resolved.
Show resolved Hide resolved
seed_everything(0)
model = MultipleOptimizerModel()
model.training_epoch_end = None
model.optimizer_step = Mock(wraps=model.optimizer_step)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
num_sanity_val_steps=0,
logger=False,
)
trainer.fit(model)
weights_complete = model.parameters()
_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_complete)

# simulate a failure
fail = True
seed_everything(0)
model = MultipleOptimizerModel()
model.training_epoch_end = None
model.optimizer_step = Mock(wraps=model.optimizer_step)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
num_sanity_val_steps=0,
logger=False,
)
with pytest.raises(CustomException):
trainer.fit(model)

_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_incomplete)

# resume from failure and collect weights
fail = False
seed_everything(0)
model = MultipleOptimizerModel()
model.training_epoch_end = None
model.optimizer_step = Mock(wraps=model.optimizer_step)
trainer = Trainer(
resume_from_checkpoint=str(tmpdir / ".pl_auto_save.ckpt"),
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
num_sanity_val_steps=0,
logger=False,
)
trainer.fit(model)
weights_resumed = model.parameters()

# check that the final weights of a resumed run match the weights of a run that never failed
for w0, w1 in zip(weights_complete, weights_resumed):
assert torch.allclose(w0, w1)

_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_resumed)


@RunIf(min_torch="1.7.0")
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize("accumulate_grad_batches", (1, 2, 3))
@pytest.mark.parametrize("n_optimizers", (1, 3, 5))
@pytest.mark.parametrize("stop_epoch", (1, 2))
@pytest.mark.parametrize("stop_batch", (1, 2))
@pytest.mark.parametrize("stop_optimizer", (1, 2))
@RunIf(min_torch="1.7.0")
def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch, stop_optimizer, n_optimizers, tmpdir):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
stop_optimizer = stop_optimizer if stop_optimizer < n_optimizers else 0
n_epochs = 3
Expand Down Expand Up @@ -382,10 +479,8 @@ def configure_optimizers_multiple(self):
)

# simulate a failure
try:
with pytest.raises(CustomException):
trainer.fit(model)
except CustomException:
pass

ckpt_path = str(tmpdir / ".pl_auto_save.ckpt")
assert os.path.exists(ckpt_path)
Expand Down Expand Up @@ -506,19 +601,35 @@ def configure_optimizers_multiple(self):
# need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the
# fit loop to have an iterator, which is only available during training
checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = ANY

assert state_dict == checkpoint["loops"]["fit_loop"]

trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
# with `restart_progress=True`, we expect all `ready` counters to be reset to `completed`
tchaton marked this conversation as resolved.
Show resolved Hide resolved
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=True)

epoch_progress = trainer.fit_loop.epoch_progress
assert epoch_progress.current.ready == stop_epoch
assert epoch_progress.current.completed == stop_epoch

batch_progress = trainer.fit_loop.epoch_loop.batch_progress
assert batch_progress.current.ready == be_batches_completed
assert batch_progress.current.completed == be_batches_completed

optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress
assert optim_progress.optimizer.step.current.ready == be_total_opt_steps
assert optim_progress.optimizer.step.current.completed == be_total_opt_steps
assert optim_progress.optimizer.zero_grad.current.ready == be_total_zero_grad
assert optim_progress.optimizer.zero_grad.current.completed == be_total_zero_grad

# not sure what the point of these assertions is
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
state_dict = trainer.fit_loop.state_dict()
assert state_dict != checkpoint["loops"]["fit_loop"]
assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1
assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch


@RunIf(min_torch="1.7.0")
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize("n_optimizers", (1, 3, 5))
@RunIf(min_torch="1.7.0")
def test_loop_state_on_complete_run(n_optimizers, tmpdir):
n_epochs = 3
n_batches = 3
Expand Down Expand Up @@ -647,8 +758,8 @@ def configure_optimizers_multiple(self):
assert checkpoint["loops"]["fit_loop"] == expected


@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@RunIf(min_torch="1.7.0")
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_fit_loop_reset(tmpdir):
"""Test that the reset logic in fit- and epoch loop is aware of whether the loop is restarting from a completed
loop or from a mid-epoch checkpoint."""
Expand Down