Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing callbacks #7042

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
ab22de1
update typing accelerators
justusschock Apr 13, 2021
45414a8
add typing to early stopping and pruning
justusschock Apr 14, 2021
31816d7
type model_checkpoint
justusschock Apr 14, 2021
827fd33
add typing lr monitor
justusschock Apr 14, 2021
2c619ed
add typing
justusschock Apr 14, 2021
3baddcf
typing gradient accumulation callback
justusschock Apr 14, 2021
735d7c9
type gpu stats monitor
justusschock Apr 14, 2021
d373526
add typing to swa
justusschock Apr 14, 2021
4ba5828
type quantization
justusschock Apr 14, 2021
3c9d143
add typing for progbar
justusschock Apr 14, 2021
bb79b10
type datamodule
justusschock Apr 14, 2021
14a92bb
type decorators
justusschock Apr 14, 2021
d588b08
type hooks
justusschock Apr 14, 2021
f9844ec
type memory
justusschock Apr 14, 2021
0c84926
type lightning optimizer
justusschock Apr 14, 2021
ac37abe
type saving.py
justusschock Apr 14, 2021
9c16fa3
type results
justusschock Apr 14, 2021
5696ffb
fix typing of core
justusschock Apr 15, 2021
6aa531e
add typing to early stopping and pruning
justusschock Apr 14, 2021
7ee888c
fix annotation
justusschock Apr 15, 2021
51badf0
unused imports
justusschock Apr 15, 2021
219bd3d
remove forward declaration for docs
justusschock Apr 19, 2021
662b7e1
fix forward declarations
justusschock Apr 19, 2021
ccba3e5
fix forward declarations
justusschock Apr 20, 2021
f21cd0e
pre-commit
justusschock Apr 20, 2021
580b1d1
remove trainingmode from typing
justusschock Apr 20, 2021
d38048e
update typing again
justusschock Apr 20, 2021
0b886b3
Pruning fix
carmocca Apr 20, 2021
f4abb15
fix callback failures except pruning
justusschock Apr 20, 2021
f4d9ca6
Merge branch 'typing_callbacks' of https://github.com/PyTorchLightnin…
justusschock Apr 20, 2021
80b6939
update typing
justusschock Apr 20, 2021
5e8c177
fix about
justusschock Apr 20, 2021
282245a
Apply suggestions from code review
justusschock Apr 20, 2021
7b16d8b
fix return type annotations with outputs
justusschock Apr 20, 2021
92ce22b
Merge branch 'typing_callbacks' of https://github.com/PyTorchLightnin…
justusschock Apr 20, 2021
a09bcce
fix returns with None
justusschock Apr 20, 2021
f4e2eb0
Fix pruning mypy complain
carmocca Apr 20, 2021
4fc0029
fix torchscript
justusschock Apr 20, 2021
14f0f9f
fix tests
justusschock Apr 20, 2021
f5e6353
fix tests
justusschock Apr 20, 2021
3b00bed
mypy + pep8
justusschock Apr 20, 2021
1298300
pep8
justusschock Apr 20, 2021
113612b
Merge branch 'master' into typing_callbacks
justusschock Apr 20, 2021
0d249bc
unused imports
justusschock Apr 20, 2021
612043f
Merge branch 'typing_callbacks' of https://github.com/PyTorchLightnin…
justusschock Apr 20, 2021
813a328
tempfile import
justusschock Apr 20, 2021
f9321a6
doctest fix
justusschock Apr 20, 2021
7c7a443
Merge branch 'master' into typing_callbacks
justusschock Apr 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ To manually optimize, do the following:
* ``optimizer.step()`` to update your model parameters

Here is a minimal example of manual optimization.

.. testcode:: python

from pytorch_lightning import LightningModule
Expand Down
17 changes: 3 additions & 14 deletions pl_examples/basic_examples/dali_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,26 +194,15 @@ def __init__(

def train_dataloader(self):
return DALIClassificationLoader(
self.pipe_train,
size=len(self.mnist_train),
auto_reset=True,
fill_last_batch=True
self.pipe_train, size=len(self.mnist_train), auto_reset=True, fill_last_batch=True
)

def val_dataloader(self):
return DALIClassificationLoader(
self.pipe_val,
size=len(self.mnist_val),
auto_reset=True,
fill_last_batch=False
)
return DALIClassificationLoader(self.pipe_val, size=len(self.mnist_val), auto_reset=True, fill_last_batch=False)

def test_dataloader(self):
return DALIClassificationLoader(
self.pipe_test,
size=len(self.mnist_test),
auto_reset=True,
fill_last_batch=False
self.pipe_test, size=len(self.mnist_test), auto_reset=True, fill_last_batch=False
)


Expand Down
5 changes: 1 addition & 4 deletions pl_examples/basic_examples/profiler_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@

class ModelToProfile(LightningModule):

def __init__(
self,
name: str = "resnet50"
):
def __init__(self, name: str = "resnet50"):
super().__init__()
self.model = getattr(models, name)(pretrained=True)
self.criterion = torch.nn.CrossEntropyLoss()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/__about__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

_this_year = time.strftime("%Y")
__version__ = '1.3.0rc2'
__version__ = "20210420"
__author__ = 'William Falcon et al.'
__author_email__ = '[email protected]'
__license__ = 'Apache-2.0'
Expand Down
11 changes: 4 additions & 7 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,11 @@
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT

if _NATIVE_AMP_AVAILABLE:
from torch.cuda.amp import GradScaler


class Accelerator:
"""
Expand Down Expand Up @@ -374,7 +371,7 @@ def to_device(self, batch: Any) -> Any:
return self.batch_to_device(batch, self.root_device)

@property
def amp_backend(self) -> Optional[LightningEnum]:
def amp_backend(self) -> Optional['AMPType']:
if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
return AMPType.APEX
elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin):
Expand All @@ -386,7 +383,7 @@ def precision(self) -> Union[str, int]:
return self.precision_plugin.precision

@property
def scaler(self) -> Optional['GradScaler']:
def scaler(self) -> Optional['torch.cuda.amp.GradScaler']:
return getattr(self.precision_plugin, 'scaler', None)

@property
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@
# limitations under the License.
import logging
import os
from typing import Any
from typing import Any, TYPE_CHECKING

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins import DataParallelPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException

_log = logging.getLogger(__name__)

import pytorch_lightning as pl


class GPUAccelerator(Accelerator):
""" Accelerator for GPU devices. """
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Union
from typing import Any, Callable, TYPE_CHECKING, Union

from torch.optim import Optimizer

Expand Down Expand Up @@ -51,7 +51,7 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
return super().setup(trainer, model)

def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
self, optimizer: 'Optimizer', optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})

Expand Down
91 changes: 51 additions & 40 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
"""

import abc
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TYPE_CHECKING

from pytorch_lightning.core.lightning import LightningModule
from torch.optim import Optimizer

import pytorch_lightning as pl


class Callback(abc.ABC):
Expand All @@ -29,158 +31,165 @@ class Callback(abc.ABC):
Subclass this class and override any of the relevant hooks
"""

def on_configure_sharded_model(self, trainer, pl_module: LightningModule) -> None:
def on_configure_sharded_model(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called before configure sharded model"""

def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule) -> None:
def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called before accelerator is being setup"""
pass

def setup(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
def setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', stage: Optional[str] = None) -> None:
"""Called when fit, validate, test, predict, or tune begins"""
pass

def teardown(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
def teardown(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', stage: Optional[str] = None) -> None:
"""Called when fit, validate, test, predict, or tune ends"""
pass

def on_init_start(self, trainer) -> None:
def on_init_start(self, trainer: 'pl.Trainer') -> None:
"""Called when the trainer initialization begins, model has not yet been set."""
pass

def on_init_end(self, trainer) -> None:
def on_init_end(self, trainer: 'pl.Trainer') -> None:
"""Called when the trainer initialization ends, model has not yet been set."""
pass

def on_fit_start(self, trainer, pl_module: LightningModule) -> None:
def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when fit begins"""
pass

def on_fit_end(self, trainer, pl_module: LightningModule) -> None:
def on_fit_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when fit ends"""
pass

def on_sanity_check_start(self, trainer, pl_module: LightningModule) -> None:
def on_sanity_check_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the validation sanity check starts."""
pass

def on_sanity_check_end(self, trainer, pl_module: LightningModule) -> None:
def on_sanity_check_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the validation sanity check ends."""
pass

def on_train_batch_start(
self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""Called when the train batch begins."""
pass

def on_train_batch_end(
self, trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: Any, batch: Any, batch_idx: int,
dataloader_idx: int
) -> None:
"""Called when the train batch ends."""
pass

def on_train_epoch_start(self, trainer, pl_module: LightningModule) -> None:
def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the train epoch begins."""
pass

def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: List[Any]) -> None:
"""Called when the train epoch ends."""
pass

def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None:
def on_validation_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the val epoch begins."""
pass

def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
def on_validation_epoch_end(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: List[Any]
) -> None:
"""Called when the val epoch ends."""
pass

def on_test_epoch_start(self, trainer, pl_module: LightningModule) -> None:
def on_test_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the test epoch begins."""
pass

def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: List[Any]) -> None:
"""Called when the test epoch ends."""
pass

def on_epoch_start(self, trainer, pl_module: LightningModule) -> None:
def on_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when either of train/val/test epoch begins."""
pass

def on_epoch_end(self, trainer, pl_module: LightningModule) -> None:
def on_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when either of train/val/test epoch ends."""
pass

def on_batch_start(self, trainer, pl_module: LightningModule) -> None:
def on_batch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the training batch begins."""
pass

def on_validation_batch_start(
self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""Called when the validation batch begins."""
pass

def on_validation_batch_end(
self, trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: Any, batch: Any, batch_idx: int,
dataloader_idx: int
) -> None:
"""Called when the validation batch ends."""
pass

def on_test_batch_start(
self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""Called when the test batch begins."""
pass

def on_test_batch_end(
self, trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: Any, batch: Any, batch_idx: int,
dataloader_idx: int
) -> None:
"""Called when the test batch ends."""
pass

def on_batch_end(self, trainer, pl_module: LightningModule) -> None:
def on_batch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the training batch ends."""
pass

def on_train_start(self, trainer, pl_module: LightningModule) -> None:
def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the train begins."""
pass

def on_train_end(self, trainer, pl_module: LightningModule) -> None:
def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the train ends."""
pass

def on_pretrain_routine_start(self, trainer, pl_module: LightningModule) -> None:
def on_pretrain_routine_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the pretrain routine begins."""
pass

def on_pretrain_routine_end(self, trainer, pl_module: LightningModule) -> None:
def on_pretrain_routine_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the pretrain routine ends."""
pass

def on_validation_start(self, trainer, pl_module: LightningModule) -> None:
def on_validation_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the validation loop begins."""
pass

def on_validation_end(self, trainer, pl_module: LightningModule) -> None:
def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the validation loop ends."""
pass

def on_test_start(self, trainer, pl_module: LightningModule) -> None:
def on_test_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the test begins."""
pass

def on_test_end(self, trainer, pl_module: LightningModule) -> None:
def on_test_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the test ends."""
pass

def on_keyboard_interrupt(self, trainer, pl_module: LightningModule) -> None:
def on_keyboard_interrupt(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the training is interrupted by ``KeyboardInterrupt``."""
pass

def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> dict:
def on_save_checkpoint(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', checkpoint: Dict[str, Any]
) -> dict:
"""
Called when saving a model checkpoint, use to persist state.

Expand All @@ -202,10 +211,12 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None:
"""
pass

def on_after_backward(self, trainer, pl_module: LightningModule) -> None:
def on_after_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called after ``loss.backward()`` and before optimizers do anything."""
pass

def on_before_zero_grad(self, trainer, pl_module: LightningModule, optimizer) -> None:
def on_before_zero_grad(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', optimizer: 'Optimizer'
) -> None:
"""Called after ``optimizer.step()`` and before ``optimizer.zero_grad()``."""
pass
Loading