diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index fede7f5df7291..f1f9b4eef5785 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -171,7 +171,7 @@ def setup( # Let accelerator/plugin wrap and connect the models and optimizers model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers)) model = _LiteModule(model, self._precision_plugin) - optimizers = [_LiteOptimizer(optimizer=optimizer, accelerator=self._accelerator) for optimizer in optimizers] + optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] self._models_setup += 1 if optimizers: # join both types in a list for API convenience diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 6a37453eec93d..c238526778d80 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -19,9 +19,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin -from pytorch_lightning.plugins import PrecisionPlugin +from pytorch_lightning.plugins import PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device @@ -30,31 +29,29 @@ def _do_nothing_closure() -> None: class _LiteOptimizer: - def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: + def __init__(self, optimizer: Optimizer, strategy: TrainingTypePlugin) -> None: """LiteOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer - step calls to the accelerator/strategy plugin. + step calls to the strategy plugin. The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`. Args: optimizer: The optimizer to wrap - accelerator: Reference to the accelerator for handling the optimizer step + strategy: Reference to the strategy for handling the optimizer step """ # `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would # not want to call on destruction of the `_LiteOptimizer self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")} self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer - self._accelerator = accelerator - # TODO (@awaelchli) refactor to take Strategy as param - self._strategy = self._accelerator.training_type_plugin + self._strategy = strategy @property def optimizer(self) -> Optimizer: return self._optimizer def state_dict(self) -> Dict[str, Tensor]: - return self._accelerator.optimizer_state(self.optimizer) + return self._strategy.optimizer_state(self.optimizer) def step(self, closure: Optional[Callable] = None) -> None: closure = closure or _do_nothing_closure diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index ff9b9e2ddb8ce..65522b74936c8 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -17,7 +17,6 @@ import torch from torch.utils.data.dataloader import DataLoader -from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer @@ -144,21 +143,19 @@ def test_lite_optimizer_wraps(): def test_lite_optimizer_state_dict(): - """Test that the LiteOptimizer calls into the accelerator/strategy to collect the state.""" + """Test that the LiteOptimizer calls into the strategy to collect the state.""" optimizer = Mock() - accelerator = Mock() - lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator) + strategy = Mock() + lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy) lite_optimizer.state_dict() - accelerator.optimizer_state.assert_called_with(optimizer) + strategy.optimizer_state.assert_called_with(optimizer) def test_lite_optimizer_steps(): """Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer.""" optimizer = Mock() strategy = Mock() - accelerator = Accelerator(None, strategy) - lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator) + lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy) lite_optimizer.step() - strategy = accelerator.training_type_plugin strategy.optimizer_step.assert_called_once() - strategy.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=accelerator.model) + strategy.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=strategy.model)