diff --git a/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vipnas_coco.md b/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vipnas_coco.md
new file mode 100644
index 0000000000..b42a872809
--- /dev/null
+++ b/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vipnas_coco.md
@@ -0,0 +1,39 @@
+
+
+
+ViPNAS (CVPR'2021)
+
+```bibtex
+@article{xu2021vipnas,
+ title={ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search},
+ author={Xu, Lumin and Guan, Yingda and Jin, Sheng and Liu, Wentao and Qian, Chen and Luo, Ping and Ouyang, Wanli and Wang, Xiaogang},
+ booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
+ year={2021}
+}
+```
+
+
+
+
+
+
+COCO (ECCV'2014)
+
+```bibtex
+@inproceedings{lin2014microsoft,
+ title={Microsoft coco: Common objects in context},
+ author={Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence},
+ booktitle={European conference on computer vision},
+ pages={740--755},
+ year={2014},
+ organization={Springer}
+}
+```
+
+
+
+Results on COCO val2017 with detector having human AP of 56.4 on COCO val2017 dataset
+
+| Arch | Input Size | AP | AP50 | AP75 | AR | AR50 | ckpt | log |
+| :-------------- | :-----------: | :------: | :------: | :------: | :------: | :------: |:------: |:------: |
+| [S-VipNAS-Res50](/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/s_vipnas_res50_coco_256x192.py) | 256x192 | 0.711 | 0.893 | 0.789 | 0.769 | 0.769 | [ckpt](https://download.openmmlab.com/mmpose/top_down/vipnas/vipnas_res50_coco_256x192-cc43b466_20210624.pth) | [log](https://download.openmmlab.com/mmpose/top_down/vipnas/vipnas_res50_coco_256x192_20210624.log.json) |
diff --git a/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vipnas_res50_coco_256x192.py b/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vipnas_res50_coco_256x192.py
new file mode 100644
index 0000000000..14c56e869b
--- /dev/null
+++ b/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/vipnas_res50_coco_256x192.py
@@ -0,0 +1,142 @@
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=10)
+evaluation = dict(interval=10, metric='mAP', key_indicator='AP')
+
+optimizer = dict(
+ type='Adam',
+ lr=5e-4,
+)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=0.001,
+ step=[170, 200])
+total_epochs = 210
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ # dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=17,
+ dataset_joints=17,
+ dataset_channel=[
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
+ ],
+ inference_channel=[
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
+ ])
+
+# model settings
+model = dict(
+ type='TopDown',
+ pretrained=None,
+ backbone=dict(type='ViPNAS_ResNet', depth=50),
+ keypoint_head=dict(
+ type='ViPNASHeatmapSimpleHead',
+ in_channels=608,
+ out_channels=channel_cfg['num_output_channels'],
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=True,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[192, 256],
+ heatmap_size=[48, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'],
+ soft_nms=False,
+ nms_thr=1.0,
+ oks_thr=0.9,
+ vis_thr=0.2,
+ use_gt_bbox=False,
+ det_bbox_thr=0.0,
+ bbox_file='data/coco/person_detection_results/'
+ 'COCO_val2017_detections_AP_H_56_person.json',
+)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownRandomFlip', flip_prob=0.5),
+ dict(
+ type='TopDownHalfBodyTransform',
+ num_joints_half_body=8,
+ prob_half_body=0.3),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=30,
+ scale_factor=0.25),
+ dict(type='TopDownAffine'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTarget', sigma=2),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs'
+ ]),
+]
+
+val_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownAffine'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'image_file', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs'
+ ]),
+]
+
+test_pipeline = val_pipeline
+
+data_root = 'data/coco'
+data = dict(
+ samples_per_gpu=64,
+ workers_per_gpu=2,
+ val_dataloader=dict(samples_per_gpu=32),
+ test_dataloader=dict(samples_per_gpu=32),
+ train=dict(
+ type='TopDownCocoDataset',
+ ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
+ img_prefix=f'{data_root}/train2017/',
+ data_cfg=data_cfg,
+ pipeline=train_pipeline),
+ val=dict(
+ type='TopDownCocoDataset',
+ ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
+ data_cfg=data_cfg,
+ pipeline=val_pipeline),
+ test=dict(
+ type='TopDownCocoDataset',
+ ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
+ data_cfg=data_cfg,
+ pipeline=val_pipeline),
+)
diff --git a/docs/papers/backbones/vipnas.md b/docs/papers/backbones/vipnas.md
new file mode 100644
index 0000000000..d6a3bac47a
--- /dev/null
+++ b/docs/papers/backbones/vipnas.md
@@ -0,0 +1,19 @@
+# ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search
+
+## Introduction
+
+
+
+
+ViPNAS (CVPR'2021)
+
+```bibtex
+@article{xu2021vipnas,
+ title={ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search},
+ author={Xu, Lumin and Guan, Yingda and Jin, Sheng and Liu, Wentao and Qian, Chen and Luo, Ping and Ouyang, Wanli and Wang, Xiaogang},
+ booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
+ year={2021}
+}
+```
+
+
diff --git a/mmpose/core/evaluation/top_down_eval.py b/mmpose/core/evaluation/top_down_eval.py
index c2a6530312..d44056e865 100644
--- a/mmpose/core/evaluation/top_down_eval.py
+++ b/mmpose/core/evaluation/top_down_eval.py
@@ -558,7 +558,9 @@ def keypoints_from_heatmaps(heatmaps,
N, K, H, W = heatmaps.shape
if use_udp:
- assert target_type in ['GaussianHeatMap', 'CombinedTarget']
+ assert target_type.lower() in [
+ 'GaussianHeatMap'.lower(), 'CombinedTarget'.lower()
+ ]
if target_type == 'GaussianHeatMap':
preds, maxvals = _get_max_preds(heatmaps)
preds = post_dark_udp(preds, heatmaps, kernel=kernel)
diff --git a/mmpose/core/post_processing/post_transforms.py b/mmpose/core/post_processing/post_transforms.py
index ba6594f778..2d32884686 100644
--- a/mmpose/core/post_processing/post_transforms.py
+++ b/mmpose/core/post_processing/post_transforms.py
@@ -126,7 +126,8 @@ def flip_back(output_flipped, flip_pairs, target_type='GaussianHeatMap'):
"""
assert output_flipped.ndim == 4, \
'output_flipped should be [batch_size, num_keypoints, height, width]'
- assert target_type in ('GaussianHeatMap', 'CombinedTarget')
+ assert target_type.lower() in ('GaussianHeatMap'.lower(),
+ 'CombinedTarget'.lower())
shape_ori = output_flipped.shape
channels = 1
if target_type == 'CombinedTarget':
diff --git a/mmpose/datasets/pipelines/top_down_transform.py b/mmpose/datasets/pipelines/top_down_transform.py
index 124aa3de74..2d3cc34647 100644
--- a/mmpose/datasets/pipelines/top_down_transform.py
+++ b/mmpose/datasets/pipelines/top_down_transform.py
@@ -434,7 +434,9 @@ def _udp_generate_target(self, cfg, joints_3d, joints_3d_visible, factor,
target_weight = np.ones((num_joints, 1), dtype=np.float32)
target_weight[:, 0] = joints_3d_visible[:, 0]
- assert target_type in ['GaussianHeatMap', 'CombinedTarget']
+ assert target_type.lower() in [
+ 'GaussianHeatMap'.lower(), 'CombinedTarget'.lower()
+ ]
if target_type == 'GaussianHeatMap':
target = np.zeros((num_joints, heatmap_size[1], heatmap_size[0]),
diff --git a/mmpose/models/backbones/__init__.py b/mmpose/models/backbones/__init__.py
index eb3959d495..f772c7a3d9 100644
--- a/mmpose/models/backbones/__init__.py
+++ b/mmpose/models/backbones/__init__.py
@@ -17,10 +17,11 @@
from .shufflenet_v2 import ShuffleNetV2
from .tcn import TCN
from .vgg import VGG
+from .vipnas_resnet import ViPNAS_ResNet
__all__ = [
'AlexNet', 'HourglassNet', 'HRNet', 'MobileNetV2', 'MobileNetV3', 'RegNet',
'ResNet', 'ResNetV1d', 'ResNeXt', 'SCNet', 'SEResNet', 'SEResNeXt',
'ShuffleNetV1', 'ShuffleNetV2', 'CPM', 'RSN', 'MSPN', 'ResNeSt', 'VGG',
- 'TCN'
+ 'TCN', 'ViPNAS_ResNet'
]
diff --git a/mmpose/models/backbones/vipnas_resnet.py b/mmpose/models/backbones/vipnas_resnet.py
new file mode 100644
index 0000000000..65777b370f
--- /dev/null
+++ b/mmpose/models/backbones/vipnas_resnet.py
@@ -0,0 +1,597 @@
+import copy
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
+from mmcv.cnn.bricks import ContextBlock
+from mmcv.utils.parrots_wrapper import _BatchNorm
+
+from ..builder import BACKBONES
+from .base_backbone import BaseBackbone
+
+
+class ViPNAS_Bottleneck(nn.Module):
+ """Bottleneck block for ViPNAS_ResNet.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ expansion (int): The ratio of ``out_channels/mid_channels`` where
+ ``mid_channels`` is the input/output channels of conv2. Default: 4.
+ stride (int): stride of the block. Default: 1
+ dilation (int): dilation of convolution. Default: 1
+ downsample (nn.Module): downsample operation on identity branch.
+ Default: None.
+ style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the
+ stride-two layer is the 3x3 conv layer, otherwise the stride-two
+ layer is the first 1x1 conv layer. Default: "pytorch".
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ kernel_size (int): kernel size of conv2 searched in ViPANS.
+ groups (int): group number of conv2 searched in ViPNAS.
+ attention (bool): whether to use attention module in the end of
+ the block.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ expansion=4,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ kernel_size=3,
+ groups=1,
+ attention=False):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__()
+ assert style in ['pytorch', 'caffe']
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.expansion = expansion
+ assert out_channels % expansion == 0
+ self.mid_channels = out_channels // expansion
+ self.stride = stride
+ self.dilation = dilation
+ self.style = style
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ if self.style == 'pytorch':
+ self.conv1_stride = 1
+ self.conv2_stride = stride
+ else:
+ self.conv1_stride = stride
+ self.conv2_stride = 1
+
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, self.mid_channels, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ norm_cfg, self.mid_channels, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ norm_cfg, out_channels, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ self.mid_channels,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ conv_cfg,
+ self.mid_channels,
+ self.mid_channels,
+ kernel_size=kernel_size,
+ stride=self.conv2_stride,
+ padding=kernel_size // 2,
+ groups=groups,
+ dilation=dilation,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ conv_cfg,
+ self.mid_channels,
+ out_channels,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ if attention:
+ self.attention = ContextBlock(out_channels,
+ max(1.0 / 16, 16.0 / out_channels))
+ else:
+ self.attention = None
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: the normalization layer named "norm2" """
+ return getattr(self, self.norm2_name)
+
+ @property
+ def norm3(self):
+ """nn.Module: the normalization layer named "norm3" """
+ return getattr(self, self.norm3_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.attention is not None:
+ out = self.attention(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+def get_expansion(block, expansion=None):
+ """Get the expansion of a residual block.
+
+ The block expansion will be obtained by the following order:
+
+ 1. If ``expansion`` is given, just return it.
+ 2. If ``block`` has the attribute ``expansion``, then return
+ ``block.expansion``.
+ 3. Return the default value according the the block type:
+ 4 for ``ViPNAS_Bottleneck``.
+
+ Args:
+ block (class): The block class.
+ expansion (int | None): The given expansion ratio.
+
+ Returns:
+ int: The expansion of the block.
+ """
+ if isinstance(expansion, int):
+ assert expansion > 0
+ elif expansion is None:
+ if hasattr(block, 'expansion'):
+ expansion = block.expansion
+ elif issubclass(block, ViPNAS_Bottleneck):
+ expansion = 1
+ else:
+ raise TypeError(f'expansion is not specified for {block.__name__}')
+ else:
+ raise TypeError('expansion must be an integer or None')
+
+ return expansion
+
+
+class ViPNAS_ResLayer(nn.Sequential):
+ """ViPNAS_ResLayer to build ResNet style backbone.
+
+ Args:
+ block (nn.Module): Residual block used to build ViPNAS ResLayer.
+ num_blocks (int): Number of blocks.
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ expansion (int, optional): The expansion for BasicBlock/Bottleneck.
+ If not specified, it will firstly be obtained via
+ ``block.expansion``. If the block has no attribute "expansion",
+ the following default values will be used: 1 for BasicBlock and
+ 4 for Bottleneck. Default: None.
+ stride (int): stride of the first block. Default: 1.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ downsample_first (bool): Downsample at the first block or last block.
+ False for Hourglass, True for ResNet. Default: True
+ kernel_size (int): Kernel Size of the corresponding convolution layer
+ searched in the block.
+ groups (int): Group number of the corresponding convolution layer
+ searched in the block.
+ attention (bool): Whether to use attention module in the end of the
+ block.
+ """
+
+ def __init__(self,
+ block,
+ num_blocks,
+ in_channels,
+ out_channels,
+ expansion=None,
+ stride=1,
+ avg_down=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ downsample_first=True,
+ kernel_size=3,
+ groups=1,
+ attention=False,
+ **kwargs):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ self.block = block
+ self.expansion = get_expansion(block, expansion)
+
+ downsample = None
+ if stride != 1 or in_channels != out_channels:
+ downsample = []
+ conv_stride = stride
+ if avg_down and stride != 1:
+ conv_stride = 1
+ downsample.append(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False))
+ downsample.extend([
+ build_conv_layer(
+ conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, out_channels)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ if downsample_first:
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ expansion=self.expansion,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ kernel_size=kernel_size,
+ groups=groups,
+ attention=attention,
+ **kwargs))
+ in_channels = out_channels
+ for _ in range(1, num_blocks):
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ expansion=self.expansion,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ kernel_size=kernel_size,
+ groups=groups,
+ attention=attention,
+ **kwargs))
+ else: # downsample_first=False is for HourglassModule
+ for i in range(0, num_blocks - 1):
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ expansion=self.expansion,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ kernel_size=kernel_size,
+ groups=groups,
+ attention=attention,
+ **kwargs))
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ expansion=self.expansion,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ kernel_size=kernel_size,
+ groups=groups,
+ attention=attention,
+ **kwargs))
+
+ super().__init__(*layers)
+
+
+@BACKBONES.register_module()
+class ViPNAS_ResNet(BaseBackbone):
+ """ResNet backbone.
+
+ ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search.
+ More details can be found in the `paper
+ `__ .
+
+ Args:
+ depth (int): Network depth, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Default: 3.
+ num_stages (int): Stages of the network. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Default: ``(1, 2, 2, 2)``.
+ dilations (Sequence[int]): Dilation of each stage.
+ Default: ``(1, 1, 1, 1)``.
+ out_indices (Sequence[int]): Output from which stages. If only one
+ stage is specified, a single tensor (feature map) is returned,
+ otherwise multiple stages are specified, a tuple of tensors will
+ be returned. Default: ``(3, )``.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
+ Default: False.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict | None): The config dict for conv layers. Default: None.
+ norm_cfg (dict): The config dict for norm layers.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: True.
+ wid (list(int)): searched width config for each stage.
+ expan (list(int)): searched expansion ratio config for each stage.
+ dep (list(int)): searched depth config for each stage.
+ ks (list(int)): searched kernel size config for each stage.
+ group (list(int)): searched group number config for each stage.
+ att (list(int)): searched attention config for each stage.
+ """
+
+ arch_settings = {
+ 50: ViPNAS_Bottleneck,
+ }
+
+ def __init__(self,
+ depth,
+ in_channels=3,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(3, ),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ with_cp=False,
+ zero_init_residual=True,
+ wid=[48, 80, 160, 304, 608],
+ expan=[None, 1, 1, 1, 1],
+ dep=[None, 4, 6, 7, 3],
+ ks=[7, 3, 5, 5, 5],
+ group=[None, 16, 16, 16, 16],
+ att=[None, True, False, True, True]):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ self.depth = depth
+ self.stem_channels = dep[0]
+ self.num_stages = num_stages
+ assert 1 <= num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.zero_init_residual = zero_init_residual
+ self.block = self.arch_settings[depth]
+ self.stage_blocks = dep[1:1 + num_stages]
+
+ self._make_stem_layer(in_channels, wid[0], ks[0])
+
+ self.res_layers = []
+ _in_channels = wid[0]
+ for i, num_blocks in enumerate(self.stage_blocks):
+ expansion = get_expansion(self.block, expan[i + 1])
+ _out_channels = wid[i + 1] * expansion
+ stride = strides[i]
+ dilation = dilations[i]
+ res_layer = self.make_res_layer(
+ block=self.block,
+ num_blocks=num_blocks,
+ in_channels=_in_channels,
+ out_channels=_out_channels,
+ expansion=expansion,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ kernel_size=ks[i + 1],
+ groups=group[i + 1],
+ attention=att[i + 1])
+ _in_channels = _out_channels
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = res_layer[-1].out_channels
+
+ def make_res_layer(self, **kwargs):
+ """Make a ViPNAS ResLayer."""
+ return ViPNAS_ResLayer(**kwargs)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def _make_stem_layer(self, in_channels, stem_channels, kernel_size):
+ """Make stem layer."""
+ if self.deep_stem:
+ self.stem = nn.Sequential(
+ ConvModule(
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=True),
+ ConvModule(
+ stem_channels // 2,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=True),
+ ConvModule(
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=True))
+ else:
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels,
+ kernel_size=kernel_size,
+ stride=2,
+ padding=kernel_size // 2,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, stem_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def _freeze_stages(self):
+ """Freeze parameters."""
+ if self.frozen_stages >= 0:
+ if self.deep_stem:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ else:
+ self.norm1.eval()
+ for m in [self.conv1, self.norm1]:
+ for param in m.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = getattr(self, f'layer{i}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+
+ super().init_weights(pretrained)
+ if pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, std=0.001)
+ for name, _ in m.named_parameters():
+ if name in ['bias']:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ nn.init.normal_(m.weight, std=0.001)
+ for name, _ in m.named_parameters():
+ if name in ['bias']:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ nn.init.normal_(m.weight, std=0.001)
+ if self.deconv_with_bias:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ """Forward function."""
+ if self.deep_stem:
+ x = self.stem(x)
+ else:
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ return tuple(outs)
+
+ def train(self, mode=True):
+ """Convert the model into training mode."""
+ super().train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmpose/models/heads/__init__.py b/mmpose/models/heads/__init__.py
index 5f4da4484b..4922ebfe78 100644
--- a/mmpose/models/heads/__init__.py
+++ b/mmpose/models/heads/__init__.py
@@ -8,10 +8,12 @@
from .topdown_heatmap_multi_stage_head import (TopdownHeatmapMSMUHead,
TopdownHeatmapMultiStageHead)
from .topdown_heatmap_simple_head import TopdownHeatmapSimpleHead
+from .vipnas_heatmap_simple_head import ViPNASHeatmapSimpleHead
__all__ = [
'TopdownHeatmapSimpleHead', 'TopdownHeatmapMultiStageHead',
'TopdownHeatmapMSMUHead', 'TopdownHeatmapBaseHead',
'AEHigherResolutionHead', 'AESimpleHead', 'DeepposeRegressionHead',
- 'TemporalRegressionHead', 'Interhand3DHead', 'HMRMeshHead'
+ 'TemporalRegressionHead', 'Interhand3DHead', 'HMRMeshHead',
+ 'ViPNASHeatmapSimpleHead'
]
diff --git a/mmpose/models/heads/vipnas_heatmap_simple_head.py b/mmpose/models/heads/vipnas_heatmap_simple_head.py
new file mode 100644
index 0000000000..69852e0d13
--- /dev/null
+++ b/mmpose/models/heads/vipnas_heatmap_simple_head.py
@@ -0,0 +1,346 @@
+import torch
+import torch.nn as nn
+from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
+ constant_init, normal_init)
+
+from mmpose.core.evaluation import pose_pck_accuracy
+from mmpose.core.post_processing import flip_back
+from mmpose.models.builder import build_loss
+from mmpose.models.utils.ops import resize
+from ..builder import HEADS
+from .topdown_heatmap_base_head import TopdownHeatmapBaseHead
+
+
+@HEADS.register_module()
+class ViPNASHeatmapSimpleHead(TopdownHeatmapBaseHead):
+ """ViPNAS heatmap simple head.
+
+ ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search.
+ More details can be found in the `paper
+ `__ .
+
+ TopdownHeatmapSimpleHead is consisted of (>=0) number of deconv layers
+ and a simple conv2d layer.
+
+ Args:
+ in_channels (int): Number of input channels
+ out_channels (int): Number of output channels
+ num_deconv_layers (int): Number of deconv layers.
+ num_deconv_layers should >= 0. Note that 0 means
+ no deconv layers.
+ num_deconv_filters (list|tuple): Number of filters.
+ If num_deconv_layers > 0, the length of
+ num_deconv_kernels (list|tuple): Kernel sizes.
+ num_deconv_groups (list|tuple): Group number.
+ in_index (int|Sequence[int]): Input feature index. Default: -1
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ Default: None.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ loss_keypoint (dict): Config for keypoint loss. Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_deconv_layers=3,
+ num_deconv_filters=(144, 144, 144),
+ num_deconv_kernels=(4, 4, 4),
+ num_deconv_groups=(16, 16, 16),
+ extra=None,
+ in_index=0,
+ input_transform=None,
+ align_corners=False,
+ loss_keypoint=None,
+ train_cfg=None,
+ test_cfg=None):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.loss = build_loss(loss_keypoint)
+
+ self.train_cfg = {} if train_cfg is None else train_cfg
+ self.test_cfg = {} if test_cfg is None else test_cfg
+ self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap')
+
+ self._init_inputs(in_channels, in_index, input_transform)
+ self.in_index = in_index
+ self.align_corners = align_corners
+
+ if extra is not None and not isinstance(extra, dict):
+ raise TypeError('extra should be dict or None.')
+
+ if num_deconv_layers > 0:
+ self.deconv_layers = self._make_deconv_layer(
+ num_deconv_layers, num_deconv_filters, num_deconv_kernels,
+ num_deconv_groups)
+ elif num_deconv_layers == 0:
+ self.deconv_layers = nn.Identity()
+ else:
+ raise ValueError(
+ f'num_deconv_layers ({num_deconv_layers}) should >= 0.')
+
+ identity_final_layer = False
+ if extra is not None and 'final_conv_kernel' in extra:
+ assert extra['final_conv_kernel'] in [0, 1, 3]
+ if extra['final_conv_kernel'] == 3:
+ padding = 1
+ elif extra['final_conv_kernel'] == 1:
+ padding = 0
+ else:
+ # 0 for Identity mapping.
+ identity_final_layer = True
+ kernel_size = extra['final_conv_kernel']
+ else:
+ kernel_size = 1
+ padding = 0
+
+ if identity_final_layer:
+ self.final_layer = nn.Identity()
+ else:
+ conv_channels = num_deconv_filters[
+ -1] if num_deconv_layers > 0 else self.in_channels
+
+ layers = []
+ if extra is not None:
+ num_conv_layers = extra.get('num_conv_layers', 0)
+ num_conv_kernels = extra.get('num_conv_kernels',
+ [1] * num_conv_layers)
+
+ for i in range(num_conv_layers):
+ layers.append(
+ build_conv_layer(
+ dict(type='Conv2d'),
+ in_channels=conv_channels,
+ out_channels=conv_channels,
+ kernel_size=num_conv_kernels[i],
+ stride=1,
+ padding=(num_conv_kernels[i] - 1) // 2))
+ layers.append(
+ build_norm_layer(dict(type='BN'), conv_channels)[1])
+ layers.append(nn.ReLU(inplace=True))
+
+ layers.append(
+ build_conv_layer(
+ cfg=dict(type='Conv2d'),
+ in_channels=conv_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding))
+
+ if len(layers) > 1:
+ self.final_layer = nn.Sequential(*layers)
+ else:
+ self.final_layer = layers[0]
+
+ def get_loss(self, output, target, target_weight):
+ """Calculate top-down keypoint loss.
+
+ Note:
+ batch_size: N
+ num_keypoints: K
+ heatmaps height: H
+ heatmaps weight: W
+
+ Args:
+ output (torch.Tensor[NxKxHxW]): Output heatmaps.
+ target (torch.Tensor[NxKxHxW]): Target heatmaps.
+ target_weight (torch.Tensor[NxKx1]):
+ Weights across different joint types.
+ """
+
+ losses = dict()
+
+ assert not isinstance(self.loss, nn.Sequential)
+ assert target.dim() == 4 and target_weight.dim() == 3
+ losses['mse_loss'] = self.loss(output, target, target_weight)
+
+ return losses
+
+ def get_accuracy(self, output, target, target_weight):
+ """Calculate accuracy for top-down keypoint loss.
+
+ Note:
+ batch_size: N
+ num_keypoints: K
+ heatmaps height: H
+ heatmaps weight: W
+
+ Args:
+ output (torch.Tensor[NxKxHxW]): Output heatmaps.
+ target (torch.Tensor[NxKxHxW]): Target heatmaps.
+ target_weight (torch.Tensor[NxKx1]):
+ Weights across different joint types.
+ """
+
+ accuracy = dict()
+
+ if self.target_type.lower() == 'GaussianHeatmap'.lower():
+ _, avg_acc, _ = pose_pck_accuracy(
+ output.detach().cpu().numpy(),
+ target.detach().cpu().numpy(),
+ target_weight.detach().cpu().numpy().squeeze(-1) > 0)
+ accuracy['acc_pose'] = float(avg_acc)
+
+ return accuracy
+
+ def forward(self, x):
+ """Forward function."""
+ x = self._transform_inputs(x)
+ x = self.deconv_layers(x)
+ x = self.final_layer(x)
+ return x
+
+ def inference_model(self, x, flip_pairs=None):
+ """Inference function.
+
+ Returns:
+ output_heatmap (np.ndarray): Output heatmaps.
+
+ Args:
+ x (torch.Tensor[NxKxHxW]): Input features.
+ flip_pairs (None | list[tuple()):
+ Pairs of keypoints which are mirrored.
+ """
+ output = self.forward(x)
+
+ if flip_pairs is not None:
+ output_heatmap = flip_back(
+ output.detach().cpu().numpy(),
+ flip_pairs,
+ target_type=self.target_type)
+ # feature is not aligned, shift flipped heatmap for higher accuracy
+ if self.test_cfg.get('shift_heatmap', False):
+ output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1]
+ else:
+ output_heatmap = output.detach().cpu().numpy()
+ return output_heatmap
+
+ def _init_inputs(self, in_channels, in_index, input_transform):
+ """Check and initialize input transforms.
+
+ The in_channels, in_index and input_transform must match.
+ Specifically, when input_transform is None, only single feature map
+ will be selected. So in_channels and in_index must be of type int.
+ When input_transform is not None, in_channels and in_index must be
+ list or tuple, with the same length.
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ in_index (int|Sequence[int]): Input feature index.
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ """
+
+ if input_transform is not None:
+ assert input_transform in ['resize_concat', 'multiple_select']
+ self.input_transform = input_transform
+ self.in_index = in_index
+ if input_transform is not None:
+ assert isinstance(in_channels, (list, tuple))
+ assert isinstance(in_index, (list, tuple))
+ assert len(in_channels) == len(in_index)
+ if input_transform == 'resize_concat':
+ self.in_channels = sum(in_channels)
+ else:
+ self.in_channels = in_channels
+ else:
+ assert isinstance(in_channels, int)
+ assert isinstance(in_index, int)
+ self.in_channels = in_channels
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+
+ Args:
+ inputs (list[Tensor] | Tensor): multi-level img features.
+
+ Returns:
+ Tensor: The transformed inputs
+ """
+ if not isinstance(inputs, list):
+ return inputs
+
+ if self.input_transform == 'resize_concat':
+ inputs = [inputs[i] for i in self.in_index]
+ upsampled_inputs = [
+ resize(
+ input=x,
+ size=inputs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners) for x in inputs
+ ]
+ inputs = torch.cat(upsampled_inputs, dim=1)
+ elif self.input_transform == 'multiple_select':
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ def _make_deconv_layer(self, num_layers, num_filters, num_kernels,
+ num_groups):
+ """Make deconv layers."""
+ if num_layers != len(num_filters):
+ error_msg = f'num_layers({num_layers}) ' \
+ f'!= length of num_filters({len(num_filters)})'
+ raise ValueError(error_msg)
+ if num_layers != len(num_kernels):
+ error_msg = f'num_layers({num_layers}) ' \
+ f'!= length of num_kernels({len(num_kernels)})'
+ raise ValueError(error_msg)
+ if num_layers != len(num_groups):
+ error_msg = f'num_layers({num_layers}) ' \
+ f'!= length of num_groups({len(num_groups)})'
+ raise ValueError(error_msg)
+
+ layers = []
+ for i in range(num_layers):
+ kernel, padding, output_padding = \
+ self._get_deconv_cfg(num_kernels[i])
+
+ planes = num_filters[i]
+ groups = num_groups[i]
+ layers.append(
+ build_upsample_layer(
+ dict(type='deconv'),
+ in_channels=self.in_channels,
+ out_channels=planes,
+ kernel_size=kernel,
+ groups=groups,
+ stride=2,
+ padding=padding,
+ output_padding=output_padding,
+ bias=False))
+ layers.append(nn.BatchNorm2d(planes))
+ layers.append(nn.ReLU(inplace=True))
+ self.in_channels = planes
+
+ return nn.Sequential(*layers)
+
+ def init_weights(self):
+ """Initialize model weights."""
+ for _, m in self.deconv_layers.named_modules():
+ if isinstance(m, nn.ConvTranspose2d):
+ normal_init(m, std=0.001)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ for m in self.final_layer.modules():
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, std=0.001, bias=0)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
diff --git a/tests/test_backbones/test_vipnas.py b/tests/test_backbones/test_vipnas.py
new file mode 100644
index 0000000000..899aec45d7
--- /dev/null
+++ b/tests/test_backbones/test_vipnas.py
@@ -0,0 +1,340 @@
+import pytest
+import torch
+import torch.nn as nn
+from mmcv.utils.parrots_wrapper import _BatchNorm
+
+from mmpose.models.backbones import ViPNAS_ResNet
+from mmpose.models.backbones.vipnas_resnet import (ViPNAS_Bottleneck,
+ ViPNAS_ResLayer,
+ get_expansion)
+
+
+def is_block(modules):
+ """Check if is ViPNAS_ResNet building block."""
+ if isinstance(modules, (ViPNAS_Bottleneck)):
+ return True
+ return False
+
+
+def all_zeros(modules):
+ """Check if the weight(and bias) is all zero."""
+ weight_zero = torch.equal(modules.weight.data,
+ torch.zeros_like(modules.weight.data))
+ if hasattr(modules, 'bias'):
+ bias_zero = torch.equal(modules.bias.data,
+ torch.zeros_like(modules.bias.data))
+ else:
+ bias_zero = True
+
+ return weight_zero and bias_zero
+
+
+def check_norm_state(modules, train_state):
+ """Check if norm layer is in correct train state."""
+ for mod in modules:
+ if isinstance(mod, _BatchNorm):
+ if mod.training != train_state:
+ return False
+ return True
+
+
+def test_get_expansion():
+ assert get_expansion(ViPNAS_Bottleneck, 2) == 2
+ assert get_expansion(ViPNAS_Bottleneck) == 1
+
+ class MyResBlock(nn.Module):
+
+ expansion = 8
+
+ assert get_expansion(MyResBlock) == 8
+
+ # expansion must be an integer or None
+ with pytest.raises(TypeError):
+ get_expansion(ViPNAS_Bottleneck, '0')
+
+ # expansion is not specified and cannot be inferred
+ with pytest.raises(TypeError):
+
+ class SomeModule(nn.Module):
+ pass
+
+ get_expansion(SomeModule)
+
+
+def test_vipnas_bottleneck():
+ # style must be in ['pytorch', 'caffe']
+ with pytest.raises(AssertionError):
+ ViPNAS_Bottleneck(64, 64, style='tensorflow')
+
+ # expansion must be divisible by out_channels
+ with pytest.raises(AssertionError):
+ ViPNAS_Bottleneck(64, 64, expansion=3)
+
+ # Test ViPNAS_Bottleneck style
+ block = ViPNAS_Bottleneck(64, 64, stride=2, style='pytorch')
+ assert block.conv1.stride == (1, 1)
+ assert block.conv2.stride == (2, 2)
+ block = ViPNAS_Bottleneck(64, 64, stride=2, style='caffe')
+ assert block.conv1.stride == (2, 2)
+ assert block.conv2.stride == (1, 1)
+
+ # ViPNAS_Bottleneck with stride 1
+ block = ViPNAS_Bottleneck(64, 64, style='pytorch')
+ assert block.in_channels == 64
+ assert block.mid_channels == 16
+ assert block.out_channels == 64
+ assert block.conv1.in_channels == 64
+ assert block.conv1.out_channels == 16
+ assert block.conv1.kernel_size == (1, 1)
+ assert block.conv2.in_channels == 16
+ assert block.conv2.out_channels == 16
+ assert block.conv2.kernel_size == (3, 3)
+ assert block.conv3.in_channels == 16
+ assert block.conv3.out_channels == 64
+ assert block.conv3.kernel_size == (1, 1)
+ x = torch.randn(1, 64, 56, 56)
+ x_out = block(x)
+ assert x_out.shape == (1, 64, 56, 56)
+
+ # ViPNAS_Bottleneck with stride 1 and downsample
+ downsample = nn.Sequential(
+ nn.Conv2d(64, 128, kernel_size=1), nn.BatchNorm2d(128))
+ block = ViPNAS_Bottleneck(64, 128, style='pytorch', downsample=downsample)
+ assert block.in_channels == 64
+ assert block.mid_channels == 32
+ assert block.out_channels == 128
+ assert block.conv1.in_channels == 64
+ assert block.conv1.out_channels == 32
+ assert block.conv1.kernel_size == (1, 1)
+ assert block.conv2.in_channels == 32
+ assert block.conv2.out_channels == 32
+ assert block.conv2.kernel_size == (3, 3)
+ assert block.conv3.in_channels == 32
+ assert block.conv3.out_channels == 128
+ assert block.conv3.kernel_size == (1, 1)
+ x = torch.randn(1, 64, 56, 56)
+ x_out = block(x)
+ assert x_out.shape == (1, 128, 56, 56)
+
+ # ViPNAS_Bottleneck with stride 2 and downsample
+ downsample = nn.Sequential(
+ nn.Conv2d(64, 128, kernel_size=1, stride=2), nn.BatchNorm2d(128))
+ block = ViPNAS_Bottleneck(
+ 64, 128, stride=2, style='pytorch', downsample=downsample)
+ x = torch.randn(1, 64, 56, 56)
+ x_out = block(x)
+ assert x_out.shape == (1, 128, 28, 28)
+
+ # ViPNAS_Bottleneck with expansion 2
+ block = ViPNAS_Bottleneck(64, 64, style='pytorch', expansion=2)
+ assert block.in_channels == 64
+ assert block.mid_channels == 32
+ assert block.out_channels == 64
+ assert block.conv1.in_channels == 64
+ assert block.conv1.out_channels == 32
+ assert block.conv1.kernel_size == (1, 1)
+ assert block.conv2.in_channels == 32
+ assert block.conv2.out_channels == 32
+ assert block.conv2.kernel_size == (3, 3)
+ assert block.conv3.in_channels == 32
+ assert block.conv3.out_channels == 64
+ assert block.conv3.kernel_size == (1, 1)
+ x = torch.randn(1, 64, 56, 56)
+ x_out = block(x)
+ assert x_out.shape == (1, 64, 56, 56)
+
+ # Test ViPNAS_Bottleneck with checkpointing
+ block = ViPNAS_Bottleneck(64, 64, with_cp=True)
+ block.train()
+ assert block.with_cp
+ x = torch.randn(1, 64, 56, 56, requires_grad=True)
+ x_out = block(x)
+ assert x_out.shape == torch.Size([1, 64, 56, 56])
+
+
+def test_vipnas_bottleneck_reslayer():
+ # 3 Bottleneck w/o downsample
+ layer = ViPNAS_ResLayer(ViPNAS_Bottleneck, 3, 32, 32)
+ assert len(layer) == 3
+ for i in range(3):
+ assert layer[i].in_channels == 32
+ assert layer[i].out_channels == 32
+ assert layer[i].downsample is None
+ x = torch.randn(1, 32, 56, 56)
+ x_out = layer(x)
+ assert x_out.shape == (1, 32, 56, 56)
+
+ # 3 ViPNAS_Bottleneck w/ stride 1 and downsample
+ layer = ViPNAS_ResLayer(ViPNAS_Bottleneck, 3, 32, 64)
+ assert len(layer) == 3
+ assert layer[0].in_channels == 32
+ assert layer[0].out_channels == 64
+ assert layer[0].stride == 1
+ assert layer[0].conv1.out_channels == 64
+ assert layer[0].downsample is not None and len(layer[0].downsample) == 2
+ assert isinstance(layer[0].downsample[0], nn.Conv2d)
+ assert layer[0].downsample[0].stride == (1, 1)
+ for i in range(1, 3):
+ assert layer[i].in_channels == 64
+ assert layer[i].out_channels == 64
+ assert layer[i].conv1.out_channels == 64
+ assert layer[i].stride == 1
+ assert layer[i].downsample is None
+ x = torch.randn(1, 32, 56, 56)
+ x_out = layer(x)
+ assert x_out.shape == (1, 64, 56, 56)
+
+ # 3 ViPNAS_Bottleneck w/ stride 2 and downsample
+ layer = ViPNAS_ResLayer(ViPNAS_Bottleneck, 3, 32, 64, stride=2)
+ assert len(layer) == 3
+ assert layer[0].in_channels == 32
+ assert layer[0].out_channels == 64
+ assert layer[0].stride == 2
+ assert layer[0].conv1.out_channels == 64
+ assert layer[0].downsample is not None and len(layer[0].downsample) == 2
+ assert isinstance(layer[0].downsample[0], nn.Conv2d)
+ assert layer[0].downsample[0].stride == (2, 2)
+ for i in range(1, 3):
+ assert layer[i].in_channels == 64
+ assert layer[i].out_channels == 64
+ assert layer[i].conv1.out_channels == 64
+ assert layer[i].stride == 1
+ assert layer[i].downsample is None
+ x = torch.randn(1, 32, 56, 56)
+ x_out = layer(x)
+ assert x_out.shape == (1, 64, 28, 28)
+
+ # 3 ViPNAS_Bottleneck w/ stride 2 and downsample with avg pool
+ layer = ViPNAS_ResLayer(
+ ViPNAS_Bottleneck, 3, 32, 64, stride=2, avg_down=True)
+ assert len(layer) == 3
+ assert layer[0].in_channels == 32
+ assert layer[0].out_channels == 64
+ assert layer[0].stride == 2
+ assert layer[0].conv1.out_channels == 64
+ assert layer[0].downsample is not None and len(layer[0].downsample) == 3
+ assert isinstance(layer[0].downsample[0], nn.AvgPool2d)
+ assert layer[0].downsample[0].stride == 2
+ for i in range(1, 3):
+ assert layer[i].in_channels == 64
+ assert layer[i].out_channels == 64
+ assert layer[i].conv1.out_channels == 64
+ assert layer[i].stride == 1
+ assert layer[i].downsample is None
+ x = torch.randn(1, 32, 56, 56)
+ x_out = layer(x)
+ assert x_out.shape == (1, 64, 28, 28)
+
+ # 3 ViPNAS_Bottleneck with custom expansion
+ layer = ViPNAS_ResLayer(ViPNAS_Bottleneck, 3, 32, 32, expansion=2)
+ assert len(layer) == 3
+ for i in range(3):
+ assert layer[i].in_channels == 32
+ assert layer[i].out_channels == 32
+ assert layer[i].stride == 1
+ assert layer[i].conv1.out_channels == 16
+ assert layer[i].downsample is None
+ x = torch.randn(1, 32, 56, 56)
+ x_out = layer(x)
+ assert x_out.shape == (1, 32, 56, 56)
+
+
+def test_resnet():
+ """Test ViPNAS_ResNet backbone."""
+ with pytest.raises(KeyError):
+ # ViPNAS_ResNet depth should be in [18, 34, 50, 101, 152]
+ ViPNAS_ResNet(20)
+
+ with pytest.raises(AssertionError):
+ # In ViPNAS_ResNet: 1 <= num_stages <= 4
+ ViPNAS_ResNet(50, num_stages=0)
+
+ with pytest.raises(AssertionError):
+ # In ViPNAS_ResNet: 1 <= num_stages <= 4
+ ViPNAS_ResNet(50, num_stages=5)
+
+ with pytest.raises(AssertionError):
+ # len(strides) == len(dilations) == num_stages
+ ViPNAS_ResNet(50, strides=(1, ), dilations=(1, 1), num_stages=3)
+
+ with pytest.raises(TypeError):
+ # pretrained must be a string path
+ model = ViPNAS_ResNet(50)
+ model.init_weights(pretrained=0)
+
+ with pytest.raises(AssertionError):
+ # Style must be in ['pytorch', 'caffe']
+ ViPNAS_ResNet(50, style='tensorflow')
+
+ # Test ViPNAS_ResNet50 norm_eval=True
+ model = ViPNAS_ResNet(50, norm_eval=True)
+ model.init_weights()
+ model.train()
+ assert check_norm_state(model.modules(), False)
+
+ # Test ViPNAS_ResNet50 with first stage frozen
+ frozen_stages = 1
+ model = ViPNAS_ResNet(50, frozen_stages=frozen_stages)
+ model.init_weights()
+ model.train()
+ assert model.norm1.training is False
+ for layer in [model.conv1, model.norm1]:
+ for param in layer.parameters():
+ assert param.requires_grad is False
+ for i in range(1, frozen_stages + 1):
+ layer = getattr(model, f'layer{i}')
+ for mod in layer.modules():
+ if isinstance(mod, _BatchNorm):
+ assert mod.training is False
+ for param in layer.parameters():
+ assert param.requires_grad is False
+
+ # Test ViPNAS_ResNet50 with BatchNorm forward
+ model = ViPNAS_ResNet(50, out_indices=(0, 1, 2, 3))
+ model.init_weights()
+ model.train()
+
+ imgs = torch.randn(1, 3, 224, 224)
+ feat = model(imgs)
+ assert len(feat) == 4
+ assert feat[0].shape == (1, 80, 56, 56)
+ assert feat[1].shape == (1, 160, 28, 28)
+ assert feat[2].shape == (1, 304, 14, 14)
+ assert feat[3].shape == (1, 608, 7, 7)
+
+ # Test ViPNAS_ResNet50 with layers 1, 2, 3 out forward
+ model = ViPNAS_ResNet(50, out_indices=(0, 1, 2))
+ model.init_weights()
+ model.train()
+
+ imgs = torch.randn(1, 3, 224, 224)
+ feat = model(imgs)
+ assert len(feat) == 3
+ assert feat[0].shape == (1, 80, 56, 56)
+ assert feat[1].shape == (1, 160, 28, 28)
+ assert feat[2].shape == (1, 304, 14, 14)
+
+ # Test ViPNAS_ResNet50 with layers 3 (top feature maps) out forward
+ model = ViPNAS_ResNet(50, out_indices=(3, ))
+ model.init_weights()
+ model.train()
+
+ imgs = torch.randn(1, 3, 224, 224)
+ feat = model(imgs)
+ assert feat.shape == (1, 608, 7, 7)
+
+ # Test ViPNAS_ResNet50 with checkpoint forward
+ model = ViPNAS_ResNet(50, out_indices=(0, 1, 2, 3), with_cp=True)
+ for m in model.modules():
+ if is_block(m):
+ assert m.with_cp
+ model.init_weights()
+ model.train()
+
+ imgs = torch.randn(1, 3, 224, 224)
+ feat = model(imgs)
+ assert len(feat) == 4
+ assert feat[0].shape == (1, 80, 56, 56)
+ assert feat[1].shape == (1, 160, 28, 28)
+ assert feat[2].shape == (1, 304, 14, 14)
+ assert feat[3].shape == (1, 608, 7, 7)
diff --git a/tests/test_model/test_top_down_forward.py b/tests/test_model/test_top_down_forward.py
index e0caa02dbd..42efa3e036 100644
--- a/tests/test_model/test_top_down_forward.py
+++ b/tests/test_model/test_top_down_forward.py
@@ -4,6 +4,57 @@
from mmpose.models.detectors import TopDown
+def test_vipnas_forward():
+ # model settings
+
+ channel_cfg = dict(
+ num_output_channels=17,
+ dataset_joints=17,
+ dataset_channel=[
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
+ ],
+ inference_channel=[
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
+ ])
+
+ model_cfg = dict(
+ type='TopDown',
+ pretrained=None,
+ backbone=dict(type='ViPNAS_ResNet', depth=50),
+ keypoint_head=dict(
+ type='ViPNASHeatmapSimpleHead',
+ in_channels=608,
+ out_channels=channel_cfg['num_output_channels'],
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=True,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+ detector = TopDown(model_cfg['backbone'], None, model_cfg['keypoint_head'],
+ model_cfg['train_cfg'], model_cfg['test_cfg'],
+ model_cfg['pretrained'])
+
+ input_shape = (1, 3, 256, 256)
+ mm_inputs = _demo_mm_inputs(input_shape)
+
+ imgs = mm_inputs.pop('imgs')
+ target = mm_inputs.pop('target')
+ target_weight = mm_inputs.pop('target_weight')
+ img_metas = mm_inputs.pop('img_metas')
+
+ # Test forward train
+ losses = detector.forward(
+ imgs, target, target_weight, img_metas, return_loss=True)
+ assert isinstance(losses, dict)
+
+ # Test forward test
+ with torch.no_grad():
+ _ = detector.forward(imgs, img_metas=img_metas, return_loss=False)
+
+
def test_topdown_forward():
model_cfg = dict(
type='TopDown',
diff --git a/tests/test_model/test_top_down_head.py b/tests/test_model/test_top_down_head.py
index 48d95423ae..502ef327dc 100644
--- a/tests/test_model/test_top_down_head.py
+++ b/tests/test_model/test_top_down_head.py
@@ -4,7 +4,135 @@
from mmpose.models import (DeepposeRegressionHead, TopdownHeatmapMSMUHead,
TopdownHeatmapMultiStageHead,
- TopdownHeatmapSimpleHead)
+ TopdownHeatmapSimpleHead, ViPNASHeatmapSimpleHead)
+
+
+def test_vipnas_simple_head():
+ """Test simple head."""
+ with pytest.raises(TypeError):
+ # extra
+ _ = ViPNASHeatmapSimpleHead(
+ out_channels=3,
+ in_channels=512,
+ extra=[],
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))
+
+ with pytest.raises(TypeError):
+ head = ViPNASHeatmapSimpleHead(
+ out_channels=3, in_channels=512, extra={'final_conv_kernel': 1})
+
+ # test num deconv layers
+ with pytest.raises(ValueError):
+ _ = ViPNASHeatmapSimpleHead(
+ out_channels=3,
+ in_channels=512,
+ num_deconv_layers=-1,
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))
+
+ _ = ViPNASHeatmapSimpleHead(
+ out_channels=3,
+ in_channels=512,
+ num_deconv_layers=0,
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))
+
+ with pytest.raises(ValueError):
+ # the number of layers should match
+ _ = ViPNASHeatmapSimpleHead(
+ out_channels=3,
+ in_channels=512,
+ num_deconv_layers=3,
+ num_deconv_filters=(256, 256),
+ num_deconv_kernels=(4, 4),
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))
+
+ with pytest.raises(ValueError):
+ # the number of kernels should match
+ _ = ViPNASHeatmapSimpleHead(
+ out_channels=3,
+ in_channels=512,
+ num_deconv_layers=3,
+ num_deconv_filters=(256, 256, 256),
+ num_deconv_kernels=(4, 4),
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))
+
+ with pytest.raises(ValueError):
+ # the deconv kernels should be 4, 3, 2
+ _ = ViPNASHeatmapSimpleHead(
+ out_channels=3,
+ in_channels=512,
+ num_deconv_layers=3,
+ num_deconv_filters=(256, 256, 256),
+ num_deconv_kernels=(3, 2, 0),
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))
+
+ with pytest.raises(ValueError):
+ # the deconv kernels should be 4, 3, 2
+ _ = ViPNASHeatmapSimpleHead(
+ out_channels=3,
+ in_channels=512,
+ num_deconv_layers=3,
+ num_deconv_filters=(256, 256, 256),
+ num_deconv_kernels=(4, 4, -1),
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))
+
+ # test final_conv_kernel
+ head = ViPNASHeatmapSimpleHead(
+ out_channels=3,
+ in_channels=512,
+ extra={'final_conv_kernel': 3},
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))
+ head.init_weights()
+ assert head.final_layer.padding == (1, 1)
+ head = ViPNASHeatmapSimpleHead(
+ out_channels=3,
+ in_channels=512,
+ extra={'final_conv_kernel': 1},
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))
+ assert head.final_layer.padding == (0, 0)
+ _ = ViPNASHeatmapSimpleHead(
+ out_channels=3,
+ in_channels=512,
+ extra={'final_conv_kernel': 0},
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))
+
+ head = ViPNASHeatmapSimpleHead(
+ out_channels=3,
+ in_channels=512,
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True),
+ extra=dict(
+ final_conv_kernel=1, num_conv_layers=1, num_conv_kernels=(1, )))
+ assert len(head.final_layer) == 4
+
+ head = ViPNASHeatmapSimpleHead(
+ out_channels=3,
+ in_channels=512,
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))
+ input_shape = (1, 512, 32, 32)
+ inputs = _demo_inputs(input_shape)
+ out = head(inputs)
+ assert out.shape == torch.Size([1, 3, 256, 256])
+
+ head = ViPNASHeatmapSimpleHead(
+ out_channels=3,
+ in_channels=512,
+ num_deconv_layers=0,
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))
+ input_shape = (1, 512, 32, 32)
+ inputs = _demo_inputs(input_shape)
+ out = head(inputs)
+ assert out.shape == torch.Size([1, 3, 32, 32])
+
+ head = ViPNASHeatmapSimpleHead(
+ out_channels=3,
+ in_channels=512,
+ num_deconv_layers=0,
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))
+ input_shape = (1, 512, 32, 32)
+ inputs = _demo_inputs(input_shape)
+ out = head([inputs])
+ assert out.shape == torch.Size([1, 3, 32, 32])
+
+ head.init_weights()
def test_top_down_simple_head():