Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename master_params to main_params #10105

Merged
merged 15 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated passing `resume_from_checkpoint` to the `Trainer` constructor in favor of `trainer.fit(ckpt_path=)` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))



- Deprecated `PrecisionPlugin.master_params()` in favor of `PrecisionPlugin.main_params()` ([#10105](https://github.com/PyTorchLightning/pytorch-lightning/pull/10105))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, amp_level: str = "O2") -> None:
self.amp_level = amp_level
self._connected = False

def master_params(self, optimizer: Optimizer) -> _PARAMETERS:
def main_params(self, optimizer: Optimizer) -> _PARAMETERS:
return amp.master_params(optimizer)

def dispatch(self, trainer: "pl.Trainer") -> None:
Expand Down
20 changes: 16 additions & 4 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytorch_lightning as pl
from pytorch_lightning.core.hooks import CheckpointHooks
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation
from pytorch_lightning.utilities.types import _PARAMETERS


Expand All @@ -34,7 +34,19 @@ class PrecisionPlugin(CheckpointHooks):
precision: Union[str, int] = 32

def master_params(self, optimizer: Optimizer) -> _PARAMETERS:
"""The master params of the model.
"""The main params of the model.

.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Use :meth:`main_params`
instead.
"""
rank_zero_deprecation(
f"`{self.__class__.__name__}.master_params` was deprecated in v1.5 and will be removed in v1.6."
f" Use `main_params` instead."
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
)
return self.main_params(optimizer)

def main_params(self, optimizer: Optimizer) -> _PARAMETERS:
"""The main params of the model.

Returns the plain model params here. Maybe different in other precision plugins.
"""
Expand Down Expand Up @@ -126,12 +138,12 @@ def clip_gradients(

def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by value."""
parameters = self.master_params(optimizer)
parameters = self.main_params(optimizer)
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)

def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by norm."""
parameters = self.master_params(optimizer)
parameters = self.main_params(optimizer)
torch.nn.utils.clip_grad_norm_(parameters, clip_val)

def pre_dispatch(self) -> None:
Expand Down
7 changes: 7 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

import pytest
import torch
from torch.optim import Optimizer

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -421,3 +423,8 @@ def test_v1_6_0_is_slurm_managing_tasks():

with pytest.deprecated_call(match=r"`AcceleratorConnector.is_slurm_managing_tasks` was deprecated in v1.5"):
trainer._accelerator_connector.is_slurm_managing_tasks = False


def test_v1_6_0_master_params():
with pytest.deprecated_call(match="`PrecisionPlugin.master_params` was deprecated in v1.5"):
PrecisionPlugin().master_params(Mock(spec=Optimizer))