Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Support hrnet frozen stage #743

Merged
merged 2 commits into from
Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions mmseg/models/backbones/hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ class HRNet(BaseModule):
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
pretrained (str, optional): model pretrained path. Default: None
Expand Down Expand Up @@ -285,6 +287,7 @@ def __init__(self,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
with_cp=False,
frozen_stages=-1,
zero_init_residual=False,
pretrained=None,
init_cfg=None):
Expand Down Expand Up @@ -315,6 +318,7 @@ def __init__(self,
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.frozen_stages = frozen_stages

# stem net
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
Expand Down Expand Up @@ -388,6 +392,8 @@ def __init__(self,
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels)

self._freeze_stages()

@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
Expand Down Expand Up @@ -534,6 +540,32 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True):

return Sequential(*hr_modules), in_channels

def _freeze_stages(self):
"""Freeze stages param and norm stats."""
if self.frozen_stages >= 0:

self.norm1.eval()
self.norm2.eval()
for m in [self.conv1, self.norm1, self.conv2, self.norm2]:
for param in m.parameters():
param.requires_grad = False

for i in range(1, self.frozen_stages + 1):
if i == 1:
m = getattr(self, f'layer{i}')
t = getattr(self, f'transition{i}')
elif i == 4:
m = getattr(self, f'stage{i}')
else:
m = getattr(self, f'stage{i}')
t = getattr(self, f'transition{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
t.eval()
for param in t.parameters():
param.requires_grad = False

def forward(self, x):
"""Forward function."""

Expand Down Expand Up @@ -575,6 +607,7 @@ def train(self, mode=True):
"""Convert the model into training mode will keeping the normalization
layer freezed."""
super(HRNet, self).train(mode)
self._freeze_stages()
Copy link
Collaborator

@xvjiarui xvjiarui Aug 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this line inside train() like resnet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xvjiarui I did not find any difference between the resnet and hrnet about self._freeze_stages(). Could you give more information?

image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. My bad, I forgot to expand the collapsed code lines.

if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
Expand Down
63 changes: 63 additions & 0 deletions tests/test_models/test_backbones/test_hrnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from mmcv.utils.parrots_wrapper import _BatchNorm

from mmseg.models.backbones import HRNet


def test_hrnet_backbone():
# Test HRNET with two stage frozen

extra = dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256)))
frozen_stages = 2
model = HRNet(extra, frozen_stages=frozen_stages)
model.init_weights()
model.train()
assert model.norm1.training is False

for layer in [model.conv1, model.norm1]:
for param in layer.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
if i == 1:
layer = getattr(model, f'layer{i}')
transition = getattr(model, f'transition{i}')
elif i == 4:
layer = getattr(model, f'stage{i}')
else:
layer = getattr(model, f'stage{i}')
transition = getattr(model, f'transition{i}')

for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False

for mod in transition.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in transition.parameters():
assert param.requires_grad is False