Skip to content

Commit

Permalink
More robust FabricOptimizer/LightningOptimizer wrapping logic (#18516)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Sep 12, 2023
1 parent 12af771 commit e958c6f
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 48 deletions.
24 changes: 4 additions & 20 deletions src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,12 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional
optimizer: The optimizer to wrap
strategy: Reference to the strategy for handling the optimizer step
"""
self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
self._optimizer = optimizer
self._strategy = strategy
self._callbacks = callbacks or []
self._refresh()
# imitate the class of the wrapped object to make isinstance checks work
self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})

@property
def optimizer(self) -> Optimizer:
Expand All @@ -65,9 +64,6 @@ def state_dict(self) -> Dict[str, Tensor]:

def load_state_dict(self, state_dict: Dict[str, Tensor]) -> None:
self.optimizer.load_state_dict(state_dict)
# `Optimizer.load_state_dict` modifies `optimizer.__dict__`, so we need to update the `__dict__` on
# this wrapper
self._refresh()

def step(self, closure: Optional[Callable] = None) -> Any:
kwargs = {"closure": closure} if closure is not None else {}
Expand All @@ -86,20 +82,8 @@ def step(self, closure: Optional[Callable] = None) -> Any:
hook(strategy=self._strategy, optimizer=optimizer)
return output

def _refresh(self) -> None:
"""Refreshes the ``__dict__`` so that it matches the internal states in the wrapped optimizer.
This is only needed to present the user with an updated view in case they inspect the state of this wrapper.
"""
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
# not want to call on destruction of the `_FabricOptimizer
self.__dict__.update(
{
k: v
for k, v in self.optimizer.__dict__.items()
if k not in ("load_state_dict", "state_dict", "step", "__del__")
}
)
def __getattr__(self, item: Any) -> Any:
return getattr(self._optimizer, item)


class _FabricModule(_DeviceDtypeModuleMixin):
Expand Down
4 changes: 1 addition & 3 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,6 @@ def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS:
opts: MODULE_OPTIMIZERS = self._fabric_optimizers
elif use_pl_optimizer:
opts = self.trainer.strategy._lightning_optimizers
for opt in opts:
opt.refresh()
else:
opts = self.trainer.optimizers

Expand Down Expand Up @@ -1089,7 +1087,7 @@ def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> N

# Then iterate over the current optimizer's parameters and set its `requires_grad`
# properties accordingly
for group in optimizer.param_groups: # type: ignore[union-attr]
for group in optimizer.param_groups:
for param in group["params"]:
param.requires_grad = param_requires_grad_state[param]
self._param_requires_grad_state = param_requires_grad_state
Expand Down
18 changes: 5 additions & 13 deletions src/lightning/pytorch/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,13 @@ class LightningOptimizer:
"""

def __init__(self, optimizer: Optimizer):
self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})

self._optimizer = optimizer
self._strategy: Optional[pl.strategies.Strategy] = None
# to inject logic around the optimizer step, particularly useful with manual optimization
self._on_before_step = do_nothing_closure
self._on_after_step = do_nothing_closure

self.refresh()
# imitate the class of the wrapped object to make isinstance checks work
self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})

@property
def optimizer(self) -> Optimizer:
Expand Down Expand Up @@ -81,15 +79,6 @@ def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]:
yield
lightning_module.untoggle_optimizer(self)

def refresh(self) -> None:
"""Refreshes the ``__dict__`` so that it matches the internal states in the wrapped optimizer.
This is only needed to present the user with an updated view in case they inspect the state of this wrapper.
"""
# copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has
# implemented custom logic which we would not want to call on destruction of the `LightningOptimizer`
self.__dict__.update({k: v for k, v in self.optimizer.__dict__.items() if k not in ("step", "__del__")})

def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> Any:
"""Performs a single optimization step (parameter update).
Expand Down Expand Up @@ -175,6 +164,9 @@ def _to_lightning_optimizer(
lightning_optimizer._strategy = proxy(strategy)
return lightning_optimizer

def __getattr__(self, item: Any) -> Any:
return getattr(self._optimizer, item)


def _init_optimizers_and_lr_schedulers(
model: "pl.LightningModule",
Expand Down
2 changes: 2 additions & 0 deletions tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ def test_fabric_optimizer_wraps():
fabric_optimizer = _FabricOptimizer(optimizer, Mock())
assert fabric_optimizer.optimizer is optimizer
assert isinstance(fabric_optimizer, optimizer_cls)
assert isinstance(fabric_optimizer, _FabricOptimizer)
assert type(fabric_optimizer).__name__ == "FabricSGD"


def test_fabric_optimizer_state_dict():
Expand Down
12 changes: 0 additions & 12 deletions tests/tests_pytorch/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,7 @@ def test_state():
assert isinstance(lightning_optimizer, Adam)
assert isinstance(lightning_optimizer, Optimizer)

lightning_dict = {
k: v
for k, v in lightning_optimizer.__dict__.items()
if k not in {"_optimizer", "_strategy", "_lightning_module", "_on_before_step", "_on_after_step"}
}

assert lightning_dict == optimizer.__dict__
assert optimizer.state_dict() == lightning_optimizer.state_dict()
assert optimizer.state == lightning_optimizer.state


def test_state_mutation():
Expand All @@ -174,10 +166,6 @@ def test_state_mutation():
optimizer1 = torch.optim.Adam(model.parameters(), lr=100)
lightning_optimizer1 = LightningOptimizer(optimizer1)
optimizer1.load_state_dict(state_dict0)

# LightningOptimizer needs to be refreshed to see the new state
assert lightning_optimizer1.param_groups[0]["lr"] != 1.0
lightning_optimizer1.refresh()
assert lightning_optimizer1.param_groups[0]["lr"] == 1.0

# Load state into wrapped optimizer
Expand Down

0 comments on commit e958c6f

Please sign in to comment.