From 148f9980359e263ac0ed93f115a7ba80e801949b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 23 Nov 2021 15:23:55 +0100 Subject: [PATCH 1/3] update --- pytorch_lightning/lite/wrappers.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 3cd2f5eb69712..c8142642ea2e2 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -23,9 +23,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin -from pytorch_lightning.plugins import PrecisionPlugin +from pytorch_lightning.plugins import PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device @@ -34,22 +33,22 @@ def _do_nothing_closure() -> None: class _LiteOptimizer: - def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: + def __init__(self, optimizer: Optimizer, strategy: TrainingTypePlugin) -> None: """LiteOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer - step calls to the accelerator/strategy plugin. + step calls to the strategy plugin. The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`. Args: optimizer: The optimizer to wrap - accelerator: Reference to the accelerator for handling the optimizer step + strategy: Reference to the strategy for handling the optimizer step """ # `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would # not want to call on destruction of the `_LiteOptimizer self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")} self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer - self._accelerator = accelerator + self._strategy = strategy @property def optimizer(self) -> Optimizer: @@ -57,11 +56,11 @@ def optimizer(self) -> Optimizer: def step(self, closure: Optional[Callable] = None) -> None: closure = closure or _do_nothing_closure - self._accelerator.optimizer_step( + self._strategy.optimizer_step( self.optimizer, opt_idx=0, closure=closure, - model=self._accelerator.model, + model=self._strategy.model, ) From f4a5a82a5d2bd1d58531eb5aebb7ab2ada048cab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 12:01:04 +0100 Subject: [PATCH 2/3] update references --- pytorch_lightning/lite/lite.py | 2 +- pytorch_lightning/lite/wrappers.py | 2 +- tests/lite/test_wrappers.py | 14 ++++++-------- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index fede7f5df7291..f1f9b4eef5785 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -171,7 +171,7 @@ def setup( # Let accelerator/plugin wrap and connect the models and optimizers model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers)) model = _LiteModule(model, self._precision_plugin) - optimizers = [_LiteOptimizer(optimizer=optimizer, accelerator=self._accelerator) for optimizer in optimizers] + optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] self._models_setup += 1 if optimizers: # join both types in a list for API convenience diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 1594124f82a0f..c238526778d80 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -51,7 +51,7 @@ def optimizer(self) -> Optimizer: return self._optimizer def state_dict(self) -> Dict[str, Tensor]: - return self._accelerator.optimizer_state(self.optimizer) + return self._strategy.optimizer_state(self.optimizer) def step(self, closure: Optional[Callable] = None) -> None: closure = closure or _do_nothing_closure diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index ff9b9e2ddb8ce..8d31e673f446a 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -144,21 +144,19 @@ def test_lite_optimizer_wraps(): def test_lite_optimizer_state_dict(): - """Test that the LiteOptimizer calls into the accelerator/strategy to collect the state.""" + """Test that the LiteOptimizer calls into the strategy to collect the state.""" optimizer = Mock() - accelerator = Mock() - lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator) + strategy = Mock() + lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy) lite_optimizer.state_dict() - accelerator.optimizer_state.assert_called_with(optimizer) + strategy.optimizer_state.assert_called_with(optimizer) def test_lite_optimizer_steps(): """Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer.""" optimizer = Mock() strategy = Mock() - accelerator = Accelerator(None, strategy) - lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator) + lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy) lite_optimizer.step() - strategy = accelerator.training_type_plugin strategy.optimizer_step.assert_called_once() - strategy.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=accelerator.model) + strategy.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=strategy.model) From 075d6b45b3f5779dc1d1b7e865dde8ebc3fd82ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 13:46:04 +0100 Subject: [PATCH 3/3] unused import --- tests/lite/test_wrappers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 8d31e673f446a..65522b74936c8 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -17,7 +17,6 @@ import torch from torch.utils.data.dataloader import DataLoader -from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer