diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index ca93c5dc283c6..b23947faee3c8 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -194,26 +194,15 @@ def __init__( def train_dataloader(self): return DALIClassificationLoader( - self.pipe_train, - size=len(self.mnist_train), - auto_reset=True, - fill_last_batch=True + self.pipe_train, size=len(self.mnist_train), auto_reset=True, fill_last_batch=True ) def val_dataloader(self): - return DALIClassificationLoader( - self.pipe_val, - size=len(self.mnist_val), - auto_reset=True, - fill_last_batch=False - ) + return DALIClassificationLoader(self.pipe_val, size=len(self.mnist_val), auto_reset=True, fill_last_batch=False) def test_dataloader(self): return DALIClassificationLoader( - self.pipe_test, - size=len(self.mnist_test), - auto_reset=True, - fill_last_batch=False + self.pipe_test, size=len(self.mnist_test), auto_reset=True, fill_last_batch=False ) diff --git a/pl_examples/basic_examples/profiler_example.py b/pl_examples/basic_examples/profiler_example.py index 0a6f855b2f109..c79214af93581 100644 --- a/pl_examples/basic_examples/profiler_example.py +++ b/pl_examples/basic_examples/profiler_example.py @@ -44,10 +44,7 @@ class ModelToProfile(LightningModule): - def __init__( - self, - name: str = "resnet50" - ): + def __init__(self, name: str = "resnet50"): super().__init__() self.model = getattr(models, name)(pretrained=True) self.criterion = torch.nn.CrossEntropyLoss() diff --git a/pytorch_lightning/__about__.py b/pytorch_lightning/__about__.py index 0ce2273febf00..b1c2ff5f892d3 100644 --- a/pytorch_lightning/__about__.py +++ b/pytorch_lightning/__about__.py @@ -1,7 +1,7 @@ import time _this_year = time.strftime("%Y") -__version__ = '1.3.0rc2' +__version__ = "1.3.0rc2" __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 66e92ae006533..cd32c232fe85f 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -24,14 +24,11 @@ from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum +from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT -if _NATIVE_AMP_AVAILABLE: - from torch.cuda.amp import GradScaler - class Accelerator: """ @@ -374,7 +371,7 @@ def to_device(self, batch: Any) -> Any: return self.batch_to_device(batch, self.root_device) @property - def amp_backend(self) -> Optional[LightningEnum]: + def amp_backend(self) -> Optional['AMPType']: if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): return AMPType.APEX elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): @@ -386,7 +383,7 @@ def precision(self) -> Union[str, int]: return self.precision_plugin.precision @property - def scaler(self) -> Optional['GradScaler']: + def scaler(self) -> Optional['torch.cuda.amp.GradScaler']: return getattr(self.precision_plugin, 'scaler', None) @property diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 458f058c274f2..8b919b8394661 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.plugins.precision import MixedPrecisionPlugin diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 5b6c55e1b70ac..25b18e185c064 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -51,7 +51,7 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None: return super().setup(trainer, model) def run_optimizer_step( - self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any + self, optimizer: 'Optimizer', optimizer_idx: int, lambda_closure: Callable, **kwargs: Any ) -> None: xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs}) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 768e4ebca30ee..eda70a6ae25b0 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -17,9 +17,12 @@ """ import abc -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional -from pytorch_lightning.core.lightning import LightningModule +from torch.optim import Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT class Callback(abc.ABC): @@ -29,158 +32,165 @@ class Callback(abc.ABC): Subclass this class and override any of the relevant hooks """ - def on_configure_sharded_model(self, trainer, pl_module: LightningModule) -> None: + def on_configure_sharded_model(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called before configure sharded model""" - def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule) -> None: + def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called before accelerator is being setup""" pass - def setup(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: + def setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', stage: Optional[str] = None) -> None: """Called when fit, validate, test, predict, or tune begins""" pass - def teardown(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: + def teardown(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', stage: Optional[str] = None) -> None: """Called when fit, validate, test, predict, or tune ends""" pass - def on_init_start(self, trainer) -> None: + def on_init_start(self, trainer: 'pl.Trainer') -> None: """Called when the trainer initialization begins, model has not yet been set.""" pass - def on_init_end(self, trainer) -> None: + def on_init_end(self, trainer: 'pl.Trainer') -> None: """Called when the trainer initialization ends, model has not yet been set.""" pass - def on_fit_start(self, trainer, pl_module: LightningModule) -> None: + def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when fit begins""" pass - def on_fit_end(self, trainer, pl_module: LightningModule) -> None: + def on_fit_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when fit ends""" pass - def on_sanity_check_start(self, trainer, pl_module: LightningModule) -> None: + def on_sanity_check_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the validation sanity check starts.""" pass - def on_sanity_check_end(self, trainer, pl_module: LightningModule) -> None: + def on_sanity_check_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the validation sanity check ends.""" pass def on_train_batch_start( - self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', batch: Any, batch_idx: int, dataloader_idx: int ) -> None: """Called when the train batch begins.""" pass def on_train_batch_end( - self, trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int ) -> None: """Called when the train batch ends.""" pass - def on_train_epoch_start(self, trainer, pl_module: LightningModule) -> None: + def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the train epoch begins.""" pass - def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: + def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT) -> None: """Called when the train epoch ends.""" pass - def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None: + def on_validation_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the val epoch begins.""" pass - def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: + def on_validation_epoch_end( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT + ) -> None: """Called when the val epoch ends.""" pass - def on_test_epoch_start(self, trainer, pl_module: LightningModule) -> None: + def on_test_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the test epoch begins.""" pass - def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: + def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT) -> None: """Called when the test epoch ends.""" pass - def on_epoch_start(self, trainer, pl_module: LightningModule) -> None: + def on_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when either of train/val/test epoch begins.""" pass - def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: + def on_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when either of train/val/test epoch ends.""" pass - def on_batch_start(self, trainer, pl_module: LightningModule) -> None: + def on_batch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the training batch begins.""" pass def on_validation_batch_start( - self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', batch: Any, batch_idx: int, dataloader_idx: int ) -> None: """Called when the validation batch begins.""" pass def on_validation_batch_end( - self, trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int ) -> None: """Called when the validation batch ends.""" pass def on_test_batch_start( - self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', batch: Any, batch_idx: int, dataloader_idx: int ) -> None: """Called when the test batch begins.""" pass def on_test_batch_end( - self, trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int ) -> None: """Called when the test batch ends.""" pass - def on_batch_end(self, trainer, pl_module: LightningModule) -> None: + def on_batch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the training batch ends.""" pass - def on_train_start(self, trainer, pl_module: LightningModule) -> None: + def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the train begins.""" pass - def on_train_end(self, trainer, pl_module: LightningModule) -> None: + def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the train ends.""" pass - def on_pretrain_routine_start(self, trainer, pl_module: LightningModule) -> None: + def on_pretrain_routine_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the pretrain routine begins.""" pass - def on_pretrain_routine_end(self, trainer, pl_module: LightningModule) -> None: + def on_pretrain_routine_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the pretrain routine ends.""" pass - def on_validation_start(self, trainer, pl_module: LightningModule) -> None: + def on_validation_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the validation loop begins.""" pass - def on_validation_end(self, trainer, pl_module: LightningModule) -> None: + def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the validation loop ends.""" pass - def on_test_start(self, trainer, pl_module: LightningModule) -> None: + def on_test_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the test begins.""" pass - def on_test_end(self, trainer, pl_module: LightningModule) -> None: + def on_test_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the test ends.""" pass - def on_keyboard_interrupt(self, trainer, pl_module: LightningModule) -> None: + def on_keyboard_interrupt(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the training is interrupted by ``KeyboardInterrupt``.""" pass - def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> dict: + def on_save_checkpoint( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', checkpoint: Dict[str, Any] + ) -> dict: """ Called when saving a model checkpoint, use to persist state. @@ -202,10 +212,12 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: """ pass - def on_after_backward(self, trainer, pl_module: LightningModule) -> None: + def on_after_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called after ``loss.backward()`` and before optimizers do anything.""" pass - def on_before_zero_grad(self, trainer, pl_module: LightningModule, optimizer) -> None: + def on_before_zero_grad( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', optimizer: 'Optimizer' + ) -> None: """Called after ``optimizer.step()`` and before ``optimizer.zero_grad()``.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 9af576aafd596..8f5a87f2b8b89 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -19,11 +19,12 @@ """ import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Mapping, Optional, Tuple import numpy as np import torch +import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -115,7 +116,7 @@ def __init__( torch_inf = torch.tensor(np.Inf) self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - def _validate_condition_metric(self, logs): + def _validate_condition_metric(self, logs: Mapping) -> bool: monitor_val = logs.get(self.monitor) error_msg = ( @@ -135,10 +136,11 @@ def _validate_condition_metric(self, logs): return True @property - def monitor_op(self): + def monitor_op(self) -> Callable: return self.mode_dict[self.mode] - def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: + def on_save_checkpoint(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', + checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { 'wait_count': self.wait_count, 'stopped_epoch': self.stopped_epoch, @@ -146,20 +148,20 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> 'patience': self.patience } - def on_load_checkpoint(self, callback_state: Dict[str, Any]): + def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: self.wait_count = callback_state['wait_count'] self.stopped_epoch = callback_state['stopped_epoch'] self.best_score = callback_state['best_score'] self.patience = callback_state['patience'] - def on_validation_end(self, trainer, pl_module): + def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: from pytorch_lightning.trainer.states import TrainerState if trainer.state != TrainerState.FITTING or trainer.sanity_checking: return self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer): + def _run_early_stopping_check(self, trainer: 'pl.Trainer') -> None: """ Checks whether the early stopping condition is met and if so tells the trainer to stop the training. @@ -172,7 +174,7 @@ def _run_early_stopping_check(self, trainer): ): return # short circuit if metric not present - current = logs.get(self.monitor) + current = logs[self.monitor] # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) @@ -187,7 +189,7 @@ def _run_early_stopping_check(self, trainer): if reason: log.info(f"[{trainer.global_rank}] {reason}") - def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: + def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, Optional[str]]: should_stop = False reason = None if self.check_finite and not torch.isfinite(current): diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index ea508775d126f..6cef11d2e42aa 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -22,17 +22,18 @@ import torch from torch.nn import Module from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parameter import Parameter from torch.optim.optimizer import Optimizer +import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException log = logging.getLogger(__name__) -def multiplicative(epoch): +def multiplicative(epoch: int) -> int: return 2 @@ -99,7 +100,7 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) - _modules.extend(BaseFinetuning.flatten_modules(m)) else: - _modules = modules.modules() + _modules = list(modules.modules()) # Leaf nodes in the graph have no children, so we use that to filter return [m for m in _modules if not list(m.children())] @@ -109,7 +110,7 @@ def filter_params( modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True, requires_grad: bool = True - ) -> Generator: + ) -> Generator['Parameter', None, None]: """Yields the `requires_grad` parameters of a given module or list of modules. Args: @@ -162,7 +163,7 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: param.requires_grad = False @staticmethod - def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List: + def filter_on_optimizer(optimizer: 'Optimizer', params: Iterable['Parameter']) -> List['Parameter']: """ This function is used to exclude any parameter which already exists in this optimizer @@ -194,7 +195,7 @@ def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List: @staticmethod def unfreeze_and_add_param_group( modules: Union[Module, Iterable[Union[Module, Iterable]]], - optimizer: Optimizer, + optimizer: 'Optimizer', lr: Optional[float] = None, initial_denom_lr: float = 10., train_bn: bool = True, @@ -223,7 +224,7 @@ def unfreeze_and_add_param_group( BaseFinetuning.make_trainable(modules) params_lr = optimizer.param_groups[0]['lr'] if lr is None else float(lr) denom_lr = initial_denom_lr if lr is None else 1. - params = BaseFinetuning.filter_params(modules, train_bn=train_bn, requires_grad=True) + params: Iterable = BaseFinetuning.filter_params(modules, train_bn=train_bn, requires_grad=True) params = BaseFinetuning.filter_on_optimizer(optimizer, params) if params: optimizer.add_param_group({ @@ -231,21 +232,23 @@ def unfreeze_and_add_param_group( 'lr': params_lr / denom_lr, }) - def on_before_accelerator_backend_setup(self, trainer, pl_module): + def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self.freeze_before_training(pl_module) - def on_train_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the epoch begins.""" for opt_idx, optimizer in trainer.train_loop.prepare_optimizers(): self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx) - def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): + def finetune_function( + self, pl_module: 'pl.LightningModule', epoch: int, optimizer: 'Optimizer', opt_idx: int + ) -> None: """ Override to add your unfreeze logic """ raise NotImplementedError - def freeze_before_training(self, pl_module: LightningModule): + def freeze_before_training(self, pl_module: 'pl.LightningModule') -> None: """ Override to add your freeze logic """ @@ -315,7 +318,7 @@ def __init__( self.round = round self.verbose = verbose - def on_fit_start(self, trainer, pl_module): + def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """ Raises: MisconfigurationException: @@ -325,10 +328,12 @@ def on_fit_start(self, trainer, pl_module): return raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute") - def freeze_before_training(self, pl_module: LightningModule): + def freeze_before_training(self, pl_module: 'pl.LightningModule') -> None: self.freeze(pl_module.backbone) - def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): + def finetune_function( + self, pl_module: 'pl.LightningModule', epoch: int, optimizer: 'Optimizer', opt_idx: int + ) -> None: """Called when the epoch begins.""" if epoch == self.unfreeze_backbone_at_epoch: diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index ffd39e9af4c16..5324a770f8b29 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -23,12 +23,14 @@ import shutil import subprocess import time -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple +import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import DeviceType, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.utilities.types import STEP_OUTPUT class GPUStatsMonitor(Callback): @@ -100,8 +102,10 @@ def __init__( 'fan_speed': fan_speed, 'temperature': temperature }) + self._snap_intra_step_time: Optional[float] = None + self._snap_inter_step_time: Optional[float] = None - def on_train_start(self, trainer, pl_module) -> None: + def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: if not trainer.logger: raise MisconfigurationException('Cannot use GPUStatsMonitor callback with Trainer that has no logger.') @@ -111,14 +115,18 @@ def on_train_start(self, trainer, pl_module) -> None: f' since gpus attribute in Trainer is set to {trainer.gpus}.' ) - self._gpu_ids = ','.join(map(str, trainer.data_parallel_device_ids)) + self._gpu_ids = ','.join( + map(str, trainer.data_parallel_device_ids if trainer.data_parallel_device_ids is not None else []) + ) - def on_train_epoch_start(self, trainer, pl_module) -> None: + def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self._snap_intra_step_time = None self._snap_inter_step_time = None @rank_zero_only - def on_train_batch_start(self, trainer, pl_module, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_start( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: if self._log_stats.intra_step_time: self._snap_intra_step_time = time.time() @@ -137,7 +145,8 @@ def on_train_batch_start(self, trainer, pl_module, batch: Any, batch_idx: int, d @rank_zero_only def on_train_batch_end( - self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int ) -> None: if self._log_stats.inter_step_time: self._snap_inter_step_time = time.time() @@ -159,7 +168,10 @@ def _get_gpu_stats(self, queries: List[str]) -> List[List[float]]: gpu_query = ','.join(queries) format = 'csv,nounits,noheader' result = subprocess.run( - [shutil.which('nvidia-smi'), f'--query-gpu={gpu_query}', f'--format={format}', f'--id={self._gpu_ids}'], + [ + str(shutil.which('nvidia-smi')), f'--query-gpu={gpu_query}', f'--format={format}', + f'--id={self._gpu_ids}' + ], encoding="utf-8", stdout=subprocess.PIPE, stderr=subprocess.PIPE, # for backward compatibility with python version 3.6 @@ -172,8 +184,8 @@ def _to_float(x: str) -> float: except ValueError: return 0. - stats = result.stdout.strip().split(os.linesep) - stats = [[_to_float(x) for x in s.split(', ')] for s in stats] + stats_str = result.stdout.strip().split(os.linesep) + stats = [[_to_float(x) for x in s.split(', ')] for s in stats_str] return stats @staticmethod @@ -210,7 +222,7 @@ def _get_gpu_device_stat_keys(self) -> List[Tuple[str, str]]: return stat_keys @staticmethod - def _should_log(trainer) -> bool: + def _should_log(trainer: 'pl.Trainer') -> bool: should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop) return should_log diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index b1885087f4da0..4b78f4056abc9 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -22,6 +22,7 @@ from typing import Dict +import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback @@ -71,12 +72,12 @@ def __init__(self, scheduling: Dict[int, int]): self.scheduling = scheduling self.epochs = sorted(scheduling.keys()) - def going_to_accumulate_grad_batches(self): + def going_to_accumulate_grad_batches(self) -> bool: return any([v > 1 for v in self.scheduling.values()]) - def on_train_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: epoch = trainer.current_epoch for i in reversed(range(len(self.epochs))): if epoch >= self.epochs[i]: - trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) + trainer.accumulate_grad_batches = self.scheduling[self.epochs[i]] break diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index a7485814b1b17..e2112f9ac46f8 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -80,81 +80,83 @@ def __init__( on_after_backward: Optional[Callable] = None, on_before_zero_grad: Optional[Callable] = None, ): + # need to call setattr here for mypy: https://github.com/python/mypy/issues/2427 + # TODO: Add Callable argument and return types to init signature variables if on_before_accelerator_backend_setup is not None: - self.on_before_accelerator_backend_setup = on_before_accelerator_backend_setup + setattr(self, 'on_before_accelerator_backend_setup', on_before_accelerator_backend_setup) if setup is not None: - self.setup = setup + setattr(self, 'setup', setup) if on_configure_sharded_model is not None: - self.on_configure_sharded_model = on_configure_sharded_model + setattr(self, 'on_configure_sharded_model', on_configure_sharded_model) if teardown is not None: - self.teardown = teardown + setattr(self, 'teardown', teardown) if on_init_start is not None: - self.on_init_start = on_init_start + setattr(self, 'on_init_start', on_init_start) if on_init_end is not None: - self.on_init_end = on_init_end + setattr(self, 'on_init_end', on_init_end) if on_fit_start is not None: - self.on_fit_start = on_fit_start + setattr(self, 'on_fit_start', on_fit_start) if on_fit_end is not None: - self.on_fit_end = on_fit_end + setattr(self, 'on_fit_end', on_fit_end) if on_sanity_check_start is not None: - self.on_sanity_check_start = on_sanity_check_start + setattr(self, 'on_sanity_check_start', on_sanity_check_start) if on_sanity_check_end is not None: - self.on_sanity_check_end = on_sanity_check_end + setattr(self, 'on_sanity_check_end', on_sanity_check_end) if on_train_batch_start is not None: - self.on_train_batch_start = on_train_batch_start + setattr(self, 'on_train_batch_start', on_train_batch_start) if on_train_batch_end is not None: - self.on_train_batch_end = on_train_batch_end + setattr(self, 'on_train_batch_end', on_train_batch_end) if on_train_epoch_start is not None: - self.on_train_epoch_start = on_train_epoch_start + setattr(self, 'on_train_epoch_start', on_train_epoch_start) if on_train_epoch_end is not None: - self.on_train_epoch_end = on_train_epoch_end + setattr(self, 'on_train_epoch_end', on_train_epoch_end) if on_validation_epoch_start is not None: - self.on_validation_epoch_start = on_validation_epoch_start + setattr(self, 'on_validation_epoch_start', on_validation_epoch_start) if on_validation_epoch_end is not None: - self.on_validation_epoch_end = on_validation_epoch_end + setattr(self, 'on_validation_epoch_end', on_validation_epoch_end) if on_test_epoch_start is not None: - self.on_test_epoch_start = on_test_epoch_start + setattr(self, 'on_test_epoch_start', on_test_epoch_start) if on_test_epoch_end is not None: - self.on_test_epoch_end = on_test_epoch_end + setattr(self, 'on_test_epoch_end', on_test_epoch_end) if on_epoch_start is not None: - self.on_epoch_start = on_epoch_start + setattr(self, 'on_epoch_start', on_epoch_start) if on_epoch_end is not None: - self.on_epoch_end = on_epoch_end + setattr(self, 'on_epoch_end', on_epoch_end) if on_batch_start is not None: - self.on_batch_start = on_batch_start + setattr(self, 'on_batch_start', on_batch_start) if on_validation_batch_start is not None: - self.on_validation_batch_start = on_validation_batch_start + setattr(self, 'on_validation_batch_start', on_validation_batch_start) if on_validation_batch_end is not None: - self.on_validation_batch_end = on_validation_batch_end + setattr(self, 'on_validation_batch_end', on_validation_batch_end) if on_test_batch_start is not None: - self.on_test_batch_start = on_test_batch_start + setattr(self, 'on_test_batch_start', on_test_batch_start) if on_test_batch_end is not None: - self.on_test_batch_end = on_test_batch_end + setattr(self, 'on_test_batch_end', on_test_batch_end) if on_batch_end is not None: - self.on_batch_end = on_batch_end + setattr(self, 'on_batch_end', on_batch_end) if on_train_start is not None: - self.on_train_start = on_train_start + setattr(self, 'on_train_start', on_train_start) if on_train_end is not None: - self.on_train_end = on_train_end + setattr(self, 'on_train_end', on_train_end) if on_pretrain_routine_start is not None: - self.on_pretrain_routine_start = on_pretrain_routine_start + setattr(self, 'on_pretrain_routine_start', on_pretrain_routine_start) if on_pretrain_routine_end is not None: - self.on_pretrain_routine_end = on_pretrain_routine_end + setattr(self, 'on_pretrain_routine_end', on_pretrain_routine_end) if on_validation_start is not None: - self.on_validation_start = on_validation_start + setattr(self, 'on_validation_start', on_validation_start) if on_validation_end is not None: - self.on_validation_end = on_validation_end + setattr(self, 'on_validation_end', on_validation_end) if on_test_start is not None: - self.on_test_start = on_test_start + setattr(self, 'on_test_start', on_test_start) if on_test_end is not None: - self.on_test_end = on_test_end + setattr(self, 'on_test_end', on_test_end) if on_keyboard_interrupt is not None: - self.on_keyboard_interrupt = on_keyboard_interrupt + setattr(self, 'on_keyboard_interrupt', on_keyboard_interrupt) if on_save_checkpoint is not None: - self.on_save_checkpoint = on_save_checkpoint + setattr(self, 'on_save_checkpoint', on_save_checkpoint) if on_load_checkpoint is not None: - self.on_load_checkpoint = on_load_checkpoint + setattr(self, 'on_load_checkpoint', on_load_checkpoint) if on_after_backward is not None: - self.on_after_backward = on_after_backward + setattr(self, 'on_after_backward', on_after_backward) if on_before_zero_grad is not None: - self.on_before_zero_grad = on_before_zero_grad + setattr(self, 'on_before_zero_grad', on_before_zero_grad) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 7530bfaa9d21e..b1c45ce402cba 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -20,8 +20,9 @@ """ -from typing import Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Sequence, Union +import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -73,10 +74,11 @@ def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = self.logging_interval = logging_interval self.log_momentum = log_momentum - self.lrs = None - self.lr_sch_names = [] + self.lrs: Dict[str, List] = {} + self.lr_sch_names: List[str] = [] + self.last_momentum_values: Dict[str, Optional[float]] = {} - def on_train_start(self, trainer, *args, **kwargs): + def on_train_start(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """ Called before training, determines unique names for all lr schedulers in the case of multiple of the same type or in @@ -100,8 +102,9 @@ def on_train_start(self, trainer, *args, **kwargs): if self.log_momentum: - def _check_no_key(key): - return any(key not in sch['scheduler'].optimizer.defaults for sch in trainer.lr_schedulers) + def _check_no_key(key: str) -> bool: + return any(key not in sch['scheduler'].optimizer.defaults + for sch in trainer.lr_schedulers) if isinstance(trainer.lr_schedulers, Iterable) else True if _check_no_key('momentum') and _check_no_key('betas'): rank_zero_warn( @@ -116,7 +119,7 @@ def _check_no_key(key): self.lrs = {name: [] for name in names} self.last_momentum_values = {name + "-momentum": None for name in names} - def on_train_batch_start(self, trainer, *args, **kwargs): + def on_train_batch_start(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: if not self._should_log(trainer): return @@ -127,7 +130,7 @@ def on_train_batch_start(self, trainer, *args, **kwargs): if latest_stat: trainer.logger.log_metrics(latest_stat, step=trainer.global_step) - def on_train_epoch_start(self, trainer, *args, **kwargs): + def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: if self.logging_interval != 'step': interval = 'epoch' if self.logging_interval is None else 'any' latest_stat = self._extract_stats(trainer, interval) @@ -135,10 +138,12 @@ def on_train_epoch_start(self, trainer, *args, **kwargs): if latest_stat: trainer.logger.log_metrics(latest_stat, step=trainer.global_step) - def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: + def _extract_stats(self, trainer: 'pl.Trainer', interval: str) -> Dict[str, float]: latest_stat = {} - for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers): + for name, scheduler in zip( + self.lr_sch_names, trainer.lr_schedulers if trainer.lr_schedulers is not None else [] + ): if scheduler['interval'] == interval or interval == 'any': opt = scheduler['scheduler'].optimizer param_groups = opt.param_groups @@ -155,53 +160,59 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: return latest_stat - def _extract_lr(self, param_group, name: str) -> Dict[str, float]: - lr = param_group.get('lr') + def _extract_lr(self, param_group: Dict[str, float], name: str) -> Dict[str, float]: + lr = param_group['lr'] self.lrs[name].append(lr) return {name: lr} - def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str, float]: + def _extract_momentum(self, param_group: Dict[str, Union[Sequence[float], float]], name: str, + use_betas: bool) -> Dict[str, float]: if not self.log_momentum: return {} - momentum = param_group.get('betas')[0] if use_betas else param_group.get('momentum', 0) + _momentum = param_group['betas'] if use_betas else param_group.get('momentum', 0.) + if isinstance(_momentum, Sequence): + momentum = _momentum[0] + else: + momentum = _momentum self.last_momentum_values[name] = momentum return {name: momentum} - def _find_names(self, lr_schedulers) -> List[str]: + def _find_names(self, lr_schedulers: Optional[List[Any]]) -> List[str]: # Create uniqe names in the case we have multiple of the same learning # rate schduler + multiple parameter groups names = [] - for scheduler in lr_schedulers: - sch = scheduler['scheduler'] - if scheduler['name'] is not None: - name = scheduler['name'] - else: - opt_name = 'lr-' + sch.optimizer.__class__.__name__ - i, name = 1, opt_name - - # Multiple schduler of the same type - while True: - if name not in names: - break - i, name = i + 1, f'{opt_name}-{i}' - - # Multiple param groups for the same schduler - param_groups = sch.optimizer.param_groups - - if len(param_groups) != 1: - for i, pg in enumerate(param_groups): - temp = f'{name}/pg{i + 1}' - names.append(temp) - else: - names.append(name) - - self.lr_sch_names.append(name) + if lr_schedulers is not None: + for scheduler in lr_schedulers: + sch = scheduler['scheduler'] + if scheduler['name'] is not None: + name = scheduler['name'] + else: + opt_name = 'lr-' + sch.optimizer.__class__.__name__ + i, name = 1, opt_name + + # Multiple schduler of the same type + while True: + if name not in names: + break + i, name = i + 1, f'{opt_name}-{i}' + + # Multiple param groups for the same schduler + param_groups = sch.optimizer.param_groups + + if len(param_groups) != 1: + for i, pg in enumerate(param_groups): + temp = f'{name}/pg{i + 1}' + names.append(temp) + else: + names.append(name) + + self.lr_sch_names.append(name) return names @staticmethod - def _should_log(trainer) -> bool: - should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop) + def _should_log(trainer: "pl.Trainer") -> bool: + should_log: bool = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop) return should_log diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a1a44fd70b139..79a86d4830902 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -23,16 +23,18 @@ import re from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import numpy as np import torch import yaml +import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache log = logging.getLogger(__name__) @@ -171,6 +173,7 @@ class ModelCheckpoint(Callback): CHECKPOINT_NAME_LAST = "last" FILE_EXTENSION = ".ckpt" STARTING_VERSION = 1 + _period: Optional[int] def __init__( self, @@ -195,20 +198,20 @@ def __init__( self.save_weights_only = save_weights_only self.auto_insert_metric_name = auto_insert_metric_name self._last_global_step_saved = -1 - self.current_score = None - self.best_k_models = {} + self.current_score: Optional[torch.Tensor] = None + self.best_k_models: Dict[str, torch.Tensor] = {} self.kth_best_model_path = "" - self.best_model_score = None + self.best_model_score: Optional[torch.Tensor] = None self.best_model_path = "" self.last_model_path = "" - self.save_function = None + self.save_function: Optional[Callable] = None self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) self.__init_triggers(every_n_train_steps, every_n_val_epochs, period) self.__validate_init_configuration() - def on_pretrain_routine_start(self, trainer, pl_module): + def on_pretrain_routine_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """ When pretrain routine starts we build the ckpt dir on the fly """ @@ -216,7 +219,8 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.save_function = trainer.save_checkpoint def on_train_batch_end( - self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int ) -> None: """ Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """ if self._should_skip_saving_checkpoint(trainer): @@ -227,7 +231,7 @@ def on_train_batch_end( return self.save_checkpoint(trainer) - def on_validation_end(self, trainer, pl_module) -> None: + def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """ checkpoints can be saved at the end of the val loop """ @@ -239,7 +243,8 @@ def on_validation_end(self, trainer, pl_module) -> None: return self.save_checkpoint(trainer) - def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: + def on_save_checkpoint(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', + checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, "best_model_score": self.best_model_score, @@ -248,11 +253,11 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> "dirpath": self.dirpath } - def on_load_checkpoint(self, callback_state: Dict[str, Any]): + def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def save_checkpoint(self, trainer, unused: Optional = None): + def save_checkpoint(self, trainer: 'pl.Trainer', unused: Optional['pl.LightningModule'] = None) -> None: """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` @@ -284,7 +289,7 @@ def save_checkpoint(self, trainer, unused: Optional = None): # Mode 3: save last checkpoints self._save_last_checkpoint(trainer, monitor_candidates) - def _should_skip_saving_checkpoint(self, trainer) -> bool: + def _should_skip_saving_checkpoint(self, trainer: 'pl.Trainer') -> bool: from pytorch_lightning.trainer.states import TrainerState return ( trainer.fast_dev_run # disable checkpointing with fast_dev_run @@ -293,7 +298,7 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: or self._last_global_step_saved == trainer.global_step # already saved at the last step ) - def __validate_init_configuration(self): + def __validate_init_configuration(self) -> None: if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') if self._every_n_train_steps < 0: @@ -328,7 +333,9 @@ def __validate_init_configuration(self): ' will duplicate the last checkpoint saved.' ) - def __init_ckpt_dir(self, dirpath, filename, save_top_k): + def __init_ckpt_dir( + self, dirpath: Optional[Union[str, Path]], filename: Optional[str], save_top_k: Optional[int] + ) -> None: self._fs = get_filesystem(str(dirpath) if dirpath else '') @@ -341,10 +348,10 @@ def __init_ckpt_dir(self, dirpath, filename, save_top_k): if dirpath and self._fs.protocol == 'file': dirpath = os.path.realpath(dirpath) - self.dirpath: Union[str, None] = dirpath or None + self.dirpath: Union[str, None, Path] = dirpath or None self.filename = filename or None - def __init_monitor_mode(self, monitor, mode): + def __init_monitor_mode(self, monitor: Optional[str], mode: str) -> None: torch_inf = torch.tensor(np.Inf) mode_dict = { "min": (torch_inf, "min"), @@ -397,21 +404,23 @@ def period(self, value: Optional[int]) -> None: self._period = value @rank_zero_only - def _del_model(self, filepath: str): + def _del_model(self, filepath: str) -> None: if self._fs.exists(filepath): self._fs.rm(filepath) log.debug(f"Removed checkpoint: {filepath}") - def _save_model(self, trainer, filepath: str): + def _save_model(self, trainer: 'pl.Trainer', filepath: str) -> None: if trainer.training_type_plugin.rpc_enabled: # RPCPlugin manages saving all model states # TODO: the rpc plugin should wrap trainer.save_checkpoint # instead of us having to do it here manually - trainer.training_type_plugin.rpc_save_model(trainer, self._do_save, filepath) + getattr( + trainer.training_type_plugin, 'rpc_save_model', lambda trainer, func, filepath: func(trainer, filepath) + ) else: self._do_save(trainer, filepath) - def _do_save(self, trainer, filepath: str): + def _do_save(self, trainer: 'pl.Trainer', filepath: str) -> None: # in debugging, track when we save checkpoints trainer.dev_debugger.track_checkpointing_history(filepath) @@ -425,8 +434,8 @@ def _do_save(self, trainer, filepath: str): else: raise ValueError(".save_function() not set") - def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool: - if current is None: + def check_monitor_top_k(self, trainer: 'pl.Trainer', current: Optional[torch.Tensor] = None) -> bool: + if current is None or self.save_top_k is None: return False if self.save_top_k == -1: @@ -445,7 +454,7 @@ def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) - current = torch.tensor(current) monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] - should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) + should_update_best_and_save: bool = bool(monitor_op(current, self.best_k_models[self.kth_best_model_path])) # If using multiple devices, make sure all processes are unanimous on the decision. should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save) @@ -523,7 +532,7 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any], ckpt_name = f"{filename}{self.FILE_EXTENSION}" return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name - def __resolve_ckpt_dir(self, trainer): + def __resolve_ckpt_dir(self, trainer: 'pl.Trainer') -> None: """ Determines model checkpoint save directory at runtime. References attributes from the trainer's logger to determine where to save checkpoints. @@ -548,12 +557,15 @@ def __resolve_ckpt_dir(self, trainer): else: save_dir = trainer.logger.save_dir or trainer.default_root_dir - version = ( + version: str = ( trainer.logger.version if isinstance(trainer.logger.version, str) else f"version_{trainer.logger.version}" ) + name: str = trainer.logger.name - version, name = trainer.training_type_plugin.broadcast((version, trainer.logger.name)) + broadcasted: Tuple = tuple(trainer.training_type_plugin.broadcast((version, name))) + version = str(broadcasted[0]) + name = str(broadcasted[1]) ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints") else: @@ -564,7 +576,7 @@ def __resolve_ckpt_dir(self, trainer): if not trainer.fast_dev_run and trainer.is_global_zero: self._fs.makedirs(self.dirpath, exist_ok=True) - def _add_backward_monitor_support(self, trainer): + def _add_backward_monitor_support(self, trainer: 'pl.Trainer') -> None: metrics = trainer.logger_connector.callback_metrics deprecation_warning = False @@ -583,7 +595,7 @@ def _add_backward_monitor_support(self, trainer): " and use it as `Trainer(callbacks=[mc])`.", DeprecationWarning ) - def _validate_monitor_key(self, trainer): + def _validate_monitor_key(self, trainer: 'pl.Trainer') -> None: metrics = trainer.logger_connector.callback_metrics # validate metric @@ -600,7 +612,7 @@ def _get_metric_interpolated_filepath_name( monitor_candidates: Dict[str, Any], epoch: int, step: int, - trainer, + trainer: 'pl.Trainer', del_filepath: Optional[str] = None, ) -> str: filepath = self.format_checkpoint_name(epoch, step, monitor_candidates) @@ -612,12 +624,12 @@ def _get_metric_interpolated_filepath_name( return filepath - def _monitor_candidates(self, trainer): + def _monitor_candidates(self, trainer: 'pl.Trainer') -> Dict[str, Any]: monitor_candidates = deepcopy(trainer.logger_connector.callback_metrics) monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch) return monitor_candidates - def _save_last_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): + def _save_last_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates: Dict[str, Any]) -> None: if not self.save_last: return @@ -627,7 +639,7 @@ def _save_last_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): trainer.global_step, monitor_candidates, ) - filepath = os.path.join(self.dirpath, f"{filepath}{self.FILE_EXTENSION}") + filepath = os.path.join(str(self.dirpath), f"{filepath}{self.FILE_EXTENSION}") self._save_model(trainer, filepath) @@ -636,7 +648,7 @@ def _save_last_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): self.last_model_path = filepath - def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): + def _save_top_k_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates: Dict[str, Any]) -> None: if self.monitor is None or self.save_top_k == 0: return @@ -645,11 +657,11 @@ def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): step = monitor_candidates.get("step") if self.check_monitor_top_k(trainer, current): - self._update_best_and_save(current, epoch, step, trainer, monitor_candidates) + self._update_best_and_save(current, epoch, step, trainer, monitor_candidates) # type: ignore elif self.verbose: rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}") - def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): + def _save_none_monitor_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates: Dict[str, Any]) -> None: if self.monitor is not None or self.save_top_k == 0: return @@ -669,12 +681,12 @@ def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, A self.best_model_path = filepath - def _is_valid_monitor_key(self, metrics): + def _is_valid_monitor_key(self, metrics: Dict[str, Any]) -> bool: return self.monitor in metrics or len(metrics) == 0 def _update_best_and_save( - self, current: torch.Tensor, epoch: int, step: int, trainer, monitor_candidates: Dict[str, Any] - ): + self, current: torch.Tensor, epoch: int, step: int, trainer: 'pl.Trainer', monitor_candidates: Dict[str, Any] + ) -> None: k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k del_filepath = None @@ -695,11 +707,11 @@ def _update_best_and_save( if len(self.best_k_models) == k: # monitor dict has reached k elements _op = max if self.mode == "min" else min - self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) + self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore self.kth_value = self.best_k_models[self.kth_best_model_path] _op = min if self.mode == "min" else max - self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) + self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore self.best_model_score = self.best_k_models[self.best_model_path] if self.verbose: @@ -712,18 +724,18 @@ def _update_best_and_save( if del_filepath is not None and filepath != del_filepath: self._del_model(del_filepath) - def to_yaml(self, filepath: Optional[Union[str, Path]] = None): + def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None: """ Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML file. """ best_k = {k: v.item() for k, v in self.best_k_models.items()} if filepath is None: - filepath = os.path.join(self.dirpath, "best_k_models.yaml") + filepath = os.path.join(str(self.dirpath), "best_k_models.yaml") with self._fs.open(filepath, "w") as fp: yaml.dump(best_k, fp) - def file_exists(self, filepath: Union[str, Path], trainer) -> bool: + def file_exists(self, filepath: Union[str, Path], trainer: 'pl.Trainer') -> bool: """ Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks. diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index aee75c3fb1cff..c39f05fd20909 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -25,13 +25,16 @@ # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed -from typing import Optional, Union +from typing import Any, Dict, Optional, Union + +from pytorch_lightning.utilities.types import STEP_OUTPUT if importlib.util.find_spec('ipywidgets') is not None: from tqdm.auto import tqdm as _tqdm else: from tqdm import tqdm as _tqdm +import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback _PAD_SIZE = 5 @@ -43,20 +46,23 @@ class tqdm(_tqdm): """ @staticmethod - def format_num(n) -> str: + def format_num(n: Union[float, str]) -> str: """ Add additional padding to the formatted numbers """ should_be_padded = isinstance(n, (float, str)) - if not isinstance(n, str): - n = _tqdm.format_num(n) - if should_be_padded and 'e' not in n: - if '.' not in n and len(n) < _PAD_SIZE: + if isinstance(n, str): + formatted_n = n + else: + formatted_n = _tqdm.format_num(n) + + if should_be_padded and 'e' not in formatted_n: + if '.' not in formatted_n and len(formatted_n) < _PAD_SIZE: try: - _ = float(n) + _ = float(formatted_n) except ValueError: - return n - n += '.' - n += "0" * (_PAD_SIZE - len(n)) - return n + return formatted_n + formatted_n += '.' + formatted_n += "0" * (_PAD_SIZE - len(formatted_n)) + return formatted_n class ProgressBarBase(Callback): @@ -87,16 +93,20 @@ def on_train_batch_end(self, trainer, pl_module, outputs): """ - def __init__(self): + def __init__(self) -> None: - self._trainer = None + self._trainer: Optional['pl.Trainer'] = None self._train_batch_idx = 0 self._val_batch_idx = 0 self._test_batch_idx = 0 self._predict_batch_idx = 0 @property - def trainer(self): + def is_enabled(self) -> bool: + return True + + @property + def trainer(self) -> Optional['pl.Trainer']: return self._trainer @property @@ -132,13 +142,13 @@ def predict_batch_idx(self) -> int: return self._predict_batch_idx @property - def total_train_batches(self) -> int: + def total_train_batches(self) -> float: """ The total number of training batches during training, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the training dataloader is of infinite size. """ - return self.trainer.num_training_batches + return self.trainer.num_training_batches if self.trainer is not None else 0 @property def total_val_batches(self) -> int: @@ -148,9 +158,11 @@ def total_val_batches(self) -> int: validation dataloader is of infinite size. """ total_val_batches = 0 + if self.trainer is None: + return 0 if self.trainer.enable_validation: is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 - total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0 + total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0 # type: ignore return total_val_batches @@ -161,7 +173,7 @@ def total_test_batches(self) -> int: Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is of infinite size. """ - return sum(self.trainer.num_test_batches) + return sum(self.trainer.num_test_batches) if self.trainer is not None else 0 # type: ignore @property def total_predict_batches(self) -> int: @@ -170,9 +182,9 @@ def total_predict_batches(self) -> int: Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader is of infinite size. """ - return sum(self.trainer.num_predict_batches) + return sum(self.trainer.num_predict_batches) if self.trainer is not None else 0 # type: ignore - def disable(self): + def disable(self) -> None: """ You should provide a way to disable the progress bar. The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this to disable the @@ -180,7 +192,7 @@ def disable(self): """ raise NotImplementedError - def enable(self): + def enable(self) -> None: """ You should provide a way to enable the progress bar. The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training @@ -189,40 +201,52 @@ def enable(self): """ raise NotImplementedError - def print(self, *args, **kwargs): + def print(self, *args: Any, **kwargs: Any) -> None: """ You should provide a way to print without breaking the progress bar. """ print(*args, **kwargs) - def on_init_end(self, trainer): + def on_init_end(self, trainer: 'pl.Trainer') -> None: self._trainer = trainer - def on_train_start(self, trainer, pl_module): + def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self._train_batch_idx = trainer.batch_idx - def on_train_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self._train_batch_idx = 0 - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int + ) -> None: self._train_batch_idx += 1 - def on_validation_start(self, trainer, pl_module): + def on_validation_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self._val_batch_idx = 0 - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_validation_batch_end( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int + ) -> None: self._val_batch_idx += 1 - def on_test_start(self, trainer, pl_module): + def on_test_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self._test_batch_idx = 0 - def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_test_batch_end( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int + ) -> None: self._test_batch_idx += 1 - def on_predict_start(self, trainer, pl_module): + def on_predict_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self._predict_batch_idx = 0 - def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_predict_batch_end( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int + ) -> None: self._predict_batch_idx += 1 @@ -274,16 +298,16 @@ def init_validation_tqdm(self): """ - def __init__(self, refresh_rate: int = 1, process_position: int = 0): + def __init__(self, refresh_rate: int = 1, process_position: int = 0) -> None: super().__init__() self._refresh_rate = refresh_rate self._process_position = process_position self._enabled = True - self.main_progress_bar = None - self.val_progress_bar = None - self.test_progress_bar = None + self.main_progress_bar: Optional[tqdm] = None + self.val_progress_bar: Optional[tqdm] = None + self.test_progress_bar: Optional[tqdm] = None - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: # can't pickle the tqdm objects state = self.__dict__.copy() state['main_progress_bar'] = None @@ -379,39 +403,49 @@ def init_test_tqdm(self) -> tqdm: ) return bar - def on_sanity_check_start(self, trainer, pl_module): + def on_sanity_check_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() self.main_progress_bar = tqdm(disable=True) # dummy progress bar - def on_sanity_check_end(self, trainer, pl_module): + def on_sanity_check_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: super().on_sanity_check_end(trainer, pl_module) - self.main_progress_bar.close() - self.val_progress_bar.close() + self.close_prog_bar(self.main_progress_bar) + self.close_prog_bar(self.val_progress_bar) + + def close_prog_bar(self, prog_bar: Optional[tqdm]) -> None: + if prog_bar is not None: + prog_bar.close() - def on_train_start(self, trainer, pl_module): + def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: super().on_train_start(trainer, pl_module) self.main_progress_bar = self.init_train_tqdm() - def on_train_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: super().on_train_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches - total_val_batches = self.total_val_batches + total_val_batches: Union[float, int] = self.total_val_batches if total_train_batches != float('inf'): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches + reset(self.main_progress_bar, total_batches) - self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}') + if self.main_progress_bar is not None: + self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}') - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int + ) -> None: super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches): self._update_bar(self.main_progress_bar) - self.main_progress_bar.set_postfix(trainer.progress_bar_dict) + if self.main_progress_bar is not None: + self.main_progress_bar.set_postfix(trainer.progress_bar_dict) - def on_validation_start(self, trainer, pl_module): + def on_validation_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: super().on_validation_start(trainer, pl_module) if trainer.sanity_checking: reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches)) @@ -420,66 +454,81 @@ def on_validation_start(self, trainer, pl_module): self.val_progress_bar = self.init_validation_tqdm() reset(self.val_progress_bar, self.total_val_batches) - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_validation_batch_end( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int + ) -> None: super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.val_batch_idx, self.total_val_batches): self._update_bar(self.val_progress_bar) self._update_bar(self.main_progress_bar) - def on_validation_end(self, trainer, pl_module): + def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: super().on_validation_end(trainer, pl_module) if self.main_progress_bar is not None: self.main_progress_bar.set_postfix(trainer.progress_bar_dict) - self.val_progress_bar.close() + self.close_prog_bar(self.val_progress_bar) - def on_train_end(self, trainer, pl_module): + def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: super().on_train_end(trainer, pl_module) - self.main_progress_bar.close() + self.close_prog_bar(self.main_progress_bar) - def on_test_start(self, trainer, pl_module): + def on_test_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: super().on_test_start(trainer, pl_module) self.test_progress_bar = self.init_test_tqdm() self.test_progress_bar.total = convert_inf(self.total_test_batches) - def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_test_batch_end( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int + ) -> None: super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.test_batch_idx, self.total_test_batches): self._update_bar(self.test_progress_bar) - def on_test_end(self, trainer, pl_module): + def on_test_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: super().on_test_end(trainer, pl_module) - self.test_progress_bar.close() + self.close_prog_bar(self.test_progress_bar) - def on_predict_start(self, trainer, pl_module): + def on_predict_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: super().on_predict_start(trainer, pl_module) self.predict_progress_bar = self.init_predict_tqdm() self.predict_progress_bar.total = convert_inf(self.total_predict_batches) - def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_predict_batch_end( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int + ) -> None: super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.predict_batch_idx, self.total_predict_batches): self._update_bar(self.predict_progress_bar) - def on_predict_end(self, trainer, pl_module): + def on_predict_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self.predict_progress_bar.close() def print( - self, *args, sep: str = ' ', end: str = os.linesep, file: Optional[io.TextIOBase] = None, nolock: bool = False - ): + self, + *args: Any, + sep: str = ' ', + end: str = os.linesep, + file: Optional[io.TextIOBase] = None, + nolock: bool = False, + **kwargs: Any + ) -> None: active_progress_bar = None - if not self.main_progress_bar.disable: + if not (self.main_progress_bar is not None and self.main_progress_bar.disable): active_progress_bar = self.main_progress_bar - elif not self.val_progress_bar.disable: + elif not (self.val_progress_bar is not None and self.val_progress_bar.disable): active_progress_bar = self.val_progress_bar - elif not self.test_progress_bar.disable: + elif not (self.test_progress_bar is not None and self.test_progress_bar.disable): active_progress_bar = self.test_progress_bar if active_progress_bar is not None: s = sep.join(map(str, args)) - active_progress_bar.write(s, end=end, file=file, nolock=nolock) + active_progress_bar.write(s, end=end, file=file, nolock=nolock, **kwargs) - def _should_update(self, current, total): + def _should_update(self, current: int, total: Union[float, int]) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def _update_bar(self, bar: Optional[tqdm]) -> None: @@ -502,7 +551,7 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: return x -def reset(bar: tqdm, total: Optional[int] = None) -> None: +def reset(bar: Optional[tqdm], total: Optional[Union[int, float]] = None) -> None: """ Resets the tqdm bar to 0 progress with a new total, unless it is disabled. """ - if not bar.disable: + if bar is not None and not bar.disable: bar.reset(total=convert_inf(total)) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 6c9fa8b4776c6..59a0b2d71859d 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -19,16 +19,19 @@ import logging from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Hashable, List, Optional, Sequence, Tuple, Union import torch import torch.nn.utils.prune as pytorch_prune from torch import nn +from torch.nn.modules.module import Module +import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import EPOCH_OUTPUT log = logging.getLogger(__name__) @@ -47,7 +50,7 @@ } _PARAM_TUPLE = Tuple[nn.Module, str] -_PARAM_LIST = Union[List[_PARAM_TUPLE], Tuple[_PARAM_TUPLE]] +_PARAM_LIST = Sequence[_PARAM_TUPLE] _MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict) @@ -57,7 +60,7 @@ class ModelPruning(Callback): def __init__( self, pruning_fn: Union[Callable, str], - parameters_to_prune: Optional[_PARAM_LIST] = None, + parameters_to_prune: _PARAM_LIST = (), parameter_names: Optional[List[str]] = None, use_global_unstructured: bool = True, amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5, @@ -153,9 +156,9 @@ def __init__( self._use_lottery_ticket_hypothesis = use_lottery_ticket_hypothesis self._resample_parameters = resample_parameters self._parameter_names = parameter_names or self.PARAMETER_NAMES - self._global_kwargs = {} - self._original_layers = None - self._pruning_fn_name = None + self._global_kwargs: Dict[str, Any] = {} + self._original_layers: Optional[Dict[Hashable, Dict[str, Union[torch.nn.Module, List[Tuple[int, str]]]]]] = None + self._pruning_fn_name: Optional[str] = None for name in self._parameter_names: if name not in self.PARAMETER_NAMES: @@ -196,9 +199,10 @@ def __init__( " HINT: if passing a `BasePruningMethod`, pass the the class, not an instance" ) - if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured": + # need to ignore typing here since pytorch base class does not define the PRUNING_TYPE attribute + if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured": # type: ignore raise MisconfigurationException( - 'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.' + 'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.' # type: ignore f" Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. " ) @@ -206,7 +210,7 @@ def __init__( self._apply_pruning = apply_pruning self._make_pruning_permanent = make_pruning_permanent - if not isinstance(amount, (int, float, Callable)): + if not (isinstance(amount, (int, float)) or callable(amount)): raise MisconfigurationException( "`amount` should be provided and be either an int, a float or Callable function." ) @@ -218,13 +222,13 @@ def __init__( self._verbose = verbose - def filter_parameters_to_prune(self, parameters_to_prune: Optional[_PARAM_LIST] = None) -> Optional[_PARAM_LIST]: + def filter_parameters_to_prune(self, parameters_to_prune: _PARAM_LIST = ()) -> _PARAM_LIST: """ This function can be overridden to control which module to prune. """ return parameters_to_prune - def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytorch_prune.BasePruningMethod]: + def _create_pruning_fn(self, pruning_fn: str, **kwargs: Any) -> Union[Callable, pytorch_prune.BasePruningMethod]: """ This function takes `pruning_fn`, a function name. @@ -232,11 +236,13 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytor ELSE, pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`. """ + pruning_fn = ( + _PYTORCH_PRUNING_METHOD[pruning_fn] + if self._use_global_unstructured else _PYTORCH_PRUNING_FUNCTIONS[pruning_fn] + ) + assert callable(pruning_fn) if self._use_global_unstructured: - pruning_fn = _PYTORCH_PRUNING_METHOD[pruning_fn] self._global_kwargs = kwargs - else: - pruning_fn = _PYTORCH_PRUNING_FUNCTIONS[pruning_fn] # save the function __name__ now because partial does not include it # and there are issues setting the attribute manually in ddp. self._pruning_fn_name = pruning_fn.__name__ @@ -245,10 +251,10 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytor return ModelPruning._wrap_pruning_fn(pruning_fn, **kwargs) @staticmethod - def _wrap_pruning_fn(pruning_fn, **kwargs): + def _wrap_pruning_fn(pruning_fn: Callable, **kwargs: Any) -> partial: return partial(pruning_fn, **kwargs) - def make_pruning_permanent(self, pl_module: LightningModule): + def make_pruning_permanent(self, pl_module: LightningModule) -> None: """ Removes pruning buffers from any pruned modules @@ -261,14 +267,14 @@ def make_pruning_permanent(self, pl_module: LightningModule): hook.remove(module) del module._forward_pre_hooks[k] - def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str): + def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str) -> None: trained = getattr(module, tensor_name) orig = getattr(orig_module, tensor_name) if trained is None or orig is None: return trained.data = orig.data.to(trained.device) - def apply_lottery_ticket_hypothesis(self): + def apply_lottery_ticket_hypothesis(self) -> None: r""" Lottery ticket hypothesis algorithm (see page 2 of the paper): @@ -282,33 +288,36 @@ def apply_lottery_ticket_hypothesis(self): The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta` """ # noqa: E501 - def copy_param(new, old, name: str) -> None: + def copy_param(new: nn.Module, old: nn.Module, name: str) -> None: dst = getattr(new, name) src = getattr(old, name) if dst is None or src is None or not isinstance(dst, torch.Tensor) or not isinstance(src, torch.Tensor): return dst.data = src.data.to(dst.device) + if self._original_layers is None: + raise RuntimeError for d in self._original_layers.values(): - copy, names = d["data"], d["names"] - if self._resample_parameters and hasattr(copy, "reset_parameters"): + copy: Module = d['data'] + names: List[Tuple[int, str]] = d["names"] + if self._resample_parameters and hasattr(copy, "reset_parameters") and callable(copy.reset_parameters): copy = deepcopy(copy) # keep the original parameters copy.reset_parameters() for i, name in names: new, new_name = self._parameters_to_prune[i] copy_param(new, copy, name) - def _apply_local_pruning(self, amount: float): + def _apply_local_pruning(self, amount: float) -> None: for module, name in self._parameters_to_prune: self.pruning_fn(module, name=name, amount=amount) - def _resolve_global_kwargs(self, amount: float): + def _resolve_global_kwargs(self, amount: float) -> Dict[str, Any]: self._global_kwargs["amount"] = amount params = set(inspect.signature(self.pruning_fn).parameters) params.discard("self") return {k: v for k, v in self._global_kwargs.items() if k in params} - def _apply_global_pruning(self, amount: float): + def _apply_global_pruning(self, amount: float) -> None: pytorch_prune.global_unstructured( self._parameters_to_prune, pruning_method=self.pruning_fn, **self._resolve_global_kwargs(amount) ) @@ -321,7 +330,7 @@ def _get_pruned_stats(module: nn.Module, name: str) -> Tuple[int, int]: mask = getattr(module, attr) return (mask == 0).sum().item(), mask.numel() - def apply_pruning(self, amount: Union[int, float]): + def apply_pruning(self, amount: Union[int, float]) -> None: """ Applies pruning to ``parameters_to_prune``. """ if self._verbose: prev_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] @@ -338,7 +347,7 @@ def apply_pruning(self, amount: Union[int, float]): @rank_zero_only def _log_sparsity_stats( self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0 - ): + ) -> None: total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters()) prev_total_zeros = sum(zeros for zeros, _ in prev) curr_total_zeros = sum(zeros for zeros, _ in curr) @@ -357,7 +366,7 @@ def _log_sparsity_stats( f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})" ) - def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule): + def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None: parameters_to_prune = self.sanitize_parameters_to_prune( pl_module, self._parameters_to_prune, parameter_names=self._parameter_names ) @@ -373,26 +382,31 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModul self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []}) self._original_layers[id_]["names"].append((i, name)) - def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs): - current_epoch = trainer.current_epoch - prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning - amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount + def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: LightningModule, outputs: EPOCH_OUTPUT) -> None: + current_epoch = pl_module.current_epoch + prune = self._apply_pruning(current_epoch) if callable(self._apply_pruning) else self._apply_pruning + amount = self.amount(current_epoch) if callable(self.amount) else self.amount if not prune or not amount: return self.apply_pruning(amount) if ( self._use_lottery_ticket_hypothesis(current_epoch) - if isinstance(self._use_lottery_ticket_hypothesis, Callable) else self._use_lottery_ticket_hypothesis + if callable(self._use_lottery_ticket_hypothesis) else self._use_lottery_ticket_hypothesis ): self.apply_lottery_ticket_hypothesis() - def on_train_end(self, trainer, pl_module: LightningModule): + def on_train_end(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None: if self._make_pruning_permanent: rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint.") self.make_pruning_permanent(pl_module) - def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]): + def on_save_checkpoint( + self, + trainer: 'pl.Trainer', + pl_module: LightningModule, + checkpoint: Dict[str, Any], + ) -> Dict[str, Any]: if self._make_pruning_permanent: rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint.") prev_device = pl_module.device @@ -402,11 +416,13 @@ def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Di checkpoint["state_dict"] = copy.state_dict() pl_module.to(prev_device) + return checkpoint + @staticmethod def sanitize_parameters_to_prune( pl_module: LightningModule, - parameters_to_prune: Optional[_PARAM_LIST] = None, - parameter_names: Optional[List[str]] = None, + parameters_to_prune: _PARAM_LIST = (), + parameter_names: Sequence[str] = (), ) -> _PARAM_LIST: """ This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``. @@ -415,13 +431,13 @@ def sanitize_parameters_to_prune( Raises: MisconfigurationException: If ``parameters_to_prune`` doesn't exist in the model, or - if ``parameters_to_prune`` is neither a list of tuple nor ``None``. + if ``parameters_to_prune`` is neither a list nor a tuple. """ parameters = parameter_names or ModelPruning.PARAMETER_NAMES current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)] - if parameters_to_prune is None: + if not parameters_to_prune: parameters_to_prune = [(m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None] elif ( diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py index 2b6064e232da7..ba2dafb428432 100644 --- a/pytorch_lightning/callbacks/quantization.py +++ b/pytorch_lightning/callbacks/quantization.py @@ -29,8 +29,8 @@ def wrap_qat_forward_context( - quant_cb, - model: pl.core.LightningModule, + quant_cb: Callback, + model: 'pl.LightningModule', func: Callable, trigger_condition: Optional[Union[Callable, int]] = None ) -> Callable: @@ -40,33 +40,38 @@ def wrap_qat_forward_context( """ # todo: consider using registering hook before/after forward @functools.wraps(func) - def wrapper(data) -> Any: - _is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer) - _is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition + def wrapper(data: torch.Tensor) -> Any: + _is_func_true = callable(trigger_condition) and trigger_condition(model.trainer) + _is_count_true = isinstance( + trigger_condition, int + ) and quant_cb._forward_calls < trigger_condition # type: ignore _quant_run = trigger_condition is None or _is_func_true or _is_count_true # apply custom trigger if _quant_run: - quant_cb._forward_calls += 1 - data = model.quant(data) + quant_cb._forward_calls += 1 # type: ignore + if callable(model.quant): + data = model.quant(data) data = func(data) # apply custom trigger - if _quant_run: + if _quant_run and callable(model.dequant): data = model.dequant(data) return data return wrapper -def wrap_quantize_forward_context(model: pl.core.LightningModule, func: Callable) -> Callable: +def wrap_quantize_forward_context(model: 'pl.LightningModule', func: Callable) -> Callable: """ Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out compatibility """ # todo: consider using registering hook before/after forward @functools.wraps(func) - def wrapper(data) -> Any: - data = model.quant(data) + def wrapper(data: torch.Tensor) -> Any: + if callable(model.quant): + data = model.quant(data) data = func(data) - data = model.dequant(data) + if callable(model.dequant): + data = model.dequant(data) return data return wrapper @@ -152,7 +157,9 @@ def custom_trigger_last(trainer): raise MisconfigurationException(f'For using {observer_type} you need to be using pytorch>=1.5.') self._observer_type = observer_type - if collect_quantization is not None and not isinstance(collect_quantization, (int, Callable)): + if collect_quantization is not None and not ( + isinstance(collect_quantization, int) or callable(collect_quantization) + ): raise MisconfigurationException( f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.' ) @@ -162,7 +169,7 @@ def custom_trigger_last(trainer): self._input_compatible = input_compatible self._forward_calls = 0 - def _check_feasible_fuse(self, model): + def _check_feasible_fuse(self, model: 'pl.LightningModule') -> bool: if not self.modules_to_fuse: return False for group in self.modules_to_fuse: @@ -172,7 +179,7 @@ def _check_feasible_fuse(self, model): ) return True - def on_fit_start(self, trainer, pl_module): + def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: # QuantStub converts tensors from floating point to quantized pl_module.quant = torch.quantization.QuantStub() # DeQuantStub converts tensors from quantized to floating point @@ -180,8 +187,12 @@ def on_fit_start(self, trainer, pl_module): # manually specify where tensors will be converted from quantized # to floating point in the quantized model self.__module_forward = pl_module.forward - pl_module.forward = wrap_qat_forward_context( - quant_cb=self, model=pl_module, func=pl_module.forward, trigger_condition=self._collect_quantization + + setattr( + pl_module, 'forward', + wrap_qat_forward_context( + quant_cb=self, model=pl_module, func=pl_module.forward, trigger_condition=self._collect_quantization + ) ) # attach a global qconfig, which contains information about what kind @@ -201,7 +212,7 @@ def on_fit_start(self, trainer, pl_module): # the model that will observe weight and activation tensors during calibration. torch.quantization.prepare_qat(pl_module, inplace=True) - def on_fit_end(self, trainer, pl_module): + def on_fit_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: pl_module.eval() # Convert the observed model to a quantized model. This does several things: # quantizes the weights, computes and stores the scale and bias value to be @@ -210,6 +221,6 @@ def on_fit_end(self, trainer, pl_module): torch.quantization.convert(pl_module, inplace=True) # check we shall preserve wrapper if self._input_compatible: - pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward) + setattr(pl_module, 'forward', wrap_quantize_forward_context(model=pl_module, func=self.__module_forward)) else: - pl_module.forward = self.__module_forward + setattr(pl_module, 'forward', self.__module_forward) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index df3edf17729bd..700265fbd3b23 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -16,7 +16,7 @@ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ """ from copy import deepcopy -from typing import Callable, Optional, Union +from typing import Any, Callable, cast, Optional, Union import torch from torch import nn @@ -30,10 +30,12 @@ if _TORCH_GREATER_EQUAL_1_6: from torch.optim.swa_utils import SWALR -_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor] +_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] class StochasticWeightAveraging(Callback): + _swa_epoch_start: Union[int, float] + _max_epochs: int def __init__( self, @@ -110,7 +112,7 @@ def __init__( if (swa_lrs is not None and (wrong_type or wrong_float or wrong_list)): raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.") - if avg_fn is not None and not isinstance(avg_fn, Callable): + if avg_fn is not None and not callable(avg_fn): raise MisconfigurationException("The `avg_fn` should be callable.") if device is not None and not isinstance(device, (torch.device, str)): @@ -122,11 +124,11 @@ def __init__( self._annealing_strategy = annealing_strategy self._avg_fn = avg_fn or self.avg_fn self._device = device - self._model_contains_batch_norm = None - self._average_model = None + self._model_contains_batch_norm: Optional[bool] = None + self._average_model: Optional['pl.LightningModule'] = None @property - def swa_start(self) -> int: + def swa_start(self) -> float: return max(self._swa_epoch_start - 1, 0) # 0-based @property @@ -134,21 +136,21 @@ def swa_end(self) -> int: return self._max_epochs - 1 # 0-based @staticmethod - def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'): + def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule') -> bool: return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()) - def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): + def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: # copy the model before moving it to accelerator device. self._average_model = deepcopy(pl_module) - def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): + def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: optimizers = trainer.optimizers lr_schedulers = trainer.lr_schedulers - if len(optimizers) != 1: + if optimizers is None or len(optimizers) != 1: raise MisconfigurationException("SWA currently works with 1 `optimizer`.") - if len(lr_schedulers) > 1: + if lr_schedulers is not None and len(lr_schedulers) > 1: raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") if isinstance(self._swa_epoch_start, float): @@ -161,12 +163,13 @@ def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): # virtually increase max_epochs to perform batch norm update on latest epoch. trainer.max_epochs += 1 - def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): + def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: if trainer.current_epoch == self.swa_start: # move average model to request device. - self._average_model = self._average_model.to(self._device or pl_module.device) + if self._average_model is not None: + self._average_model = self._average_model.to(self._device or pl_module.device) - optimizers = trainer.optimizers + optimizers = trainer.optimizers if trainer.optimizers is not None else [] for param_group in optimizers[0].param_groups: if self._swa_lrs is None: @@ -200,15 +203,17 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo rank_zero_info(f"Swapping scheduler {scheduler_cfg['scheduler']} for {self._swa_scheduler}") trainer.lr_schedulers[0] = default_scheduler_cfg else: + if trainer.lr_schedulers is None: + trainer.lr_schedulers = [] trainer.lr_schedulers.append(default_scheduler_cfg) self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) - if self.swa_start <= trainer.current_epoch <= self.swa_end: + if self.swa_start <= trainer.current_epoch <= self.swa_end and self._average_model is not None: self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn) # Note: No > here in case the callback is saved with the model and training continues - if trainer.current_epoch == self.swa_end + 1: + if trainer.current_epoch == self.swa_end + 1 and self._average_model is not None: # Transfer weights from average model to pl_module self.transfer_weights(self._average_model, pl_module) @@ -224,26 +229,26 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo self._accumulate_grad_batches = trainer.accumulate_grad_batches trainer.accumulate_grad_batches = len(trainer.train_dataloader) - def on_train_epoch_end(self, trainer: 'pl.Trainer', *args): + def on_train_epoch_end(self, trainer: 'pl.Trainer', *args: Any) -> None: trainer.train_loop._skip_backward = False - def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): + def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1: # BatchNorm epoch update. Reset state trainer.accumulate_grad_batches = self._accumulate_grad_batches trainer.num_training_batches -= 1 trainer.max_epochs -= 1 self.reset_momenta() - elif trainer.current_epoch == self.swa_end: + elif trainer.current_epoch == self.swa_end and self._average_model is not None: # Last SWA epoch. Transfer weights from average model to pl_module self.transfer_weights(self._average_model, pl_module) @staticmethod - def transfer_weights(src_pl_module: 'pl.LightningModule', dst_pl_module: 'pl.LightningModule'): + def transfer_weights(src_pl_module: 'pl.LightningModule', dst_pl_module: 'pl.LightningModule') -> None: for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): dst_param.detach().copy_(src_param.to(dst_param.device)) - def reset_batch_norm_and_save_state(self, pl_module: 'pl.LightningModule'): + def reset_batch_norm_and_save_state(self, pl_module: 'pl.LightningModule') -> None: """ Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154 """ @@ -252,16 +257,20 @@ def reset_batch_norm_and_save_state(self, pl_module: 'pl.LightningModule'): if not isinstance(module, nn.modules.batchnorm._BatchNorm): continue module.running_mean = torch.zeros_like( - module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype + cast(torch.Tensor, module.running_mean), + device=pl_module.device, + dtype=cast(torch.dtype, module.running_mean.dtype) ) module.running_var = torch.ones_like( - module.running_var, device=pl_module.device, dtype=module.running_var.dtype + cast(torch.Tensor, module.running_var), + device=pl_module.device, + dtype=cast(torch.dtype, module.running_var.dtype) ) self.momenta[module] = module.momentum module.momentum = None module.num_batches_tracked *= 0 - def reset_momenta(self): + def reset_momenta(self) -> None: """ Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165 """ @@ -270,8 +279,8 @@ def reset_momenta(self): @staticmethod def update_parameters( - average_model: 'pl.LightningModule', model: 'pl.LightningModule', n_averaged: torch.LongTensor, avg_fn: _AVG_FN - ): + average_model: 'pl.LightningModule', model: 'pl.LightningModule', n_averaged: torch.Tensor, avg_fn: _AVG_FN + ) -> None: """ Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112 """ @@ -285,8 +294,8 @@ def update_parameters( @staticmethod def avg_fn( - averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor - ) -> torch.FloatTensor: + averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.Tensor + ) -> torch.Tensor: """ Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97 """ diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py index 9b93499c82ea1..7b684af3e51ec 100644 --- a/pytorch_lightning/callbacks/timer.py +++ b/pytorch_lightning/callbacks/timer.py @@ -26,6 +26,7 @@ from pytorch_lightning.utilities import LightningEnum from pytorch_lightning.utilities.distributed import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT log = logging.getLogger(__name__) @@ -95,8 +96,8 @@ def __init__( self._duration = duration.total_seconds() if duration is not None else None self._interval = interval self._verbose = verbose - self._start_time = {stage: None for stage in RunningStage} - self._end_time = {stage: None for stage in RunningStage} + self._start_time: Dict[RunningStage, Union[None, float]] = {stage: None for stage in RunningStage} + self._end_time: Dict[RunningStage, Union[None, float]] = {stage: None for stage in RunningStage} self._offset = 0 def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: @@ -125,30 +126,35 @@ def time_remaining(self, stage: str = RunningStage.TRAINING) -> Optional[float]: if self._duration is not None: return self._duration - self.time_elapsed(stage) - def on_train_start(self, *args, **kwargs) -> None: + return None + + def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self._start_time[RunningStage.TRAINING] = time.monotonic() - def on_train_end(self, *args, **kwargs) -> None: + def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self._end_time[RunningStage.TRAINING] = time.monotonic() - def on_validation_start(self, *args, **kwargs) -> None: + def on_validation_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self._start_time[RunningStage.VALIDATING] = time.monotonic() - def on_validation_end(self, *args, **kwargs) -> None: + def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self._end_time[RunningStage.VALIDATING] = time.monotonic() - def on_test_start(self, *args, **kwargs) -> None: + def on_test_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self._start_time[RunningStage.TESTING] = time.monotonic() - def on_test_end(self, *args, **kwargs) -> None: + def on_test_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: self._end_time[RunningStage.TESTING] = time.monotonic() - def on_train_batch_end(self, trainer: 'pl.Trainer', *args, **kwargs) -> None: + def on_train_batch_end( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: STEP_OUTPUT, batch: Any, batch_idx: int, + dataloader_idx: int + ) -> None: if self._interval != Interval.step or self._duration is None: return self._check_time_remaining(trainer) - def on_train_epoch_end(self, trainer: 'pl.Trainer', *args, **kwargs) -> None: + def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT) -> None: if self._interval != Interval.epoch or self._duration is None: return self._check_time_remaining(trainer) @@ -166,8 +172,10 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: self._offset = time_elapsed.get(RunningStage.TRAINING.value, 0) def _check_time_remaining(self, trainer: 'pl.Trainer') -> None: + if self._duration is None: + raise MisconfigurationException('Cannot calculate remaining time if duration is None!') should_stop = self.time_elapsed() >= self._duration should_stop = trainer.accelerator.broadcast(should_stop) - trainer.should_stop = trainer.should_stop or should_stop + trainer.should_stop = bool(trainer.should_stop or should_stop) if should_stop and self._verbose: rank_zero_info(f"Time limit reached. Elapsed time is {self.time_elapsed}. Signaling Trainer to stop.") diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 2dd6f0b76a1b0..5eb7f437ce6f7 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -15,22 +15,24 @@ import functools from argparse import ArgumentParser, Namespace -from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union from torch.utils.data import DataLoader, Dataset +import pytorch_lightning as pl from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types +from pytorch_lightning.utilities.exceptions import MisconfigurationException class _DataModuleWrapper(type): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.__has_added_checks = False - def __call__(cls, *args, **kwargs): + def __call__(cls, *args: Any, **kwargs: Any) -> Any: """A wrapper for LightningDataModule that: 1. Runs user defined subclass's __init__ @@ -40,11 +42,11 @@ def __call__(cls, *args, **kwargs): if not cls.__has_added_checks: cls.__has_added_checks = True # Track prepare_data calls and make sure it runs on rank zero - cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data)) + cls.prepare_data: Callable = track_data_hook_calls(rank_zero_only(cls.prepare_data)) # Track setup calls - cls.setup = track_data_hook_calls(cls.setup) + cls.setup: Callable = track_data_hook_calls(cls.setup) # Track teardown calls - cls.teardown = track_data_hook_calls(cls.teardown) + cls.teardown: Callable = track_data_hook_calls(cls.teardown) # Get instance of LightningDataModule by mocking its __init__ via __call__ obj = type.__call__(cls, *args, **kwargs) @@ -52,7 +54,7 @@ def __call__(cls, *args, **kwargs): return obj -def track_data_hook_calls(fn): +def track_data_hook_calls(fn: Callable) -> Callable: """A decorator that checks if prepare_data/setup/teardown has been called. - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True @@ -69,7 +71,7 @@ def track_data_hook_calls(fn): """ @functools.wraps(fn) - def wrapped_fn(*args, **kwargs): + def wrapped_fn(*args: Any, **kwargs: Any) -> Any: # The object instance from which setup or prepare_data was called obj = args[0] @@ -141,23 +143,25 @@ def teardown(self): """ - name: str = ... + name: str def __init__( self, - train_transforms=None, - val_transforms=None, - test_transforms=None, - dims=None, + train_transforms: Optional[Callable] = None, + val_transforms: Optional[Callable] = None, + test_transforms: Optional[Callable] = None, + predict_transforms: Optional[Callable] = None, + dims: Optional[Tuple] = None, ): super().__init__() self._train_transforms = train_transforms self._val_transforms = val_transforms self._test_transforms = test_transforms + self._predict_transforms = predict_transforms self._dims = dims if dims is not None else () # Pointer to the trainer object - self.trainer = None + self.trainer: Optional['pl.Trainer'] = None # Private attrs to keep track of whether or not data hooks have been called yet self._has_prepared_data = False @@ -173,50 +177,61 @@ def __init__( self._has_teardown_predict = False @property - def train_transforms(self): + def train_transforms(self) -> Optional[Callable]: """ Optional transforms (or collection of transforms) you can apply to train dataset """ return self._train_transforms @train_transforms.setter - def train_transforms(self, t): + def train_transforms(self, t: Optional[Callable]) -> None: self._train_transforms = t @property - def val_transforms(self): + def val_transforms(self) -> Optional[Callable]: """ Optional transforms (or collection of transforms) you can apply to validation dataset """ return self._val_transforms @val_transforms.setter - def val_transforms(self, t): + def val_transforms(self, t: Optional[Callable]) -> None: self._val_transforms = t @property - def test_transforms(self): + def test_transforms(self) -> Optional[Callable]: """ Optional transforms (or collection of transforms) you can apply to test dataset """ return self._test_transforms @test_transforms.setter - def test_transforms(self, t): + def test_transforms(self, t: Optional[Callable]) -> None: self._test_transforms = t @property - def dims(self): + def predict_transforms(self) -> Optional[Callable]: + """ + Optional transforms (or collection of transforms) you can apply to predict dataset + """ + return self._predict_transforms + + @predict_transforms.setter + def predict_transforms(self, t: Optional[Callable]) -> None: + self._predict_transforms = t + + @property + def dims(self) -> tuple: """ A tuple describing the shape of your data. Extra functionality exposed in ``size``. """ return self._dims @dims.setter - def dims(self, d): + def dims(self, d: tuple) -> None: self._dims = d - def size(self, dim=None) -> Union[Tuple, int]: + def size(self, dim: Optional[int] = None) -> Union[Tuple, int]: """ Return the dimension of each input either as a tuple or list of tuples. You can index this just as you would with a torch tensor. @@ -309,12 +324,12 @@ def has_teardown_predict(self) -> bool: return self._has_teardown_predict @classmethod - def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: + def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs: Any) -> ArgumentParser: """Extends existing argparse by default `LightningDataModule` attributes.""" return add_argparse_args(cls, parent_parser, **kwargs) @classmethod - def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): + def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs: Any) -> ArgumentParser: """Create an instance from CLI arguments. Args: @@ -347,9 +362,10 @@ def from_datasets( train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None, val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None, test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None, + predict_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None, batch_size: int = 1, num_workers: int = 0, - ): + ) -> 'LightningDataModule': r""" Create an instance from torch.utils.data.Dataset. @@ -357,13 +373,14 @@ def from_datasets( train_dataset: (optional) Dataset to be used for train_dataloader() val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader() test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader() + predict_dataset: (optional) Dataset or list of Dataset to be used for predict_dataloader() batch_size: Batch size to use for each dataloader. Default is 1. num_workers: Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. Number of CPUs available. """ - def dataloader(ds, shuffle=False): + def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader: return DataLoader( ds, batch_size=batch_size, @@ -372,28 +389,45 @@ def dataloader(ds, shuffle=False): pin_memory=True, ) - def train_dataloader(): + def train_dataloader() -> Union[DataLoader, Sequence[DataLoader], Mapping[str, DataLoader]]: if isinstance(train_dataset, Mapping): return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()} if isinstance(train_dataset, Sequence): return [dataloader(ds, shuffle=True) for ds in train_dataset] + + if train_dataset is None: + raise MisconfigurationException('Expected a Dataset, but got None') return dataloader(train_dataset, shuffle=True) - def val_dataloader(): + def val_dataloader() -> Union[Sequence[DataLoader], DataLoader]: if isinstance(val_dataset, Sequence): return [dataloader(ds) for ds in val_dataset] + if val_dataset is None: + raise MisconfigurationException('Expected a Dataset, but got None') return dataloader(val_dataset) - def test_dataloader(): + def test_dataloader() -> Union[Sequence[DataLoader], DataLoader]: if isinstance(test_dataset, Sequence): return [dataloader(ds) for ds in test_dataset] + if test_dataset is None: + raise MisconfigurationException('Expected a Dataset, but got None') return dataloader(test_dataset) + def predict_dataloader() -> Union[Sequence[DataLoader], DataLoader]: + if isinstance(predict_dataset, Sequence): + return [dataloader(ds) for ds in predict_dataset] + if predict_dataset is None: + raise MisconfigurationException('Expected a Dataset, but got None') + return dataloader(predict_dataset) + datamodule = cls() if train_dataset is not None: - datamodule.train_dataloader = train_dataloader + setattr(datamodule, 'train_dataloader', train_dataloader) if val_dataset is not None: - datamodule.val_dataloader = val_dataloader + setattr(datamodule, 'val_dataloader', val_dataloader) if test_dataset is not None: - datamodule.test_dataloader = test_dataloader + setattr(datamodule, 'test_dataloader', test_dataloader) + if predict_dataset is not None: + setattr(datamodule, 'predict_dataloader', predict_dataloader) + return datamodule diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 51c602add9541..2781a52210b9a 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -14,8 +14,9 @@ """Decorator for LightningModule methods.""" from functools import wraps -from typing import Callable +from typing import Any, Callable, Union +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn @@ -53,7 +54,7 @@ def forward(self, x): """ @wraps(fn) - def auto_transfer_args(self, *args, **kwargs): + def auto_transfer_args(self: Union[Any, 'pl.LightningModule'], *args: Any, **kwargs: Any) -> Any: from pytorch_lightning.core.lightning import LightningModule if not isinstance(self, LightningModule): return fn(self, *args, **kwargs) @@ -89,7 +90,7 @@ def parameter_validation(fn: Callable) -> Callable: """ @wraps(fn) - def inner_fn(self, *args, **kwargs): + def inner_fn(self: Union[Any, 'pl.LightningModule'], *args: Any, **kwargs: Any) -> 'pl.LightningModule': pre_layer_count = len(list(self.parameters())) module = fn(self, *args, **kwargs) self.on_post_move_to_device() diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 72ee3b3c52e4a..5eab3460e8d0f 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -19,11 +19,15 @@ from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader +import pytorch_lightning as pl from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT class ModelHooks: + __jit_unused_properties__ = ['trainer'] """Hooks to be used in LightningModule.""" + trainer: Optional['pl.Trainer'] def on_fit_start(self) -> None: """ @@ -98,7 +102,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) """ # do something when the batch starts - def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ Called in the training loop after the batch. @@ -114,13 +118,15 @@ def on_validation_model_eval(self) -> None: """ Sets the model to eval during the val loop """ - self.trainer.model.eval() + if self.trainer is not None: + self.trainer.model.eval() def on_validation_model_train(self) -> None: """ Sets the model to train during the val loop """ - self.trainer.model.train() + if self.trainer is not None: + self.trainer.model.train() def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ @@ -133,7 +139,7 @@ def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: """ # do something when the batch starts - def on_validation_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_validation_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ Called in the validation loop after the batch. @@ -156,7 +162,7 @@ def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) - """ # do something when the batch starts - def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_test_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ Called in the test loop after the batch. @@ -172,19 +178,22 @@ def on_test_model_train(self) -> None: """ Sets the model to train during the test loop """ - self.trainer.model.train() + if self.trainer is not None: + self.trainer.model.train() def on_test_model_eval(self) -> None: """ Sets the model to eval during the test loop """ - self.trainer.model.eval() + if self.trainer is not None: + self.trainer.model.eval() def on_predict_model_eval(self) -> None: """ Sets the model to eval during the predict loop """ - self.trainer.model.eval() + if self.trainer is not None: + self.trainer.model.eval() def on_epoch_start(self) -> None: """ @@ -204,7 +213,7 @@ def on_train_epoch_start(self) -> None: """ # do something when the epoch starts - def on_train_epoch_end(self, outputs: List[Any]) -> None: + def on_train_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: """ Called in the training loop at the very end of the epoch. """ @@ -216,7 +225,7 @@ def on_validation_epoch_start(self) -> None: """ # do something when the epoch starts - def on_validation_epoch_end(self, outputs: List[Any]) -> None: + def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: """ Called in the validation loop at the very end of the epoch. """ @@ -228,7 +237,7 @@ def on_test_epoch_start(self) -> None: """ # do something when the epoch starts - def on_test_epoch_end(self, outputs: List[Any]) -> None: + def on_test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: """ Called in the test loop at the very end of the epoch. """ @@ -330,6 +339,18 @@ def configure_sharded_model(self) -> None: class DataHooks: """Hooks to be used for data related stuff.""" + _device: Union[str, torch.device] + + @property + def device(self) -> Union[torch.device, str]: + """ + This has to be implemented here for mypy. Actual implementation in DeviceDtypeModuleMixin + """ + return self._device + + @device.setter + def device(self, new_device: Union[str, torch.device]) -> None: + self._device = new_device def prepare_data(self) -> None: """ @@ -635,7 +656,7 @@ def on_test_dataloader(self) -> None: def on_predict_dataloader(self) -> None: """Called before requesting the predict dataloader.""" - def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: + def transfer_batch_to_device(self, batch: Any, device: Optional[Union[str, torch.device]] = None) -> Any: """ Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom data structure. @@ -687,7 +708,7 @@ def transfer_batch_to_device(self, batch, device): device = device or self.device return move_data_to_device(batch, device) - def on_before_batch_transfer(self, batch, dataloader_idx): + def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: """ Override to alter or apply batch augmentations to your batch before it is transferred to the device. @@ -720,7 +741,7 @@ def on_before_batch_transfer(self, batch, dataloader_idx): """ return batch - def on_after_batch_transfer(self, batch, dataloader_idx): + def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: """ Override to alter or apply batch augmentations to your batch after it is transferred to the device. diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c7f61da7b01c6..5711b5402d985 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -25,25 +25,37 @@ from argparse import Namespace from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Sequence, Tuple, Union import torch from torch import ScriptModule, Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer +import pytorch_lightning as pl +from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES from pytorch_lightning.core.step_result import Result -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.loggers.base import LightningLoggerBase +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import BATCH, EPOCH_OUTPUT, STEP_OUTPUT + +if _OMEGACONF_AVAILABLE: + from omegaconf import OmegaConf +_HYPERPARAMS = Union[Namespace, Dict, 'OmegaConf'] + +LRSCHED = Union[Sequence[Any], Sequence[Dict[str, Union[Any, str, int]]]] +OPTIM_LR_DICT = Dict[str, Union[Optimizer, LRSCHED]] +BACKWARD_ARGS = Union[None, bool, torch.Tensor, Sequence[torch.Tensor]] + log = logging.getLogger(__name__) @@ -61,21 +73,13 @@ class LightningModule( # Below is for property support of JIT in PyTorch 1.7 # since none of them is important when using JIT, we are going to ignore them. __jit_unused_properties__ = [ - "datamodule", - "example_input_array", - "hparams", - "hparams_initial", - "on_gpu", - "current_epoch", - "global_step", - "global_rank", - "local_rank", - "logger", - "model_size", - ] + DeviceDtypeModuleMixin.__jit_unused_properties__ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + "datamodule", "example_input_array", "hparams", "hparams_initial", "on_gpu", "current_epoch", "global_step", + "global_rank", "local_rank", "logger", "model_size", "_trainer" + ] + DeviceDtypeModuleMixin.__jit_unused_properties__ + ModelHooks.__jit_unused_properties__ + + def __init__(self, *args: Any, **kwargs: Any): + # https://github.com/python/mypy/issues/5887 + super().__init__(*args, **kwargs) # type: ignore # see (https://github.com/pytorch/pytorch/blob/3e6bb5233f9ca2c5aa55d9cda22a7ee85439aa6e/ # torch/nn/modules/module.py#L227) @@ -83,10 +87,10 @@ def __init__(self, *args, **kwargs): self.exp_save_path = None - self.loaded_optimizer_states_dict = {} + self.loaded_optimizer_states_dict: Dict = {} #: Pointer to the trainer object - self.trainer = None + self._trainer: Optional['pl.Trainer'] = None self._distrib_type = None self._device_type = None @@ -106,13 +110,33 @@ def __init__(self, *args, **kwargs): self._current_hook_fx_name = None self._current_dataloader_idx = None self._automatic_optimization: bool = True - self._param_requires_grad_state = dict() + self._param_requires_grad_state: Dict = {} + + # Necessary since torchscript does not support forward references + @property # type: ignore + @torch.jit.unused + def trainer(self) -> Optional['pl.Trainer']: + return self._trainer + + @trainer.setter # type: ignore + @torch.jit.unused + def trainer(self, trainer: Optional['pl.Trainer']) -> None: + self._trainer = trainer + + def optimizers(self, + use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer], List]: + opts: Union[List[LightningOptimizer], List[Optimizer], List] - def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: - opts = list(self.trainer.lightning_optimizers.values()) + _opts = self.trainer.lightning_optimizers if ( + self.trainer is not None and self.trainer.lightning_optimizers is not None + ) else [] + if isinstance(_opts, Mapping): + opts = [_opts[idx] for idx in range(len(_opts))] + else: + opts = list(_opts) else: - opts = self.trainer.optimizers + opts = self.trainer.optimizers if (self.trainer is not None and self.trainer.optimizers is not None) else [] # single optimizer if isinstance(opts, list) and len(opts) == 1 and isinstance(opts[0], Optimizer): @@ -121,7 +145,7 @@ def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Opt return opts def lr_schedulers(self) -> Optional[Union[Any, List[Any]]]: - if not self.trainer.lr_schedulers: + if self.trainer is None or not self.trainer.lr_schedulers: return None # ignore other keys "interval", "frequency", etc. @@ -138,6 +162,10 @@ def lr_schedulers(self) -> Optional[Union[Any, List[Any]]]: def example_input_array(self) -> Any: return self._example_input_array + @example_input_array.setter + def example_input_array(self, example: Any) -> None: + self._example_input_array = example + @property def current_epoch(self) -> int: """The current epoch""" @@ -158,10 +186,6 @@ def local_rank(self) -> int: """ The index of the current process within a single node. """ return self.trainer.local_rank if self.trainer else 0 - @example_input_array.setter - def example_input_array(self, example: Any) -> None: - self._example_input_array = example - @property def datamodule(self) -> Any: return self._datamodule @@ -171,12 +195,12 @@ def datamodule(self, datamodule: Any) -> None: self._datamodule = datamodule @property - def on_gpu(self): + def on_gpu(self) -> bool: """ True if your model is currently running on GPUs. Useful to set flags around the LightningModule for different CPU vs GPU behavior. """ - return self.device.type == "cuda" + return self.device.type == "cuda" if isinstance(self.device, torch.device) else 'cuda' in self.device @property def automatic_optimization(self) -> bool: @@ -190,17 +214,19 @@ def automatic_optimization(self, automatic_optimization: bool) -> None: self._automatic_optimization = automatic_optimization @property - def logger(self): + def logger(self) -> Optional['LightningLoggerBase']: """ Reference to the logger object in the Trainer. """ return self.trainer.logger if self.trainer else None - def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0): + def _apply_batch_transfer_handler( + self, batch: Any, device: Optional[Union[str, torch.device]] = None, dataloader_idx: int = 0 + ) -> Any: batch = self.on_before_batch_transfer(batch, dataloader_idx) batch = self.transfer_batch_to_device(batch, device) batch = self.on_after_batch_transfer(batch, dataloader_idx) return batch - def print(self, *args, **kwargs) -> None: + def print(self, *args: Any, **kwargs: Any) -> None: r""" Prints only from process 0. Use this in any distributed mode to log only once. @@ -214,10 +240,13 @@ def forward(self, x): self.print(x, 'in forward') """ - if self.trainer.is_global_zero: - progress_bar = self.trainer.progress_bar_callback - if progress_bar is not None and progress_bar.is_enabled: - progress_bar.print(*args, **kwargs) + if self.trainer is None or self.trainer.is_global_zero: + if self.trainer is not None: + progress_bar = self.trainer.progress_bar_callback + if progress_bar is not None and progress_bar.is_enabled: + progress_bar.print(*args, **kwargs) + else: + print(*args, **kwargs) else: print(*args, **kwargs) @@ -237,7 +266,7 @@ def log( sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, - ): + ) -> None: """ Log a key, value @@ -304,7 +333,14 @@ def log( f"Logged key: {name} should not contain information about dataloader_idx." ) - training_type_plugin = self.trainer.training_type_plugin + reduce_fn: Callable + if self.trainer is None: + + def reduce_fn(x: Union[torch.Tensor, Any], *_: Any, **__: Any) -> Union[torch.Tensor, Any]: + return x + + else: + reduce_fn = self.trainer.training_type_plugin.reduce # Determine if dataloader index should be added dataloader_idx = self._current_dataloader_idx if add_dataloader_idx else None @@ -323,7 +359,7 @@ def log( sync_dist, sync_dist_op, sync_dist_group, - training_type_plugin.reduce, + reduce_fn, dataloader_idx, self.device, ) @@ -343,7 +379,7 @@ def log_dict( sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, - ): + ) -> None: """ Log a dictonary of values at once @@ -389,7 +425,7 @@ def log_dict( def write_prediction( self, name: str, value: Union[torch.Tensor, List[torch.Tensor]], filename: str = 'predictions.pt' - ): + ) -> None: """ Write predictions to disk using ``torch.save`` @@ -407,9 +443,11 @@ def write_prediction( each device with respective names: ``filename_rank_0.pt``, ``filename_rank_1.pt``, ... """ + if self.trainer is None: + raise MisconfigurationException('Cannot write predictions when not attached to a trainer') self.trainer.evaluation_loop.predictions._add_prediction(name, value, filename) - def write_prediction_dict(self, predictions_dict: Dict[str, Any], filename: str = 'predictions.pt'): + def write_prediction_dict(self, predictions_dict: Dict[str, Any], filename: str = 'predictions.pt') -> None: """ Write a dictonary of predictions to disk at once using ``torch.save`` @@ -430,7 +468,7 @@ def write_prediction_dict(self, predictions_dict: Dict[str, Any], filename: str for k, v in predictions_dict.items(): self.write_prediction(k, v, filename) - def __auto_choose_log_on_step(self, on_step): + def __auto_choose_log_on_step(self, on_step: Optional[bool]) -> bool: if on_step is None: if self._current_fx_name in {'training_step', 'training_step_end'}: on_step = True @@ -443,7 +481,7 @@ def __auto_choose_log_on_step(self, on_step): return on_step - def __auto_choose_log_on_epoch(self, on_epoch): + def __auto_choose_log_on_epoch(self, on_epoch: Optional[bool]) -> bool: if on_epoch is None: if self._current_fx_name in {'training_step', 'training_step_end'}: on_epoch = False @@ -461,7 +499,7 @@ def all_gather( data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False, - ): + ) -> Union[torch.Tensor, Dict, List, Tuple]: r""" Allows users to call ``self.all_gather()`` from the LightningModule, thus making the ```all_gather``` operation accelerator agnostic. @@ -479,12 +517,16 @@ def all_gather( the output will also be a collection with tensors of this shape. """ group = group if group is not None else torch.distributed.group.WORLD + if self.trainer is None: + raise MisconfigurationException( + 'Cannot all_gather data without being attached to a trainer that orchestrates the processes' + ) all_gather = self.trainer.accelerator.all_gather data = convert_to_tensors(data, device=self.device) all_gather = partial(all_gather, group=group, sync_grads=sync_grads) return apply_to_collection(data, torch.Tensor, all_gather) - def forward(self, *args, **kwargs) -> Any: + def forward(self, *args: Any, **kwargs: Any) -> Any: r""" Same as :meth:`torch.nn.Module.forward()`. @@ -497,7 +539,26 @@ def forward(self, *args, **kwargs) -> Any: """ return super().forward(*args, **kwargs) - def training_step(self, *args, **kwargs) -> STEP_OUTPUT: + @overload + def training_step(self, batch: BATCH, batch_idx: int) -> STEP_OUTPUT: + ... + + @overload + def training_step(self, batch: BATCH, batch_idx: int, optimizer_idx: int) -> STEP_OUTPUT: + ... + + @overload + def training_step(self, batch: BATCH, batch_idx: int, hiddens: torch.Tensor) -> STEP_OUTPUT: + ... + + @overload + def training_step(self, batch: BATCH, batch_idx: int, optimizer_idx: int, hiddens: torch.Tensor) -> STEP_OUTPUT: + ... + + def training_step( + self, batch: BATCH, batch_idx: int, *args: Optional[Union[int, torch.Tensor]], + **kwargs: Optional[Union[int, torch.Tensor]] + ) -> STEP_OUTPUT: r""" Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger. @@ -562,8 +623,9 @@ def training_step(self, batch, batch_idx, hiddens): so it differs from the actual loss returned in train/validation step. """ rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer") + raise NotImplementedError - def training_step_end(self, *args, **kwargs) -> STEP_OUTPUT: + def training_step_end(self, batch_part_outputs: STEP_OUTPUT) -> Any: """ Use this when training with dp or ddp2 because :meth:`training_step` will operate on only part of the batch. However, this is still optional @@ -666,7 +728,15 @@ def training_epoch_end(self, training_step_outputs): # do something here """ - def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + @overload + def validation_step(self, batch: BATCH, batch_idx: int) -> STEP_OUTPUT: + ... + + @overload + def validation_step(self, batch: BATCH, batch_idx: int, dataloader_idx: int) -> STEP_OUTPUT: + ... + + def validation_step(self, batch: BATCH, batch_idx: int, *args: int, **kwargs: int) -> STEP_OUTPUT: r""" Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy. @@ -753,7 +823,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx): the model goes back to training mode and gradients are enabled. """ - def validation_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def validation_step_end(self, batch_parts_outputs: STEP_OUTPUT) -> Union[Any, None]: """ Use this when validating with dp or ddp2 because :meth:`validation_step` will operate on only part of the batch. However, this is still optional @@ -852,7 +922,15 @@ def validation_epoch_end(self, outputs): self.log('final_metric', final_value) """ - def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + @overload + def test_step(self, batch: BATCH, batch_idx: int) -> Union[Any, None]: + ... + + @overload + def test_step(self, batch: BATCH, batch_idx: int, dataloader_idx: int) -> Union[Any, None]: + ... + + def test_step(self, batch: BATCH, batch_idx: int, *args: int, **kwargs: int) -> Union[Any, None]: r""" Operates on a single batch of data from the test set. In this step you'd normally generate examples or calculate anything of interest @@ -928,7 +1006,7 @@ def test_step(self, batch, batch_idx, dataloader_idx): to training mode and gradients are enabled. """ - def test_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def test_step_end(self, batch_part_outputs: EPOCH_OUTPUT) -> Union[Any, None]: """ Use this when testing with dp or ddp2 because :meth:`test_step` will operate on only part of the batch. However, this is still optional @@ -1033,7 +1111,15 @@ def test_epoch_end(self, outputs): self.log('final_metric', final_value) """ - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: + @overload + def predict_step(self, batch: BATCH, batch_idx: int) -> STEP_OUTPUT: + ... + + @overload + def predict_step(self, batch: BATCH, batch_idx: int, dataloader_idx: int) -> STEP_OUTPUT: + ... + + def predict_step(self, batch: BATCH, batch_idx: int, *args: int, **kwargs: int) -> STEP_OUTPUT: """ Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict` By default, it calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`. @@ -1049,7 +1135,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] """ return self(batch) - def configure_callbacks(self): + def configure_callbacks(self) -> Union[List, List['Callback']]: """ Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()`` gets called, @@ -1075,7 +1161,10 @@ def configure_callbacks(self): """ return [] - def configure_optimizers(self): + def configure_optimizers( + self + ) -> Union[Optimizer, Sequence[Optimizer], Sequence[Union[Sequence[Optimizer], Sequence[LRSCHED], OPTIM_LR_DICT]], + Sequence[OPTIM_LR_DICT], None]: r""" Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. @@ -1184,8 +1273,15 @@ def configure_optimizers(self): """ rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer") + return None - def manual_backward(self, loss: Tensor, optimizer: Optional[Optimizer] = None, *args, **kwargs) -> None: + def manual_backward( + self, + loss: Tensor, + optimizer: Optional[Optimizer] = None, + *args: BACKWARD_ARGS, + **kwargs: BACKWARD_ARGS + ) -> None: """ Call this directly from your training_step when doing optimizations manually. By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you. @@ -1209,15 +1305,20 @@ def training_step(...): "`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4" ) + if self.trainer is None: + raise RuntimeError('Cannot dispatch to trainer backward without trainer!') + # make sure we're using manual opt self._verify_is_manual_optimization('manual_backward') # backward self._running_manual_backward = True - self.trainer.train_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs) + self.trainer.train_loop.backward(loss, None, None, *args, **kwargs) self._running_manual_backward = False - def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: + def backward( + self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args: BACKWARD_ARGS, **kwargs: BACKWARD_ARGS + ) -> None: """ Override backward with your own implementation if you need to. @@ -1236,10 +1337,10 @@ def backward(self, loss, optimizer, optimizer_idx): loss.backward() """ - if self.trainer.train_loop.automatic_optimization or self._running_manual_backward: + if self.trainer is None or self.trainer.train_loop.automatic_optimization or self._running_manual_backward: loss.backward(*args, **kwargs) - def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): + def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int) -> None: """ Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup. @@ -1258,6 +1359,7 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): # Iterate over all optimizer parameters to preserve their `requires_grad` information # in case these are pre-defined during `configure_optimizers` param_requires_grad_state = {} + # TODO: Fix typing here, mypy complains about args not being iterable for opt in self.optimizers(use_pl_optimizer=False): for group in opt.param_groups: for param in group['params']: @@ -1274,7 +1376,7 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): param.requires_grad = param_requires_grad_state[param] self._param_requires_grad_state = param_requires_grad_state - def untoggle_optimizer(self, optimizer_idx: int): + def untoggle_optimizer(self, optimizer_idx: int) -> None: """ .. note:: Only called when using multiple optimizers @@ -1283,6 +1385,7 @@ def untoggle_optimizer(self, optimizer_idx: int): Args: optimizer_idx: Current optimizer idx in training_loop """ + for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)): if optimizer_idx != opt_idx: for group in opt.param_groups: @@ -1365,12 +1468,20 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer.step(closure=optimizer_closure) """ + if optimizer is None: + raise RuntimeError('Cannot run the optimizer step without an actual optimizer!') + if not isinstance(optimizer, LightningOptimizer): # wraps into LightingOptimizer only for running step - optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, optimizer_idx) - optimizer.step(closure=optimizer_closure) + if self.trainer is None: + raise RuntimeError('Cannot wrap the optimizer to a LightningOptimizer without a trainer') + lightning_optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, optimizer_idx) + else: + lightning_optimizer = optimizer - def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): + lightning_optimizer.step(closure=optimizer_closure) + + def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int) -> None: """Override this method to change the default behaviour of ``optimizer.zero_grad()``. Args: @@ -1443,6 +1554,7 @@ def tbptt_split_batch(self, batch, split_size): for t in range(0, time_dims[0], split_size): batch_split = [] for i, x in enumerate(batch): + split_x: Union[torch.Tensor, List] if isinstance(x, torch.Tensor): split_x = x[:, t:t + split_size] elif isinstance(x, collections.Sequence): @@ -1521,6 +1633,10 @@ def get_progress_bar_dict(self): Dictionary with the items to be displayed in the progress bar. """ # call .item() only once but store elements without graphs + + if self.trainer is None: + raise RuntimeError('Cannot et the progressbar dict without an attached trainer!') + running_train_loss = self.trainer.train_loop.running_loss.mean() avg_training_loss = None if running_train_loss is not None: @@ -1528,7 +1644,7 @@ def get_progress_bar_dict(self): elif self.trainer.train_loop.automatic_optimization: avg_training_loss = float('NaN') - tqdm_dict = {} + tqdm_dict: Dict[str, Union[str, int]] = {} if avg_training_loss is not None: tqdm_dict["loss"] = f"{avg_training_loss:.3g}" @@ -1543,15 +1659,15 @@ def get_progress_bar_dict(self): return tqdm_dict - def _verify_is_manual_optimization(self, fn_name): - if self.trainer.train_loop.automatic_optimization: + def _verify_is_manual_optimization(self, fn_name: str) -> None: + if self.trainer is not None and self.trainer.train_loop.automatic_optimization: raise MisconfigurationException( f'to use {fn_name}, please disable automatic optimization:' ' set model property `automatic_optimization` as False' ) @classmethod - def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: + def _auto_collect_arguments(cls, frame: Optional[types.FrameType] = None) -> Tuple[Dict, Dict]: """ Collect all module arguments in the current constructor and all child constructors. The child constructors are all the ``__init__`` methods that reach the current class through @@ -1567,6 +1683,9 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: if not frame: frame = inspect.currentframe() + if frame is None: + raise ValueError('Frame for inspection cannot be None!') + frame_args = collect_init_args(frame.f_back, []) self_arguments = frame_args[-1] @@ -1581,7 +1700,7 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: def save_hyperparameters( self, - *args, + *args: '_HYPERPARAMS', ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None ) -> None: @@ -1650,7 +1769,7 @@ class ``__init__`` to be ignored frame = inspect.currentframe().f_back save_hyperparameters(self, *args, ignore=ignore, frame=frame) - def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: + def _set_hparams(self, hp: Union['_HYPERPARAMS', str]) -> None: if isinstance(hp, Namespace): hp = vars(hp) if isinstance(hp, dict): @@ -1670,8 +1789,9 @@ def to_onnx( self, file_path: Union[str, Path], input_sample: Optional[Any] = None, - **kwargs, - ): + **kwargs: Union[bool, List, List[str], torch.onnx.OperatorExportTypes, int, Tuple[torch.Tensor], + Dict[str, Union[int, Dict[str, int], List[int]]]], + ) -> None: """ Saves the model in ONNX format @@ -1721,7 +1841,7 @@ def to_torchscript( file_path: Optional[Union[str, Path]] = None, method: Optional[str] = 'script', example_inputs: Optional[Any] = None, - **kwargs, + **kwargs: Union[bool, List[Tuple[Any]], float], ) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. @@ -1795,13 +1915,13 @@ def to_torchscript( return torchscript_module @property - def hparams(self) -> Union[AttributeDict, dict, Namespace]: + def hparams(self) -> Union[AttributeDict, '_HYPERPARAMS']: if not hasattr(self, "_hparams"): self._hparams = AttributeDict() return self._hparams @property - def hparams_initial(self) -> AttributeDict: + def hparams_initial(self) -> Union[AttributeDict, '_HYPERPARAMS']: if not hasattr(self, "_hparams_initial"): return AttributeDict() # prevent any change diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index a3eab728f8ea8..3e49180fd1cdc 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -16,13 +16,14 @@ import shutil import subprocess from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Union import numpy as np import torch import torch.nn as nn from torch.utils.hooks import RemovableHandle +import pytorch_lightning as pl from pytorch_lightning.utilities import AMPType, DeviceType PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"] @@ -65,10 +66,10 @@ def __init__(self, module: nn.Module): super().__init__() self._module = module self._hook_handle = self._register_hook() - self._in_size = None - self._out_size = None + self._in_size: Optional[Union[str, List]] = None + self._out_size: Optional[Union[str, List]] = None - def __del__(self): + def __del__(self) -> None: self.detach_hook() def _register_hook(self) -> Optional[RemovableHandle]: @@ -82,19 +83,20 @@ def _register_hook(self) -> Optional[RemovableHandle]: A handle for the installed hook, or ``None`` if registering the hook is not possible. """ - def hook(module, inp, out): + def hook(module: torch.nn.Module, inp: Any, out: Any) -> None: if len(inp) == 1: inp = inp[0] self._in_size = parse_batch_shape(inp) self._out_size = parse_batch_shape(out) - self._hook_handle.remove() + if self._hook_handle is not None: + self._hook_handle.remove() handle = None if not isinstance(self._module, torch.jit.ScriptModule): handle = self._module.register_forward_hook(hook) return handle - def detach_hook(self): + def detach_hook(self) -> None: """ Removes the forward hook if it was not already removed in the forward pass. Will be called after the summary is created. @@ -182,7 +184,7 @@ class ModelSummary(object): MODE_DEFAULT = MODE_TOP MODES = [MODE_FULL, MODE_TOP] - def __init__(self, model, mode: str = MODE_DEFAULT): + def __init__(self, model: 'pl.LightningModule', mode: str = MODE_DEFAULT) -> None: self._model = model self._mode = mode self._layer_summary = self.summarize() @@ -193,6 +195,7 @@ def __init__(self, model, mode: str = MODE_DEFAULT): @property def named_modules(self) -> List[Tuple[str, nn.Module]]: + mods: Union[Iterable, Iterator] if self._mode == ModelSummary.MODE_FULL: mods = self._model.named_modules() mods = list(mods)[1:] # do not include root module (LightningModule) @@ -252,8 +255,14 @@ def _forward_example_input(self) -> None: input_ = model.example_input_array input_ = model._apply_batch_transfer_handler(input_, model.device) - if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU: - model.forward = torch.cuda.amp.autocast()(model.forward) + trainer_and_amp_backend_available = trainer is not None and trainer.amp_backend is not None + if trainer_and_amp_backend_available: + native_amp_tpu = trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU + else: + native_amp_tpu = False + + if (trainer_and_amp_backend_available and native_amp_tpu): + setattr(model, 'forward', torch.cuda.amp.autocast()(model.forward)) mode = model.training model.eval() @@ -267,7 +276,7 @@ def _forward_example_input(self) -> None: model(input_) model.train(mode) # restore mode of module - def __str__(self): + def __str__(self) -> str: """ Makes a summary listing with: @@ -288,7 +297,7 @@ def __str__(self): return _format_summary_table(total_parameters, trainable_parameters, model_size, *arrays) - def __repr__(self): + def __repr__(self) -> str: return str(self) @@ -303,7 +312,7 @@ def parse_batch_shape(batch: Any) -> Union[str, List]: return UNKNOWN_SIZE -def _format_summary_table(total_parameters: int, trainable_parameters: int, model_size: float, *cols) -> str: +def _format_summary_table(total_parameters: int, trainable_parameters: int, model_size: float, *cols: Sequence) -> str: """ Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big @@ -345,7 +354,7 @@ def _format_summary_table(total_parameters: int, trainable_parameters: int, mode return summary -def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: +def get_memory_profile(mode: str) -> Union[Dict[str, float], Dict[int, float]]: """ Get a profile of the current memory usage. Args: @@ -373,7 +382,7 @@ def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: return memory_map -def get_gpu_memory_map() -> Dict[str, int]: +def get_gpu_memory_map() -> Dict[str, float]: """ Get the current gpu usage. @@ -382,7 +391,7 @@ def get_gpu_memory_map() -> Dict[str, int]: values are memory usage as integers in MB. """ result = subprocess.run( - [shutil.which("nvidia-smi"), "--query-gpu=memory.used", "--format=csv,nounits,noheader"], + [str(shutil.which("nvidia-smi")), "--query-gpu=memory.used", "--format=csv,nounits,noheader"], encoding="utf-8", # capture_output=True, # valid for python version >=3.7 stdout=subprocess.PIPE, @@ -396,7 +405,7 @@ def get_gpu_memory_map() -> Dict[str, int]: return gpu_memory_map -def get_formatted_model_size(total_model_size: float) -> float: +def get_formatted_model_size(total_model_size: float) -> str: return f"{total_model_size:,.3f}" diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 162e17ca47bf5..d1ab49092116e 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -13,20 +13,21 @@ # limitations under the License. import types from contextlib import contextmanager -from typing import Callable, Optional +from typing import Any, Callable, Dict, Generator, List, Optional, Union from weakref import proxy from torch.optim import Optimizer +import pytorch_lightning as pl from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException -def is_lightning_optimizer(optimizer): +def is_lightning_optimizer(optimizer: Union['LightningOptimizer', Optimizer, Any]) -> bool: return isinstance(optimizer, LightningOptimizer) -def do_nothing_closure(): +def do_nothing_closure() -> None: return @@ -36,7 +37,7 @@ class LightningOptimizer: the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches """ - def __init__(self, optimizer: Optimizer): + def __init__(self, optimizer: Optimizer) -> None: self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ('step', "__del__")} @@ -45,71 +46,80 @@ def __init__(self, optimizer: Optimizer): self.__class__ = type( "Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__.__bases__[0]), {} ) - self.skip_synchronize = optimizer.skip_synchronize - self.synchronize = optimizer.synchronize + self.skip_synchronize = getattr(optimizer, 'skip_synchronize') + self.synchronize = getattr(optimizer, 'synchronize') else: self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer self._trainer = None - self._optimizer_idx = None + self._optimizer_idx: Optional[int] = None self._total_optimizer_step_calls = 0 @property - def optimizer(self): + def optimizer(self) -> Optimizer: return self._optimizer @property - def defaults(self): + def defaults(self) -> Dict[str, Any]: return self._optimizer.defaults @defaults.setter - def defaults(self, defaults): + def defaults(self, defaults: Dict[str, Any]) -> None: self._optimizer.defaults = defaults @property - def state(self): + def state(self) -> Dict: return self._optimizer.state @state.setter - def state(self, state): + def state(self, state: Dict) -> None: self._optimizer.state = state @property - def param_groups(self): + def param_groups(self) -> List[Dict]: return self._optimizer.param_groups @param_groups.setter - def param_groups(self, param_groups): + def param_groups(self, param_groups: List[Dict]) -> None: self._optimizer.param_groups = param_groups - def _on_trainer_init(self, trainer): + def _on_trainer_init(self, trainer: 'pl.Trainer') -> None: self._trainer = proxy(trainer) + if trainer.optimizers is None: + raise ValueError('Expected the trainer to have at least one optimizer, got None') for opt_idx, opt in enumerate(trainer.optimizers): if opt == self._optimizer: self._optimizer_idx = opt_idx break @classmethod - def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx): + def _to_lightning_optimizer(cls, optimizer: Optimizer, trainer: 'pl.Trainer', + opt_idx: Optional[int]) -> Union['LightningOptimizer', Optimizer]: # apex overrides .step function and need to be wrapped on each step - if trainer.amp_backend == AMPType.APEX: - optimizer = cls(optimizer) - optimizer._on_trainer_init(trainer) + if trainer.amp_backend is not None and trainer.amp_backend == AMPType.APEX: + new_optimizer = cls(optimizer) + new_optimizer._on_trainer_init(trainer) else: - optimizer = trainer.lightning_optimizers[opt_idx] + if opt_idx is None: + raise RuntimeError('Cannot get the correct optimizer without a proper index') + new_optimizer = trainer.lightning_optimizers[opt_idx] return optimizer - def _toggle_model(self): + def _toggle_model(self) -> None: + if self._trainer is None: + raise ValueError('Expected to have trainer reference, but got None') model_ref = self._trainer.lightning_module model_ref.toggle_optimizer(self, self._optimizer_idx) - def _untoggle_model(self): + def _untoggle_model(self) -> None: + if self._trainer is None: + raise ValueError('Expected to have trainer reference, but got None') model_ref = self._trainer.lightning_module model_ref.untoggle_optimizer(self) @contextmanager - def toggle_model(self, sync_grad: bool = True): + def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]: """ This function is just a helper for advanced users. @@ -121,19 +131,24 @@ def toggle_model(self, sync_grad: bool = True): during the accumulation phase. Setting `sync_grad` to False will block this synchronization and improve performance. """ + if self._trainer is None: + raise ValueError('Expected to have trainer reference, but got None') with self._trainer.train_loop.block_ddp_sync_behaviour(not sync_grad): self._toggle_model() yield self._untoggle_model() - def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs): + def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs: Any) -> None: trainer = self._trainer optimizer = self._optimizer + if trainer is None: + raise ValueError('Expected to have trainer reference, but got None') + with trainer.profiler.profile(profiler_name): trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs) - def step(self, *args, closure: Optional[Callable] = None, **kwargs): + def step(self, closure: Optional[Callable] = None, **kwargs: Any) -> None: """ Call this directly from your training_step when doing optimizations manually. By using this we can ensure that all the proper scaling when using 16-bit, accelerator etc @@ -211,10 +226,10 @@ def closure_dis(): raise MisconfigurationException("When closure is provided, it should be a function") profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}" - self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs) + self.__optimizer_step(closure=closure, profiler_name=profiler_name, **kwargs) self._total_optimizer_step_calls += 1 - def __repr__(self): + def __repr__(self) -> str: groups = [{k: round(v, 12) if isinstance(v, float) else v for k, v in sorted(group.items()) if k != "params"} for group in self.param_groups] return f"{self.__class__.__name__}(groups={groups})" diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index ffa9b0a1359ee..ec3e2acb42261 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -60,8 +60,8 @@ def load_from_checkpoint( map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, hparams_file: Optional[str] = None, strict: bool = True, - **kwargs, - ): + **kwargs: Any, + ) -> 'ModelIO': r""" Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to `__init__` in the checkpoint under `hyper_parameters` @@ -158,15 +158,15 @@ def load_from_checkpoint( return model @classmethod - def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cls_kwargs_new): + def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cls_kwargs_new: Any) -> 'ModelIO': cls_spec = inspect.getfullargspec(cls.__init__) - cls_init_args_name = inspect.signature(cls.__init__).parameters.keys() + cls_init_args_name = list(inspect.signature(cls.__init__).parameters.keys()) self_var, args_var, kwargs_var = parse_class_init_keys(cls) drop_names = [n for n in (self_var, args_var, kwargs_var) if n] - cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name)) + cls_init_args_name = [n for n in cls_init_args_name if n not in drop_names] - cls_kwargs_loaded = {} + cls_kwargs_loaded: Dict = {} # pass in the values we saved automatically if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: @@ -176,7 +176,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl # 2. Try to restore model hparams from checkpoint using the new key _new_hparam_key = cls.CHECKPOINT_HYPER_PARAMS_KEY - cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key)) + cls_kwargs_loaded.update(checkpoint[_new_hparam_key]) # 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace cls_kwargs_loaded = _convert_loaded_hparams( @@ -196,13 +196,13 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl # filter kwargs according to class init unless it allows any argument via kwargs _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name} - model = cls(**_cls_kwargs) + model = cls(**_cls_kwargs) # type: ignore # give model a chance to load something model.on_load_checkpoint(checkpoint) # load the state_dict on the model automatically - model.load_state_dict(checkpoint['state_dict'], strict=strict) + model.load_state_dict(checkpoint['state_dict'], strict=strict) # type: ignore return model @@ -246,7 +246,7 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: """ -def _convert_loaded_hparams(model_args: dict, hparams_type: Optional[Union[Callable, str]] = None) -> object: +def _convert_loaded_hparams(model_args: dict, hparams_type: Optional[Union[Callable, str]] = None) -> Dict: """Convert hparams according given type in callable or string (past) format.""" # if not hparams type define if not hparams_type: @@ -350,7 +350,7 @@ def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict return {} with fs.open(config_yaml, "r") as fp: - hparams = yaml.load(fp, Loader=yaml.UnsafeLoader) + hparams = yaml.load(fp, Loader=yaml.UnsafeLoader) # type: ignore if _OMEGACONF_AVAILABLE: if use_omegaconf: @@ -361,7 +361,7 @@ def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict return hparams -def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: +def save_hparams_to_yaml(config_yaml: str, hparams: Union[dict, Namespace]) -> None: """ Args: config_yaml: path to new YAML file diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index f2cdd31ab739f..873f993d92cc9 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -15,7 +15,7 @@ import numbers from copy import copy -from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -26,7 +26,7 @@ class Result(Dict): - def __init__(self, minimize: Optional[Tensor] = None): + def __init__(self, minimize: Optional[Tensor] = None) -> None: super().__init__() if minimize is not None: @@ -57,19 +57,19 @@ def __getattr__(self, key: str) -> Any: except KeyError: return None - def __setattr__(self, key: str, val: Union[Tensor, Any]): + def __setattr__(self, key: str, val: Union[Tensor, Any]) -> None: # ensure tensors are detached if isinstance(val, torch.Tensor) and key != 'minimize': val = val.detach() self[key] = val - def __getstate__(self): + def __getstate__(self) -> 'Result': return self - def __setstate__(self, d): + def __setstate__(self, d: Mapping) -> None: self.update(d) - def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], additional_err: str = ''): + def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], additional_err: str = '') -> None: if x is not None: if not isinstance(x, Tensor): raise TypeError(f'{name} must be a torch.Tensor') @@ -97,8 +97,8 @@ def log( sync_dist_group: Optional[Any] = None, sync_fn: Callable = None, dataloader_idx: Optional[int] = None, - device: torch.device = None, - ): + device: Optional[Union[str, torch.device]] = None, + ) -> None: # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): value = value.detach() @@ -195,7 +195,7 @@ def __set_meta( tbptt_reduce_fx: Callable, forked: bool, dataloader_idx: Union[int, None], - ): + ) -> None: # set the meta for the item meta_value = value meta = dict( @@ -217,12 +217,12 @@ def __set_meta( _internal = self['meta']['_internal'] _internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch) - def track_batch_size(self, batch): + def track_batch_size(self, batch: Any) -> None: batch_size = Result.extract_batch_size(batch) Result.attach_batch_size(batch_size, self) @staticmethod - def extract_batch_size(batch): + def extract_batch_size(batch: Any) -> int: try: batch_size = Result.unpack_batch_size(batch) except RecursionError: @@ -235,7 +235,7 @@ def attach_batch_size(batch_size: Union[int, None], result: 'Result') -> None: meta = result['meta'] meta['_internal']['batch_sizes'].append(batch_size) - def get_batch_sizes(self): + def get_batch_sizes(self) -> torch.Tensor: meta = self['meta'] return torch.tensor(meta['_internal']['batch_sizes']) @@ -244,7 +244,7 @@ def _add_dataloader_idx(self, k: str, dataloader_idx: Union[int, None], add_data return f"{k}/dataloader_idx_{dataloader_idx}" return k - def get_batch_log_metrics(self, include_forked_originals=True, add_dataloader_idx=False) -> dict: + def get_batch_log_metrics(self, include_forked_originals: bool = True, add_dataloader_idx: bool = False) -> dict: """ Gets the metrics to log at the end of the batch step @@ -269,7 +269,7 @@ def get_batch_log_metrics(self, include_forked_originals=True, add_dataloader_id return result - def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: + def get_epoch_log_metrics(self, add_dataloader_idx: bool = False) -> dict: """ Gets the metrics to log at the end of epoch """ @@ -296,7 +296,7 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: return result - def get_epoch_pbar_metrics(self, add_dataloader_idx=False): + def get_epoch_pbar_metrics(self, add_dataloader_idx: bool = False) -> dict: """ Gets the metrics to log at the end of epoch """ @@ -324,7 +324,7 @@ def get_epoch_pbar_metrics(self, add_dataloader_idx=False): return result - def get_forked_metrics(self, add_dataloader_idx=False): + def get_forked_metrics(self, add_dataloader_idx: bool = False) -> dict: """ Gets the metrics to log at the end of epoch """ @@ -345,7 +345,7 @@ def get_forked_metrics(self, add_dataloader_idx=False): return result - def get_batch_pbar_metrics(self, include_forked_originals=True, add_dataloader_idx=False): + def get_batch_pbar_metrics(self, include_forked_originals: bool = True, add_dataloader_idx: bool = False) -> dict: """ Gets the metrics to log at the end of the batch step """ @@ -375,10 +375,10 @@ def detach(self) -> 'Result': self.__setitem__(k, v.detach()) return self - def to(self, *args, **kwargs) -> 'Result': + def to(self, *args: Any, **kwargs: Any) -> 'Result': """Move all self attributes to the given device.""" for k, v in self.items(): - if isinstance(v, torch.Tensor): + if isinstance(v, (torch.Tensor, torch.nn.Module)): self.__setitem__(k, v.to(*args, **kwargs)) return self @@ -386,7 +386,7 @@ def cpu(self) -> 'Result': """Move all self attributes to CPU.""" return self.to(torch.device("cpu")) - def __repr__(self): + def __repr__(self) -> str: self_copy = self.copy() if 'meta' in self_copy: @@ -394,13 +394,13 @@ def __repr__(self): return str(self_copy) - def __str__(self): + def __str__(self) -> str: copy = self.copy() del copy['meta'] return str(copy) - def __copy__(self): + def __copy__(self) -> 'Result': newone = type(self)() for k, v in self.items(): if isinstance(v, torch.Tensor): @@ -409,7 +409,7 @@ def __copy__(self): return newone @staticmethod - def unpack_batch_size(sample): + def unpack_batch_size(sample: Union[torch.Tensor, str, Mapping, Iterable, Any]) -> int: """ Recursively unpack sample to find a torch.Tensor. returns len(tensor) when found, or 1 when it hits an empty or non iterable. @@ -418,7 +418,7 @@ def unpack_batch_size(sample): size = sample.size(0) elif isinstance(sample, str): return len(sample) - elif isinstance(sample, dict): + elif isinstance(sample, Mapping): sample = next(iter(sample.values()), 1) size = Result.unpack_batch_size(sample) elif isinstance(sample, Iterable): @@ -429,9 +429,9 @@ def unpack_batch_size(sample): return size @classmethod - def gather(cls, outputs): + def gather(cls, outputs: List) -> Dict: meta = outputs[0].get('meta') - result = cls() + result: Dict = cls() result = recursive_gather(outputs, result) recursive_stack(result) @@ -440,9 +440,9 @@ def gather(cls, outputs): return result @classmethod - def padded_gather(cls, outputs): + def padded_gather(cls, outputs: List) -> 'Dict': meta = outputs[0].get('meta') - result = cls() + result: Dict = cls() result = recursive_gather(outputs, result) # find the padding used for other values @@ -470,7 +470,7 @@ def padded_gather(cls, outputs): return result @classmethod - def reduce_on_epoch_end(cls, outputs): + def reduce_on_epoch_end(cls, outputs: List) -> 'Dict': # get the batch sizes for all outputs batch_sizes = [] meta = {} @@ -478,9 +478,9 @@ def reduce_on_epoch_end(cls, outputs): batch_sizes.append(x.get_batch_sizes()) meta.update(x['meta']) - batch_sizes = torch.stack(batch_sizes).view(-1) + tensor_batch_sizes = torch.stack(batch_sizes).view(-1) - result = cls() + result: Dict = cls() result = recursive_gather(outputs, result) recursive_stack(result) @@ -496,10 +496,10 @@ def reduce_on_epoch_end(cls, outputs): if option['on_epoch']: fx = option['reduce_fx'] if fx == torch.mean: - if isinstance(result[k], list): + if isinstance(result[k], (list, tuple)): result[k] = torch.tensor(result[k]).float() try: - reduced_val = weighted_mean(result[k], batch_sizes) + reduced_val: Optional[Union[torch.Tensor, Dict]] = weighted_mean(result[k], tensor_batch_sizes) # todo: specify the expected Exceptions to come except Exception: reduced_val = torch.mean(result[k]) @@ -514,11 +514,11 @@ def reduce_on_epoch_end(cls, outputs): return result @classmethod - def reduce_across_time(cls, time_outputs): + def reduce_across_time(cls, time_outputs: List) -> Dict: # auto-reduce across time for tbptt meta = time_outputs[0]['meta'] - result = cls() + result: Dict = cls() result = recursive_gather(time_outputs, result) recursive_stack(result) @@ -541,7 +541,7 @@ def reduce_across_time(cls, time_outputs): result['meta'] = meta return result - def dp_reduce(self): + def dp_reduce(self) -> None: for k, value in self.items(): if k == 'meta' or isinstance(value, Metric): continue @@ -555,7 +555,7 @@ def dp_reduce(self): def should_reduce_on_epoch_end(self) -> bool: return self['meta']['_internal']['_reduce_on_epoch'] - def rename_keys(self, map_dict: dict): + def rename_keys(self, map_dict: dict) -> None: """ Maps key values to the target values. Useful when renaming variables in mass. @@ -572,7 +572,7 @@ def rename_keys(self, map_dict: dict): meta[dest] = meta[source] del meta[source] - def get_non_metrics_keys(self): + def get_non_metrics_keys(self) -> List: """ This function is used to filter metric keys for which the value isn't a Metric """ @@ -587,15 +587,17 @@ def reset(self) -> None: value.reset() -def choose_last(x): - if isinstance(x, (torch.Tensor, list)): - return x[-1] - if isinstance(x, dict): +def choose_last(x: Union[torch.Tensor, MutableMapping, Sequence]) -> Union[torch.Tensor, Mapping, Sequence, Any]: + if isinstance(x, MutableMapping): for k, v in x.items(): x[k] = x[k][-1] + return x + + if isinstance(x, (torch.Tensor, Sequence)): + return x[-1] -def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]: +def recursive_gather(outputs: Sequence[dict], result: Dict) -> Dict: for out in outputs: if 'meta' in out: del out['meta'] @@ -605,7 +607,7 @@ def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = if k == 'minimize' and v is None: continue - if isinstance(v, dict): + if isinstance(v, Dict): in_d = result.get(k, {}) v = recursive_gather([v], in_d) result[k] = v @@ -622,7 +624,7 @@ def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = return result -def recursive_stack(result: MutableMapping): +def recursive_stack(result: MutableMapping) -> None: for k, v in result.items(): if isinstance(v, dict): recursive_stack(v) @@ -630,9 +632,9 @@ def recursive_stack(result: MutableMapping): result[k] = collate_tensors(v) -def _recursive_fx_apply(input: dict, fx): +def _recursive_fx_apply(input: dict, fx: Callable) -> None: for k, v in input.items(): - if isinstance(v, list): + if isinstance(v, (list, tuple)): v = torch.tensor(v) if isinstance(v, torch.Tensor): @@ -658,21 +660,21 @@ def collate_tensors(items: Union[List, Tuple]) -> Union[Tensor, List, Tuple]: return items -def weighted_mean(result, weights): +def weighted_mean(result: Union[Dict, List, Tuple], weights: torch.Tensor) -> Optional[Union[torch.Tensor, Dict]]: - if isinstance(result, dict): + if isinstance(result, Mapping): _process_dataloader_aggregated_steps(result, weights) - else: - if isinstance(result, list): - result = torch.tensor(result) + return result + if isinstance(result, (list, tuple)): + tensor_result = torch.tensor(result) - weights = weights.to(result.device)[:result.size(0)] - numerator = torch.dot(result.float(), weights.transpose(-1, 0).float()) - result = numerator / weights.sum().float() - return result + weights = weights.to(tensor_result.device)[:tensor_result.size(0)] + numerator = torch.dot(tensor_result.float(), weights.transpose(-1, 0).float()) + tensor_result = numerator / weights.sum().float() + return tensor_result -def _process_dataloader_aggregated_steps(result, weights): +def _process_dataloader_aggregated_steps(result: MutableMapping, weights: torch.Tensor) -> None: internal_keys = {'meta'} moved = False diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 035a42338fe68..958ed73fd8870 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -24,7 +24,7 @@ import numpy as np import torch -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_only @@ -289,7 +289,7 @@ def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs): kwargs: Optional keywoard arguments, depends on the specific logger being used """ - def log_graph(self, model: LightningModule, input_array=None) -> None: + def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None: """ Record model graph @@ -381,7 +381,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: for logger in self._logger_iterable: logger.log_hyperparams(params) - def log_graph(self, model: LightningModule, input_array=None) -> None: + def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None: for logger in self._logger_iterable: logger.log_graph(model, input_array) diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 148e512f5e439..498a16a9daa29 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -24,7 +24,7 @@ import torch from torch import is_tensor -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -318,6 +318,6 @@ def __getstate__(self): state["_experiment"] = None return state - def log_graph(self, model: LightningModule, input_array=None) -> None: + def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None: if self._experiment is not None: self._experiment.set_model_graph(model) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index af3802476571b..eeea73ee1632b 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -25,7 +25,7 @@ from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard.summary import hparams -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn @@ -210,7 +210,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> raise ValueError(m) from ex @rank_zero_only - def log_graph(self, model: LightningModule, input_array=None): + def log_graph(self, model: 'pl.LightningModule', input_array=None): if self._log_graph: if input_array is None: input_array = model.example_input_array diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index 84f231b0f16d7..0d176b4bcc279 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -18,7 +18,7 @@ from argparse import Namespace from typing import Any, Dict, Optional, Union -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn @@ -153,7 +153,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> self.experiment.log(metrics, global_step=step) @rank_zero_only - def log_graph(self, model: LightningModule, input_array=None): + def log_graph(self, model: 'pl.LightningModule', input_array=None): if self._log_graph: if input_array is None: input_array = model.example_input_array diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index c1ea3287964a8..86c428cc6243e 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -135,7 +135,7 @@ def clip_grad_by_norm( device = parameters[0].device if norm_type == math.inf: - total_norm = max(p.grad.data.abs().max() for p in parameters) + total_norm = max([p.grad.data.abs().max() for p in parameters]) else: out = torch.empty(len(parameters), device=device) for i, p in enumerate(parameters): diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index fa2c2917f98a2..462527a014304 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -17,23 +17,19 @@ import os from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Set, Type, Union import torch from torch import nn, Tensor -from torch.autograd.profiler import record_function +from torch.autograd.profiler import EventList, record_function +from torch.utils.hooks import RemovableHandle +import pytorch_lightning as pl from pytorch_lightning.profiler.profilers import BaseProfiler from pytorch_lightning.utilities.distributed import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE -if TYPE_CHECKING: - from torch.autograd.profiler import EventList - from torch.utils.hooks import RemovableHandle - - from pytorch_lightning.core.lightning import LightningModule - if _KINETO_AVAILABLE: from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler @@ -301,7 +297,7 @@ def __init__( self.profiler: Optional[_PROFILER] = None self.function_events: Optional['EventList'] = None - self._lightning_module: Optional['LightningModule'] = None # set by ProfilerConnector + self._lightning_module: Optional['pl.LightningModule'] = None # set by ProfilerConnector self._register: Optional[RegisterRecordFunction] = None self._parent_profiler: Optional[_PROFILER] = None self._recording_map: Dict[str, record_function] = {} diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 544b229a21728..08d6698aa3f3c 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -13,7 +13,7 @@ # limitations under the License. import os from datetime import timedelta -from typing import List, Union, Optional, Dict +from typing import Dict, List, Optional, Union from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase from pytorch_lightning.callbacks.timer import Timer @@ -115,9 +115,7 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic if max_time is None: return if any(isinstance(cb, Timer) for cb in self.trainer.callbacks): - rank_zero_info( - "Ignoring `Trainer(max_time=...)`, callbacks list already contains a Timer." - ) + rank_zero_info("Ignoring `Trainer(max_time=...)`, callbacks list already contains a Timer.") return timer = Timer(duration=max_time, interval="step") self.trainer.callbacks.append(timer) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 9f10ca8306ff3..90d6f31739b87 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -49,7 +49,9 @@ def on_trainer_init(self) -> None: # when true, print evaluation results in .validate() and .test() self.trainer.verbose_evaluate = True - def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: + def get_evaluation_dataloaders( + self + ) -> Tuple[Optional[List[DataLoader]], Union[List[int], List[Union[int, float]]]]: model = self.trainer.lightning_module # select dataloaders @@ -57,7 +59,7 @@ def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[U self.trainer.reset_test_dataloader(model) dataloaders = self.trainer.test_dataloaders - max_batches = self.trainer.num_test_batches + max_batches: Union[List[int], List[Union[int, float]]] = self.trainer.num_test_batches else: # val if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 2f01363086478..5515f21a5ad7a 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -15,7 +15,7 @@ import os from abc import ABC from argparse import ArgumentParser, Namespace -from typing import cast, List, Optional, Type, TypeVar, Union +from typing import cast, Dict, List, Optional, Type, TypeVar, Union import torch from torch.optim import Optimizer @@ -58,6 +58,18 @@ class TrainerProperties(ABC): limit_val_batches: int logger: LightningLoggerBase logger_connector: LoggerConnector + current_epoch: int + global_step: int + fast_dev_run: bool + should_stop: bool + log_every_n_steps: int + accumulate_grad_batches: Union[int, Dict[int, int], List[list]] + max_epochs: int + check_val_every_n_epoch: int + batch_idx: int + num_sanity_val_batches: List[int] + split_idx: int + truncated_bptt_steps: Optional[int] @property def accelerator(self) -> Accelerator: diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index e100a803bcd00..c517f161d8b3f 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -123,7 +123,7 @@ def __subclasshook__(cls, subclass): return NotImplemented -def move_data_to_device(batch: Any, device: torch.device): +def move_data_to_device(batch: Any, device: Union[str, torch.device]): """ Transfers a collection of data to the given device. Any object that defines a method ``to(device)`` will be moved and all other objects in the collection will be left untouched. @@ -161,7 +161,7 @@ def batch_to(data): return apply_to_collection(batch, dtype=dtype, function=batch_to) -def convert_to_tensors(data, device: torch.device = None): +def convert_to_tensors(data, device: Optional[Union[torch.device, str]] = None): if device is None: raise MisconfigurationException("device (torch.device) should be provided.") diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index c1c40b98c71c7..5d2668287c8ec 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -1,13 +1,16 @@ -from typing import Any, Dict, Iterator, List, Union - -import torch -from torchmetrics import Metric """ Convention: - Do not include any `_TYPE` suffix - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`) """ + +from typing import Any, Dict, Iterator, List, Sequence, Union + +import torch +from torchmetrics import Metric + _METRIC = Union[Metric, torch.Tensor, int, float] STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]] EPOCH_OUTPUT = List[STEP_OUTPUT] _PARAMETERS = Iterator[torch.nn.Parameter] +BATCH = Union[Dict[str, Union[torch.Tensor, Any]], Sequence[Union[torch.Tensor, Any]], torch.Tensor, Any] diff --git a/setup.cfg b/setup.cfg index 3fa6e39076725..391840ce7f30d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -115,14 +115,6 @@ allow_redefinition = True # disable this rule as the Trainer attributes are defined in the connectors, not in its __init__ disable_error_code = attr-defined -# todo: add proper typing to this module... -[mypy-pytorch_lightning.callbacks.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.core.*] -ignore_errors = True - # todo: add proper typing to this module... [mypy-pytorch_lightning.loggers.*] ignore_errors = True diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 3844d16edb517..68e8aaee94037 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -213,11 +213,13 @@ def test_early_stopping_no_val_step(tmpdir): assert trainer.current_epoch < trainer.max_epochs - 1 -@pytest.mark.parametrize("stopping_threshold,divergence_theshold,losses,expected_epoch", [ - (None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5), - (2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8), - (None, 15.9, [9, 4, 2, 16, 32, 64], 3), -]) +@pytest.mark.parametrize( + "stopping_threshold,divergence_theshold,losses,expected_epoch", [ + (None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5), + (2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8), + (None, 15.9, [9, 4, 2, 16, 32, 64], 3), + ] +) def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_theshold, losses, expected_epoch): class CurrentModel(BoringModel): diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 1e39cc9c330b0..098289175ebc2 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -183,25 +183,26 @@ def test_pruning_callback_ddp_cpu(tmpdir): train_with_pruning_callback(tmpdir, parameters_to_prune=True, accelerator="ddp_cpu", num_processes=2) -@pytest.mark.parametrize("resample_parameters", (False, True)) -def test_pruning_lth_callable(tmpdir, resample_parameters: bool): - model = TestModel() +class ModelPruningTestCallback(ModelPruning): + lth_calls = 0 - class ModelPruningTestCallback(ModelPruning): - lth_calls = 0 + def apply_lottery_ticket_hypothesis(self): + super().apply_lottery_ticket_hypothesis() + self.lth_calls += 1 - def apply_lottery_ticket_hypothesis(self): - super().apply_lottery_ticket_hypothesis() - self.lth_calls += 1 + for d in self._original_layers.values(): + copy, names = d["data"], d["names"] + for i, name in names: + curr, curr_name = self._parameters_to_prune[i] + assert name == curr_name + actual, expected = getattr(curr, name).data, getattr(copy, name).data + allclose = torch.allclose(actual, expected) + assert not allclose if self._resample_parameters else allclose - for d in self._original_layers.values(): - copy, names = d["data"], d["names"] - for i, name in names: - curr, curr_name = self._parameters_to_prune[i] - assert name == curr_name - actual, expected = getattr(curr, name).data, getattr(copy, name).data - allclose = torch.allclose(actual, expected) - assert not allclose if self._resample_parameters else allclose + +@pytest.mark.parametrize("resample_parameters", (False, True)) +def test_pruning_lth_callable(tmpdir, resample_parameters: bool): + model = TestModel() pruning = ModelPruningTestCallback( "l1_unstructured", use_lottery_ticket_hypothesis=lambda e: bool(e % 2), resample_parameters=resample_parameters @@ -270,6 +271,15 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool assert not has_pruning if make_pruning_permanent else has_pruning +class TestPruning(ModelPruning): + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + ret_val = super().on_save_checkpoint(trainer, pl_module, checkpoint) + assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"] + assert hasattr(pl_module.layer.mlp_3, "weight_orig") + return ret_val + + def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog): """ When a model is saved multiple times and make_permanent=True, we need to @@ -278,13 +288,6 @@ def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog): """ seed_everything(0) - class TestPruning(ModelPruning): - - def on_save_checkpoint(self, trainer, pl_module, checkpoint): - super().on_save_checkpoint(trainer, pl_module, checkpoint) - assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"] - assert hasattr(pl_module.layer.mlp_3, "weight_orig") - model = TestModel() pruning_callback = TestPruning( "random_unstructured",