diff --git a/mmseg/models/backbones/hrnet.py b/mmseg/models/backbones/hrnet.py index 055fc985bb..0f064cff7d 100644 --- a/mmseg/models/backbones/hrnet.py +++ b/mmseg/models/backbones/hrnet.py @@ -230,6 +230,8 @@ class HRNet(BaseModule): and its variants only. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. zero_init_residual (bool): whether to use zero init for last norm layer in resblocks to let them behave as identity. pretrained (str, optional): model pretrained path. Default: None @@ -285,6 +287,7 @@ def __init__(self, norm_cfg=dict(type='BN', requires_grad=True), norm_eval=False, with_cp=False, + frozen_stages=-1, zero_init_residual=False, pretrained=None, init_cfg=None): @@ -315,6 +318,7 @@ def __init__(self, self.norm_cfg = norm_cfg self.norm_eval = norm_eval self.with_cp = with_cp + self.frozen_stages = frozen_stages # stem net self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) @@ -388,6 +392,8 @@ def __init__(self, self.stage4, pre_stage_channels = self._make_stage( self.stage4_cfg, num_channels) + self._freeze_stages() + @property def norm1(self): """nn.Module: the normalization layer named "norm1" """ @@ -534,6 +540,32 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True): return Sequential(*hr_modules), in_channels + def _freeze_stages(self): + """Freeze stages param and norm stats.""" + if self.frozen_stages >= 0: + + self.norm1.eval() + self.norm2.eval() + for m in [self.conv1, self.norm1, self.conv2, self.norm2]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + if i == 1: + m = getattr(self, f'layer{i}') + t = getattr(self, f'transition{i}') + elif i == 4: + m = getattr(self, f'stage{i}') + else: + m = getattr(self, f'stage{i}') + t = getattr(self, f'transition{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + t.eval() + for param in t.parameters(): + param.requires_grad = False + def forward(self, x): """Forward function.""" @@ -575,6 +607,7 @@ def train(self, mode=True): """Convert the model into training mode will keeping the normalization layer freezed.""" super(HRNet, self).train(mode) + self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only diff --git a/tests/test_models/test_backbones/test_hrnet.py b/tests/test_models/test_backbones/test_hrnet.py new file mode 100644 index 0000000000..81611a0d11 --- /dev/null +++ b/tests/test_models/test_backbones/test_hrnet.py @@ -0,0 +1,63 @@ +from mmcv.utils.parrots_wrapper import _BatchNorm + +from mmseg.models.backbones import HRNet + + +def test_hrnet_backbone(): + # Test HRNET with two stage frozen + + extra = dict( + stage1=dict( + num_modules=1, + num_branches=1, + block='BOTTLENECK', + num_blocks=(4, ), + num_channels=(64, )), + stage2=dict( + num_modules=1, + num_branches=2, + block='BASIC', + num_blocks=(4, 4), + num_channels=(32, 64)), + stage3=dict( + num_modules=4, + num_branches=3, + block='BASIC', + num_blocks=(4, 4, 4), + num_channels=(32, 64, 128)), + stage4=dict( + num_modules=3, + num_branches=4, + block='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(32, 64, 128, 256))) + frozen_stages = 2 + model = HRNet(extra, frozen_stages=frozen_stages) + model.init_weights() + model.train() + assert model.norm1.training is False + + for layer in [model.conv1, model.norm1]: + for param in layer.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + if i == 1: + layer = getattr(model, f'layer{i}') + transition = getattr(model, f'transition{i}') + elif i == 4: + layer = getattr(model, f'stage{i}') + else: + layer = getattr(model, f'stage{i}') + transition = getattr(model, f'transition{i}') + + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False + + for mod in transition.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in transition.parameters(): + assert param.requires_grad is False