Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing to lightning.tuner #7117

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pytorch_lightning/tuner/auto_gpu_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import torch

from pytorch_lightning.utilities.exceptions import MisconfigurationException


def pick_multiple_gpus(nb):
def pick_multiple_gpus(nb: int) -> List[int]:
'''
Raises:
MisconfigurationException:
Expand All @@ -37,7 +39,7 @@ def pick_multiple_gpus(nb):
return picked


def pick_single_gpu(exclude_gpus: list):
def pick_single_gpu(exclude_gpus: list) -> int:
'''
Raises:
RuntimeError:
Expand Down
47 changes: 26 additions & 21 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License
import logging
import os
from typing import Optional, Tuple
from typing import Optional, Tuple, TYPE_CHECKING

if TYPE_CHECKING:
import pytorch_lightning as pl
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has to be imported without the if TYPE_CHECKING guard

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
Expand All @@ -28,15 +30,15 @@


def scale_batch_size(
trainer,
model: LightningModule,
mode: str = 'power',
steps_per_trial: int = 3,
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = 'batch_size',
**fit_kwargs
):
trainer: 'pl.Trainer',
model: LightningModule,
mode: str = 'power',
steps_per_trial: int = 3,
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = 'batch_size',
**fit_kwargs
) -> Optional[int]:
r"""
Will iteratively try to find the largest batch size for a given model
that does not give an out of memory (OOM) error.
Expand Down Expand Up @@ -139,7 +141,7 @@ def scale_batch_size(
return new_size


def __scale_batch_dump_params(trainer):
def __scale_batch_dump_params(trainer: 'pl.Trainer') -> None:
# Prevent going into infinite loop
trainer.__dumped_params = {
'auto_lr_find': trainer.auto_lr_find,
Expand All @@ -155,7 +157,7 @@ def __scale_batch_dump_params(trainer):
}


def __scale_batch_reset_params(trainer, model, steps_per_trial):
def __scale_batch_reset_params(trainer: 'pl.Trainer', model: LightningModule, steps_per_trial: int) -> None:
trainer.auto_scale_batch_size = None # prevent recursion
trainer.auto_lr_find = False # avoid lr find being called multiple times
trainer.current_epoch = 0
Expand All @@ -168,7 +170,7 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial):
trainer.model = model # required for saving


def __scale_batch_restore_params(trainer):
def __scale_batch_restore_params(trainer: 'pl.Trainer') -> None:
trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find']
trainer.current_epoch = trainer.__dumped_params['current_epoch']
trainer.max_steps = trainer.__dumped_params['max_steps']
Expand All @@ -181,7 +183,8 @@ def __scale_batch_restore_params(trainer):
del trainer.__dumped_params


def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs):
def _run_power_scaling(trainer: 'pl.Trainer', model: LightningModule, new_size, batch_arg_name: str, max_trials: int,
**fit_kwargs) -> int:
""" Batch scaling mode where the size is doubled at each iteration until an
OOM error is encountered. """
for _ in range(max_trials):
Expand All @@ -207,7 +210,9 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f
return new_size


def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs):
def _run_binsearch_scaling(trainer: 'pl.Trainer', model: LightningModule, new_size, batch_arg_name: str,
max_trials: int,
**fit_kwargs) -> int:
""" Batch scaling mode where the size is initially is doubled at each iteration
until an OOM error is encountered. Hereafter, the batch size is further
refined using a binary search """
Expand Down Expand Up @@ -252,11 +257,11 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials,


def _adjust_batch_size(
trainer,
batch_arg_name: str = 'batch_size',
factor: float = 1.0,
value: Optional[int] = None,
desc: Optional[str] = None
trainer: 'pl.Trainer',
batch_arg_name: str = 'batch_size',
factor: float = 1.0,
value: Optional[int] = None,
desc: Optional[str] = None
) -> Tuple[int, bool]:
""" Helper function for adjusting the batch size.

Expand Down Expand Up @@ -291,5 +296,5 @@ def _adjust_batch_size(
return new_size, changed


def _is_valid_batch_size(current_size, dataloader):
def _is_valid_batch_size(current_size: int, dataloader) -> bool:
return not has_len(dataloader) or current_size <= len(dataloader)
64 changes: 34 additions & 30 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
import logging
import os
from functools import wraps
from typing import Callable, List, Optional, Sequence, Union
from typing import Callable, List, Optional, Sequence, Union, TYPE_CHECKING

import numpy as np
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

if TYPE_CHECKING:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
Expand All @@ -42,7 +44,7 @@
log = logging.getLogger(__name__)


def _determine_lr_attr_name(trainer, model: LightningModule) -> str:
def _determine_lr_attr_name(trainer: 'pl.Trainer', model: LightningModule) -> str:
if isinstance(trainer.auto_lr_find, str):
if not lightning_hasattr(model, trainer.auto_lr_find):
raise MisconfigurationException(
Expand All @@ -63,18 +65,18 @@ def _determine_lr_attr_name(trainer, model: LightningModule) -> str:


def lr_find(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you move it to the old place. easier for us to review the changes :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done @awaelchli

trainer,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
min_lr: float = 1e-8,
max_lr: float = 1,
num_training: int = 100,
mode: str = 'exponential',
early_stop_threshold: float = 4.0,
datamodule: Optional[LightningDataModule] = None,
update_attr: bool = False,
):
trainer: 'pl.Trainer',
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
min_lr: float = 1e-8,
max_lr: float = 1,
num_training: int = 100,
mode: str = 'exponential',
early_stop_threshold: float = 4.0,
datamodule: Optional[LightningDataModule] = None,
update_attr: bool = False,
) -> '_LRFinder':
r"""
``lr_find`` enables the user to do a range test of good initial learning rates,
to reduce the amount of guesswork in picking a good starting learning rate.
Expand Down Expand Up @@ -209,7 +211,7 @@ def lr_find(
return lr_finder


def __lr_finder_dump_params(trainer, model):
def __lr_finder_dump_params(trainer: 'pl.Trainer', model: LightningModule) -> None:
# Prevent going into infinite loop
trainer.__dumped_params = {
'auto_lr_find': trainer.auto_lr_find,
Expand All @@ -221,7 +223,7 @@ def __lr_finder_dump_params(trainer, model):
}


def __lr_finder_restore_params(trainer, model):
def __lr_finder_restore_params(trainer: 'pl.Trainer', model: LightningModule) -> None:
trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find']
trainer.logger = trainer.__dumped_params['logger']
trainer.callbacks = trainer.__dumped_params['callbacks']
Expand Down Expand Up @@ -268,7 +270,7 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
self.results = {}
self._total_batch_idx = 0 # for debug purpose

def _exchange_scheduler(self, configure_optimizers: Callable):
def _exchange_scheduler(self, configure_optimizers: Callable) -> callable:
""" Decorate configure_optimizers methods such that it returns the users
originally specified optimizer together with a new scheduler that
that takes care of the learning rate search.
Expand Down Expand Up @@ -311,7 +313,7 @@ def func():

return func

def plot(self, suggest: bool = False, show: bool = False):
def plot(self, suggest: bool = False, show: bool = False) -> 'plt.Figure':
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'plt.Figure' is causing PEP8 to fail. ./pytorch_lightning/tuner/lr_finder.py:315: [F821] undefined name 'plt'
What should be done here? @justusschock

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is pyplot imported there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyplot is imported inside the function

""" Plot results from lr_find run
Args:
suggest: if True, will mark suggested lr to use with a red point
Expand Down Expand Up @@ -342,7 +344,7 @@ def plot(self, suggest: bool = False, show: bool = False):

return fig

def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> dict:
""" This will propose a suggestion for choice of initial learning rate
as the point with the steepest negative gradient.

Expand Down Expand Up @@ -383,11 +385,11 @@ class _LRCallback(Callback):
"""

def __init__(
self,
num_training: int,
early_stop_threshold: float = 4.0,
progress_bar_refresh_rate: int = 0,
beta: float = 0.98
self,
num_training: int,
early_stop_threshold: float = 4.0,
progress_bar_refresh_rate: int = 0,
beta: float = 0.98
):
self.num_training = num_training
self.early_stop_threshold = early_stop_threshold
Expand All @@ -399,7 +401,7 @@ def __init__(
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar = None

def on_batch_start(self, trainer, pl_module):
def on_batch_start(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None:
""" Called before each training batch, logs the lr that will be used """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return
Expand All @@ -409,7 +411,9 @@ def on_batch_start(self, trainer, pl_module):

self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer: 'pl.Trainer', pl_module: LightningModule, outputs, batch,
batch_idx: Optional[int],
dataloader_idx: Optional[int]) -> None:
""" Called when the training batch ends, logs the calculated loss """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return
Expand All @@ -422,7 +426,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data

# Avg loss (loss with momentum) + smoothing
self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss
smoothed_loss = self.avg_loss / (1 - self.beta**(current_step + 1))
smoothed_loss = self.avg_loss / (1 - self.beta ** (current_step + 1))

# Check if we diverging
if self.early_stop_threshold is not None:
Expand Down Expand Up @@ -459,7 +463,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
self.num_iter = num_iter
super(_LinearLR, self).__init__(optimizer, last_epoch)

def get_lr(self):
def get_lr(self) -> list:
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter

Expand Down Expand Up @@ -497,12 +501,12 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
self.num_iter = num_iter
super(_ExponentialLR, self).__init__(optimizer, last_epoch)

def get_lr(self):
def get_lr(self) -> list:
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter

if self.last_epoch > 0:
val = [base_lr * (self.end_lr / base_lr)**r for base_lr in self.base_lrs]
val = [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
else:
val = [base_lr for base_lr in self.base_lrs]
self._lr = val
Expand Down
Loading