Skip to content

Commit

Permalink
[Feature] Add Lion optimizer (open-mmlab#952)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida authored Feb 23, 2023
1 parent 25dfe41 commit fc9518e
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/en/tutorials/optim_wrapper.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ As shown in the above example, `OptimWrapperDict` exports learning rates and mom

### Configure the OptimWapper in [Runner](runner.md)

We first need to configure the `optimizer` for the OptimWrapper. MMEngine automatically adds all optimizers in PyTorch to the `OPTIMIZERS` registry, and users can specify the optimizers they need in the form of a `dict`. All supported optimizers in PyTorch are listed [here](https://pytorch.org/docs/stable/optim.html#algorithms). In addition, 'DAdaptAdaGrad', 'DAdaptAdam', and 'DAdaptSGD' can be used by installing [dadaptation](https://github.com/facebookresearch/dadaptation).
We first need to configure the `optimizer` for the OptimWrapper. MMEngine automatically adds all optimizers in PyTorch to the `OPTIMIZERS` registry, and users can specify the optimizers they need in the form of a `dict`. All supported optimizers in PyTorch are listed [here](https://pytorch.org/docs/stable/optim.html#algorithms). In addition, `DAdaptAdaGrad`, `DAdaptAdam`, and `DAdaptSGD` can be used by installing [dadaptation](https://github.com/facebookresearch/dadaptation). `Lion` optimizer can used by install [lion-pytorch](https://github.com/lucidrains/lion-pytorch)

Now we take setting up a SGD OptimWrapper as an example.

Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/tutorials/optim_wrapper.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ print(optim_dict.get_momentum()) # {'gen.momentum': [0], 'disc.momentum': [0]}

### [执行器](./runner.md)中配置优化器封装

优化器封装需要接受 `optimizer` 参数,因此我们首先需要为优化器封装配置 `optimizer`。MMEngine 会自动将 PyTorch 中的所有优化器都添加进 `OPTIMIZERS` 注册表中,用户可以用字典的形式来指定优化器,所有支持的优化器见 [PyTorch 优化器列表](https://pytorch.org/docs/stable/optim.html#algorithms)
优化器封装需要接受 `optimizer` 参数,因此我们首先需要为优化器封装配置 `optimizer`。MMEngine 会自动将 PyTorch 中的所有优化器都添加进 `OPTIMIZERS` 注册表中,用户可以用字典的形式来指定优化器,所有支持的优化器见 [PyTorch 优化器列表](https://pytorch.org/docs/stable/optim.html#algorithms)另外,可以通过安装 [dadaptation](https://github.com/facebookresearch/dadaptation) 使用 `DAdaptAdaGrad``DAdaptAdam``DAdaptSGD` 3 个优化器。也可以通过安装 [lion-pytorch](https://github.com/lucidrains/lion-pytorch) 使用 `Lion` 优化器。

以配置一个 SGD 优化器封装为例:

Expand Down
20 changes: 20 additions & 0 deletions mmengine/optim/optimizer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,26 @@ def register_dadaptation_optimizers() -> List[str]:
DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers()


def register_lion_optimizers() -> List[str]:
"""Register Lion optimizer to the ``OPTIMIZERS`` registry.
Returns:
List[str]: A list of registered optimizers' name.
"""
optimizers = []
try:
from lion_pytorch import Lion
except ImportError:
pass
else:
OPTIMIZERS.register_module(module=Lion)
optimizers.append('Lion')
return optimizers


LION_OPTIMIZERS = register_lion_optimizers()


def build_optim_wrapper(model: nn.Module,
cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper:
"""Build function of OptimWrapper.
Expand Down
2 changes: 2 additions & 0 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
coverage
dadaptation
lion-pytorch
lmdb
parameterized
pytest
13 changes: 13 additions & 0 deletions tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DefaultOptimWrapperConstructor, OptimWrapper,
build_optim_wrapper)
from mmengine.optim.optimizer.builder import (DADAPTATION_OPTIMIZERS,
LION_OPTIMIZERS,
TORCH_OPTIMIZERS)
from mmengine.registry import build_from_cfg
from mmengine.testing._internal import MultiProcessTestCase
Expand All @@ -34,6 +35,14 @@ def has_dadaptation() -> bool:
return False


def has_lion() -> bool:
try:
import lion_pytorch # noqa: F401
return True
except ImportError:
return False


class ExampleModel(nn.Module):

def __init__(self):
Expand Down Expand Up @@ -221,6 +230,10 @@ def test_dadaptation_optimizers(self):
assert set(dadaptation_optimizers).issubset(
set(DADAPTATION_OPTIMIZERS))

@unittest.skipIf(not has_lion(), 'lion-pytorch is not installed')
def test_lion_optimizers(self):
assert 'Lion' in LION_OPTIMIZERS

def test_build_optimizer(self):
# test build function without ``constructor`` and ``paramwise_cfg``
optim_wrapper_cfg = dict(
Expand Down

0 comments on commit fc9518e

Please sign in to comment.