Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Prevent deepcopy of dataloaders / Trainer in SWA Callback #8472

Merged
merged 16 commits into from
Jul 20, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed clearing dataloader references before attaching new dataloaders in consecutive `Trainer.{fit,validate,test,predict}´ runs ([#8442](https://github.com/PyTorchLightning/pytorch-lightning/pull/8442))


- Fixed prevent `deepcopy` of dataloaders and trainer in SWA callback ([#8472](https://github.com/PyTorchLightning/pytorch-lightning/pull/8472))


## [1.3.8] - 2021-07-01

### Fixed
Expand Down
26 changes: 25 additions & 1 deletion pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Stochastic Weight Averaging Callback
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""
from contextlib import contextmanager
from copy import deepcopy
from typing import Callable, Optional, Union

Expand Down Expand Up @@ -137,9 +138,32 @@ def swa_end(self) -> int:
def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'):
return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())

@contextmanager
def _prevent_data_deepcopy(self, pl_module: 'pl.LightningModule'):
"""
This function is used to prevent deepcopy of dataloaders / trainer while deepcopying the ``LightningModule``.
"""
train_dataloader = pl_module.train_dataloader
val_dataloader = pl_module.val_dataloader
test_dataloader = pl_module.test_dataloader
predict_dataloader = pl_module.predict_dataloader
trainer = pl_module.trainer
pl_module.train_dataloader = None
pl_module.val_dataloader = None
pl_module.test_dataloader = None
pl_module.predict_dataloader = None
pl_module.trainer = None
yield
pl_module.train_dataloader = train_dataloader
pl_module.val_dataloader = val_dataloader
pl_module.test_dataloader = test_dataloader
pl_module.predict_dataloader = predict_dataloader
pl_module.trainer = trainer

def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
# copy the model before moving it to accelerator device.
self._average_model = deepcopy(pl_module)
with self._prevent_data_deepcopy(pl_module):
self._average_model = deepcopy(pl_module)

def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
optimizers = trainer.optimizers
Expand Down
36 changes: 35 additions & 1 deletion tests/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import nn
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6
Expand Down Expand Up @@ -217,3 +217,37 @@ def configure_optimizers(self):
assert trainer.callbacks[0]._swa_lrs == (1e-3 if use_callbacks else 0.1)
else:
assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks)


def test_trainer_stochastic_weight_averaging_deepcopy(tmpdir):
"""Test to ensure SWA Callback doesn't deecopy dataloaders and datamodule potentially leading to OOM"""

train_dataloader = DataLoader(RandomDataset(32, 64))

class TestModel(BoringModel):

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer

class StochasticWeightAveragingCheck(StochasticWeightAveraging):

def on_before_accelerator_backend_setup(self, trainer: 'Trainer', pl_module: 'LightningModule'):
super().on_before_accelerator_backend_setup(trainer, pl_module)
assert self._average_model.train_dataloader is None
assert self._average_model.val_dataloader is None
assert self._average_model.test_dataloader is None
assert self._average_model.predict_dataloader is None
assert self._average_model.trainer is None
assert pl_module.train_dataloader is not None
assert pl_module.trainer is not None

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=StochasticWeightAveragingCheck(swa_lrs=1e-3),
limit_train_batches=4,
limit_val_batches=4,
max_epochs=2,
)
trainer.fit(model, train_dataloader=train_dataloader)