From 10d0b41977be99ed0d71d5c5eba2bce19b21f149 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 14:58:19 +0200 Subject: [PATCH] Introduce `PrecisionPlugin.forward_context()` (#9988) Co-authored-by: thomas chaton --- CHANGELOG.md | 4 +++ pytorch_lightning/plugins/precision/double.py | 32 +------------------ .../plugins/precision/native_amp.py | 20 +----------- .../plugins/precision/precision_plugin.py | 25 ++++++++++----- 4 files changed, 23 insertions(+), 58 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 43ebc6464ac51..a10c5cb41cc18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -194,6 +194,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597)) +- LightningLite: + * Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988)) + + ### Changed - Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)). diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py index 179daf9e91db8..5e9e8bd43b820 100644 --- a/pytorch_lightning/plugins/precision/double.py +++ b/pytorch_lightning/plugins/precision/double.py @@ -92,37 +92,7 @@ def connect( return super().connect(model, optimizers, lr_schedulers) @contextmanager - def train_step_context(self) -> Generator[None, None, None]: - """A context manager to change the default tensor type. - - See: :meth:`torch.set_default_tensor_type` - """ - torch.set_default_tensor_type(torch.DoubleTensor) - yield - torch.set_default_tensor_type(torch.FloatTensor) - - @contextmanager - def val_step_context(self) -> Generator[None, None, None]: - """A context manager to change the default tensor type. - - See: :meth:`torch.set_default_tensor_type` - """ - torch.set_default_tensor_type(torch.DoubleTensor) - yield - torch.set_default_tensor_type(torch.FloatTensor) - - @contextmanager - def test_step_context(self) -> Generator[None, None, None]: - """A context manager to change the default tensor type. - - See: :meth:`torch.set_default_tensor_type` - """ - torch.set_default_tensor_type(torch.DoubleTensor) - yield - torch.set_default_tensor_type(torch.FloatTensor) - - @contextmanager - def predict_step_context(self) -> Generator[None, None, None]: + def forward_context(self) -> Generator[None, None, None]: """A context manager to change the default tensor type. See: :meth:`torch.set_default_tensor_type` diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 8f93b63588c19..50c527f5f407d 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -102,25 +102,7 @@ def autocast_context_manager(self) -> torch.cuda.amp.autocast: return torch.cuda.amp.autocast() @contextmanager - def train_step_context(self) -> Generator[None, None, None]: - """Enable autocast context.""" - with self.autocast_context_manager(): - yield - - @contextmanager - def val_step_context(self) -> Generator[None, None, None]: - """Enable autocast context.""" - with self.autocast_context_manager(): - yield - - @contextmanager - def test_step_context(self) -> Generator[None, None, None]: - """Enable autocast context.""" - with self.autocast_context_manager(): - yield - - @contextmanager - def predict_step_context(self) -> Generator[None, None, None]: + def forward_context(self) -> Generator[None, None, None]: """Enable autocast context.""" with self.autocast_context_manager(): yield diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 5138bb660b9cd..c81a474faad34 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -143,21 +143,30 @@ def post_dispatch(self) -> None: """Hook to do something after the training/evaluation/prediction finishes.""" @contextlib.contextmanager - def train_step_context(self) -> Generator: - """A contextmanager for the training step.""" + def forward_context(self) -> Generator[None, None, None]: + """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" yield @contextlib.contextmanager - def val_step_context(self) -> Generator: + def train_step_context(self) -> Generator[None, None, None]: + """A contextmanager for the training step.""" + with self.forward_context(): + yield + + @contextlib.contextmanager + def val_step_context(self) -> Generator[None, None, None]: """A contextmanager for the validation step.""" - yield + with self.forward_context(): + yield @contextlib.contextmanager - def test_step_context(self) -> Generator: + def test_step_context(self) -> Generator[None, None, None]: """A contextmanager for the test step.""" - yield + with self.forward_context(): + yield @contextlib.contextmanager - def predict_step_context(self) -> Generator: + def predict_step_context(self) -> Generator[None, None, None]: """A contextmanager for the predict step.""" - yield + with self.forward_context(): + yield