Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate DistributedType in favor of StrategyType #10505

Merged
merged 14 commits into from
Nov 15, 2021
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))


-
- Deprecated `DistributedType` in favor of `_StrategyType` ([#10505](https://github.com/PyTorchLightning/pytorch-lightning/pull/10505))


-
Expand Down
16 changes: 8 additions & 8 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.utilities import DeviceType, DistributedType, move_data_to_device
from pytorch_lightning.utilities import _StrategyType, DeviceType, move_data_to_device
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.data import has_iterable_dataset
from pytorch_lightning.utilities.device_parser import _parse_devices
Expand Down Expand Up @@ -477,14 +477,14 @@ def _supported_device_types() -> Sequence[DeviceType]:
)

@staticmethod
def _supported_strategy_types() -> Sequence[DistributedType]:
def _supported_strategy_types() -> Sequence[_StrategyType]:
return (
DistributedType.DP,
DistributedType.DDP,
DistributedType.DDP_SPAWN,
DistributedType.DEEPSPEED,
DistributedType.DDP_SHARDED,
DistributedType.DDP_SHARDED_SPAWN,
_StrategyType.DP,
_StrategyType.DDP,
_StrategyType.DDP_SPAWN,
_StrategyType.DEEPSPEED,
_StrategyType.DDP_SHARDED,
_StrategyType.DDP_SHARDED_SPAWN,
)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.enums import DistributedType
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand All @@ -81,7 +81,7 @@ class DDPPlugin(ParallelPlugin):
devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.launch` launches processes.
"""

distributed_backend = DistributedType.DDP
distributed_backend = _StrategyType.DDP

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@

from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.enums import DistributedType
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.types import _METRIC_COLLECTION


class DDP2Plugin(DDPPlugin):
"""DDP2 behaves like DP in one node, but synchronization across nodes behaves like in DDP."""

distributed_backend = DistributedType.DDP2
distributed_backend = _StrategyType.DDP2

@property
def global_rank(self) -> int:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.enums import DistributedType
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand All @@ -59,7 +59,7 @@ class DDPSpawnPlugin(ParallelPlugin):
"""Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training
finishes."""

distributed_backend = DistributedType.DDP_SPAWN
distributed_backend = _StrategyType.DDP_SPAWN

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from pytorch_lightning.utilities import AMPType, GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.enums import DistributedType
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -82,7 +82,7 @@ def _move_float_tensors_to_half(self, batch: Any):


class DeepSpeedPlugin(DDPPlugin):
distributed_backend = DistributedType.DEEPSPEED
distributed_backend = _StrategyType.DEEPSPEED
DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH"

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.enums import DistributedType
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _METRIC_COLLECTION

Expand All @@ -29,7 +29,7 @@ class DataParallelPlugin(ParallelPlugin):
"""Implements data-parallel training in a single process, i.e., the model gets replicated to each device and
each gets a split of the data."""

distributed_backend = DistributedType.DP
distributed_backend = _StrategyType.DP

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
from pytorch_lightning.utilities.enums import DistributedType
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
Expand All @@ -30,7 +30,7 @@

class DDPFullyShardedPlugin(DDPPlugin):

distributed_backend = DistributedType.DDP_FULLY_SHARDED
distributed_backend = _StrategyType.DDP_FULLY_SHARDED

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as dist_group
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.enums import DistributedType
from pytorch_lightning.utilities.enums import _StrategyType

if _HOROVOD_AVAILABLE:
import horovod.torch as hvd
Expand All @@ -35,7 +35,7 @@
class HorovodPlugin(ParallelPlugin):
"""Plugin for Horovod distributed training integration."""

distributed_backend = DistributedType.HOROVOD
distributed_backend = _StrategyType.HOROVOD

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.enums import DistributedType
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_AVAILABLE:
Expand All @@ -36,7 +36,7 @@
class DDPShardedPlugin(DDPPlugin):
"""Optimizer and gradient sharded training provided by FairScale."""

distributed_backend = DistributedType.DDP_SHARDED
distributed_backend = _StrategyType.DDP_SHARDED
_REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M

def __init__(self, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.enums import DistributedType
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_AVAILABLE:
Expand All @@ -38,7 +38,7 @@
class DDPSpawnShardedPlugin(DDPSpawnPlugin):
"""Optimizer sharded training provided by FairScale."""

distributed_backend = DistributedType.DDP_SHARDED_SPAWN
distributed_backend = _StrategyType.DDP_SHARDED_SPAWN

def configure_ddp(self) -> None:
trainer = self.lightning_module.trainer
Expand Down
Loading