Skip to content

Commit

Permalink
Teardown sync-batchnorm after training (#11078)
Browse files Browse the repository at this point in the history

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ananthsub <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
4 people authored Dec 16, 2021
1 parent 46d6fbf commit 2b0075a
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 7 deletions.
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed duplicated file extension when uploading model checkpoints with `NeptuneLogger` ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/pull/11015))


- Moved ownership of the `Accelerator` instance to the `TrainingTypePlugin`; all training-type plugins now take an optional parameter `accelerator` ([#11022](https://github.com/PyTorchLightning/pytorch-lightning/pull/11022))
- The `DDPPlugin` and `DDPSpawnPlugin` and their subclasses now remove the `SyncBatchNorm` wrappers in `teardown()` to enable proper support at inference after fitting ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078))


- Moved ownership of the `Accelerator` instance to the `TrainingTypePlugin`; all training-type plugins now take an optional parameter `accelerator` ([#11022](https://github.com/PyTorchLightning/pytorch-lightning/pull/11022))

### Deprecated

Expand Down Expand Up @@ -276,13 +278,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
-


-
- Fixed an issue when torch-scripting a `LightningModule` after training with `Trainer(sync_batchnorm=True)` ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078))


-



## [1.5.6] - 2021-12-15

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def replace(self, **loops: Union["Loop", Type["Loop"]]) -> None:
)
# instantiate the loop
kwargs = {p: getattr(old_loop, p) for p in old_parameters if p != "self"}
loop = type_or_object(**kwargs) # type: ignore[call-arg]
loop = type_or_object(**kwargs)
else:
loop = type_or_object

Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
_TORCH_GREATER_EQUAL_1_10,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import (
init_dist_connection,
Expand Down Expand Up @@ -504,6 +504,9 @@ def teardown(self) -> None:
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

if self.sync_batchnorm:
self.model = _revert_sync_batchnorm(self.model)

if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pytorch_lightning.trainer.states import TrainerFn, TrainerState
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import (
init_dist_connection,
Expand Down Expand Up @@ -378,6 +378,9 @@ def teardown(self) -> None:
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

if self.sync_batchnorm:
self.model = _revert_sync_batchnorm(self.model)

if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
Expand Down
35 changes: 35 additions & 0 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel

import pytorch_lightning as pl
Expand Down Expand Up @@ -399,3 +400,37 @@ def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]:
if not distributed_available():
return {0: state}
return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())}


class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
def _check_input_dim(self, input: torch.Tensor) -> None:
# The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
# is this method that is overwritten by the subclass.
# Here, we are bypassing some tensor sanity checks and trusting that the user
# provides the right input dimensions at inference.
return


def _revert_sync_batchnorm(module: Module) -> Module:
# Code adapted from https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547
# Original author: Kapil Yedidi (@kapily)
converted_module = module
if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm):
# Unfortunately, SyncBatchNorm does not store the original class - if it did
# we could return the one that was originally created.
converted_module = _BatchNormXd(
module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats
)
if module.affine:
with torch.no_grad():
converted_module.weight = module.weight
converted_module.bias = module.bias
converted_module.running_mean = module.running_mean
converted_module.running_var = module.running_var
converted_module.num_batches_tracked = module.num_batches_tracked
if hasattr(module, "qconfig"):
converted_module.qconfig = module.qconfig
for name, child in module.named_children():
converted_module.add_module(name, _revert_sync_batchnorm(child))
del module
return converted_module
7 changes: 6 additions & 1 deletion tests/models/test_sync_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(self, gpu_count=1, **kwargs):
self.linear = nn.Linear(28 * 28, 10)
self.bn_layer = nn.BatchNorm1d(28 * 28)

def on_train_start(self) -> None:
assert isinstance(self.bn_layer, torch.nn.modules.batchnorm.SyncBatchNorm)

def forward(self, x, batch_idx):
with torch.no_grad():
out_bn = self.bn_layer(x.view(x.size(0), -1))
Expand Down Expand Up @@ -123,4 +126,6 @@ def test_sync_batchnorm_ddp(tmpdir):
)

trainer.fit(model, dm)
assert trainer.state.finished, "Sync batchnorm failing with DDP"
# the strategy is responsible for tearing down the batchnorm wrappers
assert not isinstance(model.bn_layer, torch.nn.modules.batchnorm.SyncBatchNorm)
assert isinstance(model.bn_layer, torch.nn.modules.batchnorm._BatchNorm)

0 comments on commit 2b0075a

Please sign in to comment.