Skip to content

Commit

Permalink
Add configure_gradient_clipping hook in LightningModule (#9584)
Browse files Browse the repository at this point in the history
* init hook

* docs

* dep train args

* update tests

* doc

* doc

* .gitignore

* not dep

* add trainer args

* add & update tests

* fix tests

* pre-commit

* docs

* add docs

* add exception

* code review

* deepspeed

* update tests

* not

* try fix

* Apply suggestions from code review

* update deepspeed

* disable some tests

* disable some tests

* enable all tests
  • Loading branch information
rohitgr7 authored Oct 13, 2021
1 parent 05b15e6 commit 23e8b59
Show file tree
Hide file tree
Showing 12 changed files with 317 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,4 @@ cifar-10-batches-py
*.pt
# ctags
tags
.tags
7 changes: 7 additions & 0 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,7 @@ for more information.
on_after_backward()
on_before_optimizer_step()
configure_gradient_clipping()
optimizer_step()
on_train_batch_end()
Expand Down Expand Up @@ -1452,6 +1453,12 @@ on_before_optimizer_step
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_optimizer_step
:noindex:

configure_gradient_clipping
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping
:noindex:

optimizer_step
~~~~~~~~~~~~~~

Expand Down
46 changes: 45 additions & 1 deletion docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Here is a minimal example of manual optimization.
Gradient accumulation
---------------------
You can accumulate gradients over batches similarly to
:attr:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches` of automatic optimization.
:attr:`~pytorch_lightning.trainer.trainer.Trainer.accumulate_grad_batches` of automatic optimization.
To perform gradient accumulation with one optimizer, you can do as such.

.. testcode:: python
Expand Down Expand Up @@ -516,3 +516,47 @@ to perform a step, Lightning won't be able to support accelerators and precision
):
optimizer = optimizer.optimizer
optimizer.step(closure=optimizer_closure)

-----

Configure gradient clipping
---------------------------
To configure custom gradient clipping, consider overriding
the :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping` method.
Attributes :attr:`~pytorch_lightning.trainer.trainer.Trainer.gradient_clip_val` and
:attr:`~pytorch_lightning.trainer.trainer.Trainer.gradient_clip_algorithm` will be passed in the respective
arguments here and Lightning will handle gradient clipping for you. In case you want to set
different values for your arguments of your choice and let Lightning handle the gradient clipping, you can
use the inbuilt :meth:`~pytorch_lightning.core.lightning.LightningModule.clip_gradients` method and pass
the arguments along with your optimizer.

.. note::
Make sure to not override :meth:`~pytorch_lightning.core.lightning.LightningModule.clip_gradients`
method. If you want to customize gradient clipping, consider using
:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping` method.

For example, here we will apply gradient clipping only to the gradients associated with optimizer A.

.. testcode:: python

def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
if optimizer_idx == 0:
# Lightning will handle the gradient clipping
self.clip_gradients(
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)

Here we configure gradient clipping differently for optimizer B.

.. testcode:: python

def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
if optimizer_idx == 0:
# Lightning will handle the gradient clipping
self.clip_gradients(
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)
elif optimizer_idx == 1:
self.clip_gradients(
optimizer, gradient_clip_val=gradient_clip_val * 2, gradient_clip_algorithm=gradient_clip_algorithm
)
99 changes: 97 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.utilities import _TORCH_SHARDED_TENSOR_AVAILABLE, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities import (
_TORCH_SHARDED_TENSOR_AVAILABLE,
GradClipAlgorithmType,
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
Expand Down Expand Up @@ -1460,7 +1465,7 @@ def untoggle_optimizer(self, optimizer_idx: int):
optimizer_idx: Current optimizer idx in the training loop
Note:
Only called when using multiple optimizers
Only called when using multiple_optimizers
"""
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
if optimizer_idx != opt_idx:
Expand All @@ -1471,6 +1476,96 @@ def untoggle_optimizer(self, optimizer_idx: int):
# save memory
self._param_requires_grad_state = {}

def clip_gradients(
self,
optimizer: Optimizer,
gradient_clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[Union[str, GradClipAlgorithmType]] = None,
):
"""Handles gradient clipping internally.
Note:
Do not override this method. If you want to customize gradient clipping, consider
using :meth:`configure_gradient_clipping` method.
Args:
optimizer: Current optimizer being used.
gradient_clip_val: The value at which to clip gradients.
gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm.
"""
if gradient_clip_val is None:
gradient_clip_val = self.trainer.gradient_clip_val or 0.0
elif self.trainer.gradient_clip_val is not None and self.trainer.gradient_clip_val != gradient_clip_val:
raise MisconfigurationException(
"You have set `Trainer(gradient_clip_val)` and have passed"
" `gradient_clip_val` inside `clip_gradients`. Please use only one of them."
)

if gradient_clip_algorithm is None:
gradient_clip_algorithm = self.trainer.gradient_clip_algorithm or "norm"
else:
gradient_clip_algorithm = gradient_clip_algorithm.lower()
if (
self.trainer.gradient_clip_algorithm is not None
and self.trainer.gradient_clip_algorithm != gradient_clip_algorithm
):
raise MisconfigurationException(
"You have set `Trainer(gradient_clip_algorithm)` and have passed"
" `gradient_clip_algorithm` inside `clip_gradients`. Please use only one of them."
)

if not isinstance(gradient_clip_val, (int, float)):
raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")

if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):
raise MisconfigurationException(
f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid."
f" Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
)

gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
self.trainer.accelerator.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm)

def configure_gradient_clipping(
self,
optimizer: Optimizer,
optimizer_idx: int,
gradient_clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[str] = None,
):
"""Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`.
Note:
This hook won't be called when using deepspeed since it handles gradient clipping internally.
Consider setting ``gradient_clip_val`` and ``gradient_clip_algorithm`` inside ``Trainer``."
Args:
optimizer: Current optimizer being used.
optimizer_idx: Index of the current optimizer being used.
gradient_clip_val: The value at which to clip gradients. By default value passed in Trainer
will be available here.
gradient_clip_algorithm: The gradient clipping algorithm to use. By default value
passed in Trainer will be available here.
Example::
# Perform gradient clipping on gradients associated with discriminator (optimizer_idx=1) in GAN
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
if optimizer_idx == 1:
# Lightning will handle the gradient clipping
self.clip_gradients(
optimizer,
gradient_clip_val=gradient_clip_val,
gradient_clip_algorithm=gradient_clip_algorithm
)
else:
# implement your own custom logic to clip gradients for generator (optimizer_idx=0)
"""
self.clip_gradients(
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)

def optimizer_step(
self,
epoch: int = None,
Expand Down
14 changes: 9 additions & 5 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _backward(

if not self.trainer.fit_loop._should_accumulate():
# track gradients
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer, opt_idx=opt_idx)
if grad_norm_dict:
self.trainer.lightning_module._current_fx_name = "on_after_backward"
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)
Expand Down Expand Up @@ -470,7 +470,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos

return result

def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, float]:
def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer, opt_idx: int) -> Dict[str, float]:
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.
Args:
Expand All @@ -484,7 +484,11 @@ def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, fl
grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm)

# clip gradients
self.trainer.accelerator.clip_gradients(
optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm
)
if not self.trainer.accelerator_connector.use_deepspeed:
self.trainer.lightning_module.configure_gradient_clipping(
optimizer,
opt_idx,
gradient_clip_val=self.trainer.gradient_clip_val,
gradient_clip_algorithm=self.trainer.gradient_clip_algorithm,
)
return grad_norm_dict
16 changes: 15 additions & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.enums import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache
Expand Down Expand Up @@ -376,6 +378,18 @@ def pre_dispatch(self):
self.barrier()

def init_deepspeed(self):
# check that `configure_gradient_clipping` hook isn't overriden since deepspeed handles
# gradient clipping internally
if is_overridden("configure_gradient_clipping", self.lightning_module):
rank_zero_warn(
"Since deepspeed handles gradient clipping internally, this hook will"
" be ignored. Consider setting `gradient_clip_val` and `gradient_clip_algorithm`"
" inside `Trainer`."
)

if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
raise MisconfigurationException("Deepspeed does not support clipping gradients by value.")

accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler

if accumulation_scheduler.epochs != [0]:
Expand Down Expand Up @@ -569,7 +583,7 @@ def _format_batch_size_and_grad_accum_config(self):
batch_size = self._auto_select_batch_size()
self.config["train_micro_batch_size_per_gpu"] = batch_size
if "gradient_clipping" not in self.config:
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0

def _auto_select_batch_size(self):
# train_micro_batch_size_per_gpu is used for throughput logging purposes
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __verify_dp_batch_transfer_support(self, model: "pl.LightningModule") -> Non
def __verify_manual_optimization_support(self, model: "pl.LightningModule") -> None:
if model.automatic_optimization:
return
if self.trainer.gradient_clip_val > 0:
if self.trainer.gradient_clip_val is not None and self.trainer.gradient_clip_val > 0:
raise MisconfigurationException(
"Automatic gradient clipping is not supported for manual optimization."
f" Remove `Trainer(gradient_clip_val={self.trainer.gradient_clip_val})`"
Expand Down
16 changes: 11 additions & 5 deletions pytorch_lightning/trainer/connectors/training_trick_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def __init__(self, trainer):

def on_trainer_init(
self,
gradient_clip_val: Union[int, float],
gradient_clip_algorithm: str,
gradient_clip_val: Optional[Union[int, float]],
gradient_clip_algorithm: Optional[str],
track_grad_norm: Union[int, float, str],
terminate_on_nan: Optional[bool],
):
Expand All @@ -37,10 +37,12 @@ def on_trainer_init(
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")

# gradient clipping
if not isinstance(gradient_clip_val, (int, float)):
if gradient_clip_val is not None and not isinstance(gradient_clip_val, (int, float)):
raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")

if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):
if gradient_clip_algorithm is not None and not GradClipAlgorithmType.supported_type(
gradient_clip_algorithm.lower()
):
raise MisconfigurationException(
f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid. "
f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
Expand All @@ -54,5 +56,9 @@ def on_trainer_init(

self.trainer._terminate_on_nan = terminate_on_nan
self.trainer.gradient_clip_val = gradient_clip_val
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm.lower())
self.trainer.gradient_clip_algorithm = (
GradClipAlgorithmType(gradient_clip_algorithm.lower())
if gradient_clip_algorithm is not None
else gradient_clip_algorithm
)
self.trainer.track_grad_norm = float(track_grad_norm)
11 changes: 6 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def __init__(
enable_checkpointing: bool = True,
callbacks: Optional[Union[List[Callback], Callback]] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: str = "norm",
gradient_clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[str] = None,
process_position: int = 0,
num_nodes: int = 1,
num_processes: int = 1,
Expand Down Expand Up @@ -254,11 +254,12 @@ def __init__(
gpus: Number of GPUs to train on (int) or which GPUs to train on (list or str) applied per node
gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=0`` disables gradient
clipping.
gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
gradient clipping.
gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
for clip_by_value, and ``gradient_clip_algorithm="norm"`` for clip_by_norm.
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will
be set to ``"norm"``.
limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches).
Expand Down
Loading

0 comments on commit 23e8b59

Please sign in to comment.