Skip to content

Commit

Permalink
Rewrite accelerator_connector (#11448)
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish authored Feb 17, 2022
1 parent a0ca8d0 commit 6e14209
Show file tree
Hide file tree
Showing 32 changed files with 952 additions and 1,042 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ module = [
"pytorch_lightning.profiler.pytorch",
"pytorch_lightning.profiler.simple",
"pytorch_lightning.trainer.callback_hook",
"pytorch_lightning.trainer.connectors.accelerator_connector",
"pytorch_lightning.trainer.connectors.callback_connector",
"pytorch_lightning.trainer.connectors.data_connector",
"pytorch_lightning.trainer.data_loading",
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import _AcceleratorType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only
Expand Down Expand Up @@ -127,7 +126,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
if not trainer.logger:
raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.")

if trainer._device_type != _AcceleratorType.GPU:
if trainer.strategy.root_device.type != "cuda":
raise MisconfigurationException(
"You are using GPUStatsMonitor but are not running on GPU"
f" since gpus attribute in Trainer is set to {trainer.gpus}."
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
self._check_strategy_support(strategy)
gpu_ids, tpu_cores = _parse_devices(gpus=gpus, auto_select_gpus=False, tpu_cores=tpu_cores)
self._accelerator_connector = AcceleratorConnector(
num_processes=1,
num_processes=None,
devices=devices,
tpu_cores=tpu_cores,
ipus=None,
Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/strategies/bagua.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _BAGUA_AVAILABLE
from pytorch_lightning.utilities.seed import reset_seed
Expand Down Expand Up @@ -58,7 +57,7 @@ def __init__(self, pl_module: "pl.LightningModule") -> None:


class BaguaStrategy(DDPStrategy):
distributed_backend = _StrategyType.BAGUA
strategy_name = "bagua"

def __init__(
self,
Expand Down Expand Up @@ -180,8 +179,12 @@ def _setup_model(self, model: Module) -> BaguaDistributedDataParallel:
)

@classmethod
def register_plugins(cls, plugin_registry: Dict) -> None:
plugin_registry.register("bagua", cls, description="Default Bagua Plugin")
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)

def teardown(self) -> None:
# abort the background communication for async algorithm
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.seed import reset_seed
Expand All @@ -63,7 +62,7 @@
class DDPStrategy(ParallelStrategy):
"""Strategy for multi-process single-device training on one or multiple nodes."""

distributed_backend = _StrategyType.DDP
strategy_name = "ddp"

def __init__(
self,
Expand Down Expand Up @@ -96,7 +95,6 @@ def __init__(
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self._rank_0_will_call_children_scripts: bool = False
self.set_world_ranks()

@property
def is_distributed(self) -> bool:
Expand All @@ -114,7 +112,6 @@ def num_nodes(self) -> int:
def num_nodes(self, num_nodes: int) -> None:
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
self._num_nodes = num_nodes
self.set_world_ranks()

@property
def num_processes(self):
Expand Down Expand Up @@ -346,6 +343,11 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
description="DDP Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
)
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)

def _should_run_deadlock_detection(self) -> bool:
"""Determines whether the plugin will perform process reconciliation in case of errors.
Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/strategies/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict

import torch

from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.types import _METRIC_COLLECTION


class DDP2Strategy(DDPStrategy):
"""DDP2 behaves like DP in one node, but synchronization across nodes behaves like in DDP."""

distributed_backend = _StrategyType.DDP2
strategy_name = "ddp2"

@property
def global_rank(self) -> int:
Expand Down Expand Up @@ -73,3 +74,11 @@ def set_world_ranks(self) -> None:
return
self.cluster_environment.set_global_rank(self.node_rank)
self.cluster_environment.set_world_size(self.num_nodes)

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)
10 changes: 6 additions & 4 deletions pytorch_lightning/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand All @@ -48,7 +47,7 @@ class DDPSpawnStrategy(ParallelStrategy):
"""Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training
finishes."""

distributed_backend = _StrategyType.DDP_SPAWN
strategy_name = "ddp_spawn"

def __init__(
self,
Expand Down Expand Up @@ -76,7 +75,6 @@ def __init__(
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
self._local_rank = 0
self.set_world_ranks()

@property
def num_nodes(self) -> int:
Expand All @@ -86,7 +84,6 @@ def num_nodes(self) -> int:
def num_nodes(self, num_nodes: int) -> None:
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
self._num_nodes = num_nodes
self.set_world_ranks()

@property
def local_rank(self) -> int:
Expand Down Expand Up @@ -264,6 +261,11 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
description="DDPSpawn Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
)
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)

def teardown(self) -> None:
super().teardown()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import log
from pytorch_lightning.utilities.enums import _StrategyType, AMPType, PrecisionType
from pytorch_lightning.utilities.enums import AMPType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -82,7 +82,7 @@ def _move_float_tensors_to_half(self, batch: Any):


class DeepSpeedStrategy(DDPStrategy):
distributed_backend = _StrategyType.DEEPSPEED
strategy_name = "deepspeed"
DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH"

def __init__(
Expand Down
13 changes: 10 additions & 3 deletions pytorch_lightning/strategies/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional

import torch
from torch.nn import DataParallel, Module
Expand All @@ -22,7 +22,6 @@
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, STEP_OUTPUT

Expand All @@ -31,7 +30,7 @@ class DataParallelStrategy(ParallelStrategy):
"""Implements data-parallel training in a single process, i.e., the model gets replicated to each device and
each gets a split of the data."""

distributed_backend = _StrategyType.DP
strategy_name = "dp"

def __init__(
self,
Expand Down Expand Up @@ -149,6 +148,14 @@ def training_step_end(self, output):

return output

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)

def teardown(self) -> None:
super().teardown()
if self.root_device.type == "cuda":
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/strategies/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT

Expand All @@ -36,7 +36,7 @@

class DDPFullyShardedStrategy(DDPStrategy):

distributed_backend = _StrategyType.DDP_FULLY_SHARDED
strategy_name = "ddp_fully_sharded"

def __init__(
self,
Expand Down Expand Up @@ -212,3 +212,9 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
"fsdp", cls, description="Fully sharded training with checkpointing the full state dict."
)

strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)
13 changes: 10 additions & 3 deletions pytorch_lightning/strategies/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import ExitStack
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -26,7 +26,6 @@
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as dist_group
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only

Expand All @@ -37,7 +36,7 @@
class HorovodStrategy(ParallelStrategy):
"""Plugin for Horovod distributed training integration."""

distributed_backend = _StrategyType.HOROVOD
strategy_name = "horovod"

def __init__(
self,
Expand Down Expand Up @@ -196,6 +195,14 @@ def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tup
opt_params = {p for group in optimizer.param_groups for p in group.get("params", [])}
return [(name, p) for name, p in model.named_parameters() if p in opt_params]

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)

def teardown(self) -> None:
super().teardown()
# teardown may be called before `_exit_stack` is set
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import json
import os
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -62,6 +62,8 @@ def _move_float_tensors_to_half(self, batch: Any) -> Any:
class IPUStrategy(ParallelStrategy):
"""Plugin for training on IPU devices."""

strategy_name = "ipu_strategy"

def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
Expand Down Expand Up @@ -360,3 +362,11 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra

def broadcast(self, obj: object, src: int = 0) -> object:
return obj

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)
9 changes: 7 additions & 2 deletions pytorch_lightning/strategies/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only
Expand All @@ -37,7 +37,7 @@
class DDPShardedStrategy(DDPStrategy):
"""Optimizer and gradient sharded training provided by FairScale."""

distributed_backend = _StrategyType.DDP_SHARDED
strategy_name = "ddp_sharded"
_REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M

def configure_ddp(self) -> None:
Expand Down Expand Up @@ -135,3 +135,8 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
description="DDP Sharded Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
)
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)
Loading

0 comments on commit 6e14209

Please sign in to comment.