From dc5096670200b0bc42775ac7d7040c7cc8c28b9e Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Tue, 16 Nov 2021 09:56:03 -0800 Subject: [PATCH 01/27] 1/n move precision plugin into strategy - update reference --- docs/source/extensions/accelerators.rst | 3 +- docs/source/extensions/plugins.rst | 3 +- pytorch_lightning/accelerators/accelerator.py | 44 +++++++++++-------- pytorch_lightning/accelerators/tpu.py | 5 ++- pytorch_lightning/lite/lite.py | 2 +- .../plugins/training_type/ddp.py | 3 ++ .../plugins/training_type/ddp_spawn.py | 3 ++ .../plugins/training_type/deepspeed.py | 8 ++-- pytorch_lightning/plugins/training_type/dp.py | 9 +++- .../plugins/training_type/fully_sharded.py | 3 ++ .../plugins/training_type/horovod.py | 9 +++- .../plugins/training_type/ipu.py | 5 ++- .../plugins/training_type/parallel.py | 4 +- .../plugins/training_type/sharded.py | 2 +- .../plugins/training_type/sharded_spawn.py | 2 +- .../plugins/training_type/single_device.py | 4 +- .../plugins/training_type/single_tpu.py | 4 +- .../plugins/training_type/tpu_spawn.py | 6 ++- .../training_type/training_type_plugin.py | 14 +++++- .../connectors/accelerator_connector.py | 12 +++-- pytorch_lightning/trainer/trainer.py | 4 +- tests/accelerators/test_cpu.py | 8 ++-- tests/accelerators/test_gpu.py | 4 +- tests/accelerators/test_ipu.py | 8 ++-- tests/accelerators/test_tpu.py | 12 ++--- ..._ddp_fully_sharded_with_full_state_dict.py | 2 +- tests/plugins/test_deepspeed_plugin.py | 4 +- 27 files changed, 121 insertions(+), 66 deletions(-) diff --git a/docs/source/extensions/accelerators.rst b/docs/source/extensions/accelerators.rst index 11a85cb082af8..66561db0ccef4 100644 --- a/docs/source/extensions/accelerators.rst +++ b/docs/source/extensions/accelerators.rst @@ -26,8 +26,7 @@ One to handle differences from the training routine and one to handle different from pytorch_lightning.plugins import NativeMixedPrecisionPlugin, DDPPlugin accelerator = GPUAccelerator( - precision_plugin=NativeMixedPrecisionPlugin(16, "cuda"), - training_type_plugin=DDPPlugin(), + training_type_plugin=DDPPlugin(precision_plugin=NativeMixedPrecisionPlugin(16, "cuda")), ) trainer = Trainer(accelerator=accelerator) diff --git a/docs/source/extensions/plugins.rst b/docs/source/extensions/plugins.rst index 56e14a97502cc..25a28c6cc1c41 100644 --- a/docs/source/extensions/plugins.rst +++ b/docs/source/extensions/plugins.rst @@ -81,8 +81,7 @@ can then be passed into the Trainer directly or via a (custom) accelerator: # fully custom accelerator and plugins accelerator = MyAccelerator( - precision_plugin=CustomPrecisionPlugin(), - training_type_plugin=CustomDDPPlugin(), + training_type_plugin=CustomDDPPlugin(precision_plugin=CustomPrecisionPlugin()), ) trainer = Trainer(accelerator=accelerator) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 14b6a47c7243f..63caa317b7a27 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -44,15 +44,20 @@ class Accelerator: One to handle differences from the training routine and one to handle different precisions. """ - def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: TrainingTypePlugin) -> None: + def __init__( + self, training_type_plugin: TrainingTypePlugin, precision_plugin: Optional[PrecisionPlugin] = None + ) -> None: """ Args: precision_plugin: the plugin to handle precision-specific parts training_type_plugin: the plugin to handle different training routines """ - self.precision_plugin = precision_plugin + self.training_type_plugin = training_type_plugin + if precision_plugin: + self.training_type_plugin._precision_plugin = precision_plugin + self.optimizers: List = [] self.lr_schedulers: List = [] self.optimizer_frequencies: List = [] @@ -84,7 +89,7 @@ def pre_dispatch(self, trainer: "pl.Trainer") -> None: if self.training_type_plugin.setup_optimizers_in_pre_dispatch: self.setup_optimizers(trainer) - self.precision_plugin.pre_dispatch() + self.training_type_plugin.precision_plugin.pre_dispatch() def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the GPU if needed.""" @@ -96,12 +101,12 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: def dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.dispatch(trainer) - self.precision_plugin.dispatch(trainer) + self.training_type_plugin.precision_plugin.dispatch(trainer) def post_dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something after the training/evaluation/prediction starts.""" self.training_type_plugin.post_dispatch(trainer) - self.precision_plugin.post_dispatch() + self.training_type_plugin.precision_plugin.post_dispatch() @property def model(self) -> Module: @@ -159,7 +164,7 @@ def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details """ - with self.precision_plugin.train_step_context(): + with self.training_type_plugin.precision_plugin.train_step_context(): return self.training_type_plugin.training_step(*step_kwargs.values()) def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: @@ -167,7 +172,7 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details """ - with self.precision_plugin.val_step_context(): + with self.training_type_plugin.precision_plugin.val_step_context(): return self.training_type_plugin.validation_step(*step_kwargs.values()) def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: @@ -175,7 +180,7 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details """ - with self.precision_plugin.test_step_context(): + with self.training_type_plugin.precision_plugin.test_step_context(): return self.training_type_plugin.test_step(*step_kwargs.values()) def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: @@ -183,7 +188,7 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details """ - with self.precision_plugin.predict_step_context(): + with self.training_type_plugin.precision_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: @@ -193,11 +198,11 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: closure_loss: a tensor holding the loss value to backpropagate """ self.training_type_plugin.pre_backward(closure_loss) - closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss) + closure_loss = self.training_type_plugin.precision_plugin.pre_backward(self.lightning_module, closure_loss) - self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs) + self.training_type_plugin.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs) - closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss) + closure_loss = self.training_type_plugin.precision_plugin.post_backward(self.lightning_module, closure_loss) self.training_type_plugin.post_backward(closure_loss) return closure_loss @@ -220,7 +225,7 @@ def optimizer_step( **kwargs: Any extra arguments to ``optimizer.step`` """ model = model or self.lightning_module - self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) + self.training_type_plugin.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: """Zeros all model parameter's gradients.""" @@ -248,26 +253,29 @@ def setup_training_type_plugin(self) -> None: def setup_precision_plugin(self) -> None: """Attaches the precision plugin to the accelerator.""" - model, optimizers, schedulers = self.precision_plugin.connect(self.model, self.optimizers, self.lr_schedulers) + model, optimizers, schedulers = self.training_type_plugin.precision_plugin.connect( + self.model, self.optimizers, self.lr_schedulers + ) self.model = model self.optimizers = optimizers self.lr_schedulers = schedulers @property def amp_backend(self) -> Optional[LightningEnum]: - if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): + if isinstance(self.training_type_plugin.precision_plugin, ApexMixedPrecisionPlugin): return AMPType.APEX - if isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): + if isinstance(self.training_type_plugin.precision_plugin, NativeMixedPrecisionPlugin): return AMPType.NATIVE return None @property def precision(self) -> Union[str, int]: - return self.precision_plugin.precision + """deprecated.""" + return self.training_type_plugin.precision @property def scaler(self) -> Optional["GradScaler"]: - return getattr(self.precision_plugin, "scaler", None) + return getattr(self.training_type_plugin.precision_plugin, "scaler", None) def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: """Returns state of an optimizer. diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 6e824a25f6b9d..4ab3b643d443c 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -36,10 +36,11 @@ def setup(self, trainer: "pl.Trainer") -> None: ValueError: If the precision or training type plugin are unsupported. """ - if not isinstance(self.precision_plugin, TPUPrecisionPlugin): + if not isinstance(self.training_type_plugin.precision_plugin, TPUPrecisionPlugin): # this configuration should have been avoided in the accelerator connector raise ValueError( - f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`, found: {self.precision_plugin}." + f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`," + f" found: {self.training_type_plugin}." ) if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)): raise ValueError( diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 2a2ed9586b420..bb07c763156aa 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -108,7 +108,7 @@ def __init__( ) self._accelerator = self._accelerator_connector.accelerator self._strategy = self._accelerator.training_type_plugin - self._precision_plugin = self._accelerator.precision_plugin + self._precision_plugin = self._strategy.precision_plugin self._models_setup: int = 0 # wrap the run method so we can inject setup logic or spawn processes for the user diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 0285859a6714a..6d1b168d5ac7a 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -36,6 +36,7 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( @@ -86,6 +87,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, @@ -96,6 +98,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, ) self.interactive_ddp_procs = [] self._num_nodes = 1 diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index a77027adb6dcf..da724944ade7e 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -29,6 +29,7 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn @@ -65,6 +66,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, @@ -74,6 +76,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, ) self._num_nodes = 1 self.sync_batchnorm = False diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index eb087ad199808..01959bdcee212 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -30,6 +30,7 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.trainer.states import TrainerFn @@ -129,6 +130,7 @@ def __init__( synchronize_checkpoint_boundary: bool = False, load_full_weights: bool = False, partition_module: bool = True, + precision_plugin: Optional[PrecisionPlugin] = None, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. `For more information: https://pytorch- @@ -273,6 +275,7 @@ def __init__( super().__init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, + precision_plugin=precision_plugin, ) self.config = self._load_config(config) @@ -331,7 +334,7 @@ def __init__( @property def precision(self) -> Union[str, int]: - return self._precision or self.lightning_module.trainer.precision + return self._precision or self.precision_plugin.precision @property def amp_level(self) -> Optional[str]: @@ -456,8 +459,7 @@ def init_deepspeed(self): "DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs." ) - precision = self.lightning_module.trainer.accelerator.precision - model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) + model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision) if self.zero_stage_3 and self.partition_module: # Ensure the entire model has been moved to the appropriate device diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 83328e8c47271..3f1b9a3acfa50 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -18,6 +18,7 @@ from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.enums import _StrategyType @@ -35,8 +36,14 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): - super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=None, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + ) @property def global_rank(self) -> int: diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index c9601a905df1c..716e007efc592 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -18,6 +18,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE from pytorch_lightning.utilities.enums import _StrategyType @@ -46,6 +47,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): """Plugin for Fully Sharded Data Parallel provided by FairScale. @@ -97,6 +99,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, ) self.cpu_offload = cpu_offload self.move_grads_to_cpu = move_grads_to_cpu diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 51558189a3d35..961d2764b8ef3 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -21,6 +21,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import distributed_available @@ -41,8 +42,14 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): - super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=None, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + ) rank_zero_only.rank = self.global_rank @property diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 898e62791d6ee..26e2f381e6300 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -22,6 +22,7 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE @@ -64,6 +65,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, training_opts: Optional["poptorch.Options"] = None, inference_opts: Optional["poptorch.Options"] = None, ) -> None: @@ -116,8 +118,7 @@ def setup(self) -> None: self.lightning_module.trainer._update_dataloader = self._convert_to_poptorch_loader def pre_dispatch(self) -> None: - precision = self.lightning_module.trainer.precision - model = LightningIPUModule(self.lightning_module, precision) + model = LightningIPUModule(self.lightning_module, self.precision) self.model = model # reset the backup diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 4f4b2c5b8e3c3..07ede1ae4f833 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -23,6 +23,7 @@ from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp @@ -36,8 +37,9 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): - super().__init__(checkpoint_io) + super().__init__(checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.parallel_devices = parallel_devices self.cluster_environment = cluster_environment diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index d7563437bd16b..475d1e44095d8 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -75,7 +75,7 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, Lightnin optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: - precision = self._precision or self.lightning_module.trainer.precision + precision = self._precision or self.precision is_fp16 = precision in ("mixed", 16) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 12e627edbe5cb..ae308b2074c74 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -118,7 +118,7 @@ def post_training_step(self): def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process - precision_plugin = trainer.accelerator.precision_plugin + precision_plugin = trainer.precision_plugin if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin): precision_plugin.scaler = ShardedGradScaler() return super().new_process(trainer, mp_queue) diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 1737bf3b41ca8..12a0f625b64fc 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -16,6 +16,7 @@ import torch from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE @@ -27,8 +28,9 @@ def __init__( self, device: torch.device, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): - super().__init__(checkpoint_io) + super().__init__(checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.device: torch.device = device self.global_rank = 0 self.local_rank = 0 diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 9fed2000391dd..e6f6a5f4b26f2 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -16,6 +16,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -33,12 +34,13 @@ def __init__( self, device: int, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, ): device = xm.xla_device(device) checkpoint_io = checkpoint_io or XLACheckpointIO() - super().__init__(device=device, checkpoint_io=checkpoint_io) + super().__init__(device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.debug = debug self.tpu_local_core_rank = 0 diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 7aa4a67721c04..a3ff3603edd03 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -27,6 +27,7 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn @@ -56,11 +57,14 @@ def __init__( self, parallel_devices: Optional[List[int]] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, **_: Any ) -> None: checkpoint_io = checkpoint_io or XLACheckpointIO() - super().__init__(parallel_devices=parallel_devices, checkpoint_io=checkpoint_io) + super().__init__( + parallel_devices=parallel_devices, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin + ) self.debug = debug self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index c23edf594146f..def1c1f2e43b4 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -25,6 +25,7 @@ from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT @@ -33,16 +34,27 @@ class TrainingTypePlugin(ABC): """Base class for all training type plugins that change the behaviour of the training, validation and test- loop.""" - def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None: + def __init__( + self, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None + ) -> None: self._model: Optional[Module] = None self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() self._checkpoint_io = checkpoint_io + self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin() @property def checkpoint_io(self) -> CheckpointIO: return self._checkpoint_io + @property + def precision(self) -> Union[str, int]: + return self._precision_plugin.precision + + @property + def precision_plugin(self) -> Optional[PrecisionPlugin]: + return self._precision_plugin + @checkpoint_io.setter def checkpoint_io(self, plugin: CheckpointIO) -> None: self._checkpoint_io = plugin diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 5532385ca1d98..5a5c758826c88 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -1,4 +1,3 @@ -# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -405,6 +404,7 @@ def training_type_plugin(self) -> TrainingTypePlugin: # attach checkpoint plugin to the training type plugin if self._checkpoint_io is not None: self._training_type_plugin.checkpoint_io = self._checkpoint_io + self._training_type_plugin._precision_plugin = self.precision_plugin self._training_type_plugin_resolved = True return self._training_type_plugin @@ -531,11 +531,11 @@ def use_deepspeed(self) -> bool: @property def _is_sharded_training_type(self) -> bool: - return isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)) + return isinstance(self._training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)) @property def _is_fully_sharded_training_type(self) -> bool: - return isinstance(self.training_type_plugin, DDPFullyShardedPlugin) + return isinstance(self._training_type_plugin, DDPFullyShardedPlugin) @property def is_distributed(self) -> bool: @@ -793,12 +793,10 @@ def select_accelerator(self) -> Accelerator: acc_cls = IPUAccelerator else: acc_cls = CPUAccelerator - # as precision_plugin is dependent on training_type_plugin, make sure - # that we first select training_type_plugin, then precision_plugin - accelerator = acc_cls(training_type_plugin=self.training_type_plugin, precision_plugin=self.precision_plugin) + + accelerator = acc_cls(training_type_plugin=self.training_type_plugin) # transfer ownership of the plugins to the accelerator self._training_type_plugin = proxy(self.training_type_plugin) - self._precision_plugin = proxy(self.precision_plugin) return accelerator diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index be9c71e2fe470..9e296eef4d11b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1568,7 +1568,7 @@ def training_type_plugin(self) -> TrainingTypePlugin: @property def precision_plugin(self) -> PrecisionPlugin: - return self.accelerator.precision_plugin + return self.training_type_plugin.precision_plugin @property def global_rank(self) -> int: @@ -1672,7 +1672,7 @@ def amp_backend(self) -> Optional[str]: @property def precision(self) -> Union[str, int]: - return self.accelerator.precision + return self.training_type_plugin.precision @property def scaler(self): diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 41e73431495b9..e728d50292d9f 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -41,8 +41,8 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: def test_restore_checkpoint_after_pre_dispatch_default(): """Assert default for restore_checkpoint_after_pre_dispatch is False.""" - plugin = SingleDevicePlugin(torch.device("cpu")) - accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) + plugin = SingleDevicePlugin(torch.device("cpu"), precision_plugin=PrecisionPlugin()) + accelerator = CPUAccelerator(training_type_plugin=plugin) assert not accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch assert not plugin.restore_checkpoint_after_pre_dispatch @@ -74,8 +74,8 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) - plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO()) - accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) + plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO(), precision_plugin=PrecisionPlugin()) + accelerator = CPUAccelerator(training_type_plugin=plugin) assert accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch diff --git a/tests/accelerators/test_gpu.py b/tests/accelerators/test_gpu.py index 85ce0cd9f0f18..d24797ea20c9b 100644 --- a/tests/accelerators/test_gpu.py +++ b/tests/accelerators/test_gpu.py @@ -12,7 +12,7 @@ def test_get_torch_gpu_stats(tmpdir): """Test GPU get_device_stats with Pytorch >= 1.8.0.""" current_device = torch.device(f"cuda:{torch.cuda.current_device()}") GPUAccel = GPUAccelerator( - training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin() + training_type_plugin=DataParallelPlugin(parallel_devices=[current_device], precision_plugin=PrecisionPlugin()) ) gpu_stats = GPUAccel.get_device_stats(current_device) fields = ["allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak"] @@ -27,7 +27,7 @@ def test_get_nvidia_gpu_stats(tmpdir): """Test GPU get_device_stats with Pytorch < 1.8.0.""" current_device = torch.device(f"cuda:{torch.cuda.current_device()}") GPUAccel = GPUAccelerator( - training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin() + training_type_plugin=DataParallelPlugin(parallel_devices=[current_device], precision_plugin=PrecisionPlugin()) ) gpu_stats = GPUAccel.get_device_stats(current_device) fields = ["utilization.gpu", "memory.used", "memory.free", "utilization.memory"] diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index dfaa1c8042355..d65a7b4e63581 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -193,8 +193,8 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st model = IPUModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=1, precision=16, callbacks=TestCallback()) - assert isinstance(trainer.accelerator.precision_plugin, IPUPrecisionPlugin) - assert trainer.accelerator.precision_plugin.precision == 16 + assert isinstance(trainer.precision_plugin, IPUPrecisionPlugin) + assert trainer.precision_plugin.precision == 16 with pytest.raises(SystemExit): trainer.fit(model) @@ -213,8 +213,8 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=1, precision=16, callbacks=TestCallback()) assert isinstance(trainer.accelerator.training_type_plugin, IPUPlugin) - assert isinstance(trainer.accelerator.precision_plugin, IPUPrecisionPlugin) - assert trainer.accelerator.precision_plugin.precision == 16 + assert isinstance(trainer.precision_plugin, IPUPrecisionPlugin) + assert trainer.precision_plugin.precision == 16 with pytest.raises(SystemExit): trainer.fit(model) diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index 78e4c505bb99a..c83f60e0c2ec0 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -23,7 +23,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.tpu import TPUAccelerator -from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO +from pytorch_lightning.plugins import DDPPlugin, TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO from pytorch_lightning.utilities import find_shared_parameters from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset @@ -288,12 +288,14 @@ def forward(self, x): def test_tpu_invalid_raises(): - accelerator = TPUAccelerator(object(), TPUSpawnPlugin()) - with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): + accelerator = TPUAccelerator(TPUSpawnPlugin(precision_plugin=object())) + with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"): accelerator.setup(object()) - accelerator = TPUAccelerator(TPUPrecisionPlugin(), object()) - with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"): + accelerator = TPUAccelerator(DDPPlugin(precision_plugin=TPUPrecisionPlugin())) + with pytest.raises( + ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin" + ): accelerator.setup(object()) diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index 1468c7f4a4137..910b8329dbb06 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -35,7 +35,7 @@ def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp", gpus=1, precision=16) assert isinstance(trainer.accelerator.training_type_plugin, DDPFullyShardedPlugin) - assert isinstance(trainer.accelerator.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) + assert isinstance(trainer.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) class TestFSDPModel(BoringModel): diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 2d39a3de6b5c5..da7eebc719521 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -170,8 +170,8 @@ def test_deepspeed_precision_choice(amp_backend, precision, tmpdir): ) assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) - assert isinstance(trainer.accelerator.precision_plugin, DeepSpeedPrecisionPlugin) - assert trainer.accelerator.precision_plugin.precision == precision + assert isinstance(trainer.precision_plugin, DeepSpeedPrecisionPlugin) + assert trainer.precision_plugin.precision == precision @RunIf(deepspeed=True) From 8fd976e63cbd66360e7755dd18210b18e1ec4bc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Nov 2021 00:10:03 +0100 Subject: [PATCH 02/27] update precision plugin reference in tpu_spawn --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index a3ff3603edd03..3ab9a8171aac5 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -171,7 +171,7 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: set_shared_parameters(self.model.module, shared_params) trainer.accelerator.setup_optimizers(trainer) - trainer.precision_plugin.connect(self._model, None, None) + self.precision_plugin.connect(self._model, None, None) self.barrier("pre-run-stage") From ecadcbd2fd117b2a97767fefa2b53fff6081a01e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Nov 2021 00:13:06 +0100 Subject: [PATCH 03/27] add missing reference in error message --- pytorch_lightning/accelerators/tpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 4ab3b643d443c..673e8419ca7fb 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -40,7 +40,7 @@ def setup(self, trainer: "pl.Trainer") -> None: # this configuration should have been avoided in the accelerator connector raise ValueError( f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`," - f" found: {self.training_type_plugin}." + f" found: {self.training_type_plugin.precision_plugin}." ) if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)): raise ValueError( From 846b595f8297667c45604dfa7c141faa9ab2a258 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Nov 2021 00:13:18 +0100 Subject: [PATCH 04/27] add back removed license line --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 5a5c758826c88..56a90bf1ba811 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -1,3 +1,4 @@ +# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 42b0325f61a383340add5cad3d13556e6cb47f78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Nov 2021 00:17:28 +0100 Subject: [PATCH 05/27] update references in tests --- tests/accelerators/test_ipu.py | 8 ++++---- .../test_ddp_fully_sharded_with_full_state_dict.py | 4 ++-- tests/plugins/test_deepspeed_plugin.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index d65a7b4e63581..be2e597c9a2f9 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -193,8 +193,8 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st model = IPUModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=1, precision=16, callbacks=TestCallback()) - assert isinstance(trainer.precision_plugin, IPUPrecisionPlugin) - assert trainer.precision_plugin.precision == 16 + assert isinstance(trainer.training_type_plugin.precision_plugin, IPUPrecisionPlugin) + assert trainer.training_type_plugin.precision_plugin.precision == 16 with pytest.raises(SystemExit): trainer.fit(model) @@ -213,8 +213,8 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=1, precision=16, callbacks=TestCallback()) assert isinstance(trainer.accelerator.training_type_plugin, IPUPlugin) - assert isinstance(trainer.precision_plugin, IPUPrecisionPlugin) - assert trainer.precision_plugin.precision == 16 + assert isinstance(trainer.training_type_plugin.precision_plugin, IPUPrecisionPlugin) + assert trainer.training_type_plugin.precision_plugin.precision == 16 with pytest.raises(SystemExit): trainer.fit(model) diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index 910b8329dbb06..c0fab297173e7 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -34,8 +34,8 @@ def test_invalid_on_cpu(tmpdir): def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp", gpus=1, precision=16) - assert isinstance(trainer.accelerator.training_type_plugin, DDPFullyShardedPlugin) - assert isinstance(trainer.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) + assert isinstance(trainer.training_type_plugin, DDPFullyShardedPlugin) + assert isinstance(trainer.training_type_plugin.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) class TestFSDPModel(BoringModel): diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index da7eebc719521..480b050c39b36 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -170,8 +170,8 @@ def test_deepspeed_precision_choice(amp_backend, precision, tmpdir): ) assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) - assert isinstance(trainer.precision_plugin, DeepSpeedPrecisionPlugin) - assert trainer.precision_plugin.precision == precision + assert isinstance(trainer.training_type_plugin.precision_plugin, DeepSpeedPrecisionPlugin) + assert trainer.training_type_plugin.precision_plugin.precision == precision @RunIf(deepspeed=True) From 85f058bb2f539ac4557477e0cbf94c5cd3efed71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Nov 2021 00:17:44 +0100 Subject: [PATCH 06/27] update reference in trainer --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9e296eef4d11b..f81ce0396e5bb 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1672,7 +1672,7 @@ def amp_backend(self) -> Optional[str]: @property def precision(self) -> Union[str, int]: - return self.training_type_plugin.precision + return self.training_type_plugin.precision_plugin.precision @property def scaler(self): From cff894ea888059fac06efde0a271c775f343df86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Nov 2021 00:21:10 +0100 Subject: [PATCH 07/27] update return annotation for precision_plugin property on TTP --- pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index def1c1f2e43b4..dbda2715e4035 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -52,7 +52,7 @@ def precision(self) -> Union[str, int]: return self._precision_plugin.precision @property - def precision_plugin(self) -> Optional[PrecisionPlugin]: + def precision_plugin(self) -> PrecisionPlugin: return self._precision_plugin @checkpoint_io.setter From 7e6d63521144ec05206e41a7fced47212002dc5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 17 Nov 2021 00:26:16 +0100 Subject: [PATCH 08/27] simplify access to precision plugin reference in sharded plug --- pytorch_lightning/plugins/training_type/sharded_spawn.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index ae308b2074c74..12c06b9dde541 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -118,9 +118,8 @@ def post_training_step(self): def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process - precision_plugin = trainer.precision_plugin - if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin): - precision_plugin.scaler = ShardedGradScaler() + if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): + self.precision_plugin.scaler = ShardedGradScaler() return super().new_process(trainer, mp_queue) @classmethod From 7c0e651a3bb064ec324d8511d4a028e851daa526 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Tue, 16 Nov 2021 15:44:27 -0800 Subject: [PATCH 09/27] add changelog --- CHANGELOG.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ada4815bcf57..5335d906de56c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520)) +- Moved `precision_plugin` into `Training_type_plugin` and updated reference ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) + + - @@ -50,7 +53,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `DistributedType` in favor of `_StrategyType` ([#10505](https://github.com/PyTorchLightning/pytorch-lightning/pull/10505)) -- +- Deprecated `precision_plugin` from `Accelerator` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) - @@ -139,6 +142,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `reload_dataloaders_every_epoch` from `Trainer` in favour of `reload_dataloaders_every_n_epochs` ([#10481](https://github.com/PyTorchLightning/pytorch-lightning/pull/10481)) +- Removed `precision_plugin` from `Accelerator` in favor of `precision_plugin` in `training_type_plugin` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) + ### Fixed From ae6d6c5d0adad4fd98bb1f7492b17788c487dada Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Tue, 16 Nov 2021 15:54:13 -0800 Subject: [PATCH 10/27] remove precision property from ttp and add deprecation message --- pytorch_lightning/accelerators/accelerator.py | 24 ++++++++++++++++--- .../training_type/training_type_plugin.py | 4 ---- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 63caa317b7a27..676985e541e3b 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -25,6 +25,7 @@ from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.enums import AMPType, LightningEnum from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -56,6 +57,15 @@ def __init__( self.training_type_plugin = training_type_plugin if precision_plugin: + """ + .. deprecated + precision_plugin parameter is deprecated will be removed soon. + Use :`training_type_plugin(precision_plugin) instead. + """ + rank_zero_deprecation( + f"`{self.__class__.__name__}.precision` was and will be removed soon" + f" Use `training_type_plugin.precision_plugin.precision` instead." + ) self.training_type_plugin._precision_plugin = precision_plugin self.optimizers: List = [] @@ -213,7 +223,7 @@ def optimizer_step( opt_idx: int, closure: Callable[[], Any], model: Optional[Union["pl.LightningModule", Module]] = None, - **kwargs: Any + **kwargs: Any, ) -> None: """performs the actual optimizer step. @@ -270,8 +280,16 @@ def amp_backend(self) -> Optional[LightningEnum]: @property def precision(self) -> Union[str, int]: - """deprecated.""" - return self.training_type_plugin.precision + """ + .. deprecated + This method is deprecated will be removed soon. + Use :`training_type_plugin.precision_plugin.precision` instead. + """ + rank_zero_deprecation( + f"`{self.__class__.__name__}.precision` was and will be removed soon" + f" Use `training_type_plugin.precision_plugin.precision` instead." + ) + return self.training_type_plugin.precision_plugin.precision @property def scaler(self) -> Optional["GradScaler"]: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index dbda2715e4035..7010c0e878dc9 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -47,10 +47,6 @@ def __init__( def checkpoint_io(self) -> CheckpointIO: return self._checkpoint_io - @property - def precision(self) -> Union[str, int]: - return self._precision_plugin.precision - @property def precision_plugin(self) -> PrecisionPlugin: return self._precision_plugin From 9936c5122cf7023d6feab173a50774a2e7d352a1 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Tue, 16 Nov 2021 16:25:31 -0800 Subject: [PATCH 11/27] fix make doc and update precision reference --- pytorch_lightning/accelerators/accelerator.py | 7 +++---- pytorch_lightning/plugins/training_type/ipu.py | 2 +- pytorch_lightning/plugins/training_type/sharded.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 676985e541e3b..eb69180d464e5 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -280,10 +280,9 @@ def amp_backend(self) -> Optional[LightningEnum]: @property def precision(self) -> Union[str, int]: - """ - .. deprecated - This method is deprecated will be removed soon. - Use :`training_type_plugin.precision_plugin.precision` instead. + """This method is deprecated and will be removed soon. + + Use `training_type_plugin.precision_plugin.precision` instead. """ rank_zero_deprecation( f"`{self.__class__.__name__}.precision` was and will be removed soon" diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 26e2f381e6300..78a21980624bb 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -118,7 +118,7 @@ def setup(self) -> None: self.lightning_module.trainer._update_dataloader = self._convert_to_poptorch_loader def pre_dispatch(self) -> None: - model = LightningIPUModule(self.lightning_module, self.precision) + model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision) self.model = model # reset the backup diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 475d1e44095d8..eb4cb48534708 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -75,7 +75,7 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, Lightnin optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: - precision = self._precision or self.precision + precision = self._precision or self.precision_plugin.precision is_fp16 = precision in ("mixed", 16) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade From 6120d058d5fad82b02cbbc18fbbc8639d4c461fd Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Wed, 17 Nov 2021 12:39:14 -0800 Subject: [PATCH 12/27] simplify a reference to precision accidentally overridden Adrian's change, now add it back --- pytorch_lightning/plugins/training_type/fully_sharded.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 716e007efc592..73ea87b05835e 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -127,7 +127,7 @@ def setup_distributed(self) -> None: @contextlib.contextmanager def model_sharded_context(self) -> Generator: - precision = self.lightning_module.trainer.precision + precision = self.precision_plugin.precision def wrap_policy(*args, **kwargs): return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params) From d86e212f3959206f7658d7dbaf8f4b468605c1d4 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Wed, 17 Nov 2021 12:44:10 -0800 Subject: [PATCH 13/27] Update CHANGELOG.md add Adrian's change back --- CHANGELOG.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5335d906de56c..ae0515cf22703 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,7 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520)) -- Moved `precision_plugin` into `Training_type_plugin` and updated reference ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) +- Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) - @@ -53,7 +53,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `DistributedType` in favor of `_StrategyType` ([#10505](https://github.com/PyTorchLightning/pytorch-lightning/pull/10505)) -- Deprecated `precision_plugin` from `Accelerator` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) +- Deprecated the `precision_plugin` constructor argument from `Accelerator` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) - @@ -142,8 +142,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `reload_dataloaders_every_epoch` from `Trainer` in favour of `reload_dataloaders_every_n_epochs` ([#10481](https://github.com/PyTorchLightning/pytorch-lightning/pull/10481)) -- Removed `precision_plugin` from `Accelerator` in favor of `precision_plugin` in `training_type_plugin` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) - +- Removed the `precision_plugin` attribute from `Accelerator` in favor of its equivalent attribute `precision_plugin` in the `TrainingTypePlugin` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) ### Fixed From e4e938432311a506849b6257e7c59a8dfe375c38 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Wed, 17 Nov 2021 12:50:44 -0800 Subject: [PATCH 14/27] Update accelerator precision Add Adrian's change back --- pytorch_lightning/accelerators/accelerator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index eb69180d464e5..9ea3a60085cde 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -50,8 +50,12 @@ def __init__( ) -> None: """ Args: - precision_plugin: the plugin to handle precision-specific parts training_type_plugin: the plugin to handle different training routines + precision_plugin: the plugin to handle precision-specific parts + + Notes: + precision_plugin is deprecated and will be removed soon + User `training_type_plugin(precision_plugi)n` instead """ self.training_type_plugin = training_type_plugin From abd3ac8b080fa5d131ccbae8ecad134dd3373b97 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Nov 2021 20:51:55 +0000 Subject: [PATCH 15/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/accelerators/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 9ea3a60085cde..8f88461254941 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -52,7 +52,7 @@ def __init__( Args: training_type_plugin: the plugin to handle different training routines precision_plugin: the plugin to handle precision-specific parts - + Notes: precision_plugin is deprecated and will be removed soon User `training_type_plugin(precision_plugi)n` instead From 2711145a452ea8bcf7d30b2e693a15abd9a87300 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Wed, 17 Nov 2021 12:53:17 -0800 Subject: [PATCH 16/27] Add none check for precision plugin just to be safe --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 56a90bf1ba811..19c5d1c9b175b 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -405,7 +405,9 @@ def training_type_plugin(self) -> TrainingTypePlugin: # attach checkpoint plugin to the training type plugin if self._checkpoint_io is not None: self._training_type_plugin.checkpoint_io = self._checkpoint_io - self._training_type_plugin._precision_plugin = self.precision_plugin + precision_plugin = self.precision_plugin + if precision_plugin is not None: + self._training_type_plugin._precision_plugin = self.precision_plugin self._training_type_plugin_resolved = True return self._training_type_plugin From 173c4df8fda1d2551df2a2e93aae22a5e31208bd Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Wed, 17 Nov 2021 13:30:38 -0800 Subject: [PATCH 17/27] Update ipu.py --- pytorch_lightning/plugins/training_type/ipu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 78a21980624bb..c24008ac3ee4f 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -86,6 +86,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, ) if not _IPU_AVAILABLE: raise MisconfigurationException( From 39ee3146c75826460723840e5f3d6b891d9a7a21 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Wed, 17 Nov 2021 15:00:48 -0800 Subject: [PATCH 18/27] update precision_plugin param deprecation message --- pytorch_lightning/accelerators/accelerator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 8f88461254941..320831f59eef4 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -55,7 +55,7 @@ def __init__( Notes: precision_plugin is deprecated and will be removed soon - User `training_type_plugin(precision_plugi)n` instead + User `training_type_plugin(precision_plugin)` instead """ self.training_type_plugin = training_type_plugin @@ -67,8 +67,8 @@ def __init__( Use :`training_type_plugin(precision_plugin) instead. """ rank_zero_deprecation( - f"`{self.__class__.__name__}.precision` was and will be removed soon" - f" Use `training_type_plugin.precision_plugin.precision` instead." + f"`precision_plugin` was deprecated and will be removed soon" + f" Use `training_type_plugin(precision_plugin)` instead." ) self.training_type_plugin._precision_plugin = precision_plugin From 9cd599bedcde54ba55ebfe50232d8d27064a666e Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Wed, 17 Nov 2021 15:06:16 -0800 Subject: [PATCH 19/27] Update accelerator.py --- pytorch_lightning/accelerators/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 320831f59eef4..0941dd1119a49 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -67,7 +67,7 @@ def __init__( Use :`training_type_plugin(precision_plugin) instead. """ rank_zero_deprecation( - f"`precision_plugin` was deprecated and will be removed soon" + f"`precision_plugin` in {self.__class__.__name__} was deprecated and will be removed soon" f" Use `training_type_plugin(precision_plugin)` instead." ) self.training_type_plugin._precision_plugin = precision_plugin From 5de312088486d3ee2ea1504bb7dc2ec3044c8f8f Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Wed, 17 Nov 2021 15:39:51 -0800 Subject: [PATCH 20/27] Remove deprecated warning Tests will fail after 9940 --- pytorch_lightning/accelerators/accelerator.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 0941dd1119a49..c8b4e7c3e011b 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -66,10 +66,6 @@ def __init__( precision_plugin parameter is deprecated will be removed soon. Use :`training_type_plugin(precision_plugin) instead. """ - rank_zero_deprecation( - f"`precision_plugin` in {self.__class__.__name__} was deprecated and will be removed soon" - f" Use `training_type_plugin(precision_plugin)` instead." - ) self.training_type_plugin._precision_plugin = precision_plugin self.optimizers: List = [] From c3ca785a36ce60ca656fdfc5e0ddfe7906dc4085 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Thu, 18 Nov 2021 09:59:05 -0800 Subject: [PATCH 21/27] keep accelerator api --- pytorch_lightning/accelerators/accelerator.py | 13 ++++--------- .../trainer/connectors/accelerator_connector.py | 2 +- tests/accelerators/test_cpu.py | 8 ++++---- tests/accelerators/test_gpu.py | 4 ++-- tests/accelerators/test_tpu.py | 14 ++++++++++++-- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c8b4e7c3e011b..cad80aba0251e 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -45,9 +45,7 @@ class Accelerator: One to handle differences from the training routine and one to handle different precisions. """ - def __init__( - self, training_type_plugin: TrainingTypePlugin, precision_plugin: Optional[PrecisionPlugin] = None - ) -> None: + def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_plugin: TrainingTypePlugin) -> None: """ Args: training_type_plugin: the plugin to handle different training routines @@ -60,12 +58,9 @@ def __init__( self.training_type_plugin = training_type_plugin - if precision_plugin: - """ - .. deprecated - precision_plugin parameter is deprecated will be removed soon. - Use :`training_type_plugin(precision_plugin) instead. - """ + if precision_plugin is not None: + """precision_plugin is deprecated and will be removed soon User + `training_type_plugin(precision_plugin)` instead.""" self.training_type_plugin._precision_plugin = precision_plugin self.optimizers: List = [] diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 19c5d1c9b175b..0694cb6fbce9b 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -797,7 +797,7 @@ def select_accelerator(self) -> Accelerator: else: acc_cls = CPUAccelerator - accelerator = acc_cls(training_type_plugin=self.training_type_plugin) + accelerator = acc_cls(precision_plugin=None, training_type_plugin=self.training_type_plugin) # transfer ownership of the plugins to the accelerator self._training_type_plugin = proxy(self.training_type_plugin) diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index e728d50292d9f..41e73431495b9 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -41,8 +41,8 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: def test_restore_checkpoint_after_pre_dispatch_default(): """Assert default for restore_checkpoint_after_pre_dispatch is False.""" - plugin = SingleDevicePlugin(torch.device("cpu"), precision_plugin=PrecisionPlugin()) - accelerator = CPUAccelerator(training_type_plugin=plugin) + plugin = SingleDevicePlugin(torch.device("cpu")) + accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) assert not accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch assert not plugin.restore_checkpoint_after_pre_dispatch @@ -74,8 +74,8 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) - plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO(), precision_plugin=PrecisionPlugin()) - accelerator = CPUAccelerator(training_type_plugin=plugin) + plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO()) + accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) assert accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch diff --git a/tests/accelerators/test_gpu.py b/tests/accelerators/test_gpu.py index d24797ea20c9b..85ce0cd9f0f18 100644 --- a/tests/accelerators/test_gpu.py +++ b/tests/accelerators/test_gpu.py @@ -12,7 +12,7 @@ def test_get_torch_gpu_stats(tmpdir): """Test GPU get_device_stats with Pytorch >= 1.8.0.""" current_device = torch.device(f"cuda:{torch.cuda.current_device()}") GPUAccel = GPUAccelerator( - training_type_plugin=DataParallelPlugin(parallel_devices=[current_device], precision_plugin=PrecisionPlugin()) + training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin() ) gpu_stats = GPUAccel.get_device_stats(current_device) fields = ["allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak"] @@ -27,7 +27,7 @@ def test_get_nvidia_gpu_stats(tmpdir): """Test GPU get_device_stats with Pytorch < 1.8.0.""" current_device = torch.device(f"cuda:{torch.cuda.current_device()}") GPUAccel = GPUAccelerator( - training_type_plugin=DataParallelPlugin(parallel_devices=[current_device], precision_plugin=PrecisionPlugin()) + training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin() ) gpu_stats = GPUAccel.get_device_stats(current_device) fields = ["utilization.gpu", "memory.used", "memory.free", "utilization.memory"] diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index c83f60e0c2ec0..fc1ce413cd494 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -288,11 +288,21 @@ def forward(self, x): def test_tpu_invalid_raises(): - accelerator = TPUAccelerator(TPUSpawnPlugin(precision_plugin=object())) + accelerator = TPUAccelerator(object(), TPUSpawnPlugin()) + with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): + accelerator.setup(object()) + + accelerator = TPUAccelerator(TPUPrecisionPlugin(), DDPPlugin()) + with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"): + accelerator.setup(object()) + + +def test_tpu_invalid_raises_set_precision_with_strategy(): + accelerator = TPUAccelerator(object(), TPUSpawnPlugin(precision_plugin=object())) with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"): accelerator.setup(object()) - accelerator = TPUAccelerator(DDPPlugin(precision_plugin=TPUPrecisionPlugin())) + accelerator = TPUAccelerator(None, DDPPlugin(precision_plugin=TPUPrecisionPlugin())) with pytest.raises( ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin" ): From 9f06093ce21109f00569855cecde9d84074acdc9 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Thu, 18 Nov 2021 10:28:39 -0800 Subject: [PATCH 22/27] udpate deprecation message and docs --- docs/source/extensions/accelerators.rst | 3 ++- docs/source/extensions/plugins.rst | 3 ++- pytorch_lightning/accelerators/accelerator.py | 18 +++++++++--------- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/docs/source/extensions/accelerators.rst b/docs/source/extensions/accelerators.rst index 66561db0ccef4..11a85cb082af8 100644 --- a/docs/source/extensions/accelerators.rst +++ b/docs/source/extensions/accelerators.rst @@ -26,7 +26,8 @@ One to handle differences from the training routine and one to handle different from pytorch_lightning.plugins import NativeMixedPrecisionPlugin, DDPPlugin accelerator = GPUAccelerator( - training_type_plugin=DDPPlugin(precision_plugin=NativeMixedPrecisionPlugin(16, "cuda")), + precision_plugin=NativeMixedPrecisionPlugin(16, "cuda"), + training_type_plugin=DDPPlugin(), ) trainer = Trainer(accelerator=accelerator) diff --git a/docs/source/extensions/plugins.rst b/docs/source/extensions/plugins.rst index 25a28c6cc1c41..56e14a97502cc 100644 --- a/docs/source/extensions/plugins.rst +++ b/docs/source/extensions/plugins.rst @@ -81,7 +81,8 @@ can then be passed into the Trainer directly or via a (custom) accelerator: # fully custom accelerator and plugins accelerator = MyAccelerator( - training_type_plugin=CustomDDPPlugin(precision_plugin=CustomPrecisionPlugin()), + precision_plugin=CustomPrecisionPlugin(), + training_type_plugin=CustomDDPPlugin(), ) trainer = Trainer(accelerator=accelerator) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index cad80aba0251e..c94192eb6d22a 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -48,18 +48,17 @@ class Accelerator: def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_plugin: TrainingTypePlugin) -> None: """ Args: - training_type_plugin: the plugin to handle different training routines precision_plugin: the plugin to handle precision-specific parts - - Notes: - precision_plugin is deprecated and will be removed soon - User `training_type_plugin(precision_plugin)` instead + .. deprecated:: + The ``precision_plugin`` parameter has been deprecated and will be removed soon. + Pass the precision plugin as a parameter to the ``TrainingTypePlugin`` instead. + training_type_plugin: the plugin to handle different training routines """ self.training_type_plugin = training_type_plugin if precision_plugin is not None: - """precision_plugin is deprecated and will be removed soon User + """precision_plugin is deprecated and will be removed soon, use `training_type_plugin(precision_plugin)` instead.""" self.training_type_plugin._precision_plugin = precision_plugin @@ -275,12 +274,13 @@ def amp_backend(self) -> Optional[LightningEnum]: @property def precision(self) -> Union[str, int]: - """This method is deprecated and will be removed soon. + """The type of precision being used with this accelerator. - Use `training_type_plugin.precision_plugin.precision` instead. + .. deprecated:: The ``precision_plugin`` parameter has been deprecated and will be removed soon. Pass + the precision plugin as a parameter to the ``TrainingTypePlugin`` instead. """ rank_zero_deprecation( - f"`{self.__class__.__name__}.precision` was and will be removed soon" + f"`{self.__class__.__name__}.precision` has been deprecated and will be removed soon" f" Use `training_type_plugin.precision_plugin.precision` instead." ) return self.training_type_plugin.precision_plugin.precision From 8eccdc01c78d0bada441572ee907d138f290351c Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Thu, 18 Nov 2021 10:48:15 -0800 Subject: [PATCH 23/27] fix comments format --- pytorch_lightning/accelerators/accelerator.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c94192eb6d22a..73f0d40799e34 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -49,17 +49,21 @@ def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_pl """ Args: precision_plugin: the plugin to handle precision-specific parts + .. deprecated:: The ``precision_plugin`` parameter has been deprecated and will be removed soon. Pass the precision plugin as a parameter to the ``TrainingTypePlugin`` instead. + training_type_plugin: the plugin to handle different training routines """ self.training_type_plugin = training_type_plugin if precision_plugin is not None: - """precision_plugin is deprecated and will be removed soon, use - `training_type_plugin(precision_plugin)` instead.""" + """precision_plugin is deprecated and will be removed soon. + + Use `training_type_plugin(precision_plugin)` instead. + """ self.training_type_plugin._precision_plugin = precision_plugin self.optimizers: List = [] @@ -276,8 +280,11 @@ def amp_backend(self) -> Optional[LightningEnum]: def precision(self) -> Union[str, int]: """The type of precision being used with this accelerator. - .. deprecated:: The ``precision_plugin`` parameter has been deprecated and will be removed soon. Pass - the precision plugin as a parameter to the ``TrainingTypePlugin`` instead. + Use `training_type_plugin.precision_plugin.precision` instead. + + .. deprecated:: + The ``precision_plugin`` parameter has been deprecated and will be removed soon. + Pass the precision plugin as a parameter to the ``TrainingTypePlugin`` instead. """ rank_zero_deprecation( f"`{self.__class__.__name__}.precision` has been deprecated and will be removed soon" From 19dde8609e38a040836a7f17660893c6ff3f3a74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 18 Nov 2021 21:47:01 +0100 Subject: [PATCH 24/27] remove string comment --- pytorch_lightning/accelerators/accelerator.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 73f0d40799e34..cd4efe2a3ddf1 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -60,10 +60,6 @@ def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_pl self.training_type_plugin = training_type_plugin if precision_plugin is not None: - """precision_plugin is deprecated and will be removed soon. - - Use `training_type_plugin(precision_plugin)` instead. - """ self.training_type_plugin._precision_plugin = precision_plugin self.optimizers: List = [] From 06e33adda15f04a4a9249120953e816ac6c68fcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 18 Nov 2021 21:50:31 +0100 Subject: [PATCH 25/27] fix duplicated deprecation message --- pytorch_lightning/accelerators/accelerator.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index cd4efe2a3ddf1..52ad899ef8a12 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -276,11 +276,10 @@ def amp_backend(self) -> Optional[LightningEnum]: def precision(self) -> Union[str, int]: """The type of precision being used with this accelerator. - Use `training_type_plugin.precision_plugin.precision` instead. - .. deprecated:: - The ``precision_plugin`` parameter has been deprecated and will be removed soon. - Pass the precision plugin as a parameter to the ``TrainingTypePlugin`` instead. + This property been deprecated and will be removed soon. + Use ``training_type_plugin.precision_plugin.precision`` instead. + """ rank_zero_deprecation( f"`{self.__class__.__name__}.precision` has been deprecated and will be removed soon" From 9c6fd3e593a4446fa8a09f121dacf4540c275b2d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Nov 2021 20:52:18 +0000 Subject: [PATCH 26/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/accelerators/accelerator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 52ad899ef8a12..eb3886b209503 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -279,7 +279,6 @@ def precision(self) -> Union[str, int]: .. deprecated:: This property been deprecated and will be removed soon. Use ``training_type_plugin.precision_plugin.precision`` instead. - """ rank_zero_deprecation( f"`{self.__class__.__name__}.precision` has been deprecated and will be removed soon" From 97ba08bb7f50b873d66a4bfcc5adf80703e03421 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Thu, 18 Nov 2021 15:33:43 -0800 Subject: [PATCH 27/27] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 0694cb6fbce9b..e5df9c3b84898 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -407,7 +407,7 @@ def training_type_plugin(self) -> TrainingTypePlugin: self._training_type_plugin.checkpoint_io = self._checkpoint_io precision_plugin = self.precision_plugin if precision_plugin is not None: - self._training_type_plugin._precision_plugin = self.precision_plugin + self._training_type_plugin._precision_plugin = precision_plugin self._training_type_plugin_resolved = True return self._training_type_plugin