Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing to lightning.tuner #7117

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions pytorch_lightning/tuner/auto_gpu_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import torch

from pytorch_lightning.utilities.exceptions import MisconfigurationException


def pick_multiple_gpus(nb):
def pick_multiple_gpus(nb: int) -> List[int]:
"""
Raises:
MisconfigurationException:
Expand All @@ -30,14 +32,14 @@ def pick_multiple_gpus(nb):

nb = torch.cuda.device_count() if nb == -1 else nb

picked = []
picked: List[int] = []
for _ in range(nb):
picked.append(pick_single_gpu(exclude_gpus=picked))

return picked


def pick_single_gpu(exclude_gpus: list):
def pick_single_gpu(exclude_gpus: list) -> int:
"""
Raises:
RuntimeError:
Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import os
from typing import Optional, Tuple

from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
Expand All @@ -39,7 +41,7 @@ def scale_batch_size(
"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`"""
if trainer.fast_dev_run:
rank_zero_warn('Skipping batch size scaler since fast_dev_run is enabled.', UserWarning)
return
return None

if not lightning_hasattr(model, batch_arg_name):
raise MisconfigurationException(f'Field {batch_arg_name} not found in both `model` and `model.hparams`')
Expand Down Expand Up @@ -244,12 +246,12 @@ def _adjust_batch_size(
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')

if not _is_valid_batch_size(new_size, trainer.train_dataloader):
new_size = min(new_size, len(trainer.train_dataloader.dataset))
new_size = min(new_size, len(trainer.train_dataloader.dataset)) # type: ignore

changed = new_size != batch_size
lightning_setattr(model, batch_arg_name, new_size)
return new_size, changed


def _is_valid_batch_size(current_size, dataloader):
def _is_valid_batch_size(current_size: int, dataloader: 'DataLoader') -> bool:
return not has_len(dataloader) or current_size <= len(dataloader)
48 changes: 27 additions & 21 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
import os
from functools import wraps
from typing import Callable, Optional, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -95,17 +95,17 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
self.lr_max = lr_max
self.num_training = num_training

self.results = {}
self.results: Dict = {}
self._total_batch_idx = 0 # for debug purpose

def _exchange_scheduler(self, configure_optimizers: Callable):
def _exchange_scheduler(self, configure_optimizers: Callable) -> Callable:
""" Decorate configure_optimizers methods such that it returns the users
originally specified optimizer together with a new scheduler that
that takes care of the learning rate search.
"""

@wraps(configure_optimizers)
def func():
def func() -> Tuple:
# Decide the structure of the output from configure_optimizers
# Same logic as method `init_optimizers` in trainer/optimizers.py
optim_conf = configure_optimizers()
Expand All @@ -119,7 +119,7 @@ def func():
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
elif isinstance(optim_conf, (list, tuple)):
optimizers = [optim_conf]
optimizers = [optim_conf] # type: ignore

if len(optimizers) != 1:
raise MisconfigurationException(
Expand All @@ -141,7 +141,7 @@ def func():

return func

def plot(self, suggest: bool = False, show: bool = False):
def plot(self, suggest: bool = False, show: bool = False): # type: ignore
""" Plot results from lr_find run
Args:
suggest: if True, will mark suggested lr to use with a red point
Expand Down Expand Up @@ -172,7 +172,7 @@ def plot(self, suggest: bool = False, show: bool = False):

return fig

def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]:
""" This will propose a suggestion for choice of initial learning rate
as the point with the steepest negative gradient.

Expand All @@ -192,6 +192,7 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
except Exception:
log.exception('Failed to compute suggesting for `lr`. There might not be enough points.')
self._optimal_idx = None
return None


def lr_find(
Expand All @@ -207,7 +208,7 @@ def lr_find(
"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""
if trainer.fast_dev_run:
rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning)
return
return None

# Determine lr attr
if update_attr:
Expand Down Expand Up @@ -244,7 +245,7 @@ def lr_find(
trainer.save_checkpoint(str(save_path))

# Configure optimizer and scheduler
model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers)
model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers) # type: ignore

# Fit, lr & loss logged in callback
trainer.tuner._run(model)
Expand Down Expand Up @@ -280,7 +281,7 @@ def lr_find(
return lr_finder


def __lr_finder_dump_params(trainer, model):
def __lr_finder_dump_params(trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
# Prevent going into infinite loop
trainer.__dumped_params = {
'auto_lr_find': trainer.auto_lr_find,
Expand All @@ -293,13 +294,13 @@ def __lr_finder_dump_params(trainer, model):
}


def __lr_finder_restore_params(trainer, model):
def __lr_finder_restore_params(trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find']
trainer.logger = trainer.__dumped_params['logger']
trainer.callbacks = trainer.__dumped_params['callbacks']
trainer.train_loop.max_steps = trainer.__dumped_params['max_steps']
trainer.train_loop.current_epoch = trainer.__dumped_params['current_epoch']
model.configure_optimizers = trainer.__dumped_params['configure_optimizers']
model.configure_optimizers = trainer.__dumped_params['configure_optimizers'] # type: ignore
del trainer.__dumped_params


Expand Down Expand Up @@ -331,24 +332,28 @@ def __init__(
self.num_training = num_training
self.early_stop_threshold = early_stop_threshold
self.beta = beta
self.losses = []
self.lrs = []
self.losses: List[float] = []
self.lrs: List[float] = []
self.avg_loss = 0.0
self.best_loss = 0.0
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar = None

def on_batch_start(self, trainer, pl_module):
def on_batch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
""" Called before each training batch, logs the lr that will be used """
if (trainer.train_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return

if self.progress_bar_refresh_rate and self.progress_bar is None:
self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training)

self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])
self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0]) # type: ignore

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: Optional[Union[torch.Tensor, Dict[str,
Any]]],
batch: Any, batch_idx: Optional[int], dataloader_idx: Optional[int]
) -> None:
""" Called when the training batch ends, logs the calculated loss """
if (trainer.train_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return
Expand Down Expand Up @@ -399,7 +404,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
self.num_iter = num_iter
super(_LinearLR, self).__init__(optimizer, last_epoch)

def get_lr(self):
def get_lr(self) -> List[float]: # type: ignore
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter

Expand All @@ -411,7 +416,7 @@ def get_lr(self):
return val

@property
def lr(self):
def lr(self) -> List[float]:
return self._lr


Expand All @@ -437,17 +442,18 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
self.num_iter = num_iter
super(_ExponentialLR, self).__init__(optimizer, last_epoch)

def get_lr(self):
def get_lr(self) -> List[float]: # type: ignore
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter

if self.last_epoch > 0:
val = [base_lr * (self.end_lr / base_lr)**r for base_lr in self.base_lrs]
else:
val = [base_lr for base_lr in self.base_lrs]
# todo: why not `val = self.base_lrs`?
self._lr = val
return val

@property
def lr(self):
def lr(self) -> List[float]:
return self._lr
10 changes: 5 additions & 5 deletions pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _tune(
model: 'pl.LightningModule',
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
lr_find_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Optional[Union[int, _LRFinder]]]:
) -> Dict[str, Union[int, _LRFinder, None]]:
scale_batch_size_kwargs = scale_batch_size_kwargs or {}
lr_find_kwargs = lr_find_kwargs or {}
# return a dict instead of a tuple so BC is not broken if a new tuning procedure is added
Expand All @@ -51,11 +51,11 @@ def _tune(
# Run learning rate finder:
if self.trainer.auto_lr_find:
lr_find_kwargs.setdefault('update_attr', True)
result['lr_find'] = lr_find(self.trainer, model, **lr_find_kwargs)
result['lr_find'] = lr_find(self.trainer, model, **lr_find_kwargs) # type: ignore

self.trainer.state.status = TrainerStatus.FINISHED

return result
return result # type: ignore

def _run(self, *args: Any, **kwargs: Any) -> None:
"""`_run` wrapper to set the proper state during tuning, as this can be called multiple times"""
Expand All @@ -75,7 +75,7 @@ def scale_batch_size(
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = 'batch_size',
) -> Optional[int]:
) -> Union[int, _LRFinder, None]:
"""
Iteratively try to find the largest batch size for a given model
that does not give an out of memory (OOM) error.
Expand Down Expand Up @@ -198,4 +198,4 @@ def lr_find(
}
)
self.trainer.auto_lr_find = False
return result['lr_find']
return result['lr_find'] # type: ignore
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,6 @@ ignore_errors = False
[mypy-pytorch_lightning.distributed.*]
ignore_errors = True

# todo: add proper typing to this module...
[mypy-pytorch_lightning.tuner.*]
ignore_errors = True

# todo: add proper typing to this module...
[mypy-pytorch_lightning.utilities.*]
ignore_errors = True
Expand Down