Skip to content

Commit

Permalink
switch to new config
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis committed Sep 15, 2023
1 parent e781fcf commit b19fffa
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 34 deletions.
2 changes: 2 additions & 0 deletions configs/body_2d_keypoint/edpose/coco/edpose_coco.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,5 @@ Results on COCO val2017.
| [edpose_res50_coco](/configs/body_2d_keypoint/edpose/coco/edpose_res50_8xb2-50e_coco-800x1333.py) | ResNet-50 | 0.716 | 0.898 | 0.783 | 0.793 | 0.944 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.json) |

The checkpoint is converted from the official repo. The training of EDPose is not supported yet. It will be supported in the future updates.

The above config follows [Pure Python style](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta). Please install `mmengine>=0.8.2` to use this config.
Original file line number Diff line number Diff line change
@@ -1,21 +1,40 @@
_base_ = ['../../../_base_/default_runtime.py']
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from mmpose.configs._base_.default_runtime import * # noqa

from mmcv.transforms import RandomChoice, RandomChoiceResize
from mmengine.dataset import DefaultSampler
from mmengine.model import PretrainedInit
from mmengine.optim import LinearLR, MultiStepLR
from torch.nn import GroupNorm
from torch.optim import Adam

from mmpose.codecs import EDPoseLabel
from mmpose.datasets import (BottomupRandomChoiceResize, BottomupRandomCrop,
CocoDataset, LoadImage, PackPoseInputs,
RandomFlip)
from mmpose.evaluation import CocoMetric
from mmpose.models import (BottomupPoseEstimator, ChannelMapper, EDPoseHead,
PoseDataPreprocessor, ResNet)
from mmpose.models.utils import FrozenBatchNorm2d

# runtime
train_cfg = dict(max_epochs=50, val_interval=10)
train_cfg.update(max_epochs=50, val_interval=10) # noqa

# optimizer
optim_wrapper = dict(optimizer=dict(
type='Adam',
type=Adam,
lr=1e-3,
))

# learning policy
param_scheduler = [
dict(type=LinearLR, begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='LinearLR', begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='MultiStepLR',
type=MultiStepLR,
begin=0,
end=140,
milestones=[33, 45],
Expand All @@ -27,40 +46,42 @@
auto_scale_lr = dict(base_batch_size=80)

# hooks
default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))
default_hooks.update( # noqa
checkpoint=dict(save_best='coco/AP', rule='greater'))

# codec settings
codec = dict(type='EDPoseLabel', num_select=50, num_keypoints=17)
codec = dict(type=EDPoseLabel, num_select=50, num_keypoints=17)

# model settings
model = dict(
type='BottomupPoseEstimator',
type=BottomupPoseEstimator,
data_preprocessor=dict(
type='PoseDataPreprocessor',
type=PoseDataPreprocessor,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=1),
backbone=dict(
type='ResNet',
type=ResNet,
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='FrozenBatchNorm2d', requires_grad=False),
norm_cfg=dict(type=FrozenBatchNorm2d, requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
init_cfg=dict(
type=PretrainedInit, checkpoint='torchvision://resnet50')),
neck=dict(
type='ChannelMapper',
type=ChannelMapper,
in_channels=[512, 1024, 2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
norm_cfg=dict(type=GroupNorm, num_groups=32),
num_outs=4),
head=dict(
type='EDPoseHead',
type=EDPoseHead,
num_queries=900,
num_feature_levels=4,
num_keypoints=17,
Expand Down Expand Up @@ -117,57 +138,57 @@
find_unused_parameters = True

# base dataset settings
dataset_type = 'CocoDataset'
dataset_type = CocoDataset
data_mode = 'bottomup'
data_root = 'data/coco/'

# pipelines
train_pipeline = [
dict(type='LoadImage'),
dict(type='RandomFlip', direction='horizontal'),
dict(type=LoadImage),
dict(type=RandomFlip, direction='horizontal'),
dict(
type='RandomChoice',
type=RandomChoice,
transforms=[
[
dict(
type='RandomChoiceResize',
type=RandomChoiceResize,
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
],
[
dict(
type='BottomupRandomChoiceResize',
type=BottomupRandomChoiceResize,
# The radio of all image in train dataset < 7
# follow the original implement
scales=[(400, 4200), (500, 4200), (600, 4200)],
keep_ratio=True),
dict(
type='BottomupRandomCrop',
type=BottomupRandomCrop,
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='BottomupRandomChoiceResize',
type=BottomupRandomChoiceResize,
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
]
]),
dict(type='PackPoseInputs'),
dict(type=PackPoseInputs),
]

val_pipeline = [
dict(type='LoadImage'),
dict(type=LoadImage),
dict(
type='BottomupRandomChoiceResize',
type=BottomupRandomChoiceResize,
scales=[(800, 1333)],
keep_ratio=True,
backend='pillow'),
dict(
type='PackPoseInputs',
type=PackPoseInputs,
meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape',
'img_shape', 'input_size', 'input_center', 'input_scale',
'flip', 'flip_direction', 'flip_indices', 'raw_ann_info',
Expand All @@ -179,7 +200,7 @@
batch_size=1,
num_workers=1,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
Expand All @@ -194,7 +215,7 @@
num_workers=8,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
sampler=dict(type=DefaultSampler, shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
Expand All @@ -208,8 +229,7 @@

# evaluators
val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'annotations/person_keypoints_val2017.json',
type=CocoMetric,
nms_mode='none',
score_mode='keypoint',
)
Expand Down
3 changes: 2 additions & 1 deletion mmpose/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def dataset_meta_from_config(config: Config,
import mmpose.datasets.datasets # noqa: F401, F403
from mmpose.registry import DATASETS

dataset_class = DATASETS.get(dataset_cfg.type)
dataset_class = dataset_cfg.type if isinstance(
dataset_cfg.type, type) else DATASETS.get(dataset_cfg.type)
metainfo = dataset_class.METAINFO

metainfo = parse_pose_metainfo(metainfo)
Expand Down

0 comments on commit b19fffa

Please sign in to comment.