-
Notifications
You must be signed in to change notification settings - Fork 233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add Dsnas Algorithm #226
Changes from 9 commits
6fad48c
6ded2d4
1932d82
be98e8b
b1f5c7d
8f8beef
6f4e6e8
3deb833
2247bdd
8d8d1b8
c2dbcaf
bc813f3
ce87f89
73d9c3b
5de5bc9
1e4014d
e98043a
2fbdd01
d676d93
420dcac
d6a401b
aa2e0f2
08964ce
32baf69
22cf3ed
e0bafe2
84b5367
eef0514
7258926
2dfdc79
79d69ef
b511403
91ef7d8
cfccb5e
6794188
bbc93ff
3a6b696
6df18d1
a77d894
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
_STAGE_MUTABLE = dict( | ||
type='mmrazor.OneHotMutableOP', | ||
candidates=dict( | ||
shuffle_3x3=dict(type='ShuffleBlock', kernel_size=3), | ||
shuffle_5x5=dict(type='ShuffleBlock', kernel_size=5), | ||
shuffle_7x7=dict(type='ShuffleBlock', kernel_size=7), | ||
shuffle_xception=dict(type='ShuffleXception'))) | ||
|
||
arch_setting = [ | ||
# Parameters to build layers. 3 parameters are needed to construct a | ||
# layer, from left to right: channel, num_blocks, mutable_cfg. | ||
[64, 4, _STAGE_MUTABLE], | ||
[160, 4, _STAGE_MUTABLE], | ||
[320, 8, _STAGE_MUTABLE], | ||
[640, 4, _STAGE_MUTABLE] | ||
] | ||
|
||
nas_backbone = dict( | ||
type='mmrazor.SearchableShuffleNetV2', | ||
widen_factor=1.0, | ||
arch_setting=arch_setting, | ||
adjust_channels=True) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# dataset settings | ||
dataset_type = 'mmcls.ImageNet' | ||
preprocess_cfg = dict( | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
to_rgb=True, | ||
) | ||
|
||
train_pipeline = [ | ||
dict(type='mmcls.LoadImageFromFile'), | ||
dict(type='mmcls.RandomResizedCrop', scale=224, backend='pillow'), | ||
dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'), | ||
dict(type='mmcls.PackClsInputs'), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='mmcls.LoadImageFromFile'), | ||
dict(type='mmcls.ResizeEdge', scale=256, edge='short', backend='pillow'), | ||
dict(type='mmcls.CenterCrop', crop_size=224), | ||
dict(type='mmcls.PackClsInputs'), | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=512, | ||
num_workers=15, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/imagenet', | ||
ann_file='meta/train.txt', | ||
data_prefix='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type='mmcls.DefaultSampler', shuffle=True), | ||
persistent_workers=True, | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=100, | ||
num_workers=15, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/imagenet', | ||
ann_file='meta/val.txt', | ||
data_prefix='val', | ||
pipeline=test_pipeline), | ||
sampler=dict(type='mmcls.DefaultSampler', shuffle=False), | ||
persistent_workers=True, | ||
) | ||
val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5)) | ||
|
||
# If you want standard test, please manually configure the test dataset | ||
test_dataloader = val_dataloader | ||
test_evaluator = val_evaluator | ||
|
||
# optimizer | ||
paramwise_cfg = dict( | ||
bias_decay_mult=0.0, norm_decay_mult=0.0, dwconv_decay_mult=0.0) | ||
|
||
# optimizer | ||
optim_wrapper = dict( | ||
architecture=dict( | ||
type='mmcls.SGD', lr=0.5, momentum=0.9, weight_decay=4e-5), | ||
mutator=dict(type='mmcls.Adam', lr=2e-3, weight_decay=0.0), | ||
# clip_grad=dict(max_norm=5, norm_type=2), | ||
) | ||
|
||
# leanring policy | ||
param_scheduler = dict( | ||
architecture=[ | ||
dict(type='mmcls.LinearLR', end=5, by_epoch=True, start_factor=0.0001), | ||
dict( | ||
type='mmcls.CosineAnnealingLR', | ||
T_max=240, | ||
begin=5, | ||
eta_min=0.0, | ||
by_epoch=True, | ||
), | ||
], | ||
mutator=[]) | ||
|
||
# train, val, test setting | ||
train_cfg = dict(by_epoch=True, max_epochs=240) | ||
val_cfg = dict( | ||
type='mmrazor.EvaluatorLoop', | ||
dataloader=val_dataloader, | ||
evaluator=dict( | ||
type='mmrazor.NaiveEvaluator', | ||
metrics=dict(type='mmcls.Accuracy', topk=(1, 5)), | ||
)) | ||
test_cfg = dict() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
_base_ = [ | ||
'mmrazor::_base_/settings/imagenet_bs1024_dsnas.py', | ||
'mmrazor::_base_/nas_backbones/dsnas_shufflenet_supernet.py', | ||
'mmcls::_base_/default_runtime.py', | ||
] | ||
|
||
# model | ||
model = dict( | ||
type='mmrazor.Dsnas', | ||
architecture=dict( | ||
type='ImageClassifier', | ||
backbone=_base_.nas_backbone, | ||
neck=dict(type='GlobalAveragePooling'), | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=1024, | ||
loss=dict( | ||
type='LabelSmoothLoss', | ||
num_classes=1000, | ||
label_smooth_val=0.1, | ||
mode='original', | ||
loss_weight=1.0), | ||
topk=(1, 5))), | ||
mutator=dict(type='mmrazor.DiffModuleMutator'), | ||
pretrain_epochs=0, | ||
finetune_epochs=80, | ||
) | ||
|
||
model_wrapper_cfg = dict( | ||
type='mmrazor.DsnasDDP', | ||
broadcast_buffers=False, | ||
find_unused_parameters=True) | ||
|
||
custom_hooks = [ | ||
dict(type='mmrazor.DumpSubnetHook', interval=5, max_keep_subnets=2), | ||
] | ||
|
||
# TRAINING | ||
optim_wrapper = dict( | ||
_delete_=True, | ||
constructor='mmrazor.SeparateOptimWrapperConstructor', | ||
architecture=dict( | ||
type='mmrazor.DsnasOptimWrapper', | ||
optimizer=dict(type='SGD', lr=0.5, momentum=0.9, weight_decay=4e-5)), | ||
mutator=dict(optimizer=dict(type='Adam', lr=0.001, weight_decay=0.0))) | ||
|
||
randomness = dict(seed=22, diff_rank_seed=False) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import torch | ||
|
||
from mmengine.hooks import Hook | ||
from mmengine.registry import HOOKS | ||
|
||
|
||
@HOOKS.register_module() | ||
class EarlyArchFixerHook(Hook): | ||
"""Fix the arch params early. | ||
The EarlyArchFixerHook will fix the value of the max arch param in each | ||
layer at 1 when the difference between the top-2 arch params is larger | ||
than the `threshold`. | ||
NOTE: Only supports differentiable NAS methods at present. | ||
|
||
Args: | ||
by_epoch (bool): By epoch or by iteration. | ||
Default: True. | ||
threshold (float): Threshold to judge whether to fix params or not. | ||
Default: 0.3 (in paper). | ||
""" | ||
|
||
def __init__(self, by_epoch=True, threshold=0.3, **kwargs): | ||
self.by_epoch = by_epoch | ||
self.threshold = threshold | ||
|
||
def before_train_epoch(self, runner): | ||
"""Executed in before_train_epoch stage.""" | ||
if not self.by_epoch: | ||
return | ||
|
||
model = runner.model.module | ||
mutator = model.mutator | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add type hints for mutator, because the mutator here (DiffModuleMutator) is not suitable for all kinds of mutators. |
||
|
||
if mutator.early_fix_arch: | ||
if len(mutator.fix_arch_index.keys()) > 0: | ||
for k, v in mutator.fix_arch_index.items(): | ||
mutator.arch_params[k].data = v[1] | ||
for mutable in mutator.mutables: | ||
arch_param = mutator.arch_params[mutable.key].detach().clone() | ||
# find the top-2 values of arch_params in the layer | ||
sort_arch_params = torch.topk( | ||
mutator.compute_arch_probs(arch_param), 2) | ||
argmax_index = ( | ||
sort_arch_params[0][0] - sort_arch_params[0][1] >= | ||
self.threshold) | ||
# if the max value is large enough, fix current layer. | ||
if argmax_index: | ||
if mutable.key not in mutator.fix_arch_index.keys(): | ||
mutator.fix_arch_index[mutable.key] = [ | ||
sort_arch_params[1][0].item(), arch_param | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .dsnas_optimizer_wrapper import DsnasOptimWrapper | ||
from .optimizer_constructor import SeparateOptimWrapperConstructor | ||
|
||
__all__ = ['SeparateOptimWrapperConstructor'] | ||
__all__ = ['DsnasOptimWrapper', 'SeparateOptimWrapperConstructor'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Optional | ||
|
||
import torch | ||
from torch.optim import Optimizer | ||
|
||
from mmengine.optim import OptimWrapper | ||
from mmengine.registry import OPTIM_WRAPPERS | ||
|
||
|
||
@OPTIM_WRAPPERS.register_module() | ||
class DsnasOptimWrapper(OptimWrapper): | ||
"""Optimizer wrapper provides a common interface for updating parameters. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please update the docstring of DsnasOptimWrapper and describe the difference between DsnasOptimWrapper and OptimWrapper. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DsnasOptimWrapper will be dropped, |
||
|
||
Optimizer wrapper provides a unified interface for single precision | ||
training and automatic mixed precision training with different hardware. | ||
OptimWrapper encapsulates optimizer to provide simplified interfaces | ||
for commonly used training techniques such as gradient accumulative and | ||
grad clips. ``OptimWrapper`` implements the basic logic of gradient | ||
accumulation and gradient clipping based on ``torch.optim.Optimizer``. | ||
The subclasses only need to override some methods to implement the mixed | ||
precision training. See more information in :class:`AmpOptimWrapper`. | ||
|
||
Args: | ||
optimizer (Optimizer): Optimizer used to update model parameters. | ||
accumulative_counts (int): The number of iterations to accumulate | ||
gradients. The parameters will be updated per | ||
``accumulative_counts``. | ||
clip_grad (dict, optional): If ``clip_grad`` is not None, it will be | ||
the arguments of ``torch.nn.utils.clip_grad``. | ||
|
||
Note: | ||
If ``accumulative_counts`` is larger than 1, perform | ||
:meth:`update_params` under the context of ``optim_context`` | ||
could avoid unnecessary gradient synchronization. | ||
|
||
Note: | ||
If you use ``IterBasedRunner`` and enable gradient accumulation, | ||
the original `max_iters` should be multiplied by | ||
``accumulative_counts``. | ||
|
||
Note: | ||
The subclass should ensure that once :meth:`update_params` is called, | ||
``_inner_count += 1`` is automatically performed. | ||
|
||
Examples: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Delete the Examples or add new examples for DsnasOptimWrapper. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DsnasOptimWrapper is dropped. |
||
>>> # Config sample of OptimWrapper. | ||
>>> optim_wrapper_cfg = dict( | ||
>>> type='OptimWrapper', | ||
>>> _accumulative_counts=1, | ||
>>> clip_grad=dict(max_norm=0.2)) | ||
>>> # Use OptimWrapper to update model. | ||
>>> import torch.nn as nn | ||
>>> import torch | ||
>>> from torch.optim import SGD | ||
>>> from torch.utils.data import DataLoader | ||
>>> from mmengine.optim import OptimWrapper | ||
>>> | ||
>>> model = nn.Linear(1, 1) | ||
>>> dataset = torch.randn(10, 1, 1) | ||
>>> dataloader = DataLoader(dataset) | ||
>>> optimizer = SGD(model.parameters(), lr=0.1) | ||
>>> optim_wrapper = OptimWrapper(optimizer) | ||
>>> | ||
>>> for data in dataloader: | ||
>>> loss = model(data) | ||
>>> optim_wrapper.update_params(loss) | ||
>>> # Enable gradient accumulation | ||
>>> optim_wrapper_cfg = dict( | ||
>>> type='OptimWrapper', | ||
>>> _accumulative_counts=3, | ||
>>> clip_grad=dict(max_norm=0.2)) | ||
>>> ddp_model = DistributedDataParallel(model) | ||
>>> optimizer = SGD(ddp_model.parameters(), lr=0.1) | ||
>>> optim_wrapper = OptimWrapper(optimizer) | ||
>>> optim_wrapper.initialize_count_status(0, len(dataloader)) | ||
>>> # If model is a subclass instance of DistributedDataParallel, | ||
>>> # `optim_context` context manager can avoid unnecessary gradient | ||
>>> # synchronize. | ||
>>> for iter, data in enumerate(dataloader): | ||
>>> with optim_wrapper.optim_context(ddp_model): | ||
>>> loss = model(data) | ||
>>> optim_wrapper.update_params(loss) | ||
""" | ||
|
||
def __init__(self, | ||
optimizer: Optimizer, | ||
accumulative_counts: int = 1, | ||
clip_grad: Optional[dict] = None): | ||
super().__init__( | ||
optimizer, | ||
accumulative_counts=accumulative_counts, | ||
clip_grad=clip_grad, | ||
) | ||
|
||
def update_params(self, | ||
loss: torch.Tensor, | ||
retain_graph: bool = False) -> None: | ||
"""Update parameters in :attr:`optimizer`. | ||
|
||
Args: | ||
loss (torch.Tensor): A tensor for back propagation. | ||
""" | ||
loss = self.scale_loss(loss) | ||
self.backward(loss, retain_graph=retain_graph) | ||
# Update parameters only if `self._inner_count` is divisible by | ||
# `self._accumulative_counts` or `self._inner_count` equals to | ||
# `self._max_counts` | ||
if self.should_update(): | ||
self.step() | ||
self.zero_grad() | ||
|
||
def backward(self, loss: torch.Tensor, retain_graph: bool = False) -> None: | ||
"""Perform gradient back propagation. | ||
|
||
Provide unified ``backward`` interface compatible with automatic mixed | ||
precision training. Subclass can overload this method to implement the | ||
required logic. For example, ``torch.cuda.amp`` require some extra | ||
operation on GradScaler during backward process. | ||
|
||
Note: | ||
If subclasses inherit from ``OptimWrapper`` override | ||
``backward``, ``_inner_count +=1`` must be implemented. | ||
|
||
Args: | ||
loss (torch.Tensor): The loss of current iteration. | ||
""" | ||
loss.backward(retain_graph=retain_graph) | ||
self._inner_count += 1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .base import BaseAlgorithm | ||
from .distill import FpnTeacherDistill, SingleTeacherDistill | ||
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP | ||
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP, Dsnas, DsnasDDP | ||
from .pruning import SlimmableNetwork, SlimmableNetworkDDP | ||
|
||
__all__ = [ | ||
'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS', | ||
'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP', | ||
'Darts', 'DartsDDP' | ||
'Darts', 'DartsDDP', 'Dsnas', 'DsnasDDP' | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,9 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .autoslim import AutoSlim, AutoSlimDDP | ||
from .darts import Darts, DartsDDP | ||
from .dsnas import Dsnas, DsnasDDP | ||
from .spos import SPOS | ||
|
||
__all__ = ['SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP'] | ||
__all__ = [ | ||
'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'Dsnas', 'DsnasDDP' | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it normal to set pretrain_epochs to zero?