-
Notifications
You must be signed in to change notification settings - Fork 9.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support ConvNeXt-V2 in projects (#9619)
- Loading branch information
Showing
2 changed files
with
130 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# ConvNeXt-V2 | ||
|
||
> [ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](http://arxiv.org/abs/2301.00808) | ||
## Abstract | ||
|
||
Driven by improved architectures and better representation learning frameworks, the field of visual recognition has enjoyed rapid modernization and performance boost in the early 2020s. For example, modern ConvNets, represented by ConvNeXt \[52\], have demonstrated strong performance in various scenarios. While these models were originally designed for supervised learning with ImageNet labels, they can also potentially benefit from self-supervised learning techniques such as masked autoencoders (MAE) . However, we found that simply combining these two approaches leads to subpar performance. In this paper, we propose a fully convolutional masked autoencoder framework and a new Global Response Normalization (GRN) layer that can be added to the ConvNeXt architecture to enhance inter-channel feature competition. This co-design of self-supervised learning techniques and architectural improvement results in a new model family called ConvNeXt V2, which significantly improves the performance of pure ConvNets on various recognition benchmarks, including ImageNet classification, COCO detection, and ADE20K segmentation. We also provide pre-trained ConvNeXt V2 models of various sizes, ranging from an efficient 3.7Mparameter Atto model with 76.7% top-1 accuracy on Im-ageNet, to a 650M Huge model that achieves a state-of-theart 88.9% accuracy using only public training data. | ||
|
||
<div align=center> | ||
<img src="https://user-images.githubusercontent.com/12907710/212588579-02d621d8-5796-4f0d-b4d2-758fe9c2f395.png" width="50%"/> | ||
</div> | ||
|
||
## Results and models | ||
|
||
| Method | Backbone | Pretrain | Lr schd | Augmentation | Mem (GB) | box AP | mask AP | Config | Download | | ||
| :--------: | :-----------: | :------: | :-----: | :----------: | :------: | :----: | :-----: | :----------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | ||
| Mask R-CNN | ConvNeXt-V2-B | FCMAE | 3x | LSJ | 22.5 | 52.9 | 46.4 | [config](./mask-rcnn_convnext-v2-b_fpn_lsj-3x-fcmae_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/convnextv2/mask-rcnn_convnext-v2-b_fpn_lsj-3x-fcmae_coco/mask-rcnn_convnext-v2-b_fpn_lsj-3x-fcmae_coco_20230113_110947-757ee2dd.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/convnextv2/mask-rcnn_convnext-v2-b_fpn_lsj-3x-fcmae_coco/mask-rcnn_convnext-v2-b_fpn_lsj-3x-fcmae_coco_20230113_110947.log.json) | | ||
|
||
**Note**: | ||
|
||
- This is a pre-release version of ConvNeXt-V2 object detection. The official finetuning setting of ConvNeXt-V2 has not been released yet. | ||
- ConvNeXt backbone needs to install [MMClassification dev-1.x branch](https://github.com/open-mmlab/mmclassification/tree/dev-1.x) first, which has abundant backbones for downstream tasks. | ||
|
||
```shell | ||
git clone -b dev-1.x https://github.com/open-mmlab/mmclassification.git | ||
cd mmclassification | ||
pip install -U openmim && mim install -e . | ||
``` | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@article{Woo2023ConvNeXtV2, | ||
title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders}, | ||
author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie}, | ||
year={2023}, | ||
journal={arXiv preprint arXiv:2301.00808}, | ||
} | ||
``` |
91 changes: 91 additions & 0 deletions
91
projects/ConvNeXt-V2/configs/mask-rcnn_convnext-v2-b_fpn_lsj-3x-fcmae_coco.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
_base_ = [ | ||
'mmdet::_base_/models/mask-rcnn_r50_fpn.py', | ||
'mmdet::_base_/datasets/coco_instance.py', | ||
'mmdet::_base_/schedules/schedule_1x.py', | ||
'mmdet::_base_/default_runtime.py' | ||
] | ||
|
||
# please install the mmclassification dev-1.x branch | ||
# import mmcls.models to trigger register_module in mmcls | ||
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False) | ||
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext-v2/convnext-v2-base_3rdparty-fcmae_in1k_20230104-8a798eaf.pth' # noqa | ||
image_size = (1024, 1024) | ||
|
||
model = dict( | ||
backbone=dict( | ||
_delete_=True, | ||
type='mmcls.ConvNeXt', | ||
arch='base', | ||
out_indices=[0, 1, 2, 3], | ||
# TODO: verify stochastic depth rate {0.1, 0.2, 0.3, 0.4} | ||
drop_path_rate=0.4, | ||
layer_scale_init_value=0., # disable layer scale when using GRN | ||
gap_before_final_norm=False, | ||
use_grn=True, # V2 uses GRN | ||
init_cfg=dict( | ||
type='Pretrained', checkpoint=checkpoint_file, | ||
prefix='backbone.')), | ||
neck=dict(in_channels=[128, 256, 512, 1024]), | ||
test_cfg=dict( | ||
rpn=dict(nms=dict(type='nms')), # TODO: does RPN use soft_nms? | ||
rcnn=dict(nms=dict(type='soft_nms')))) | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), | ||
dict(type='LoadAnnotations', with_bbox=True, with_mask=True), | ||
dict( | ||
type='RandomResize', | ||
scale=image_size, | ||
ratio_range=(0.1, 2.0), | ||
keep_ratio=True), | ||
dict( | ||
type='RandomCrop', | ||
crop_type='absolute_range', | ||
crop_size=image_size, | ||
recompute_bbox=True, | ||
allow_negative_crop=True), | ||
dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PackDetInputs') | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=4, # total_batch_size 32 = 8 GPUS x 4 images | ||
num_workers=8, | ||
dataset=dict(pipeline=train_pipeline)) | ||
|
||
max_epochs = 36 | ||
train_cfg = dict(max_epochs=max_epochs) | ||
|
||
# learning rate | ||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, | ||
end=1000), | ||
dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=max_epochs, | ||
by_epoch=True, | ||
milestones=[27, 33], | ||
gamma=0.1) | ||
] | ||
|
||
# Enable automatic-mixed-precision training with AmpOptimWrapper. | ||
optim_wrapper = dict( | ||
type='AmpOptimWrapper', | ||
constructor='LearningRateDecayOptimizerConstructor', | ||
paramwise_cfg={ | ||
'decay_rate': 0.95, | ||
'decay_type': 'layer_wise', # TODO: sweep layer-wise lr decay? | ||
'num_layers': 12 | ||
}, | ||
optimizer=dict( | ||
_delete_=True, | ||
type='AdamW', | ||
lr=0.0001, | ||
betas=(0.9, 0.999), | ||
weight_decay=0.05, | ||
)) | ||
|
||
default_hooks = dict(checkpoint=dict(max_keep_ckpts=1)) |