Skip to content

Commit

Permalink
Revert "Update setup logic in training type plugins (sharded) [4 / 4] (
Browse files Browse the repository at this point in the history
…#10028)"

This reverts commit 4ea72a9.
  • Loading branch information
awaelchli committed Oct 22, 2021
1 parent 94e2bf5 commit 01bbae9
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 82 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
* Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_models_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028))
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))


Expand Down
63 changes: 19 additions & 44 deletions pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Dict, Generator, List, Optional, Tuple, Union
from typing import Dict, Generator, Optional

import torch
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
Expand All @@ -35,70 +33,47 @@
class DDPShardedPlugin(DDPPlugin):
"""Optimizer and gradient sharded training provided by FairScale."""

_REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._precision = None
_REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M

def configure_ddp(self) -> None:
trainer = self.lightning_module.trainer
self._wrap_optimizers()

if "reduce_buffer_size" not in self._ddp_kwargs:
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0

[self._model], optimizers = self._setup_models_and_optimizers(
models=[LightningShardedDataParallel(self.model)],
optimizers=trainer.optimizers,
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model),
sharded_optimizer=self.lightning_module.trainer.optimizers,
**self._ddp_kwargs
)
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()

def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.
setattr(self._model, "require_backward_grad_sync", False)

Currently only one model can be setup at once.
Return:
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
if len(models) > 1:
raise ValueError(
"DDPSharded only supports setting up a single model with one or several optimizers."
f" Got {len(models)} models."
)

optimizers = self._wrap_optimizers(optimizers)
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
setattr(model, "require_backward_grad_sync", False) # TODO: needed?
return [model], optimizers

def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
def _reinit_optimizers_with_oss(self):
optimizers = self.lightning_module.trainer.optimizers
for x, optimizer in enumerate(optimizers):
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
precision = self._precision or self.lightning_module.trainer.precision
precision = self.lightning_module.trainer.precision
is_fp16 = precision in ("mixed", 16)
# For multi-node training, compressing the model shards in fp16 before broadcasting
# improves performance. When using PyTorch AMP, it will not degrade
# the model performance.
zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
optimizers[x] = zero_optimizer
del optimizer
return optimizers

def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
return optimizers
trainer = self.lightning_module.trainer
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()

return self._reinit_optimizers_with_oss(optimizers)
def _wrap_optimizers(self):
if self.model.trainer.state.fn != TrainerFn.FITTING:
return
self._reinit_optimizers_with_oss()

def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if isinstance(optimizer, LightningOptimizer):
Expand Down
52 changes: 15 additions & 37 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
# limitations under the License.
from contextlib import contextmanager
from multiprocessing.queues import SimpleQueue
from typing import Dict, Generator, List, Optional, Tuple
from typing import Dict, Generator, Optional

import torch
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
Expand All @@ -38,49 +36,29 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
"""Optimizer sharded training provided by FairScale."""

def configure_ddp(self) -> None:
trainer = self.lightning_module.trainer
[self._model], optimizers = self._setup_models_and_optimizers(
models=[LightningShardedDataParallel(self.model)],
optimizers=trainer.optimizers,
self._wrap_optimizers()
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model),
sharded_optimizer=self.lightning_module.trainer.optimizers,
**self._ddp_kwargs
)
trainer.optimizers = optimizers

def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.
setattr(self._model, "require_backward_grad_sync", False)

Currently only one model can be setup at once.
Return:
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
if len(models) > 1:
raise ValueError(
f"DDPShardedSpawn only supports setting up a single model with one or several optimizers."
f" Got {len(models)} models."
)

optimizers = self._wrap_optimizers(optimizers)
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
setattr(model, "require_backward_grad_sync", False) # TODO: needed?
return [model], optimizers

def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
def _reinit_optimizers_with_oss(self):
optimizers = self.lightning_module.trainer.optimizers
for x, optimizer in enumerate(optimizers):
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
optimizers[x] = zero_optimizer
del optimizer
return optimizers

def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
return optimizers
trainer = self.lightning_module.trainer
trainer.optimizers = optimizers

return self._reinit_optimizers_with_oss(optimizers)
def _wrap_optimizers(self):
if self.model.trainer.state.fn != TrainerFn.FITTING:
return
self._reinit_optimizers_with_oss()

def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if isinstance(optimizer, OSS):
Expand Down

0 comments on commit 01bbae9

Please sign in to comment.