Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed May 13, 2021
1 parent 3700cea commit 432dd03
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _adjust_batch_size(
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')

if not _is_valid_batch_size(new_size, trainer.train_dataloader):
new_size = min(new_size, len(trainer.train_dataloader.dataset))
new_size = min(new_size, len(trainer.train_dataloader.dataset)) # type: ignore

changed = new_size != batch_size
lightning_setattr(model, batch_arg_name, new_size)
Expand Down
25 changes: 13 additions & 12 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def func() -> Tuple:
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
elif isinstance(optim_conf, (list, tuple)):
optimizers = [optim_conf]
optimizers = [optim_conf] # type: ignore

if len(optimizers) != 1:
raise MisconfigurationException(
Expand All @@ -141,7 +141,7 @@ def func() -> Tuple:

return func

def plot(self, suggest: bool = False, show: bool = False):
def plot(self, suggest: bool = False, show: bool = False): # type: ignore
""" Plot results from lr_find run
Args:
suggest: if True, will mark suggested lr to use with a red point
Expand Down Expand Up @@ -172,7 +172,7 @@ def plot(self, suggest: bool = False, show: bool = False):

return fig

def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> float:
def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]:
""" This will propose a suggestion for choice of initial learning rate
as the point with the steepest negative gradient.
Expand Down Expand Up @@ -245,7 +245,7 @@ def lr_find(
trainer.save_checkpoint(str(save_path))

# Configure optimizer and scheduler
model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers)
model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers) # type: ignore

# Fit, lr & loss logged in callback
trainer.tuner._run(model)
Expand Down Expand Up @@ -300,7 +300,7 @@ def __lr_finder_restore_params(trainer: 'pl.Trainer', model: 'pl.LightningModule
trainer.callbacks = trainer.__dumped_params['callbacks']
trainer.train_loop.max_steps = trainer.__dumped_params['max_steps']
trainer.train_loop.current_epoch = trainer.__dumped_params['current_epoch']
model.configure_optimizers = trainer.__dumped_params['configure_optimizers']
model.configure_optimizers = trainer.__dumped_params['configure_optimizers'] # type: ignore
del trainer.__dumped_params


Expand Down Expand Up @@ -332,8 +332,8 @@ def __init__(
self.num_training = num_training
self.early_stop_threshold = early_stop_threshold
self.beta = beta
self.losses = []
self.lrs = []
self.losses: List[float] = []
self.lrs: List[float] = []
self.avg_loss = 0.0
self.best_loss = 0.0
self.progress_bar_refresh_rate = progress_bar_refresh_rate
Expand All @@ -347,7 +347,7 @@ def on_batch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule')
if self.progress_bar_refresh_rate and self.progress_bar is None:
self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training)

self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])
self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0]) # type: ignore

def on_train_batch_end(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: Optional[Union[torch.Tensor, Dict[str,
Expand Down Expand Up @@ -404,7 +404,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
self.num_iter = num_iter
super(_LinearLR, self).__init__(optimizer, last_epoch)

def get_lr(self) -> List:
def get_lr(self) -> List[float]: # type: ignore
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter

Expand All @@ -416,7 +416,7 @@ def get_lr(self) -> List:
return val

@property
def lr(self):
def lr(self) -> List[float]:
return self._lr


Expand All @@ -442,17 +442,18 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
self.num_iter = num_iter
super(_ExponentialLR, self).__init__(optimizer, last_epoch)

def get_lr(self) -> list:
def get_lr(self) -> List[float]: # type: ignore
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter

if self.last_epoch > 0:
val = [base_lr * (self.end_lr / base_lr)**r for base_lr in self.base_lrs]
else:
val = [base_lr for base_lr in self.base_lrs]
# todo: why not `val = self.base_lrs`?
self._lr = val
return val

@property
def lr(self) -> float:
def lr(self) -> List[float]:
return self._lr
6 changes: 3 additions & 3 deletions pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def _tune(
# Run learning rate finder:
if self.trainer.auto_lr_find:
lr_find_kwargs.setdefault('update_attr', True)
result['lr_find'] = lr_find(self.trainer, model, **lr_find_kwargs)
result['lr_find'] = lr_find(self.trainer, model, **lr_find_kwargs) # type: ignore

self.trainer.state.status = TrainerStatus.FINISHED

return result
return result # type: ignore

def _run(self, *args: Any, **kwargs: Any) -> None:
"""`_run` wrapper to set the proper state during tuning, as this can be called multiple times"""
Expand Down Expand Up @@ -198,4 +198,4 @@ def lr_find(
}
)
self.trainer.auto_lr_find = False
return result['lr_find']
return result['lr_find'] # type: ignore

0 comments on commit 432dd03

Please sign in to comment.