Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only import PostLocalSGD related modules when it's needed #10359

Merged
merged 3 commits into from
Nov 8, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,14 @@
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _TORCH_GREATER_EQUAL_1_10:
if not _IS_WINDOWS:
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
if _TORCH_GREATER_EQUAL_1_10:
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
import torch.distributed.algorithms.model_averaging.averagers as averagers


log = logging.getLogger(__name__)

Expand Down Expand Up @@ -324,19 +317,24 @@ def _register_ddp_hooks(self) -> None:
ddp_comm_wrapper=self._ddp_comm_wrapper,
)

if (
_TORCH_GREATER_EQUAL_1_10
and isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState)
and self.lightning_module.trainer.state.fn == TrainerFn.FITTING
):
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
if _TORCH_GREATER_EQUAL_1_10 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD

if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)

def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
optimizers = self.lightning_module.trainer.optimizers
if self._model_averaging_period is None:
raise ValueError(
"Post-localSGD algorithm is used, but model averaging period is not provided to DDP plugin."
)
if _TORCH_GREATER_EQUAL_1_10:
if not _IS_WINDOWS:
from torch.distributed.optim import DistributedOptimizer
import torch.distributed.algorithms.model_averaging.averagers as averagers
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer

averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps)
for x, optimizer in enumerate(optimizers):
if isinstance(optimizer, LightningOptimizer):
Expand Down