Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update LiteOptimizer signature after optimizer changes in TrainingTypePlugin #10708

Merged
merged 4 commits into from
Nov 30, 2021
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.plugins import PrecisionPlugin, TrainingTypePlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device


Expand All @@ -30,24 +29,22 @@ def _do_nothing_closure() -> None:


class _LiteOptimizer:
def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None:
def __init__(self, optimizer: Optimizer, strategy: TrainingTypePlugin) -> None:
"""LiteOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer
step calls to the accelerator/strategy plugin.
step calls to the strategy plugin.

The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`.

Args:
optimizer: The optimizer to wrap
accelerator: Reference to the accelerator for handling the optimizer step
strategy: Reference to the strategy for handling the optimizer step
"""
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
# not want to call on destruction of the `_LiteOptimizer
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")}
self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
self._optimizer = optimizer
self._accelerator = accelerator
# TODO (@awaelchli) refactor to take Strategy as param
self._strategy = self._accelerator.training_type_plugin
self._strategy = strategy
justusschock marked this conversation as resolved.
Show resolved Hide resolved

@property
def optimizer(self) -> Optimizer:
Expand Down