Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Dsnas Algorithm #226

Merged
merged 39 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6fad48c
[tmp] Update Dsnas
Aug 1, 2022
6ded2d4
[tmp] refactor arch_loss & flops_loss
Aug 5, 2022
1932d82
Update Dsnas & MMRAZOR_EVALUATOR:
Aug 10, 2022
be98e8b
Merge branch 'dev-1.x' into gy/dsnas
Aug 11, 2022
b1f5c7d
Update lr scheduler & fix a bug:
Aug 12, 2022
8f8beef
remove old evaluators
Aug 16, 2022
6f4e6e8
Merge branch 'dev-1.x' of github.com:open-mmlab/mmrazor into gy/dsnas
Aug 16, 2022
3deb833
remove old evaluators
Aug 16, 2022
2247bdd
update param_scheduler config
Aug 16, 2022
8d8d1b8
merge dev-1.x into gy/estimator
Aug 23, 2022
c2dbcaf
add flops_loss in Dsnas using ResourcesEstimator
Aug 23, 2022
bc813f3
get resources before mutator.prepare_from_supernet
Aug 25, 2022
ce87f89
delete unness broadcast api from gml
Aug 25, 2022
73d9c3b
broadcast spec_modules_resources when estimating
Aug 25, 2022
5de5bc9
update early fix mechanism for Dsnas
Aug 25, 2022
1e4014d
merge dev-1.x into gy/dsnas
Aug 25, 2022
e98043a
fix merge
Aug 25, 2022
2fbdd01
update units in estimator
Aug 26, 2022
d676d93
minor change
Aug 26, 2022
420dcac
merge dev-1.x into gy/dsnas
Aug 29, 2022
d6a401b
fix data_preprocessor api
Aug 31, 2022
aa2e0f2
add flops_loss_coef
Aug 31, 2022
08964ce
remove DsnasOptimWrapper
Aug 31, 2022
32baf69
fix bn eps and data_preprocessor
Aug 31, 2022
22cf3ed
fix bn weight decay bug
Sep 6, 2022
e0bafe2
add betas for mutator optimizer
Sep 6, 2022
84b5367
set diff_rank_seed=True for dsnas
Sep 6, 2022
eef0514
fix start_factor of lr when warm up
Sep 6, 2022
7258926
remove .module in non-ddp mode
Sep 6, 2022
2dfdc79
add GlobalAveragePoolingWithDropout
Sep 6, 2022
79d69ef
add UT for dsnas
Sep 6, 2022
b511403
remove unness channel adjustment for shufflenetv2
Sep 7, 2022
91ef7d8
update supernet configs
Sep 26, 2022
cfccb5e
delete unness dropout
Sep 26, 2022
6794188
delete unness part with minor change on dsnas
Sep 26, 2022
bbc93ff
merge dev-1.x into gy/dsnas
Sep 29, 2022
3a6b696
minor change on the flag of search stage
Sep 29, 2022
6df18d1
update README and subnet configs
Sep 29, 2022
a77d894
add UT for OneHotMutableOP
Sep 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions configs/_base_/nas_backbones/dsnas_shufflenet_supernet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
_STAGE_MUTABLE = dict(
type='mmrazor.OneHotMutableOP',
candidates=dict(
shuffle_3x3=dict(type='ShuffleBlock', kernel_size=3),
shuffle_5x5=dict(type='ShuffleBlock', kernel_size=5),
shuffle_7x7=dict(type='ShuffleBlock', kernel_size=7),
shuffle_xception=dict(type='ShuffleXception')))

arch_setting = [
# Parameters to build layers. 3 parameters are needed to construct a
# layer, from left to right: channel, num_blocks, mutable_cfg.
[64, 4, _STAGE_MUTABLE],
[160, 4, _STAGE_MUTABLE],
[320, 8, _STAGE_MUTABLE],
[640, 4, _STAGE_MUTABLE]
]

nas_backbone = dict(
type='mmrazor.SearchableShuffleNetV2',
widen_factor=1.0,
arch_setting=arch_setting,
adjust_channels=True)
91 changes: 91 additions & 0 deletions configs/_base_/settings/imagenet_bs1024_dsnas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# dataset settings
dataset_type = 'mmcls.ImageNet'
preprocess_cfg = dict(
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)

train_pipeline = [
dict(type='mmcls.LoadImageFromFile'),
dict(type='mmcls.RandomResizedCrop', scale=224, backend='pillow'),
dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'),
dict(type='mmcls.PackClsInputs'),
]

test_pipeline = [
dict(type='mmcls.LoadImageFromFile'),
dict(type='mmcls.ResizeEdge', scale=256, edge='short', backend='pillow'),
dict(type='mmcls.CenterCrop', crop_size=224),
dict(type='mmcls.PackClsInputs'),
]

train_dataloader = dict(
batch_size=512,
num_workers=15,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='mmcls.DefaultSampler', shuffle=True),
persistent_workers=True,
)

val_dataloader = dict(
batch_size=100,
num_workers=15,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='mmcls.DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5))

# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

# optimizer
paramwise_cfg = dict(
bias_decay_mult=0.0, norm_decay_mult=0.0, dwconv_decay_mult=0.0)

# optimizer
optim_wrapper = dict(
architecture=dict(
type='mmcls.SGD', lr=0.5, momentum=0.9, weight_decay=4e-5),
mutator=dict(type='mmcls.Adam', lr=2e-3, weight_decay=0.0),
# clip_grad=dict(max_norm=5, norm_type=2),
)

# leanring policy
param_scheduler = dict(
architecture=[
dict(type='mmcls.LinearLR', end=5, by_epoch=True, start_factor=0.0001),
dict(
type='mmcls.CosineAnnealingLR',
T_max=240,
begin=5,
eta_min=0.0,
by_epoch=True,
),
],
mutator=[])

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=240)
val_cfg = dict(
type='mmrazor.EvaluatorLoop',
dataloader=val_dataloader,
evaluator=dict(
type='mmrazor.NaiveEvaluator',
metrics=dict(type='mmcls.Accuracy', topk=(1, 5)),
))
test_cfg = dict()
48 changes: 48 additions & 0 deletions configs/nas/mmcls/dsnas/dsnas_supernet_8xb512_imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
_base_ = [
'mmrazor::_base_/settings/imagenet_bs1024_dsnas.py',
'mmrazor::_base_/nas_backbones/dsnas_shufflenet_supernet.py',
'mmcls::_base_/default_runtime.py',
]

# model
model = dict(
type='mmrazor.Dsnas',
architecture=dict(
type='ImageClassifier',
backbone=_base_.nas_backbone,
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
loss=dict(
type='LabelSmoothLoss',
num_classes=1000,
label_smooth_val=0.1,
mode='original',
loss_weight=1.0),
topk=(1, 5))),
mutator=dict(type='mmrazor.DiffModuleMutator'),
pretrain_epochs=0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it normal to set pretrain_epochs to zero?

finetune_epochs=80,
)

model_wrapper_cfg = dict(
type='mmrazor.DsnasDDP',
broadcast_buffers=False,
find_unused_parameters=True)

custom_hooks = [
dict(type='mmrazor.DumpSubnetHook', interval=5, max_keep_subnets=2),
]

# TRAINING
optim_wrapper = dict(
_delete_=True,
constructor='mmrazor.SeparateOptimWrapperConstructor',
architecture=dict(
type='mmrazor.DsnasOptimWrapper',
optimizer=dict(type='SGD', lr=0.5, momentum=0.9, weight_decay=4e-5)),
mutator=dict(optimizer=dict(type='Adam', lr=0.001, weight_decay=0.0)))

randomness = dict(seed=22, diff_rank_seed=False)
4 changes: 2 additions & 2 deletions mmrazor/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hooks import DumpSubnetHook
from .optimizers import SeparateOptimWrapperConstructor
from .optimizers import DsnasOptimWrapper, SeparateOptimWrapperConstructor
from .runner import (AutoSlimValLoop, DartsEpochBasedTrainLoop,
DartsIterBasedTrainLoop, EvolutionSearchLoop,
GreedySamplerTrainLoop, SingleTeacherDistillValLoop,
Expand All @@ -10,5 +10,5 @@
'SeparateOptimWrapperConstructor', 'DumpSubnetHook',
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'AutoSlimValLoop'
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'DsnasOptimWrapper'
]
52 changes: 52 additions & 0 deletions mmrazor/engine/hooks/early_fix_arch_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmengine.hooks import Hook
from mmengine.registry import HOOKS


@HOOKS.register_module()
class EarlyArchFixerHook(Hook):
"""Fix the arch params early.
The EarlyArchFixerHook will fix the value of the max arch param in each
layer at 1 when the difference between the top-2 arch params is larger
than the `threshold`.
NOTE: Only supports differentiable NAS methods at present.

Args:
by_epoch (bool): By epoch or by iteration.
Default: True.
threshold (float): Threshold to judge whether to fix params or not.
Default: 0.3 (in paper).
"""

def __init__(self, by_epoch=True, threshold=0.3, **kwargs):
self.by_epoch = by_epoch
self.threshold = threshold

def before_train_epoch(self, runner):
"""Executed in before_train_epoch stage."""
if not self.by_epoch:
return

model = runner.model.module
mutator = model.mutator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add type hints for mutator, because the mutator here (DiffModuleMutator) is not suitable for all kinds of mutators.


if mutator.early_fix_arch:
if len(mutator.fix_arch_index.keys()) > 0:
for k, v in mutator.fix_arch_index.items():
mutator.arch_params[k].data = v[1]
for mutable in mutator.mutables:
arch_param = mutator.arch_params[mutable.key].detach().clone()
# find the top-2 values of arch_params in the layer
sort_arch_params = torch.topk(
mutator.compute_arch_probs(arch_param), 2)
argmax_index = (
sort_arch_params[0][0] - sort_arch_params[0][1] >=
self.threshold)
# if the max value is large enough, fix current layer.
if argmax_index:
if mutable.key not in mutator.fix_arch_index.keys():
mutator.fix_arch_index[mutable.key] = [
sort_arch_params[1][0].item(), arch_param
]
3 changes: 2 additions & 1 deletion mmrazor/engine/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dsnas_optimizer_wrapper import DsnasOptimWrapper
from .optimizer_constructor import SeparateOptimWrapperConstructor

__all__ = ['SeparateOptimWrapperConstructor']
__all__ = ['DsnasOptimWrapper', 'SeparateOptimWrapperConstructor']
129 changes: 129 additions & 0 deletions mmrazor/engine/optimizers/dsnas_optimizer_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
from torch.optim import Optimizer

from mmengine.optim import OptimWrapper
from mmengine.registry import OPTIM_WRAPPERS


@OPTIM_WRAPPERS.register_module()
class DsnasOptimWrapper(OptimWrapper):
"""Optimizer wrapper provides a common interface for updating parameters.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the docstring of DsnasOptimWrapper and describe the difference between DsnasOptimWrapper and OptimWrapper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DsnasOptimWrapper will be dropped, retain_graph=True is supposed to be implemented in mmengine


Optimizer wrapper provides a unified interface for single precision
training and automatic mixed precision training with different hardware.
OptimWrapper encapsulates optimizer to provide simplified interfaces
for commonly used training techniques such as gradient accumulative and
grad clips. ``OptimWrapper`` implements the basic logic of gradient
accumulation and gradient clipping based on ``torch.optim.Optimizer``.
The subclasses only need to override some methods to implement the mixed
precision training. See more information in :class:`AmpOptimWrapper`.

Args:
optimizer (Optimizer): Optimizer used to update model parameters.
accumulative_counts (int): The number of iterations to accumulate
gradients. The parameters will be updated per
``accumulative_counts``.
clip_grad (dict, optional): If ``clip_grad`` is not None, it will be
the arguments of ``torch.nn.utils.clip_grad``.

Note:
If ``accumulative_counts`` is larger than 1, perform
:meth:`update_params` under the context of ``optim_context``
could avoid unnecessary gradient synchronization.

Note:
If you use ``IterBasedRunner`` and enable gradient accumulation,
the original `max_iters` should be multiplied by
``accumulative_counts``.

Note:
The subclass should ensure that once :meth:`update_params` is called,
``_inner_count += 1`` is automatically performed.

Examples:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete the Examples or add new examples for DsnasOptimWrapper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DsnasOptimWrapper is dropped.

>>> # Config sample of OptimWrapper.
>>> optim_wrapper_cfg = dict(
>>> type='OptimWrapper',
>>> _accumulative_counts=1,
>>> clip_grad=dict(max_norm=0.2))
>>> # Use OptimWrapper to update model.
>>> import torch.nn as nn
>>> import torch
>>> from torch.optim import SGD
>>> from torch.utils.data import DataLoader
>>> from mmengine.optim import OptimWrapper
>>>
>>> model = nn.Linear(1, 1)
>>> dataset = torch.randn(10, 1, 1)
>>> dataloader = DataLoader(dataset)
>>> optimizer = SGD(model.parameters(), lr=0.1)
>>> optim_wrapper = OptimWrapper(optimizer)
>>>
>>> for data in dataloader:
>>> loss = model(data)
>>> optim_wrapper.update_params(loss)
>>> # Enable gradient accumulation
>>> optim_wrapper_cfg = dict(
>>> type='OptimWrapper',
>>> _accumulative_counts=3,
>>> clip_grad=dict(max_norm=0.2))
>>> ddp_model = DistributedDataParallel(model)
>>> optimizer = SGD(ddp_model.parameters(), lr=0.1)
>>> optim_wrapper = OptimWrapper(optimizer)
>>> optim_wrapper.initialize_count_status(0, len(dataloader))
>>> # If model is a subclass instance of DistributedDataParallel,
>>> # `optim_context` context manager can avoid unnecessary gradient
>>> # synchronize.
>>> for iter, data in enumerate(dataloader):
>>> with optim_wrapper.optim_context(ddp_model):
>>> loss = model(data)
>>> optim_wrapper.update_params(loss)
"""

def __init__(self,
optimizer: Optimizer,
accumulative_counts: int = 1,
clip_grad: Optional[dict] = None):
super().__init__(
optimizer,
accumulative_counts=accumulative_counts,
clip_grad=clip_grad,
)

def update_params(self,
loss: torch.Tensor,
retain_graph: bool = False) -> None:
"""Update parameters in :attr:`optimizer`.

Args:
loss (torch.Tensor): A tensor for back propagation.
"""
loss = self.scale_loss(loss)
self.backward(loss, retain_graph=retain_graph)
# Update parameters only if `self._inner_count` is divisible by
# `self._accumulative_counts` or `self._inner_count` equals to
# `self._max_counts`
if self.should_update():
self.step()
self.zero_grad()

def backward(self, loss: torch.Tensor, retain_graph: bool = False) -> None:
"""Perform gradient back propagation.

Provide unified ``backward`` interface compatible with automatic mixed
precision training. Subclass can overload this method to implement the
required logic. For example, ``torch.cuda.amp`` require some extra
operation on GradScaler during backward process.

Note:
If subclasses inherit from ``OptimWrapper`` override
``backward``, ``_inner_count +=1`` must be implemented.

Args:
loss (torch.Tensor): The loss of current iteration.
"""
loss.backward(retain_graph=retain_graph)
self._inner_count += 1
4 changes: 2 additions & 2 deletions mmrazor/models/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseAlgorithm
from .distill import FpnTeacherDistill, SingleTeacherDistill
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP, Dsnas, DsnasDDP
from .pruning import SlimmableNetwork, SlimmableNetworkDDP

__all__ = [
'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS',
'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP',
'Darts', 'DartsDDP'
'Darts', 'DartsDDP', 'Dsnas', 'DsnasDDP'
]
5 changes: 4 additions & 1 deletion mmrazor/models/algorithms/nas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .autoslim import AutoSlim, AutoSlimDDP
from .darts import Darts, DartsDDP
from .dsnas import Dsnas, DsnasDDP
from .spos import SPOS

__all__ = ['SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP']
__all__ = [
'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'Dsnas', 'DsnasDDP'
]
Loading