-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add typing to lightning.tuner #7117
Changes from 8 commits
8ae74d4
3f12c1c
b0406b7
2c47b49
f313a9b
214b43d
e9d4c86
20e854c
ddb5b3a
feaac84
5743798
66889aa
365ee74
3c15527
5c08ff1
6138d44
956cbae
b9b5ceb
1b73302
132d2d6
29d8c36
ef4c4eb
3700cea
432dd03
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,14 +15,16 @@ | |
import logging | ||
import os | ||
from functools import wraps | ||
from typing import Callable, List, Optional, Sequence, Union | ||
from typing import Callable, List, Optional, Sequence, Union, TYPE_CHECKING | ||
|
||
import numpy as np | ||
import torch | ||
from torch.optim import Optimizer | ||
from torch.optim.lr_scheduler import _LRScheduler | ||
from torch.utils.data import DataLoader | ||
|
||
if TYPE_CHECKING: | ||
import pytorch_lightning as pl | ||
from pytorch_lightning.callbacks import Callback | ||
from pytorch_lightning.core.datamodule import LightningDataModule | ||
from pytorch_lightning.core.lightning import LightningModule | ||
|
@@ -42,7 +44,7 @@ | |
log = logging.getLogger(__name__) | ||
|
||
|
||
def _determine_lr_attr_name(trainer, model: LightningModule) -> str: | ||
def _determine_lr_attr_name(trainer: 'pl.Trainer', model: LightningModule) -> str: | ||
if isinstance(trainer.auto_lr_find, str): | ||
if not lightning_hasattr(model, trainer.auto_lr_find): | ||
raise MisconfigurationException( | ||
|
@@ -63,18 +65,18 @@ def _determine_lr_attr_name(trainer, model: LightningModule) -> str: | |
|
||
|
||
def lr_find( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you move it to the old place. easier for us to review the changes :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done @awaelchli |
||
trainer, | ||
model: LightningModule, | ||
train_dataloader: Optional[DataLoader] = None, | ||
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, | ||
min_lr: float = 1e-8, | ||
max_lr: float = 1, | ||
num_training: int = 100, | ||
mode: str = 'exponential', | ||
early_stop_threshold: float = 4.0, | ||
datamodule: Optional[LightningDataModule] = None, | ||
update_attr: bool = False, | ||
): | ||
trainer: 'pl.Trainer', | ||
model: LightningModule, | ||
train_dataloader: Optional[DataLoader] = None, | ||
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, | ||
min_lr: float = 1e-8, | ||
max_lr: float = 1, | ||
num_training: int = 100, | ||
mode: str = 'exponential', | ||
early_stop_threshold: float = 4.0, | ||
datamodule: Optional[LightningDataModule] = None, | ||
update_attr: bool = False, | ||
) -> '_LRFinder': | ||
r""" | ||
``lr_find`` enables the user to do a range test of good initial learning rates, | ||
to reduce the amount of guesswork in picking a good starting learning rate. | ||
|
@@ -209,7 +211,7 @@ def lr_find( | |
return lr_finder | ||
|
||
|
||
def __lr_finder_dump_params(trainer, model): | ||
def __lr_finder_dump_params(trainer: 'pl.Trainer', model: LightningModule) -> None: | ||
# Prevent going into infinite loop | ||
trainer.__dumped_params = { | ||
'auto_lr_find': trainer.auto_lr_find, | ||
|
@@ -221,7 +223,7 @@ def __lr_finder_dump_params(trainer, model): | |
} | ||
|
||
|
||
def __lr_finder_restore_params(trainer, model): | ||
def __lr_finder_restore_params(trainer: 'pl.Trainer', model: LightningModule) -> None: | ||
trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] | ||
trainer.logger = trainer.__dumped_params['logger'] | ||
trainer.callbacks = trainer.__dumped_params['callbacks'] | ||
|
@@ -268,7 +270,7 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): | |
self.results = {} | ||
self._total_batch_idx = 0 # for debug purpose | ||
|
||
def _exchange_scheduler(self, configure_optimizers: Callable): | ||
def _exchange_scheduler(self, configure_optimizers: Callable) -> callable: | ||
""" Decorate configure_optimizers methods such that it returns the users | ||
originally specified optimizer together with a new scheduler that | ||
that takes care of the learning rate search. | ||
|
@@ -311,7 +313,7 @@ def func(): | |
|
||
return func | ||
|
||
def plot(self, suggest: bool = False, show: bool = False): | ||
def plot(self, suggest: bool = False, show: bool = False) -> 'plt.Figure': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'plt.Figure' is causing PEP8 to fail. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is pyplot imported there? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pyplot is imported inside the function |
||
""" Plot results from lr_find run | ||
Args: | ||
suggest: if True, will mark suggested lr to use with a red point | ||
|
@@ -342,7 +344,7 @@ def plot(self, suggest: bool = False, show: bool = False): | |
|
||
return fig | ||
|
||
def suggestion(self, skip_begin: int = 10, skip_end: int = 1): | ||
def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> dict: | ||
""" This will propose a suggestion for choice of initial learning rate | ||
as the point with the steepest negative gradient. | ||
|
||
|
@@ -383,11 +385,11 @@ class _LRCallback(Callback): | |
""" | ||
|
||
def __init__( | ||
self, | ||
num_training: int, | ||
early_stop_threshold: float = 4.0, | ||
progress_bar_refresh_rate: int = 0, | ||
beta: float = 0.98 | ||
self, | ||
num_training: int, | ||
early_stop_threshold: float = 4.0, | ||
progress_bar_refresh_rate: int = 0, | ||
beta: float = 0.98 | ||
): | ||
self.num_training = num_training | ||
self.early_stop_threshold = early_stop_threshold | ||
|
@@ -399,7 +401,7 @@ def __init__( | |
self.progress_bar_refresh_rate = progress_bar_refresh_rate | ||
self.progress_bar = None | ||
|
||
def on_batch_start(self, trainer, pl_module): | ||
def on_batch_start(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None: | ||
""" Called before each training batch, logs the lr that will be used """ | ||
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0: | ||
return | ||
|
@@ -409,7 +411,9 @@ def on_batch_start(self, trainer, pl_module): | |
|
||
self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0]) | ||
|
||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): | ||
def on_train_batch_end(self, trainer: 'pl.Trainer', pl_module: LightningModule, outputs, batch, | ||
batch_idx: Optional[int], | ||
dataloader_idx: Optional[int]) -> None: | ||
""" Called when the training batch ends, logs the calculated loss """ | ||
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0: | ||
return | ||
|
@@ -422,7 +426,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data | |
|
||
# Avg loss (loss with momentum) + smoothing | ||
self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss | ||
smoothed_loss = self.avg_loss / (1 - self.beta**(current_step + 1)) | ||
smoothed_loss = self.avg_loss / (1 - self.beta ** (current_step + 1)) | ||
|
||
# Check if we diverging | ||
if self.early_stop_threshold is not None: | ||
|
@@ -459,7 +463,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in | |
self.num_iter = num_iter | ||
super(_LinearLR, self).__init__(optimizer, last_epoch) | ||
|
||
def get_lr(self): | ||
def get_lr(self) -> list: | ||
curr_iter = self.last_epoch + 1 | ||
r = curr_iter / self.num_iter | ||
|
||
|
@@ -497,12 +501,12 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in | |
self.num_iter = num_iter | ||
super(_ExponentialLR, self).__init__(optimizer, last_epoch) | ||
|
||
def get_lr(self): | ||
def get_lr(self) -> list: | ||
curr_iter = self.last_epoch + 1 | ||
r = curr_iter / self.num_iter | ||
|
||
if self.last_epoch > 0: | ||
val = [base_lr * (self.end_lr / base_lr)**r for base_lr in self.base_lrs] | ||
val = [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] | ||
else: | ||
val = [base_lr for base_lr in self.base_lrs] | ||
self._lr = val | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this has to be imported without the
if TYPE_CHECKING
guard