Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multiple optimizer restart with fault-tolerant training #9537

Merged
merged 21 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953))
* Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))
* Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401))
* Added support for restarting an optimizer loop (multiple optimizers) ([#9537](https://github.com/PyTorchLightning/pytorch-lightning/pull/9537))


- Checkpoint saving & loading extensibility:
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def reset(self) -> None:
if not self.restarting or self.done:
if not self.restarting:
# when reset() is called from outside (manually), we reset the loop progress
self.optim_progress.optimizer_position = 0
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]

Expand All @@ -203,6 +204,8 @@ def on_run_start( # type: ignore[override]
) -> None:
self._batch_idx = batch_idx
self._indices, self._optimizers = zip(*optimizers)
if self.done:
self.optim_progress.optimizer_position = 0

def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
result = self._run_optimization(
Expand Down
115 changes: 114 additions & 1 deletion tests/loops/optimization/test_optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
from unittest.mock import Mock

import pytest
import torch
from torch.optim import Adam, SGD

from pytorch_lightning import Trainer
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.optimization.optimizer_loop import ClosureResult
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf


def test_closure_result_deepcopy():
Expand Down Expand Up @@ -127,3 +130,113 @@ def configure_optimizers(self):
assert all(isinstance(opt, LightningOptimizer) for opt in pl_optimizer_sequence)
optimizer_sequence = [opt._optimizer.__class__.__name__ for opt in pl_optimizer_sequence]
assert list(zip(opt_idx_sequence, optimizer_sequence)) == expected


class CustomException(Exception):
pass


@RunIf(min_torch="1.7.0")
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize("stop_epoch", (0, 1))
@pytest.mark.parametrize("stop_batch", (0, 1, 2))
@pytest.mark.parametrize("n_optimizers,stop_optimizer", [(2, 0), (2, 1), (3, 2)])
def test_loop_restart_progress_multiple_optimizers(tmpdir, n_optimizers, stop_optimizer, stop_epoch, stop_batch):
"""Test that Lightning can resume from a point where a training_step failed while in the middle of processing
several optimizer steps for one batch.

The test asserts that we end up with the same trained weights as if no failure occured.
"""

n_batches = 3
n_epochs = 2

def _assert_optimizer_sequence(method_mock, expected):
positional_args = [c[0] for c in method_mock.call_args_list]
sequence = [arg[3] for arg in positional_args]
assert sequence == expected

num_optimizers_incomplete = stop_epoch * n_batches * n_optimizers + stop_batch * n_optimizers + stop_optimizer

opt_idx_sequence_complete = list(range(n_optimizers)) * n_epochs * n_batches # [0, 1, 2, 0, 1, 2, 0, 1, ...]
# +1 because we fail inside the closure inside optimizer_step()
opt_idx_sequence_incomplete = opt_idx_sequence_complete[: (num_optimizers_incomplete + 1)]
opt_idx_sequence_resumed = opt_idx_sequence_complete[num_optimizers_incomplete:]

class MultipleOptimizerModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx):
if (
fail
and self.current_epoch == stop_epoch
and batch_idx == stop_batch
and optimizer_idx == stop_optimizer
):
raise CustomException
return super().training_step(batch, batch_idx)

def configure_optimizers(self):
return [torch.optim.SGD(self.parameters(), lr=0.1) for _ in range(n_optimizers)]

# run without a failure, collect weights
fail = False
seed_everything(0)
model = MultipleOptimizerModel()
model.training_epoch_end = None
model.optimizer_step = Mock(wraps=model.optimizer_step)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
num_sanity_val_steps=0,
logger=False,
checkpoint_callback=False,
)
trainer.fit(model)
weights_complete = model.parameters()
_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_complete)

# simulate a failure
fail = True
seed_everything(0)
model = MultipleOptimizerModel()
model.training_epoch_end = None
model.optimizer_step = Mock(wraps=model.optimizer_step)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
num_sanity_val_steps=0,
logger=False,
checkpoint_callback=False,
)
with pytest.raises(CustomException):
trainer.fit(model)

_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_incomplete)

# resume from failure and collect weights
fail = False
seed_everything(0)
model = MultipleOptimizerModel()
model.training_epoch_end = None
model.optimizer_step = Mock(wraps=model.optimizer_step)
trainer = Trainer(
resume_from_checkpoint=str(tmpdir / ".pl_auto_save.ckpt"),
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
num_sanity_val_steps=0,
logger=False,
checkpoint_callback=False,
)
trainer.fit(model)
weights_resumed = model.parameters()

# check that the final weights of a resumed run match the weights of a run that never failed
for w0, w1 in zip(weights_complete, weights_resumed):
assert torch.allclose(w0, w1)

_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_resumed)
39 changes: 25 additions & 14 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,11 @@ def on_load_checkpoint(self, state_dict: Dict) -> None:
assert state_dict == {"state_dict": {"a": 1}, "progress": {"increment": 1}}


@RunIf(min_torch="1.7.0")
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize("stop_epoch", (1, 2))
@pytest.mark.parametrize("stop_batch", (1, 2))
@pytest.mark.parametrize("n_dataloaders,stop_dataloader", [(2, 0), (2, 1), (3, 2)])
@RunIf(min_torch="1.7.0")
def test_loop_restart_progress_multiple_dataloaders(tmpdir, n_dataloaders, stop_dataloader, stop_epoch, stop_batch):
n_batches = 5
n_epochs = 3
Expand Down Expand Up @@ -284,10 +284,8 @@ def val_dataloader(self):
)

# simulate a failure
try:
with pytest.raises(CustomException):
trainer.fit(model)
except CustomException:
pass

ckpt_path = str(tmpdir / ".pl_auto_save.ckpt")
checkpoint = torch.load(ckpt_path)["loops"]["fit_loop"]
Expand Down Expand Up @@ -334,13 +332,13 @@ def val_dataloader(self):
assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected


@RunIf(min_torch="1.7.0")
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize("accumulate_grad_batches", (1, 2, 3))
@pytest.mark.parametrize("n_optimizers", (1, 3, 5))
@pytest.mark.parametrize("stop_epoch", (1, 2))
@pytest.mark.parametrize("stop_batch", (1, 2))
@pytest.mark.parametrize("stop_optimizer", (1, 2))
@RunIf(min_torch="1.7.0")
def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch, stop_optimizer, n_optimizers, tmpdir):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
stop_optimizer = stop_optimizer if stop_optimizer < n_optimizers else 0
n_epochs = 3
Expand Down Expand Up @@ -382,10 +380,8 @@ def configure_optimizers_multiple(self):
)

# simulate a failure
try:
with pytest.raises(CustomException):
trainer.fit(model)
except CustomException:
pass

ckpt_path = str(tmpdir / ".pl_auto_save.ckpt")
assert os.path.exists(ckpt_path)
Expand Down Expand Up @@ -506,19 +502,34 @@ def configure_optimizers_multiple(self):
# need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the
# fit loop to have an iterator, which is only available during training
checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = ANY

assert state_dict == checkpoint["loops"]["fit_loop"]

trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
# with `restart_progress=True`, we expect all `ready` counters to be reset to `completed`
tchaton marked this conversation as resolved.
Show resolved Hide resolved
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=True)

epoch_progress = trainer.fit_loop.epoch_progress
assert epoch_progress.current.ready == stop_epoch
assert epoch_progress.current.completed == stop_epoch

batch_progress = trainer.fit_loop.epoch_loop.batch_progress
assert batch_progress.current.ready == be_batches_completed
assert batch_progress.current.completed == be_batches_completed

optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress
assert optim_progress.optimizer.step.current.ready == be_total_opt_steps
assert optim_progress.optimizer.step.current.completed == be_total_opt_steps
assert optim_progress.optimizer.zero_grad.current.ready == be_total_zero_grad
assert optim_progress.optimizer.zero_grad.current.completed == be_total_zero_grad

state_dict = trainer.fit_loop.state_dict()
assert state_dict != checkpoint["loops"]["fit_loop"]
assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1
assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch


@RunIf(min_torch="1.7.0")
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize("n_optimizers", (1, 3, 5))
@RunIf(min_torch="1.7.0")
def test_loop_state_on_complete_run(n_optimizers, tmpdir):
n_epochs = 3
n_batches = 3
Expand Down Expand Up @@ -647,8 +658,8 @@ def configure_optimizers_multiple(self):
assert checkpoint["loops"]["fit_loop"] == expected


@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@RunIf(min_torch="1.7.0")
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_fit_loop_reset(tmpdir):
"""Test that the reset logic in fit- and epoch loop is aware of whether the loop is restarting from a completed
loop or from a mid-epoch checkpoint."""
Expand Down Expand Up @@ -702,7 +713,7 @@ def mid_epoch_reset_assertions():
epoch_loop.reset()
optimizer_loop.reset()
mid_epoch_reset_assertions()
assert optimizer_loop.optim_progress.optimizer_position == 0
assert optimizer_loop.optim_progress.optimizer_position == 1

# reset state loaded from a checkpoint from the end of an epoch
end_of_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=3.ckpt"))
Expand Down Expand Up @@ -745,4 +756,4 @@ def mid_epoch_reset_assertions():
assert epoch_loop.batch_progress.current.ready == 0
assert epoch_loop.batch_progress.current.completed == 0

assert optimizer_loop.optim_progress.optimizer_position == 0
assert optimizer_loop.optim_progress.optimizer_position == 1