Skip to content

Commit

Permalink
update types
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed May 11, 2021
1 parent 29d8c36 commit ef4c4eb
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
8 changes: 5 additions & 3 deletions pytorch_lightning/tuner/auto_gpu_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import torch

from pytorch_lightning.utilities.exceptions import MisconfigurationException


def pick_multiple_gpus(nb):
def pick_multiple_gpus(nb: int) -> List[int]:
"""
Raises:
MisconfigurationException:
Expand All @@ -30,14 +32,14 @@ def pick_multiple_gpus(nb):

nb = torch.cuda.device_count() if nb == -1 else nb

picked = []
picked: List[int] = []
for _ in range(nb):
picked.append(pick_single_gpu(exclude_gpus=picked))

return picked


def pick_single_gpu(exclude_gpus: list):
def pick_single_gpu(exclude_gpus: list) -> int:
"""
Raises:
RuntimeError:
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import os
from typing import Optional, Tuple

from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
Expand All @@ -39,7 +41,7 @@ def scale_batch_size(
"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`"""
if trainer.fast_dev_run:
rank_zero_warn('Skipping batch size scaler since fast_dev_run is enabled.', UserWarning)
return
return None

if not lightning_hasattr(model, batch_arg_name):
raise MisconfigurationException(f'Field {batch_arg_name} not found in both `model` and `model.hparams`')
Expand Down Expand Up @@ -251,5 +253,5 @@ def _adjust_batch_size(
return new_size, changed


def _is_valid_batch_size(current_size: int, dataloader) -> bool:
def _is_valid_batch_size(current_size: int, dataloader: 'DataLoader') -> bool:
return not has_len(dataloader) or current_size <= len(dataloader)
15 changes: 8 additions & 7 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
import os
from functools import wraps
from typing import Any, Callable, Dict, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
self.lr_max = lr_max
self.num_training = num_training

self.results = {}
self.results: Dict = {}
self._total_batch_idx = 0 # for debug purpose

def _exchange_scheduler(self, configure_optimizers: Callable) -> Callable:
Expand All @@ -105,7 +105,7 @@ def _exchange_scheduler(self, configure_optimizers: Callable) -> Callable:
"""

@wraps(configure_optimizers)
def func():
def func() -> Tuple:
# Decide the structure of the output from configure_optimizers
# Same logic as method `init_optimizers` in trainer/optimizers.py
optim_conf = configure_optimizers()
Expand Down Expand Up @@ -172,7 +172,7 @@ def plot(self, suggest: bool = False, show: bool = False):

return fig

def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> float:
""" This will propose a suggestion for choice of initial learning rate
as the point with the steepest negative gradient.
Expand All @@ -192,6 +192,7 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
except Exception:
log.exception('Failed to compute suggesting for `lr`. There might not be enough points.')
self._optimal_idx = None
return None


def lr_find(
Expand All @@ -207,7 +208,7 @@ def lr_find(
"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""
if trainer.fast_dev_run:
rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning)
return
return None

# Determine lr attr
if update_attr:
Expand Down Expand Up @@ -403,7 +404,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
self.num_iter = num_iter
super(_LinearLR, self).__init__(optimizer, last_epoch)

def get_lr(self) -> list:
def get_lr(self) -> List:
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter

Expand Down Expand Up @@ -453,5 +454,5 @@ def get_lr(self) -> list:
return val

@property
def lr(self):
def lr(self) -> float:
return self._lr
4 changes: 2 additions & 2 deletions pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _tune(
model: 'pl.LightningModule',
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
lr_find_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Optional[Union[int, _LRFinder]]]:
) -> Dict[str, Union[int, _LRFinder, None]]:
scale_batch_size_kwargs = scale_batch_size_kwargs or {}
lr_find_kwargs = lr_find_kwargs or {}
# return a dict instead of a tuple so BC is not broken if a new tuning procedure is added
Expand Down Expand Up @@ -75,7 +75,7 @@ def scale_batch_size(
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = 'batch_size',
) -> Optional[int]:
) -> Union[int, _LRFinder, None]:
"""
Iteratively try to find the largest batch size for a given model
that does not give an out of memory (OOM) error.
Expand Down

0 comments on commit ef4c4eb

Please sign in to comment.