Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add configure_gradient_clipping hook in LightningModule #9584

Merged
merged 26 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,4 @@ cifar-10-batches-py
*.pt
# ctags
tags
.tags
7 changes: 7 additions & 0 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,7 @@ for more information.
on_after_backward()

on_before_optimizer_step()
configure_gradient_clipping()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
optimizer_step()

on_train_batch_end()
Expand Down Expand Up @@ -1452,6 +1453,12 @@ on_before_optimizer_step
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_optimizer_step
:noindex:

configure_gradient_clipping
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping
:noindex:

optimizer_step
~~~~~~~~~~~~~~

Expand Down
46 changes: 45 additions & 1 deletion docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Here is a minimal example of manual optimization.
Gradient accumulation
---------------------
You can accumulate gradients over batches similarly to
:attr:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches` of automatic optimization.
:attr:`~pytorch_lightning.trainer.trainer.Trainer.accumulate_grad_batches` of automatic optimization.
To perform gradient accumulation with one optimizer, you can do as such.

.. testcode:: python
Expand Down Expand Up @@ -516,3 +516,47 @@ to perform a step, Lightning won't be able to support accelerators and precision
):
optimizer = optimizer.optimizer
optimizer.step(closure=optimizer_closure)

-----

Configure gradient clipping
---------------------------
To configure custom gradient clipping, consider overriding
the :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping` method.
Attributes :attr:`~pytorch_lightning.trainer.trainer.Trainer.gradient_clip_val` and
:attr:`~pytorch_lightning.trainer.trainer.Trainer.gradient_clip_algorithm` will be passed in the respective
arguments here and Lightning will handle gradient clipping for you. In case you want to set
different values for your arguments of your choice and let Lightning handle the gradient clipping, you can
use the inbuilt :meth:`~pytorch_lightning.core.lightning.LightningModule.clip_gradients` method and pass
the arguments along with your optimizer.

.. note::
Make sure to not override :meth:`~pytorch_lightning.core.lightning.LightningModule.clip_gradients`
method. If you want to customize gradient clipping, consider using
:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping` method.

For example, here we will apply gradient clipping only to the gradients associated with optimizer A.

.. testcode:: python

def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
if optimizer_idx == 0:
# Lightning will handle the gradient clipping
self.clip_gradients(
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)

Here we configure gradient clipping differently for optimizer B.

.. testcode:: python

def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
if optimizer_idx == 0:
# Lightning will handle the gradient clipping
self.clip_gradients(
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)
elif optimizer_idx == 1:
self.clip_gradients(
optimizer, gradient_clip_val=gradient_clip_val * 2, gradient_clip_algorithm=gradient_clip_algorithm
)
99 changes: 97 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.utilities import _TORCH_SHARDED_TENSOR_AVAILABLE, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities import (
_TORCH_SHARDED_TENSOR_AVAILABLE,
GradClipAlgorithmType,
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
Expand Down Expand Up @@ -1460,7 +1465,7 @@ def untoggle_optimizer(self, optimizer_idx: int):
optimizer_idx: Current optimizer idx in the training loop

Note:
Only called when using multiple optimizers
Only called when using multiple_optimizers
"""
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
if optimizer_idx != opt_idx:
Expand All @@ -1471,6 +1476,96 @@ def untoggle_optimizer(self, optimizer_idx: int):
# save memory
self._param_requires_grad_state = {}

def clip_gradients(
self,
optimizer: Optimizer,
gradient_clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[Union[str, GradClipAlgorithmType]] = None,
):
"""Handles gradient clipping internally.

Note:
Do not override this method. If you want to customize gradient clipping, consider
using :meth:`configure_gradient_clipping` method.

Args:
optimizer: Current optimizer being used.
gradient_clip_val: The value at which to clip gradients.
gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm.
"""
if gradient_clip_val is None:
gradient_clip_val = self.trainer.gradient_clip_val or 0.0
elif self.trainer.gradient_clip_val is not None and self.trainer.gradient_clip_val != gradient_clip_val:
raise MisconfigurationException(
"You have set `Trainer(gradient_clip_val)` and have passed"
" `gradient_clip_val` inside `clip_gradients`. Please use only one of them."
)

if gradient_clip_algorithm is None:
gradient_clip_algorithm = self.trainer.gradient_clip_algorithm or "norm"
else:
gradient_clip_algorithm = gradient_clip_algorithm.lower()
if (
self.trainer.gradient_clip_algorithm is not None
and self.trainer.gradient_clip_algorithm != gradient_clip_algorithm
):
raise MisconfigurationException(
"You have set `Trainer(gradient_clip_algorithm)` and have passed"
" `gradient_clip_algorithm` inside `clip_gradients`. Please use only one of them."
)

if not isinstance(gradient_clip_val, (int, float)):
raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")

if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):
raise MisconfigurationException(
f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid."
f" Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
)

gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
self.trainer.accelerator.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm)

def configure_gradient_clipping(
self,
optimizer: Optimizer,
optimizer_idx: int,
gradient_clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[str] = None,
):
"""Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`.

Note:
This hook won't be called when using deepspeed since it handles gradient clipping internally.
Consider setting ``gradient_clip_val`` and ``gradient_clip_algorithm`` inside ``Trainer``."

Args:
optimizer: Current optimizer being used.
optimizer_idx: Index of the current optimizer being used.
gradient_clip_val: The value at which to clip gradients. By default value passed in Trainer
will be available here.
gradient_clip_algorithm: The gradient clipping algorithm to use. By default value
passed in Trainer will be available here.

Example::

# Perform gradient clipping on gradients associated with discriminator (optimizer_idx=1) in GAN
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
if optimizer_idx == 1:
# Lightning will handle the gradient clipping
self.clip_gradients(
optimizer,
gradient_clip_val=gradient_clip_val,
gradient_clip_algorithm=gradient_clip_algorithm
)
else:
# implement your own custom logic to clip gradients for generator (optimizer_idx=0)
"""
self.clip_gradients(
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)

def optimizer_step(
self,
epoch: int = None,
Expand Down
14 changes: 9 additions & 5 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _backward(

if not self.trainer.fit_loop._should_accumulate():
# track gradients
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer, opt_idx=opt_idx)
if grad_norm_dict:
self.trainer.lightning_module._current_fx_name = "on_after_backward"
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)
Expand Down Expand Up @@ -470,7 +470,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos

return result

def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, float]:
def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer, opt_idx: int) -> Dict[str, float]:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.

Args:
Expand All @@ -484,7 +484,11 @@ def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, fl
grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm)

# clip gradients
self.trainer.accelerator.clip_gradients(
optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm
)
if not self.trainer.accelerator_connector.use_deepspeed:
self.trainer.lightning_module.configure_gradient_clipping(
optimizer,
opt_idx,
gradient_clip_val=self.trainer.gradient_clip_val,
gradient_clip_algorithm=self.trainer.gradient_clip_algorithm,
)
return grad_norm_dict
16 changes: 15 additions & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.enums import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache
Expand Down Expand Up @@ -376,6 +378,18 @@ def pre_dispatch(self):
self.barrier()

def init_deepspeed(self):
# check that `configure_gradient_clipping` hook isn't overriden since deepspeed handles
# gradient clipping internally
if is_overridden("configure_gradient_clipping", self.lightning_module):
rank_zero_warn(
"Since deepspeed handles gradient clipping internally, this hook will"
" be ignored. Consider setting `gradient_clip_val` and `gradient_clip_algorithm`"
" inside `Trainer`."
)

if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
raise MisconfigurationException("Deepspeed does not support clipping gradients by value.")

accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler

if accumulation_scheduler.epochs != [0]:
Expand Down Expand Up @@ -569,7 +583,7 @@ def _format_batch_size_and_grad_accum_config(self):
batch_size = self._auto_select_batch_size()
self.config["train_micro_batch_size_per_gpu"] = batch_size
if "gradient_clipping" not in self.config:
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0

def _auto_select_batch_size(self):
# train_micro_batch_size_per_gpu is used for throughput logging purposes
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __verify_dp_batch_transfer_support(self, model: "pl.LightningModule") -> Non
def __verify_manual_optimization_support(self, model: "pl.LightningModule") -> None:
if model.automatic_optimization:
return
if self.trainer.gradient_clip_val > 0:
if self.trainer.gradient_clip_val is not None and self.trainer.gradient_clip_val > 0:
raise MisconfigurationException(
"Automatic gradient clipping is not supported for manual optimization."
f" Remove `Trainer(gradient_clip_val={self.trainer.gradient_clip_val})`"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def __init__(self, trainer):

def on_trainer_init(
self,
gradient_clip_val: Union[int, float],
gradient_clip_algorithm: str,
gradient_clip_val: Optional[Union[int, float]],
gradient_clip_algorithm: Optional[str],
track_grad_norm: Union[int, float, str],
terminate_on_nan: Optional[bool],
):
Expand All @@ -37,10 +37,12 @@ def on_trainer_init(
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")

# gradient clipping
if not isinstance(gradient_clip_val, (int, float)):
if gradient_clip_val is not None and not isinstance(gradient_clip_val, (int, float)):
raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")

if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):
if gradient_clip_algorithm is not None and not GradClipAlgorithmType.supported_type(
gradient_clip_algorithm.lower()
):
raise MisconfigurationException(
f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid. "
f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
Expand All @@ -54,5 +56,9 @@ def on_trainer_init(

self.trainer.terminate_on_nan = terminate_on_nan
self.trainer.gradient_clip_val = gradient_clip_val
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm.lower())
self.trainer.gradient_clip_algorithm = (
GradClipAlgorithmType(gradient_clip_algorithm.lower())
if gradient_clip_algorithm is not None
else gradient_clip_algorithm
)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
self.trainer.track_grad_norm = float(track_grad_norm)
11 changes: 6 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def __init__(
enable_checkpointing: bool = True,
callbacks: Optional[Union[List[Callback], Callback]] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: str = "norm",
gradient_clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[str] = None,
process_position: int = 0,
num_nodes: int = 1,
num_processes: int = 1,
Expand Down Expand Up @@ -252,11 +252,12 @@ def __init__(

gpus: Number of GPUs to train on (int) or which GPUs to train on (list or str) applied per node

gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=0`` disables gradient
clipping.
gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
gradient clipping.

gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
for clip_by_value, and ``gradient_clip_algorithm="norm"`` for clip_by_norm.
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will
be set to ``"norm"``.

limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches).

Expand Down
Loading