Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Teardown sync-batchnorm after training #11078

Merged
merged 12 commits into from
Dec 16, 2021
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed duplicated file extension when uploading model checkpoints with `NeptuneLogger` ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/pull/11015))


- The `DDPPlugin` and `DDPSpawnPlugin` and their subclasses now remove the `SyncBatchNorm` wrappers in `teardown()` to enable proper support at inference after fitting ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078))


### Deprecated

Expand Down Expand Up @@ -277,7 +279,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed support for logging within callbacks returned from `LightningModule` ([#10991](https://github.com/PyTorchLightning/pytorch-lightning/pull/10991))


-
- Fixed an issue when torch-scripting a `LightningModule` after training with `Trainer(sync_batchnorm=True)` ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078))


-
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
_TORCH_GREATER_EQUAL_1_10,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import (
init_dist_connection,
Expand Down Expand Up @@ -501,6 +501,9 @@ def teardown(self) -> None:
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

if self.sync_batchnorm:
self.model = _revert_sync_batchnorm(self.model)

if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pytorch_lightning.trainer.states import TrainerFn, TrainerState
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import (
init_dist_connection,
Expand Down Expand Up @@ -376,6 +376,9 @@ def teardown(self) -> None:
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

if self.sync_batchnorm:
self.model = _revert_sync_batchnorm(self.model)

if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
Expand Down
34 changes: 34 additions & 0 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,37 @@ def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]:
if not distributed_available():
return {0: state}
return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())}


class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
def _check_input_dim(self, input):
# The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
# is this method that is overwritten by the subclass.
# Here, we are bypassing some tensor sanity checks and trusting that the user
# provides the right input dimensions at inference.
return
justusschock marked this conversation as resolved.
Show resolved Hide resolved


def _revert_sync_batchnorm(module):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
# Code adapted from https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547
# Original author: Kapil Yedidi (@kapily)
converted_module = module
if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm):
# Unfortunately, SyncBatchNorm does not store the original class - if it did
# we could return the one that was originally created.
converted_module = _BatchNormXd(
module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats
)
if module.affine:
with torch.no_grad():
converted_module.weight = module.weight
converted_module.bias = module.bias
converted_module.running_mean = module.running_mean
converted_module.running_var = module.running_var
converted_module.num_batches_tracked = module.num_batches_tracked
if hasattr(module, "qconfig"):
converted_module.qconfig = module.qconfig
for name, child in module.named_children():
converted_module.add_module(name, _revert_sync_batchnorm(child))
del module
return converted_module
7 changes: 6 additions & 1 deletion tests/models/test_sync_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(self, gpu_count=1, **kwargs):
self.linear = nn.Linear(28 * 28, 10)
self.bn_layer = nn.BatchNorm1d(28 * 28)

def on_train_start(self) -> None:
assert isinstance(self.bn_layer, torch.nn.modules.batchnorm.SyncBatchNorm)

def forward(self, x, batch_idx):
with torch.no_grad():
out_bn = self.bn_layer(x.view(x.size(0), -1))
Expand Down Expand Up @@ -123,4 +126,6 @@ def test_sync_batchnorm_ddp(tmpdir):
)

trainer.fit(model, dm)
assert trainer.state.finished, "Sync batchnorm failing with DDP"
# the strategy is responsible for tearing down the batchnorm wrappers
assert not isinstance(model.bn_layer, torch.nn.modules.batchnorm.SyncBatchNorm)
assert isinstance(model.bn_layer, torch.nn.modules.batchnorm._BatchNorm)