From 23e8b59ae7f586cc7e4eff3e73538ce878b991e6 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 13 Oct 2021 20:15:13 +0530 Subject: [PATCH] Add `configure_gradient_clipping` hook in `LightningModule` (#9584) * init hook * docs * dep train args * update tests * doc * doc * .gitignore * not dep * add trainer args * add & update tests * fix tests * pre-commit * docs * add docs * add exception * code review * deepspeed * update tests * not * try fix * Apply suggestions from code review * update deepspeed * disable some tests * disable some tests * enable all tests --- .gitignore | 1 + docs/source/common/lightning_module.rst | 7 ++ docs/source/common/optimizers.rst | 46 ++++++++- pytorch_lightning/core/lightning.py | 99 ++++++++++++++++++- .../loops/optimization/optimizer_loop.py | 14 ++- .../plugins/training_type/deepspeed.py | 16 ++- .../trainer/configuration_validator.py | 2 +- .../connectors/training_trick_connector.py | 16 ++- pytorch_lightning/trainer/trainer.py | 11 ++- tests/core/test_lightning_module.py | 71 +++++++++++++ tests/models/test_hooks.py | 19 ++++ tests/plugins/test_deepspeed_plugin.py | 35 +++++++ 12 files changed, 317 insertions(+), 20 deletions(-) diff --git a/.gitignore b/.gitignore index 6ad0671fb3306..7b1247433e7b4 100644 --- a/.gitignore +++ b/.gitignore @@ -156,3 +156,4 @@ cifar-10-batches-py *.pt # ctags tags +.tags diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index ba2694286739e..6ee0ebe7b1110 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1195,6 +1195,7 @@ for more information. on_after_backward() on_before_optimizer_step() + configure_gradient_clipping() optimizer_step() on_train_batch_end() @@ -1452,6 +1453,12 @@ on_before_optimizer_step .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_optimizer_step :noindex: +configure_gradient_clipping +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping + :noindex: + optimizer_step ~~~~~~~~~~~~~~ diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index 39a583d9c94d8..0405b9a4365af 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -69,7 +69,7 @@ Here is a minimal example of manual optimization. Gradient accumulation --------------------- You can accumulate gradients over batches similarly to -:attr:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches` of automatic optimization. +:attr:`~pytorch_lightning.trainer.trainer.Trainer.accumulate_grad_batches` of automatic optimization. To perform gradient accumulation with one optimizer, you can do as such. .. testcode:: python @@ -516,3 +516,47 @@ to perform a step, Lightning won't be able to support accelerators and precision ): optimizer = optimizer.optimizer optimizer.step(closure=optimizer_closure) + +----- + +Configure gradient clipping +--------------------------- +To configure custom gradient clipping, consider overriding +the :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping` method. +Attributes :attr:`~pytorch_lightning.trainer.trainer.Trainer.gradient_clip_val` and +:attr:`~pytorch_lightning.trainer.trainer.Trainer.gradient_clip_algorithm` will be passed in the respective +arguments here and Lightning will handle gradient clipping for you. In case you want to set +different values for your arguments of your choice and let Lightning handle the gradient clipping, you can +use the inbuilt :meth:`~pytorch_lightning.core.lightning.LightningModule.clip_gradients` method and pass +the arguments along with your optimizer. + +.. note:: + Make sure to not override :meth:`~pytorch_lightning.core.lightning.LightningModule.clip_gradients` + method. If you want to customize gradient clipping, consider using + :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping` method. + +For example, here we will apply gradient clipping only to the gradients associated with optimizer A. + +.. testcode:: python + + def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): + if optimizer_idx == 0: + # Lightning will handle the gradient clipping + self.clip_gradients( + optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm + ) + +Here we configure gradient clipping differently for optimizer B. + +.. testcode:: python + + def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): + if optimizer_idx == 0: + # Lightning will handle the gradient clipping + self.clip_gradients( + optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm + ) + elif optimizer_idx == 1: + self.clip_gradients( + optimizer, gradient_clip_val=gradient_clip_val * 2, gradient_clip_algorithm=gradient_clip_algorithm + ) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 08c3309762944..995d4f7ace3cf 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -36,7 +36,12 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ModelIO from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator -from pytorch_lightning.utilities import _TORCH_SHARDED_TENSOR_AVAILABLE, rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities import ( + _TORCH_SHARDED_TENSOR_AVAILABLE, + GradClipAlgorithmType, + rank_zero_deprecation, + rank_zero_warn, +) from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp @@ -1460,7 +1465,7 @@ def untoggle_optimizer(self, optimizer_idx: int): optimizer_idx: Current optimizer idx in the training loop Note: - Only called when using multiple optimizers + Only called when using multiple_optimizers """ for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)): if optimizer_idx != opt_idx: @@ -1471,6 +1476,96 @@ def untoggle_optimizer(self, optimizer_idx: int): # save memory self._param_requires_grad_state = {} + def clip_gradients( + self, + optimizer: Optimizer, + gradient_clip_val: Optional[Union[int, float]] = None, + gradient_clip_algorithm: Optional[Union[str, GradClipAlgorithmType]] = None, + ): + """Handles gradient clipping internally. + + Note: + Do not override this method. If you want to customize gradient clipping, consider + using :meth:`configure_gradient_clipping` method. + + Args: + optimizer: Current optimizer being used. + gradient_clip_val: The value at which to clip gradients. + gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"`` + to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. + """ + if gradient_clip_val is None: + gradient_clip_val = self.trainer.gradient_clip_val or 0.0 + elif self.trainer.gradient_clip_val is not None and self.trainer.gradient_clip_val != gradient_clip_val: + raise MisconfigurationException( + "You have set `Trainer(gradient_clip_val)` and have passed" + " `gradient_clip_val` inside `clip_gradients`. Please use only one of them." + ) + + if gradient_clip_algorithm is None: + gradient_clip_algorithm = self.trainer.gradient_clip_algorithm or "norm" + else: + gradient_clip_algorithm = gradient_clip_algorithm.lower() + if ( + self.trainer.gradient_clip_algorithm is not None + and self.trainer.gradient_clip_algorithm != gradient_clip_algorithm + ): + raise MisconfigurationException( + "You have set `Trainer(gradient_clip_algorithm)` and have passed" + " `gradient_clip_algorithm` inside `clip_gradients`. Please use only one of them." + ) + + if not isinstance(gradient_clip_val, (int, float)): + raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.") + + if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()): + raise MisconfigurationException( + f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid." + f" Allowed algorithms: {GradClipAlgorithmType.supported_types()}." + ) + + gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm) + self.trainer.accelerator.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm) + + def configure_gradient_clipping( + self, + optimizer: Optimizer, + optimizer_idx: int, + gradient_clip_val: Optional[Union[int, float]] = None, + gradient_clip_algorithm: Optional[str] = None, + ): + """Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`. + + Note: + This hook won't be called when using deepspeed since it handles gradient clipping internally. + Consider setting ``gradient_clip_val`` and ``gradient_clip_algorithm`` inside ``Trainer``." + + Args: + optimizer: Current optimizer being used. + optimizer_idx: Index of the current optimizer being used. + gradient_clip_val: The value at which to clip gradients. By default value passed in Trainer + will be available here. + gradient_clip_algorithm: The gradient clipping algorithm to use. By default value + passed in Trainer will be available here. + + Example:: + + # Perform gradient clipping on gradients associated with discriminator (optimizer_idx=1) in GAN + def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): + if optimizer_idx == 1: + # Lightning will handle the gradient clipping + self.clip_gradients( + optimizer, + gradient_clip_val=gradient_clip_val, + gradient_clip_algorithm=gradient_clip_algorithm + ) + else: + # implement your own custom logic to clip gradients for generator (optimizer_idx=0) + """ + self.clip_gradients( + optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm + ) + def optimizer_step( self, epoch: int = None, diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index f0ab8b915b29f..3a73795014c80 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -240,7 +240,7 @@ def _backward( if not self.trainer.fit_loop._should_accumulate(): # track gradients - grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer) + grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer, opt_idx=opt_idx) if grad_norm_dict: self.trainer.lightning_module._current_fx_name = "on_after_backward" self.trainer.lightning_module.log_grad_norm(grad_norm_dict) @@ -470,7 +470,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos return result - def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, float]: + def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer, opt_idx: int) -> Dict[str, float]: """Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer. Args: @@ -484,7 +484,11 @@ def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, fl grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm) # clip gradients - self.trainer.accelerator.clip_gradients( - optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm - ) + if not self.trainer.accelerator_connector.use_deepspeed: + self.trainer.lightning_module.configure_gradient_clipping( + optimizer, + opt_idx, + gradient_clip_val=self.trainer.gradient_clip_val, + gradient_clip_algorithm=self.trainer.gradient_clip_algorithm, + ) return grad_norm_dict diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index a1d9e346f1217..e2e8c316f48d1 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -34,8 +34,10 @@ from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only +from pytorch_lightning.utilities.enums import GradClipAlgorithmType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE +from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache @@ -376,6 +378,18 @@ def pre_dispatch(self): self.barrier() def init_deepspeed(self): + # check that `configure_gradient_clipping` hook isn't overriden since deepspeed handles + # gradient clipping internally + if is_overridden("configure_gradient_clipping", self.lightning_module): + rank_zero_warn( + "Since deepspeed handles gradient clipping internally, this hook will" + " be ignored. Consider setting `gradient_clip_val` and `gradient_clip_algorithm`" + " inside `Trainer`." + ) + + if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE: + raise MisconfigurationException("Deepspeed does not support clipping gradients by value.") + accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler if accumulation_scheduler.epochs != [0]: @@ -569,7 +583,7 @@ def _format_batch_size_and_grad_accum_config(self): batch_size = self._auto_select_batch_size() self.config["train_micro_batch_size_per_gpu"] = batch_size if "gradient_clipping" not in self.config: - self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val + self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0 def _auto_select_batch_size(self): # train_micro_batch_size_per_gpu is used for throughput logging purposes diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 3da05d69c1ff2..58260e67b77f6 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -201,7 +201,7 @@ def __verify_dp_batch_transfer_support(self, model: "pl.LightningModule") -> Non def __verify_manual_optimization_support(self, model: "pl.LightningModule") -> None: if model.automatic_optimization: return - if self.trainer.gradient_clip_val > 0: + if self.trainer.gradient_clip_val is not None and self.trainer.gradient_clip_val > 0: raise MisconfigurationException( "Automatic gradient clipping is not supported for manual optimization." f" Remove `Trainer(gradient_clip_val={self.trainer.gradient_clip_val})`" diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index ffa11ef1985a8..6cf17f7ac3e6e 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -23,8 +23,8 @@ def __init__(self, trainer): def on_trainer_init( self, - gradient_clip_val: Union[int, float], - gradient_clip_algorithm: str, + gradient_clip_val: Optional[Union[int, float]], + gradient_clip_algorithm: Optional[str], track_grad_norm: Union[int, float, str], terminate_on_nan: Optional[bool], ): @@ -37,10 +37,12 @@ def on_trainer_init( raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.") # gradient clipping - if not isinstance(gradient_clip_val, (int, float)): + if gradient_clip_val is not None and not isinstance(gradient_clip_val, (int, float)): raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.") - if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()): + if gradient_clip_algorithm is not None and not GradClipAlgorithmType.supported_type( + gradient_clip_algorithm.lower() + ): raise MisconfigurationException( f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid. " f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}." @@ -54,5 +56,9 @@ def on_trainer_init( self.trainer._terminate_on_nan = terminate_on_nan self.trainer.gradient_clip_val = gradient_clip_val - self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm.lower()) + self.trainer.gradient_clip_algorithm = ( + GradClipAlgorithmType(gradient_clip_algorithm.lower()) + if gradient_clip_algorithm is not None + else gradient_clip_algorithm + ) self.trainer.track_grad_norm = float(track_grad_norm) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6699f3554b80c..01acae35fd46c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -124,8 +124,8 @@ def __init__( enable_checkpointing: bool = True, callbacks: Optional[Union[List[Callback], Callback]] = None, default_root_dir: Optional[str] = None, - gradient_clip_val: Union[int, float] = 0.0, - gradient_clip_algorithm: str = "norm", + gradient_clip_val: Optional[Union[int, float]] = None, + gradient_clip_algorithm: Optional[str] = None, process_position: int = 0, num_nodes: int = 1, num_processes: int = 1, @@ -254,11 +254,12 @@ def __init__( gpus: Number of GPUs to train on (int) or which GPUs to train on (list or str) applied per node - gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=0`` disables gradient - clipping. + gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables + gradient clipping. gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"`` - for clip_by_value, and ``gradient_clip_algorithm="norm"`` for clip_by_norm. + to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will + be set to ``"norm"``. limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches). diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 8b787e0f57fcb..692044d91b894 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -22,6 +22,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities import _TORCH_SHARDED_TENSOR_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -335,3 +336,73 @@ def test_sharded_tensor_state_dict(tmpdir, single_process_pg): assert torch.allclose( m_1.sharded_tensor.local_shards()[0].tensor, m_0.sharded_tensor.local_shards()[0].tensor ), "Expect the shards to be same after `m_1` loading `m_0`'s state dict" + + +def test_lightning_module_configure_gradient_clipping(tmpdir): + """Test custom gradient clipping inside `configure_gradient_clipping` hook.""" + + class TestModel(BoringModel): + + has_validated_gradients = False + custom_gradient_clip_val = 1e-2 + + def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): + assert gradient_clip_val == self.trainer.gradient_clip_val + assert gradient_clip_algorithm == self.trainer.gradient_clip_algorithm + + for pg in optimizer.param_groups: + for p in pg["params"]: + p.grad[p.grad > self.custom_gradient_clip_val] = self.custom_gradient_clip_val + p.grad[p.grad <= 0] = 0 + + def on_before_optimizer_step(self, optimizer, optimizer_idx): + for pg in optimizer.param_groups: + for p in pg["params"]: + if p.grad is not None and p.grad.abs().sum() > 0: + self.has_validated_gradients = True + assert p.grad.min() >= 0 + assert p.grad.max() <= self.custom_gradient_clip_val + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0, gradient_clip_val=1e-4 + ) + trainer.fit(model) + assert model.has_validated_gradients + + +def test_lightning_module_configure_gradient_clipping_different_argument_values(tmpdir): + """Test that setting gradient clipping arguments in `Trainer` and cusotmizing gradient clipping inside + `configure_gradient_clipping` with different values raises an exception.""" + + class TestModel(BoringModel): + custom_gradient_clip_val = 1e-2 + + def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): + self.clip_gradients(optimizer, gradient_clip_val=self.custom_gradient_clip_val) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0, gradient_clip_val=1e-4 + ) + with pytest.raises(MisconfigurationException, match=r".*have set `Trainer\(gradient_clip_val\)` and have passed.*"): + trainer.fit(model) + + class TestModel(BoringModel): + custom_gradient_clip_algorithm = "value" + + def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): + self.clip_gradients(optimizer, gradient_clip_algorithm=self.custom_gradient_clip_algorithm) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=0, + gradient_clip_algorithm="norm", + ) + with pytest.raises( + MisconfigurationException, match=r".*have set `Trainer\(gradient_clip_algorithm\)` and have passed.*" + ): + trainer.fit(model) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 1575ea1daaec3..565ca001a9e52 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -281,6 +281,24 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)), dict(name="on_before_optimizer_step", args=(ANY, 0)), ] + + # deepspeed handles gradient clipping internally + configure_gradient_clipping = ( + [] + if using_deepspeed + else [ + dict( + name="clip_gradients", + args=(ANY,), + kwargs=dict(gradient_clip_val=None, gradient_clip_algorithm=None), + ), + dict( + name="configure_gradient_clipping", + args=(ANY, 0), + kwargs=dict(gradient_clip_val=None, gradient_clip_algorithm=None), + ), + ] + ) for i in range(batches): out.extend( [ @@ -305,6 +323,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre *([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []), dict(name="Callback.on_after_backward", args=(trainer, model)), dict(name="on_after_backward"), + *configure_gradient_clipping, *(on_before_optimizer_step if using_plugin else []), dict( name="optimizer_step", diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index f3b4733fdc803..53b7bdbf7f0f0 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -962,6 +962,41 @@ def configure_optimizers(self): assert mock_step.call_count == 1 + (max_epoch * limit_train_batches) +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_configure_gradient_clipping(tmpdir): + """Test to ensure that a warning is raised when `LightningModule.configure_gradient_clipping` is overridden in + case of deepspeed.""" + + class TestModel(BoringModel): + def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): + if optimizer_idx == 0: + self.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + gpus=1, + plugins="deepspeed", + fast_dev_run=True, + ) + with pytest.warns(UserWarning, match="handles gradient clipping internally"): + trainer.fit(model) + + +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_gradient_clip_by_value(tmpdir): + """Test to ensure that an exception is raised when using `gradient_clip_algorithm='value'`.""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + gpus=1, + plugins="deepspeed", + gradient_clip_algorithm="value", + ) + with pytest.raises(MisconfigurationException, match="does not support clipping gradients by value"): + trainer.fit(model) + + @RunIf(min_gpus=1, deepspeed=True, special=True) def test_different_accumulate_grad_batches_fails(tmpdir): model = BoringModel()