diff --git a/README.md b/README.md
index 6ffc91ebc7..0b73beda53 100644
--- a/README.md
+++ b/README.md
@@ -118,6 +118,7 @@ Supported methods:
- [x] [STDC (CVPR'2021)](configs/stdc)
- [x] [SETR (CVPR'2021)](configs/setr)
- [x] [DPT (ArXiv'2021)](configs/dpt)
+- [x] [Segmenter (ICCV'2021)](configs/segmenter)
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
Supported datasets:
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 0b0503f984..ebcdd45047 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -117,6 +117,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [STDC (CVPR'2021)](configs/stdc)
- [x] [SETR (CVPR'2021)](configs/setr)
- [x] [DPT (ArXiv'2021)](configs/dpt)
+- [x] [Segmenter (ICCV'2021)](configs/segmenter)
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
已支持的数据集:
diff --git a/configs/_base_/models/segmenter_vit-b16_mask.py b/configs/_base_/models/segmenter_vit-b16_mask.py
new file mode 100644
index 0000000000..967a65c200
--- /dev/null
+++ b/configs/_base_/models/segmenter_vit-b16_mask.py
@@ -0,0 +1,35 @@
+# model settings
+backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='pretrain/vit_base_p16_384.pth',
+ backbone=dict(
+ type='VisionTransformer',
+ img_size=(512, 512),
+ patch_size=16,
+ in_channels=3,
+ embed_dims=768,
+ num_layers=12,
+ num_heads=12,
+ drop_path_rate=0.1,
+ attn_drop_rate=0.0,
+ drop_rate=0.0,
+ final_norm=True,
+ norm_cfg=backbone_norm_cfg,
+ with_cls_token=True,
+ interpolate_mode='bicubic',
+ ),
+ decode_head=dict(
+ type='SegmenterMaskTransformerHead',
+ in_channels=768,
+ channels=768,
+ num_classes=150,
+ num_layers=2,
+ num_heads=12,
+ embed_dims=768,
+ dropout_ratio=0.0,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+ ),
+ test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(480, 480)),
+)
diff --git a/configs/segmenter/README.md b/configs/segmenter/README.md
new file mode 100644
index 0000000000..b073e88ceb
--- /dev/null
+++ b/configs/segmenter/README.md
@@ -0,0 +1,73 @@
+# Segmenter
+
+[Segmenter: Transformer for Semantic Segmentation](https://arxiv.org/abs/2105.05633)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+Image segmentation is often ambiguous at the level of individual image patches and requires contextual information to reach label consensus. In this paper we introduce Segmenter, a transformer model for semantic segmentation. In contrast to convolution-based methods, our approach allows to model global context already at the first layer and throughout the network. We build on the recent Vision Transformer (ViT) and extend it to semantic segmentation. To do so, we rely on the output embeddings corresponding to image patches and obtain class labels from these embeddings with a point-wise linear decoder or a mask transformer decoder. We leverage models pre-trained for image classification and show that we can fine-tune them on moderate sized datasets available for semantic segmentation. The linear decoder allows to obtain excellent results already, but the performance can be further improved by a mask transformer generating class masks. We conduct an extensive ablation study to show the impact of the different parameters, in particular the performance is better for large models and small patch sizes. Segmenter attains excellent results for semantic segmentation. It outperforms the state of the art on both ADE20K and Pascal Context datasets and is competitive on Cityscapes.
+
+
+
+
![](https://user-images.githubusercontent.com/24582831/148507554-87eb80bd-02c7-4c31-b102-c6141e231ec8.png)
+
+
+```bibtex
+@article{strudel2021Segmenter,
+ title={Segmenter: Transformer for Semantic Segmentation},
+ author={Strudel, Robin and Ricardo, Garcia, and Laptev, Ivan and Schmid, Cordelia},
+ journal={arXiv preprint arXiv:2105.05633},
+ year={2021}
+}
+```
+
+
+## Usage
+
+To use the pre-trained ViT model from [Segmenter](https://github.com/rstrudel/segmenter), it is necessary to convert keys.
+
+We provide a script [`vitjax2mmseg.py`](../../tools/model_converters/vitjax2mmseg.py) in the tools directory to convert the key of models from [ViT-AugReg](https://github.com/rwightman/pytorch-image-models/blob/f55c22bebf9d8afc449d317a723231ef72e0d662/timm/models/vision_transformer.py#L54-L106) to MMSegmentation style.
+
+```shell
+python tools/model_converters/vitjax2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
+```
+
+E.g.
+
+```shell
+python tools/model_converters/vitjax2mmseg.py \
+Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz \
+pretrain/vit_tiny_p16_384.pth
+```
+
+This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
+
+In our default setting, pretrained models and their corresponding [ViT-AugReg](https://github.com/rwightman/pytorch-image-models/blob/f55c22bebf9d8afc449d317a723231ef72e0d662/timm/models/vision_transformer.py#L54-L106) models could be defined below:
+
+ | pretrained models | original models |
+ | ------ | -------- |
+ |vit_tiny_p16_384.pth | ['vit_tiny_patch16_384'](https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz) |
+ |vit_small_p16_384.pth | ['vit_small_patch16_384'](https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz) |
+ |vit_base_p16_384.pth | ['vit_base_patch16_384'](https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz) |
+ |vit_large_p16_384.pth | ['vit_large_patch16_384'](https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz) |
+
+## Results and models
+
+### ADE20K
+
+| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
+| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- |
+| Segmenter-Mask | ViT-T_16 | 512x512 | 160000 | 1.21 | 27.98 | 39.99 | 40.83 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k/segmenter_vit-t_mask_8x1_512x512_160k_ade20k_20220105_151706-ffcf7509.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k/segmenter_vit-t_mask_8x1_512x512_160k_ade20k_20220105_151706.log.json) |
+| Segmenter-Linear | ViT-S_16 | 512x512 | 160000 | 1.78 | 28.07 | 45.75 | 46.82 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k/segmenter_vit-s_linear_8x1_512x512_160k_ade20k_20220105_151713-39658c46.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k/segmenter_vit-s_linear_8x1_512x512_160k_ade20k_20220105_151713.log.json) |
+| Segmenter-Mask | ViT-S_16 | 512x512 | 160000 | 2.03 | 24.80 | 46.19 | 47.85 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k/segmenter_vit-s_mask_8x1_512x512_160k_ade20k_20220105_151706-511bb103.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k/segmenter_vit-s_mask_8x1_512x512_160k_ade20k_20220105_151706.log.json) |
+| Segmenter-Mask | ViT-B_16 |512x512 | 160000 | 4.20 | 13.20 | 49.60 | 51.07 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k/segmenter_vit-b_mask_8x1_512x512_160k_ade20k_20220105_151706-bc533b08.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k/segmenter_vit-b_mask_8x1_512x512_160k_ade20k_20220105_151706.log.json) |
+| Segmenter-Mask | ViT-L_16 |640x640 | 160000 | 16.56 | 2.62 | 52.16 | 53.65 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k/segmenter_vit-l_mask_8x1_512x512_160k_ade20k_20220105_162750-7ef345be.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k/segmenter_vit-l_mask_8x1_512x512_160k_ade20k_20220105_162750.log.json) |
diff --git a/configs/segmenter/segmenter.yml b/configs/segmenter/segmenter.yml
new file mode 100644
index 0000000000..67cec8932d
--- /dev/null
+++ b/configs/segmenter/segmenter.yml
@@ -0,0 +1,125 @@
+Collections:
+- Name: segmenter
+ Metadata:
+ Training Data:
+ - ADE20K
+ Paper:
+ URL: https://arxiv.org/abs/2105.05633
+ Title: 'Segmenter: Transformer for Semantic Segmentation'
+ README: configs/segmenter/README.md
+ Code:
+ URL: https://github.com/open-mmlab/mmsegmentation/blob/v0.21.0/mmseg/models/decode_heads/segmenter_mask_head.py#L15
+ Version: v0.21.0
+ Converted From:
+ Code: https://github.com/rstrudel/segmenter
+Models:
+- Name: segmenter_vit-t_mask_8x1_512x512_160k_ade20k
+ In Collection: segmenter
+ Metadata:
+ backbone: ViT-T_16
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 35.74
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 1.21
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 39.99
+ mIoU(ms+flip): 40.83
+ Config: configs/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k/segmenter_vit-t_mask_8x1_512x512_160k_ade20k_20220105_151706-ffcf7509.pth
+- Name: segmenter_vit-s_linear_8x1_512x512_160k_ade20k
+ In Collection: segmenter
+ Metadata:
+ backbone: ViT-S_16
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 35.63
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 1.78
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 45.75
+ mIoU(ms+flip): 46.82
+ Config: configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k/segmenter_vit-s_linear_8x1_512x512_160k_ade20k_20220105_151713-39658c46.pth
+- Name: segmenter_vit-s_mask_8x1_512x512_160k_ade20k
+ In Collection: segmenter
+ Metadata:
+ backbone: ViT-S_16
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 40.32
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 2.03
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 46.19
+ mIoU(ms+flip): 47.85
+ Config: configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k/segmenter_vit-s_mask_8x1_512x512_160k_ade20k_20220105_151706-511bb103.pth
+- Name: segmenter_vit-b_mask_8x1_512x512_160k_ade20k
+ In Collection: segmenter
+ Metadata:
+ backbone: ViT-B_16
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 75.76
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 4.2
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 49.6
+ mIoU(ms+flip): 51.07
+ Config: configs/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k/segmenter_vit-b_mask_8x1_512x512_160k_ade20k_20220105_151706-bc533b08.pth
+- Name: segmenter_vit-l_mask_8x1_512x512_160k_ade20k
+ In Collection: segmenter
+ Metadata:
+ backbone: ViT-L_16
+ crop size: (640,640)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 381.68
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (640,640)
+ Training Memory (GB): 16.56
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 52.16
+ mIoU(ms+flip): 53.65
+ Config: configs/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k/segmenter_vit-l_mask_8x1_512x512_160k_ade20k_20220105_162750-7ef345be.pth
diff --git a/configs/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k.py b/configs/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k.py
new file mode 100644
index 0000000000..766a99fbf0
--- /dev/null
+++ b/configs/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k.py
@@ -0,0 +1,43 @@
+_base_ = [
+ '../_base_/models/segmenter_vit-b16_mask.py',
+ '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_160k.py'
+]
+optimizer = dict(lr=0.001, weight_decay=0.0)
+
+img_norm_cfg = dict(
+ mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
+crop_size = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 512),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+data = dict(
+ # num_gpus: 8 -> batch_size: 8
+ samples_per_gpu=1,
+ train=dict(pipeline=train_pipeline),
+ val=dict(pipeline=test_pipeline),
+ test=dict(pipeline=test_pipeline))
diff --git a/configs/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k.py b/configs/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k.py
new file mode 100644
index 0000000000..0ed004e55c
--- /dev/null
+++ b/configs/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k.py
@@ -0,0 +1,60 @@
+_base_ = [
+ '../_base_/models/segmenter_vit-b16_mask.py',
+ '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_160k.py'
+]
+
+model = dict(
+ pretrained='pretrain/vit_large_p16_384.pth',
+ backbone=dict(
+ type='VisionTransformer',
+ img_size=(640, 640),
+ embed_dims=1024,
+ num_layers=24,
+ num_heads=16),
+ decode_head=dict(
+ type='SegmenterMaskTransformerHead',
+ in_channels=1024,
+ channels=1024,
+ num_heads=16,
+ embed_dims=1024),
+ test_cfg=dict(mode='slide', crop_size=(640, 640), stride=(608, 608)))
+
+optimizer = dict(lr=0.001, weight_decay=0.0)
+
+img_norm_cfg = dict(
+ mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
+crop_size = (640, 640)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 640),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+data = dict(
+ # num_gpus: 8 -> batch_size: 8
+ samples_per_gpu=1,
+ train=dict(pipeline=train_pipeline),
+ val=dict(pipeline=test_pipeline),
+ test=dict(pipeline=test_pipeline))
diff --git a/configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py b/configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py
new file mode 100644
index 0000000000..adc8c1b283
--- /dev/null
+++ b/configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py
@@ -0,0 +1,14 @@
+_base_ = './segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py'
+
+model = dict(
+ decode_head=dict(
+ _delete_=True,
+ type='FCNHead',
+ in_channels=384,
+ channels=384,
+ num_convs=0,
+ dropout_ratio=0.0,
+ concat_input=False,
+ num_classes=150,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
diff --git a/configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py b/configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py
new file mode 100644
index 0000000000..8455ebe1da
--- /dev/null
+++ b/configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py
@@ -0,0 +1,64 @@
+_base_ = [
+ '../_base_/models/segmenter_vit-b16_mask.py',
+ '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_160k.py'
+]
+
+backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
+model = dict(
+ pretrained='pretrain/vit_small_p16_384.pth',
+ backbone=dict(
+ img_size=(512, 512),
+ embed_dims=384,
+ num_heads=6,
+ ),
+ decode_head=dict(
+ type='SegmenterMaskTransformerHead',
+ in_channels=384,
+ channels=384,
+ num_classes=150,
+ num_layers=2,
+ num_heads=6,
+ embed_dims=384,
+ dropout_ratio=0.0,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
+
+optimizer = dict(lr=0.001, weight_decay=0.0)
+
+img_norm_cfg = dict(
+ mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
+crop_size = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 512),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+data = dict(
+ # num_gpus: 8 -> batch_size: 8
+ samples_per_gpu=1,
+ train=dict(pipeline=train_pipeline),
+ val=dict(pipeline=test_pipeline),
+ test=dict(pipeline=test_pipeline))
diff --git a/configs/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k.py b/configs/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k.py
new file mode 100644
index 0000000000..c9332fe8e5
--- /dev/null
+++ b/configs/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k.py
@@ -0,0 +1,54 @@
+_base_ = [
+ '../_base_/models/segmenter_vit-b16_mask.py',
+ '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_160k.py'
+]
+
+model = dict(
+ pretrained='pretrain/vit_tiny_p16_384.pth',
+ backbone=dict(embed_dims=192, num_heads=3),
+ decode_head=dict(
+ type='SegmenterMaskTransformerHead',
+ in_channels=192,
+ channels=192,
+ num_heads=3,
+ embed_dims=192))
+
+optimizer = dict(lr=0.001, weight_decay=0.0)
+
+img_norm_cfg = dict(
+ mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
+crop_size = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 512),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+data = dict(
+ # num_gpus: 8 -> batch_size: 8
+ samples_per_gpu=1,
+ train=dict(pipeline=train_pipeline),
+ val=dict(pipeline=test_pipeline),
+ test=dict(pipeline=test_pipeline))
diff --git a/docs/en/changelog.md b/docs/en/changelog.md
index 7a1b4ea3ce..a615a03c6a 100644
--- a/docs/en/changelog.md
+++ b/docs/en/changelog.md
@@ -1,5 +1,6 @@
## Changelog
+
### V0.20.2 (12/15/2021)
**Bug Fixes**
diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py
index b5375a1f5a..dcde813264 100644
--- a/mmseg/models/decode_heads/__init__.py
+++ b/mmseg/models/decode_heads/__init__.py
@@ -20,6 +20,7 @@
from .psa_head import PSAHead
from .psp_head import PSPHead
from .segformer_head import SegformerHead
+from .segmenter_mask_head import SegmenterMaskTransformerHead
from .sep_aspp_head import DepthwiseSeparableASPPHead
from .sep_fcn_head import DepthwiseSeparableFCNHead
from .setr_mla_head import SETRMLAHead
@@ -32,6 +33,6 @@
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
- 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegformerHead', 'ISAHead',
- 'STDCHead'
+ 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
+ 'SegformerHead', 'ISAHead', 'STDCHead'
]
diff --git a/mmseg/models/decode_heads/segmenter_mask_head.py b/mmseg/models/decode_heads/segmenter_mask_head.py
new file mode 100644
index 0000000000..6a9b3d47ec
--- /dev/null
+++ b/mmseg/models/decode_heads/segmenter_mask_head.py
@@ -0,0 +1,133 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
+ trunc_normal_init)
+from mmcv.runner import ModuleList
+
+from mmseg.models.backbones.vit import TransformerEncoderLayer
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+@HEADS.register_module()
+class SegmenterMaskTransformerHead(BaseDecodeHead):
+ """Segmenter: Transformer for Semantic Segmentation.
+
+ This head is the implementation of
+ `Segmenter: `_.
+
+ Args:
+ backbone_cfg:(dict): Config of backbone of
+ Context Path.
+ in_channels (int): The number of channels of input image.
+ num_layers (int): The depth of transformer.
+ num_heads (int): The number of attention heads.
+ embed_dims (int): The number of embedding dimension.
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
+ Default: 4.
+ drop_path_rate (float): stochastic depth rate. Default 0.1.
+ drop_rate (float): Probability of an element to be zeroed.
+ Default 0.0
+ attn_drop_rate (float): The drop out rate for attention layer.
+ Default 0.0
+ num_fcs (int): The number of fully-connected layers for FFNs.
+ Default: 2.
+ qkv_bias (bool): Enable bias for qkv if True. Default: True.
+ act_cfg (dict): The activation config for FFNs.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN')
+ init_std (float): The value of std in weight initialization.
+ Default: 0.02.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ num_layers,
+ num_heads,
+ embed_dims,
+ mlp_ratio=4,
+ drop_path_rate=0.1,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ num_fcs=2,
+ qkv_bias=True,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ init_std=0.02,
+ **kwargs,
+ ):
+ super(SegmenterMaskTransformerHead, self).__init__(
+ in_channels=in_channels, **kwargs)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
+ self.layers = ModuleList()
+ for i in range(num_layers):
+ self.layers.append(
+ TransformerEncoderLayer(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ feedforward_channels=mlp_ratio * embed_dims,
+ attn_drop_rate=attn_drop_rate,
+ drop_rate=drop_rate,
+ drop_path_rate=dpr[i],
+ num_fcs=num_fcs,
+ qkv_bias=qkv_bias,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ batch_first=True,
+ ))
+
+ self.dec_proj = nn.Linear(in_channels, embed_dims)
+
+ self.cls_emb = nn.Parameter(
+ torch.randn(1, self.num_classes, embed_dims))
+ self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False)
+ self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False)
+
+ self.decoder_norm = build_norm_layer(
+ norm_cfg, embed_dims, postfix=1)[1]
+ self.mask_norm = build_norm_layer(
+ norm_cfg, self.num_classes, postfix=2)[1]
+
+ self.init_std = init_std
+
+ delattr(self, 'conv_seg')
+
+ def init_weights(self):
+ trunc_normal_(self.cls_emb, std=self.init_std)
+ trunc_normal_init(self.patch_proj, std=self.init_std)
+ trunc_normal_init(self.classes_proj, std=self.init_std)
+ for n, m in self.named_modules():
+ if isinstance(m, nn.Linear):
+ trunc_normal_init(m, std=self.init_std, bias=0)
+ elif isinstance(m, nn.LayerNorm):
+ constant_init(m, val=1.0, bias=0.0)
+
+ def forward(self, inputs):
+ x = self._transform_inputs(inputs)
+ b, c, h, w = x.shape
+ x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c)
+
+ x = self.dec_proj(x)
+ cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
+ x = torch.cat((x, cls_emb), 1)
+ for layer in self.layers:
+ x = layer(x)
+ x = self.decoder_norm(x)
+
+ patches = self.patch_proj(x[:, :-self.num_classes])
+ cls_seg_feat = self.classes_proj(x[:, -self.num_classes:])
+
+ patches = F.normalize(patches, dim=2, p=2)
+ cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2)
+
+ masks = patches @ cls_seg_feat.transpose(1, 2)
+ masks = self.mask_norm(masks)
+ masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w)
+
+ return masks
diff --git a/model-index.yml b/model-index.yml
index 0c02909fad..1a491d9340 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -30,6 +30,7 @@ Import:
- configs/pspnet/pspnet.yml
- configs/resnest/resnest.yml
- configs/segformer/segformer.yml
+- configs/segmenter/segmenter.yml
- configs/sem_fpn/sem_fpn.yml
- configs/setr/setr.yml
- configs/stdc/stdc.yml
diff --git a/tests/test_models/test_heads/test_segmenter_mask_head.py b/tests/test_models/test_heads/test_segmenter_mask_head.py
new file mode 100644
index 0000000000..7b681ac15c
--- /dev/null
+++ b/tests/test_models/test_heads/test_segmenter_mask_head.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmseg.models.decode_heads import SegmenterMaskTransformerHead
+from .utils import _conv_has_norm, to_cuda
+
+
+def test_segmenter_mask_transformer_head():
+ head = SegmenterMaskTransformerHead(
+ in_channels=2,
+ channels=2,
+ num_classes=150,
+ num_layers=2,
+ num_heads=3,
+ embed_dims=192,
+ dropout_ratio=0.0)
+ assert _conv_has_norm(head, sync_bn=True)
+ head.init_weights()
+
+ inputs = [torch.randn(1, 2, 32, 32)]
+ if torch.cuda.is_available():
+ head, inputs = to_cuda(head, inputs)
+ outputs = head(inputs)
+ assert outputs.shape == (1, head.num_classes, 32, 32)
diff --git a/tools/model_converters/vitjax2mmseg.py b/tools/model_converters/vitjax2mmseg.py
new file mode 100644
index 0000000000..e3a0986ac6
--- /dev/null
+++ b/tools/model_converters/vitjax2mmseg.py
@@ -0,0 +1,122 @@
+import argparse
+import os.path as osp
+
+import mmcv
+import numpy as np
+import torch
+
+
+def vit_jax_to_torch(jax_weights, num_layer=12):
+ torch_weights = dict()
+
+ # patch embedding
+ conv_filters = jax_weights['embedding/kernel']
+ conv_filters = conv_filters.permute(3, 2, 0, 1)
+ torch_weights['patch_embed.projection.weight'] = conv_filters
+ torch_weights['patch_embed.projection.bias'] = jax_weights[
+ 'embedding/bias']
+
+ # pos embedding
+ torch_weights['pos_embed'] = jax_weights[
+ 'Transformer/posembed_input/pos_embedding']
+
+ # cls token
+ torch_weights['cls_token'] = jax_weights['cls']
+
+ # head
+ torch_weights['ln1.weight'] = jax_weights['Transformer/encoder_norm/scale']
+ torch_weights['ln1.bias'] = jax_weights['Transformer/encoder_norm/bias']
+
+ # transformer blocks
+ for i in range(num_layer):
+ jax_block = f'Transformer/encoderblock_{i}'
+ torch_block = f'layers.{i}'
+
+ # attention norm
+ torch_weights[f'{torch_block}.ln1.weight'] = jax_weights[
+ f'{jax_block}/LayerNorm_0/scale']
+ torch_weights[f'{torch_block}.ln1.bias'] = jax_weights[
+ f'{jax_block}/LayerNorm_0/bias']
+
+ # attention
+ query_weight = jax_weights[
+ f'{jax_block}/MultiHeadDotProductAttention_1/query/kernel']
+ query_bias = jax_weights[
+ f'{jax_block}/MultiHeadDotProductAttention_1/query/bias']
+ key_weight = jax_weights[
+ f'{jax_block}/MultiHeadDotProductAttention_1/key/kernel']
+ key_bias = jax_weights[
+ f'{jax_block}/MultiHeadDotProductAttention_1/key/bias']
+ value_weight = jax_weights[
+ f'{jax_block}/MultiHeadDotProductAttention_1/value/kernel']
+ value_bias = jax_weights[
+ f'{jax_block}/MultiHeadDotProductAttention_1/value/bias']
+
+ qkv_weight = torch.from_numpy(
+ np.stack((query_weight, key_weight, value_weight), 1))
+ qkv_weight = torch.flatten(qkv_weight, start_dim=1)
+ qkv_bias = torch.from_numpy(
+ np.stack((query_bias, key_bias, value_bias), 0))
+ qkv_bias = torch.flatten(qkv_bias, start_dim=0)
+
+ torch_weights[f'{torch_block}.attn.attn.in_proj_weight'] = qkv_weight
+ torch_weights[f'{torch_block}.attn.attn.in_proj_bias'] = qkv_bias
+ to_out_weight = jax_weights[
+ f'{jax_block}/MultiHeadDotProductAttention_1/out/kernel']
+ to_out_weight = torch.flatten(to_out_weight, start_dim=0, end_dim=1)
+ torch_weights[
+ f'{torch_block}.attn.attn.out_proj.weight'] = to_out_weight
+ torch_weights[f'{torch_block}.attn.attn.out_proj.bias'] = jax_weights[
+ f'{jax_block}/MultiHeadDotProductAttention_1/out/bias']
+
+ # mlp norm
+ torch_weights[f'{torch_block}.ln2.weight'] = jax_weights[
+ f'{jax_block}/LayerNorm_2/scale']
+ torch_weights[f'{torch_block}.ln2.bias'] = jax_weights[
+ f'{jax_block}/LayerNorm_2/bias']
+
+ # mlp
+ torch_weights[f'{torch_block}.ffn.layers.0.0.weight'] = jax_weights[
+ f'{jax_block}/MlpBlock_3/Dense_0/kernel']
+ torch_weights[f'{torch_block}.ffn.layers.0.0.bias'] = jax_weights[
+ f'{jax_block}/MlpBlock_3/Dense_0/bias']
+ torch_weights[f'{torch_block}.ffn.layers.1.weight'] = jax_weights[
+ f'{jax_block}/MlpBlock_3/Dense_1/kernel']
+ torch_weights[f'{torch_block}.ffn.layers.1.bias'] = jax_weights[
+ f'{jax_block}/MlpBlock_3/Dense_1/bias']
+
+ # transpose weights
+ for k, v in torch_weights.items():
+ if 'weight' in k and 'patch_embed' not in k and 'ln' not in k:
+ v = v.permute(1, 0)
+ torch_weights[k] = v
+
+ return torch_weights
+
+
+def main():
+ # stole refactoring code from Robin Strudel, thanks
+ parser = argparse.ArgumentParser(
+ description='Convert keys from jax official pretrained vit models to '
+ 'MMSegmentation style.')
+ parser.add_argument('src', help='src model path or url')
+ # The dst path must be a full path of the new checkpoint.
+ parser.add_argument('dst', help='save path')
+ args = parser.parse_args()
+
+ jax_weights = np.load(args.src)
+ jax_weights_tensor = {}
+ for key in jax_weights.files:
+ value = torch.from_numpy(jax_weights[key])
+ jax_weights_tensor[key] = value
+ if 'L_16-i21k' in args.src:
+ num_layer = 24
+ else:
+ num_layer = 12
+ torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer)
+ mmcv.mkdir_or_exist(osp.dirname(args.dst))
+ torch.save(torch_weights, args.dst)
+
+
+if __name__ == '__main__':
+ main()