Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move block_ddp_sync_behaviour to utilities #9192

Merged
merged 8 commits into from
Aug 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed `rank_zero_warn` to `NotImplementedError` in the `{train, val, test, predict}_dataloader` hooks that `Lightning(Data)Module` uses ([#9161](https://github.com/PyTorchLightning/pytorch-lightning/pull/9161))


- Moved `block_ddp_sync_behaviour` out of `TrainingBatchLoop` to loop utilities ([#9192](https://github.com/PyTorchLightning/pytorch-lightning/pull/9192))


### Deprecated

- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def toggle_model(self, sync_grad: bool = True):
during the accumulation phase.
Setting `sync_grad` to False will block this synchronization and improve performance.
"""
with self._trainer.fit_loop.epoch_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad):
# local import here to avoid circular import
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior

with _block_parallel_sync_behavior(self._trainer, block=(not sync_grad)):
self._toggle_model()
yield
self._untoggle_model()
Expand Down
32 changes: 4 additions & 28 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
# limitations under the License.

from collections import OrderedDict
from contextlib import contextmanager
from copy import copy
from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np
import torch
Expand All @@ -28,11 +27,11 @@
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.closure import Closure, ClosureResult
from pytorch_lightning.loops.utilities import (
_block_parallel_sync_behavior,
_check_training_step_output,
_process_training_step_output,
check_finite_loss,
)
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
Expand Down Expand Up @@ -186,9 +185,8 @@ def _run_optimization(
# -------------------
# calculate loss (train step + train step end)
# -------------------
# automatic_optimization=True: perform ddp sync only when performing optimizer_step
# automatic_optimization=False: don't block synchronization here
with self.block_ddp_sync_behaviour():
# automatic_optimization: perform ddp sync only when performing optimizer_step
with _block_parallel_sync_behavior(self._trainer):
closure()

# ------------------------------
Expand Down Expand Up @@ -460,28 +458,6 @@ def _run_optimization_end(self, opt_idx: int) -> None:
model = self.trainer.lightning_module
model.untoggle_optimizer(opt_idx)

@contextmanager
def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator[None, None, None]:
"""
automatic_optimization = True
Blocks ddp sync gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead

automatic_optimization = False
do not block ddp gradient sync when using manual optimization
as gradients are needed within the training step

Returns:
context manager with sync behaviour off
"""
if isinstance(self.trainer.training_type_plugin, ParallelPlugin) and (
self.trainer.lightning_module.automatic_optimization or should_block_sync
):
with self.trainer.training_type_plugin.block_backward_sync():
yield None
else:
yield None

def backward(
self,
loss: Tensor,
Expand Down
25 changes: 23 additions & 2 deletions pytorch_lightning/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Iterator, Mapping, Optional, Tuple
from contextlib import contextmanager
from typing import Any, Generator, Iterator, Mapping, Optional, Tuple

import torch

import pytorch_lightning as pl
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -113,3 +114,23 @@ def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int)
else:
dataloader_iter = iter(data_fetcher)
return dataloader_iter


@contextmanager
def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) -> Generator[None, None, None]:
"""
Blocks synchronization in :class:`~pytorch_lightning.plugins.training_type.parallel.ParallelPlugin`.
This is useful for example when when accumulating gradients to reduce communication when it is not needed.

Args:
trainer: the trainer instance with a reference to a training type plugin
block: whether the context manager is enabled or not

Returns:
context manager with sync behaviour off
"""
if isinstance(trainer.training_type_plugin, ParallelPlugin) and block:
with trainer.training_type_plugin.block_backward_sync():
yield None
else:
yield None