Skip to content

Commit

Permalink
Rename master_params to main_params (Lightning-AI#10105)
Browse files Browse the repository at this point in the history
Co-authored-by: tchaton <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
  • Loading branch information
3 people authored and ninginthecloud committed Oct 27, 2021
1 parent 07e5bd3 commit 44db122
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `ClusterEnvironment.creates_children()` in favor of `ClusterEnvironment.creates_processes_externally` (property) ([#10106](https://github.com/PyTorchLightning/pytorch-lightning/pull/10106))



- Deprecated `PrecisionPlugin.master_params()` in favor of `PrecisionPlugin.main_params()` ([#10105](https://github.com/PyTorchLightning/pytorch-lightning/pull/10105))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, amp_level: str = "O2") -> None:
self.amp_level = amp_level
self._connected = False

def master_params(self, optimizer: Optimizer) -> _PARAMETERS:
def main_params(self, optimizer: Optimizer) -> _PARAMETERS:
return amp.master_params(optimizer)

def dispatch(self, trainer: "pl.Trainer") -> None:
Expand Down
21 changes: 17 additions & 4 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytorch_lightning as pl
from pytorch_lightning.core.hooks import CheckpointHooks
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation
from pytorch_lightning.utilities.types import _PARAMETERS


Expand All @@ -34,7 +34,20 @@ class PrecisionPlugin(CheckpointHooks):
precision: Union[str, int] = 32

def master_params(self, optimizer: Optimizer) -> _PARAMETERS:
"""The master params of the model.
"""The main params of the model.
.. deprecated:: v1.5
This method is deprecated in v1.5 and will be removed in v1.6. Use :meth:`main_params` instead.
"""
rank_zero_deprecation(
f"`{self.__class__.__name__}.master_params` was deprecated in v1.5 and will be removed in v1.6."
f" Use `main_params` instead."
)
return self.main_params(optimizer)

def main_params(self, optimizer: Optimizer) -> _PARAMETERS:
"""The main params of the model.
Returns the plain model params here. Maybe different in other precision plugins.
"""
Expand Down Expand Up @@ -126,12 +139,12 @@ def clip_gradients(

def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by value."""
parameters = self.master_params(optimizer)
parameters = self.main_params(optimizer)
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)

def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by norm."""
parameters = self.master_params(optimizer)
parameters = self.main_params(optimizer)
torch.nn.utils.clip_grad_norm_(parameters, clip_val)

def pre_dispatch(self) -> None:
Expand Down
7 changes: 7 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

import pytest
import torch
from torch.optim import Optimizer

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.plugins.environments import (
KubeflowEnvironment,
LightningEnvironment,
Expand Down Expand Up @@ -441,3 +443,8 @@ def test_v1_6_0_is_slurm_managing_tasks():
def test_v1_6_0_cluster_environment_creates_children(cluster_environment):
with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.6"):
cluster_environment.creates_children()


def test_v1_6_0_master_params():
with pytest.deprecated_call(match="`PrecisionPlugin.master_params` was deprecated in v1.5"):
PrecisionPlugin().master_params(Mock(spec=Optimizer))

0 comments on commit 44db122

Please sign in to comment.