diff --git a/CHANGELOG.md b/CHANGELOG.md index 87ecfdef4d448..0bd409b73255c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,7 +44,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103)) -- +- Deprecated `DistributedType` in favor of `_StrategyType` ([#10505](https://github.com/PyTorchLightning/pytorch-lightning/pull/10505)) - diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index d36e874cbae7b..2a2ed9586b420 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -41,7 +41,7 @@ ) from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin -from pytorch_lightning.utilities import DeviceType, DistributedType, move_data_to_device +from pytorch_lightning.utilities import _StrategyType, DeviceType, move_data_to_device from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.data import has_iterable_dataset from pytorch_lightning.utilities.device_parser import _parse_devices @@ -477,14 +477,14 @@ def _supported_device_types() -> Sequence[DeviceType]: ) @staticmethod - def _supported_strategy_types() -> Sequence[DistributedType]: + def _supported_strategy_types() -> Sequence[_StrategyType]: return ( - DistributedType.DP, - DistributedType.DDP, - DistributedType.DDP_SPAWN, - DistributedType.DEEPSPEED, - DistributedType.DDP_SHARDED, - DistributedType.DDP_SHARDED_SPAWN, + _StrategyType.DP, + _StrategyType.DDP, + _StrategyType.DDP_SPAWN, + _StrategyType.DEEPSPEED, + _StrategyType.DDP_SHARDED, + _StrategyType.DDP_SHARDED_SPAWN, ) @staticmethod diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index c528be4c8bfef..cc0a109c14b63 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -57,7 +57,7 @@ ReduceOp, sync_ddp_if_available, ) -from pytorch_lightning.utilities.enums import DistributedType +from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -81,7 +81,7 @@ class DDPPlugin(ParallelPlugin): devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.launch` launches processes. """ - distributed_backend = DistributedType.DDP + distributed_backend = _StrategyType.DDP def __init__( self, diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index ef623a794da42..a142d518a0f2f 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -15,14 +15,14 @@ from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.enums import DistributedType +from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.types import _METRIC_COLLECTION class DDP2Plugin(DDPPlugin): """DDP2 behaves like DP in one node, but synchronization across nodes behaves like in DDP.""" - distributed_backend = DistributedType.DDP2 + distributed_backend = _StrategyType.DDP2 @property def global_rank(self) -> int: diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 926409925b9c7..6194787419754 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -44,7 +44,7 @@ ReduceOp, sync_ddp_if_available, ) -from pytorch_lightning.utilities.enums import DistributedType +from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -59,7 +59,7 @@ class DDPSpawnPlugin(ParallelPlugin): """Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training finishes.""" - distributed_backend = DistributedType.DDP_SPAWN + distributed_backend = _StrategyType.DDP_SPAWN def __init__( self, diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 2464a8ba4eeca..94235f361d945 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -36,7 +36,7 @@ from pytorch_lightning.utilities import AMPType, GradClipAlgorithmType from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only -from pytorch_lightning.utilities.enums import DistributedType +from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden @@ -82,7 +82,7 @@ def _move_float_tensors_to_half(self, batch: Any): class DeepSpeedPlugin(DDPPlugin): - distributed_backend = DistributedType.DEEPSPEED + distributed_backend = _StrategyType.DEEPSPEED DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH" def __init__( diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index a0f53791bc373..83328e8c47271 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -20,7 +20,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.enums import DistributedType +from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _METRIC_COLLECTION @@ -29,7 +29,7 @@ class DataParallelPlugin(ParallelPlugin): """Implements data-parallel training in a single process, i.e., the model gets replicated to each device and each gets a split of the data.""" - distributed_backend = DistributedType.DP + distributed_backend = _StrategyType.DP def __init__( self, diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 704afa1a91aaa..c9601a905df1c 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -20,7 +20,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE -from pytorch_lightning.utilities.enums import DistributedType +from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: @@ -30,7 +30,7 @@ class DDPFullyShardedPlugin(DDPPlugin): - distributed_backend = DistributedType.DDP_FULLY_SHARDED + distributed_backend = _StrategyType.DDP_FULLY_SHARDED def __init__( self, diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 30360e1ab458f..51558189a3d35 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -26,7 +26,7 @@ from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.distributed import group as dist_group from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp -from pytorch_lightning.utilities.enums import DistributedType +from pytorch_lightning.utilities.enums import _StrategyType if _HOROVOD_AVAILABLE: import horovod.torch as hvd @@ -35,7 +35,7 @@ class HorovodPlugin(ParallelPlugin): """Plugin for Horovod distributed training integration.""" - distributed_backend = DistributedType.HOROVOD + distributed_backend = _StrategyType.HOROVOD def __init__( self, diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 5955f3a46f38e..d7563437bd16b 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -23,7 +23,7 @@ from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only -from pytorch_lightning.utilities.enums import DistributedType +from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_AVAILABLE: @@ -36,7 +36,7 @@ class DDPShardedPlugin(DDPPlugin): """Optimizer and gradient sharded training provided by FairScale.""" - distributed_backend = DistributedType.DDP_SHARDED + distributed_backend = _StrategyType.DDP_SHARDED _REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M def __init__(self, *args, **kwargs): diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index e0ae5c7bba187..12e627edbe5cb 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -24,7 +24,7 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only -from pytorch_lightning.utilities.enums import DistributedType +from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_AVAILABLE: @@ -38,7 +38,7 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): """Optimizer sharded training provided by FairScale.""" - distributed_backend = DistributedType.DDP_SHARDED_SPAWN + distributed_backend = _StrategyType.DDP_SHARDED_SPAWN def configure_ddp(self) -> None: trainer = self.lightning_module.trainer diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index e15f7bb853db8..6be52b83e633f 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -61,10 +61,10 @@ TorchElasticEnvironment, ) from pytorch_lightning.utilities import ( + _StrategyType, AMPType, device_parser, DeviceType, - DistributedType, rank_zero_deprecation, rank_zero_info, rank_zero_warn, @@ -281,7 +281,7 @@ def _set_devices_if_none(self) -> None: self.devices = self.num_processes def _handle_accelerator_and_strategy(self) -> None: - deprecated_types = [t for t in DistributedType if t not in (DistributedType.TPU_SPAWN, DistributedType.DDP_CPU)] + deprecated_types = [t for t in _StrategyType if t not in (_StrategyType.TPU_SPAWN, _StrategyType.DDP_CPU)] if self.distributed_backend is not None and self.distributed_backend in deprecated_types: rank_zero_deprecation( f"Passing `Trainer(accelerator={self.distributed_backend!r})` has been deprecated" @@ -293,12 +293,12 @@ def _handle_accelerator_and_strategy(self) -> None: f" also passed `Trainer(accelerator={self.distributed_backend!r})`." f" HINT: Use just `Trainer(strategy={self.strategy!r})` instead." ) - if self.strategy == DistributedType.TPU_SPAWN: + if self.strategy == _StrategyType.TPU_SPAWN: raise MisconfigurationException( "`Trainer(strategy='tpu_spawn')` is not a valid strategy," " you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead." ) - if self.strategy == DistributedType.DDP_CPU: + if self.strategy == _StrategyType.DDP_CPU: raise MisconfigurationException( "`Trainer(strategy='ddp_cpu')` is not a valid strategy," " you can use `Trainer(strategy='ddp'|'ddp_spawn', accelerator='cpu')` instead." @@ -508,31 +508,31 @@ def _map_devices_to_accelerator(self, accelerator: str) -> bool: @property def use_dp(self) -> bool: - return self._distrib_type == DistributedType.DP + return self._distrib_type == _StrategyType.DP @property def use_ddp(self) -> bool: return self._distrib_type in ( - DistributedType.DDP, - DistributedType.DDP_SPAWN, - DistributedType.DDP_SHARDED, - DistributedType.DDP_SHARDED_SPAWN, - DistributedType.DDP_FULLY_SHARDED, - DistributedType.DEEPSPEED, - DistributedType.TPU_SPAWN, + _StrategyType.DDP, + _StrategyType.DDP_SPAWN, + _StrategyType.DDP_SHARDED, + _StrategyType.DDP_SHARDED_SPAWN, + _StrategyType.DDP_FULLY_SHARDED, + _StrategyType.DEEPSPEED, + _StrategyType.TPU_SPAWN, ) @property def use_ddp2(self) -> bool: - return self._distrib_type == DistributedType.DDP2 + return self._distrib_type == _StrategyType.DDP2 @property def use_horovod(self) -> bool: - return self._distrib_type == DistributedType.HOROVOD + return self._distrib_type == _StrategyType.HOROVOD @property def use_deepspeed(self) -> bool: - return self._distrib_type == DistributedType.DEEPSPEED + return self._distrib_type == _StrategyType.DEEPSPEED @property def _is_sharded_training_type(self) -> bool: @@ -593,7 +593,7 @@ def root_gpu(self) -> Optional[int]: @staticmethod def _is_plugin_training_type(plugin: Union[str, TrainingTypePlugin]) -> bool: - if isinstance(plugin, str) and (plugin in TrainingTypePluginsRegistry or plugin in list(DistributedType)): + if isinstance(plugin, str) and (plugin in TrainingTypePluginsRegistry or plugin in list(_StrategyType)): return True return isinstance(plugin, TrainingTypePlugin) @@ -638,7 +638,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: ) return TPUBf16PrecisionPlugin() - if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin): + if self._distrib_type == _StrategyType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin): return DeepSpeedPrecisionPlugin(self.precision) if self.precision == 32: @@ -709,15 +709,15 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: use_slurm_ddp = self.use_ddp and self._is_slurm_managing_tasks use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic() use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow() - use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN + use_ddp_spawn = self._distrib_type == _StrategyType.DDP_SPAWN use_ddp_cpu_spawn = use_ddp_spawn and self.use_cpu - use_tpu_spawn = self.use_tpu and self._distrib_type == DistributedType.TPU_SPAWN + use_tpu_spawn = self.use_tpu and self._distrib_type == _StrategyType.TPU_SPAWN use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic() use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow() use_ddp_cpu_slurm = use_ddp_cpu_spawn and self._is_slurm_managing_tasks - use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED - use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN - use_ddp_fully_sharded = self._distrib_type == DistributedType.DDP_FULLY_SHARDED + use_ddp_sharded = self._distrib_type == _StrategyType.DDP_SHARDED + use_ddp_sharded_spawn = self._distrib_type == _StrategyType.DDP_SHARDED_SPAWN + use_ddp_fully_sharded = self._distrib_type == _StrategyType.DDP_FULLY_SHARDED if use_tpu_spawn: ddp_plugin_cls = TPUSpawnPlugin @@ -842,27 +842,27 @@ def set_distributed_mode(self, strategy: Optional[str] = None): if self.has_horovodrun(): self._set_horovod_backend() elif self.num_gpus == 0 and self.num_nodes > 1: - self._distrib_type = DistributedType.DDP + self._distrib_type = _StrategyType.DDP elif self.num_gpus == 0 and self.num_processes > 1: - self.distributed_backend = DistributedType.DDP_SPAWN + self.distributed_backend = _StrategyType.DDP_SPAWN elif self.num_gpus > 1 and not _use_cpu: rank_zero_warn( "You requested multiple GPUs but did not specify a backend, e.g." ' `Trainer(strategy="dp"|"ddp"|"ddp2")`. Setting `strategy="ddp_spawn"` for you.' ) - self.distributed_backend = DistributedType.DDP_SPAWN + self.distributed_backend = _StrategyType.DDP_SPAWN # special case with DDP on CPUs - if self.distributed_backend == DistributedType.DDP_CPU: + if self.distributed_backend == _StrategyType.DDP_CPU: if _TPU_AVAILABLE: raise MisconfigurationException( "`accelerator='ddp_cpu'` is not supported on TPU machines. " "Learn more: https://github.com/PyTorchLightning/pytorch-lightning/issues/7810" ) if self.num_processes == 1 and self.num_nodes > 1: - self._distrib_type = DistributedType.DDP + self._distrib_type = _StrategyType.DDP else: - self._distrib_type = DistributedType.DDP_SPAWN + self._distrib_type = _StrategyType.DDP_SPAWN if self.num_gpus > 0: rank_zero_warn( "You requested one or more GPUs, but set `accelerator='ddp_cpu'`. Training will not use GPUs." @@ -875,25 +875,25 @@ def set_distributed_mode(self, strategy: Optional[str] = None): elif self.has_tpu and not _use_cpu: self._device_type = DeviceType.TPU if isinstance(self.tpu_cores, int): - self._distrib_type = DistributedType.TPU_SPAWN + self._distrib_type = _StrategyType.TPU_SPAWN elif self.has_ipu and not _use_cpu: self._device_type = DeviceType.IPU elif self.distributed_backend and self._distrib_type is None: - self._distrib_type = DistributedType(self.distributed_backend) + self._distrib_type = _StrategyType(self.distributed_backend) if self.num_gpus > 0 and not _use_cpu: self._device_type = DeviceType.GPU - _gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) + _gpu_distrib_types = (_StrategyType.DP, _StrategyType.DDP, _StrategyType.DDP_SPAWN, _StrategyType.DDP2) # DP and DDP2 cannot run without GPU if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _use_cpu: if (self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1): - if self._distrib_type in (DistributedType.DP, DistributedType.DDP2): + if self._distrib_type in (_StrategyType.DP, _StrategyType.DDP2): rank_zero_warn( f"{self._distrib_type.value!r} is not supported on CPUs, hence setting `strategy='ddp'`." ) - self._distrib_type = DistributedType.DDP + self._distrib_type = _StrategyType.DDP else: rank_zero_warn("You are running on single node with no parallelization, so distributed has no effect.") self._distrib_type = None @@ -903,28 +903,28 @@ def set_distributed_mode(self, strategy: Optional[str] = None): # for DDP overwrite nb processes by requested GPUs if self._device_type == DeviceType.GPU and self._distrib_type in ( - DistributedType.DDP, - DistributedType.DDP_SPAWN, + _StrategyType.DDP, + _StrategyType.DDP_SPAWN, ): self.num_processes = self.num_gpus - if self._device_type == DeviceType.GPU and self._distrib_type == DistributedType.DDP2: + if self._device_type == DeviceType.GPU and self._distrib_type == _StrategyType.DDP2: self.num_processes = self.num_nodes # Horovod is an extra case... - if self.distributed_backend == DistributedType.HOROVOD: + if self.distributed_backend == _StrategyType.HOROVOD: self._set_horovod_backend() using_valid_distributed = self.use_ddp or self.use_ddp2 if self.num_nodes > 1 and not using_valid_distributed: - # throw error to force user to choose a supported distributed type such as ddp or ddp2 + # throw error to force user to choose a supported strategy type such as ddp or ddp2 raise MisconfigurationException( "Your chosen strategy does not support `num_nodes > 1`. Please set `strategy=('ddp'|'ddp2')`." ) def _set_horovod_backend(self): self.check_horovod() - self._distrib_type = DistributedType.HOROVOD + self._distrib_type = _StrategyType.HOROVOD # Initialize Horovod to get rank / size info hvd.init() @@ -944,7 +944,7 @@ def check_interactive_compatibility(self): f"`Trainer(strategy={self._distrib_type.value!r})` or" f" `Trainer(accelerator={self._distrib_type.value!r})` is not compatible with an interactive" " environment. Run your code as a script, or choose one of the compatible backends:" - f" {', '.join(DistributedType.interactive_compatible_types())}." + f" {', '.join(_StrategyType.interactive_compatible_types())}." " In case you are spawning processes yourself, make sure to include the Trainer" " creation inside the worker function." ) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 37a234f32f711..931f6a92958ee 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -38,7 +38,7 @@ FastForwardSampler, ) from pytorch_lightning.utilities.data import get_len, has_iterable_dataset, has_len_all_ranks -from pytorch_lightning.utilities.enums import DistributedType +from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden @@ -70,7 +70,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: if not isinstance(dataloader, DataLoader): return - using_spawn = self._accelerator_connector._distrib_type == DistributedType.DDP_SPAWN + using_spawn = self._accelerator_connector._distrib_type == _StrategyType.DDP_SPAWN num_cpus = multiprocessing.cpu_count() # ddp_spawn + num_workers > 0 don't mix! tell the user diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 396289000251d..c67949a6dde92 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -64,10 +64,10 @@ from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities import ( _IPU_AVAILABLE, + _StrategyType, _TPU_AVAILABLE, device_parser, DeviceType, - DistributedType, GradClipAlgorithmType, parsing, rank_zero_deprecation, @@ -1589,7 +1589,7 @@ def should_rank_save_checkpoint(self) -> bool: return self.training_type_plugin.should_rank_save_checkpoint @property - def _distrib_type(self) -> DistributedType: + def _distrib_type(self) -> _StrategyType: return self._accelerator_connector._distrib_type @property @@ -1752,10 +1752,10 @@ def distributed_sampler_kwargs(self) -> Optional[dict]: @property def data_parallel(self) -> bool: return self._distrib_type in ( - DistributedType.DP, - DistributedType.DDP, - DistributedType.DDP_SPAWN, - DistributedType.DDP2, + _StrategyType.DP, + _StrategyType.DDP, + _StrategyType.DDP_SPAWN, + _StrategyType.DDP2, ) @property diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 158d7356c91ce..1528e19c77484 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -18,6 +18,7 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only # noqa: F401 from pytorch_lightning.utilities.enums import ( # noqa: F401 + _StrategyType, AMPType, DeviceType, DistributedType, diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 436c675c382c2..18b0336b82d5f 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Enumerated utilities.""" -from enum import Enum -from typing import List, Optional, Union +from enum import Enum, EnumMeta +from typing import Any, List, Optional, Union + +from pytorch_lightning.utilities.warnings import rank_zero_deprecation class LightningEnum(str, Enum): @@ -37,6 +39,31 @@ def __hash__(self) -> int: return hash(self.value.lower()) +class _OnAccessEnumMeta(EnumMeta): + """Enum with a hook to run a function whenever a member is accessed. + + Adapted from: + https://www.buzzphp.com/posts/how-do-i-detect-and-invoke-a-function-when-a-python-enum-member-is-accessed + """ + + def __getattribute__(cls, name: str) -> Any: + obj = super().__getattribute__(name) + if isinstance(obj, Enum): + obj.deprecate() + return obj + + def __getitem__(cls, name: str) -> Any: + member = super().__getitem__(name) + member.deprecate() + return member + + def __call__(cls, value: str, *args: Any, **kwargs: Any) -> Any: + obj = super().__call__(value, *args, **kwargs) + if isinstance(obj, Enum): + obj.deprecate() + return obj + + class AMPType(LightningEnum): """Type of Automatic Mixed Precission used for training. @@ -73,8 +100,8 @@ def supported_types() -> List[str]: return [x.value for x in PrecisionType] -class DistributedType(LightningEnum): - """Define type of distributed computing. +class DistributedType(LightningEnum, metaclass=_OnAccessEnumMeta): + """Define type of training strategy. >>> # you can match the type with string >>> DistributedType.DDP == 'ddp' @@ -82,8 +109,24 @@ class DistributedType(LightningEnum): >>> # which is case invariant >>> DistributedType.DDP2 in ('ddp2', ) True + + Deprecated since v1.6.0 and will be removed in v1.8.0. + + Use `_StrategyType` instead. """ + DP = "dp" + DDP = "ddp" + DDP2 = "ddp2" + DDP_CPU = "ddp_cpu" + DDP_SPAWN = "ddp_spawn" + TPU_SPAWN = "tpu_spawn" + DEEPSPEED = "deepspeed" + HOROVOD = "horovod" + DDP_SHARDED = "ddp_sharded" + DDP_SHARDED_SPAWN = "ddp_sharded_spawn" + DDP_FULLY_SHARDED = "ddp_fully_sharded" + @staticmethod def interactive_compatible_types() -> List["DistributedType"]: """Returns a list containing interactive compatible DistributeTypes.""" @@ -98,17 +141,11 @@ def is_interactive_compatible(self) -> bool: """Returns whether self is interactive compatible.""" return self in DistributedType.interactive_compatible_types() - DP = "dp" - DDP = "ddp" - DDP2 = "ddp2" - DDP_CPU = "ddp_cpu" - DDP_SPAWN = "ddp_spawn" - TPU_SPAWN = "tpu_spawn" - DEEPSPEED = "deepspeed" - HOROVOD = "horovod" - DDP_SHARDED = "ddp_sharded" - DDP_SHARDED_SPAWN = "ddp_sharded_spawn" - DDP_FULLY_SHARDED = "ddp_fully_sharded" + def deprecate(self) -> None: + rank_zero_deprecation( + "`DistributedType` Enum has been deprecated in v1.6 and will be removed in v1.8." + " Use the string value `{self.value!r}` instead." + ) class DeviceType(LightningEnum): @@ -188,3 +225,41 @@ def get_max_depth(mode: str) -> int: @staticmethod def supported_types() -> List[str]: return [x.value for x in ModelSummaryMode] + + +class _StrategyType(LightningEnum): + """Define type of training strategy. + + >>> # you can match the type with string + >>> _StrategyType.DDP == 'ddp' + True + >>> # which is case invariant + >>> _StrategyType.DDP2 in ('ddp2', ) + True + """ + + DP = "dp" + DDP = "ddp" + DDP2 = "ddp2" + DDP_CPU = "ddp_cpu" + DDP_SPAWN = "ddp_spawn" + TPU_SPAWN = "tpu_spawn" + DEEPSPEED = "deepspeed" + HOROVOD = "horovod" + DDP_SHARDED = "ddp_sharded" + DDP_SHARDED_SPAWN = "ddp_sharded_spawn" + DDP_FULLY_SHARDED = "ddp_fully_sharded" + + @staticmethod + def interactive_compatible_types() -> List["_StrategyType"]: + """Returns a list containing interactive compatible _StrategyTypes.""" + return [ + _StrategyType.DP, + _StrategyType.DDP_SPAWN, + _StrategyType.DDP_SHARDED_SPAWN, + _StrategyType.TPU_SPAWN, + ] + + def is_interactive_compatible(self) -> bool: + """Returns whether self is interactive compatible.""" + return self in _StrategyType.interactive_compatible_types() diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index d95f5c8e6f9ea..e70d862b048e0 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -43,7 +43,7 @@ SLURMEnvironment, TorchElasticEnvironment, ) -from pytorch_lightning.utilities import DeviceType, DistributedType +from pytorch_lightning.utilities import _StrategyType, DeviceType from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -636,7 +636,7 @@ def test_unsupported_distrib_types_on_cpu(training_type): with pytest.warns(UserWarning, match="is not supported on CPUs, hence setting `strategy='ddp"): trainer = Trainer(accelerator=training_type, num_processes=2) - assert trainer._distrib_type == DistributedType.DDP + assert trainer._distrib_type == _StrategyType.DDP def test_accelerator_ddp_for_cpu(tmpdir): diff --git a/tests/base/model_test_epoch_ends.py b/tests/base/model_test_epoch_ends.py index 746ceb94a5de0..b001298e93dd0 100644 --- a/tests/base/model_test_epoch_ends.py +++ b/tests/base/model_test_epoch_ends.py @@ -15,7 +15,7 @@ import torch -from pytorch_lightning.utilities import DistributedType +from pytorch_lightning.utilities import _StrategyType class TestEpochEndVariations(ABC): @@ -34,13 +34,13 @@ def test_epoch_end(self, outputs): test_loss = self.get_output_metric(output, "test_loss") # reduce manually when using dp - if self.trainer._distrib_type == DistributedType.DP: + if self.trainer._distrib_type == _StrategyType.DP: test_loss = torch.mean(test_loss) test_loss_mean += test_loss # reduce manually when using dp test_acc = self.get_output_metric(output, "test_acc") - if self.trainer._distrib_type == DistributedType.DP: + if self.trainer._distrib_type == _StrategyType.DP: test_acc = torch.mean(test_acc) test_acc_mean += test_acc @@ -69,13 +69,13 @@ def test_epoch_end__multiple_dataloaders(self, outputs): test_loss = output["test_loss"] # reduce manually when using dp - if self.trainer._distrib_type == DistributedType.DP: + if self.trainer._distrib_type == _StrategyType.DP: test_loss = torch.mean(test_loss) test_loss_mean += test_loss # reduce manually when using dp test_acc = output["test_acc"] - if self.trainer._distrib_type == DistributedType.DP: + if self.trainer._distrib_type == _StrategyType.DP: test_acc = torch.mean(test_acc) test_acc_mean += test_acc diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py new file mode 100644 index 0000000000000..f668f63b9f450 --- /dev/null +++ b/tests/deprecated_api/test_remove_1-8.py @@ -0,0 +1,23 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in v1.8.0.""" +import pytest + +from pytorch_lightning.utilities.enums import DistributedType + + +def test_v1_8_0_deprecated_distributed_type_enum(): + + with pytest.deprecated_call(match="has been deprecated in v1.6 and will be removed in v1.8."): + _ = DistributedType.DDP diff --git a/tests/helpers/pipelines.py b/tests/helpers/pipelines.py index 643d3e50cb894..6fa3bbb5dc943 100644 --- a/tests/helpers/pipelines.py +++ b/tests/helpers/pipelines.py @@ -15,7 +15,7 @@ from torchmetrics.functional import accuracy from pytorch_lightning import LightningDataModule, LightningModule, Trainer -from pytorch_lightning.utilities import DistributedType +from pytorch_lightning.utilities import _StrategyType from tests.helpers import BoringModel from tests.helpers.utils import get_default_logger, load_model_from_checkpoint, reset_seed @@ -82,7 +82,7 @@ def run_model_test( run_prediction_eval_model_template(model, dataloader, min_acc=min_acc) if with_hpc: - if trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2): + if trainer._distrib_type in (_StrategyType.DDP, _StrategyType.DDP_SPAWN, _StrategyType.DDP2): # on hpc this would work fine... but need to hack it for the purpose of the test trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = trainer.init_optimizers( pretrained_model diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index bd69cf359473e..7c79cb7f2e709 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -31,7 +31,7 @@ _replace_dataloader_init_method, ) from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin -from pytorch_lightning.utilities import DistributedType +from pytorch_lightning.utilities import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import pl_worker_init_function from tests.helpers.runif import RunIf @@ -251,12 +251,12 @@ def test_seed_everything(): @pytest.mark.parametrize( "strategy", [ - DistributedType.DP, - DistributedType.DDP, - DistributedType.DDP_SPAWN, - pytest.param(DistributedType.DEEPSPEED, marks=RunIf(deepspeed=True)), - pytest.param(DistributedType.DDP_SHARDED, marks=RunIf(fairscale=True)), - pytest.param(DistributedType.DDP_SHARDED_SPAWN, marks=RunIf(fairscale=True)), + _StrategyType.DP, + _StrategyType.DDP, + _StrategyType.DDP_SPAWN, + pytest.param(_StrategyType.DEEPSPEED, marks=RunIf(deepspeed=True)), + pytest.param(_StrategyType.DDP_SHARDED, marks=RunIf(fairscale=True)), + pytest.param(_StrategyType.DDP_SHARDED_SPAWN, marks=RunIf(fairscale=True)), ], ) def test_setup_dataloaders_replace_custom_sampler(strategy): @@ -279,12 +279,12 @@ def test_setup_dataloaders_replace_custom_sampler(strategy): @pytest.mark.parametrize( "strategy", [ - DistributedType.DP, - DistributedType.DDP, - DistributedType.DDP_SPAWN, - pytest.param(DistributedType.DEEPSPEED, marks=RunIf(deepspeed=True)), - pytest.param(DistributedType.DDP_SHARDED, marks=RunIf(fairscale=True)), - pytest.param(DistributedType.DDP_SHARDED_SPAWN, marks=RunIf(fairscale=True)), + _StrategyType.DP, + _StrategyType.DDP, + _StrategyType.DDP_SPAWN, + pytest.param(_StrategyType.DEEPSPEED, marks=RunIf(deepspeed=True)), + pytest.param(_StrategyType.DDP_SHARDED, marks=RunIf(fairscale=True)), + pytest.param(_StrategyType.DDP_SHARDED_SPAWN, marks=RunIf(fairscale=True)), ], ) @pytest.mark.parametrize("shuffle", [True, False]) diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index 0f6abd38e6836..7d35c592b7b62 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -20,7 +20,7 @@ from torch.utils.data.sampler import BatchSampler, Sampler, SequentialSampler from pytorch_lightning import Trainer -from pytorch_lightning.utilities.enums import DistributedType +from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -137,7 +137,7 @@ def _get_warning_msg(): @pytest.mark.parametrize("num_workers", [0, 1]) def test_dataloader_warnings(tmpdir, num_workers): trainer = Trainer(default_root_dir=tmpdir, strategy="ddp_spawn", num_processes=2, fast_dev_run=4) - assert trainer._accelerator_connector._distrib_type == DistributedType.DDP_SPAWN + assert trainer._accelerator_connector._distrib_type == _StrategyType.DDP_SPAWN trainer.fit(TestSpawnBoringModel(num_workers)) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d2e5f771a9c40..dc0ce2b68452c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -48,7 +48,7 @@ DDPSpawnShardedPlugin, ) from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import DeviceType, DistributedType +from pytorch_lightning.utilities import _StrategyType, DeviceType from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything @@ -1154,15 +1154,15 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): ), ( dict(accelerator="ddp", num_processes=2, gpus=None), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), ), ( dict(accelerator="ddp", num_nodes=2, gpus=None), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), ), ( dict(accelerator="ddp_cpu", num_processes=2, gpus=None), - dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), ), ( dict(accelerator="ddp2", gpus=None), @@ -1174,43 +1174,43 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): ), ( dict(accelerator="dp", gpus=1), - dict(_distrib_type=DistributedType.DP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), ), ( dict(accelerator="ddp", gpus=1), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), ), ( dict(accelerator="ddp_cpu", num_processes=2, gpus=1), - dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), ), ( dict(accelerator="ddp2", gpus=1), - dict(_distrib_type=DistributedType.DDP2, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), ), ( dict(accelerator=None, gpus=2), - dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), ), ( dict(accelerator="dp", gpus=2), - dict(_distrib_type=DistributedType.DP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), ), ( dict(accelerator="ddp", gpus=2), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), ), ( dict(accelerator="ddp2", gpus=2), - dict(_distrib_type=DistributedType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), ), ( dict(accelerator="ddp2", num_processes=2, gpus=None), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), ), ( dict(accelerator="dp", num_processes=2, gpus=None), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), ), ], ) @@ -2096,11 +2096,11 @@ def training_step(self, batch, batch_idx): ), ( dict(strategy="ddp", num_processes=2, gpus=None), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy="ddp", num_nodes=2, gpus=None), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), ), ( dict(strategy="ddp2", gpus=None), @@ -2112,47 +2112,47 @@ def training_step(self, batch, batch_idx): ), ( dict(strategy="dp", gpus=1), - dict(_distrib_type=DistributedType.DP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), ), ( dict(strategy="ddp", gpus=1), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), ), ( dict(strategy="ddp_spawn", gpus=1), - dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), ), ( dict(strategy="ddp2", gpus=1), - dict(_distrib_type=DistributedType.DDP2, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), ), ( dict(strategy=None, gpus=2), - dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), ), ( dict(strategy="dp", gpus=2), - dict(_distrib_type=DistributedType.DP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy="ddp", gpus=2), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), ), ( dict(strategy="ddp2", gpus=2), - dict(_distrib_type=DistributedType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy="ddp2", num_processes=2, gpus=None), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy="dp", num_processes=2, gpus=None), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy="ddp_spawn", num_processes=2, gpus=None), - dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy="ddp_spawn", num_processes=1, gpus=None), @@ -2161,7 +2161,7 @@ def training_step(self, batch, batch_idx): ( dict(strategy="ddp_fully_sharded", gpus=1), dict( - _distrib_type=DistributedType.DDP_FULLY_SHARDED, + _distrib_type=_StrategyType.DDP_FULLY_SHARDED, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1, @@ -2169,32 +2169,32 @@ def training_step(self, batch, batch_idx): ), ( dict(strategy=DDPSpawnPlugin(), num_processes=2, gpus=None), - dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy=DDPSpawnPlugin(), gpus=2), - dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy=DDPPlugin(), num_processes=2, gpus=None), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy=DDPPlugin(), gpus=2), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy=DDP2Plugin(), gpus=2), - dict(_distrib_type=DistributedType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy=DataParallelPlugin(), gpus=2), - dict(_distrib_type=DistributedType.DP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy=DDPFullyShardedPlugin(), gpus=2), dict( - _distrib_type=DistributedType.DDP_FULLY_SHARDED, + _distrib_type=_StrategyType.DDP_FULLY_SHARDED, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1, @@ -2203,7 +2203,7 @@ def training_step(self, batch, batch_idx): ( dict(strategy=DDPSpawnShardedPlugin(), gpus=2), dict( - _distrib_type=DistributedType.DDP_SHARDED_SPAWN, + _distrib_type=_StrategyType.DDP_SHARDED_SPAWN, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1, @@ -2211,7 +2211,7 @@ def training_step(self, batch, batch_idx): ), ( dict(strategy=DDPShardedPlugin(), gpus=2), - dict(_distrib_type=DistributedType.DDP_SHARDED, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP_SHARDED, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), ), ], )