Skip to content

Commit

Permalink
[Feature] Add segformer decode head and related train config (open-mm…
Browse files Browse the repository at this point in the history
…lab#599)

* [Feature]Segformer re-implementation

* Using act_cfg and norm_cfg to control activation and normalization

* Split this PR into several little PRs

* Fix lint error

* Remove SegFormerHead

* [Feature] Add segformer decode head and related train config

* Add ade20K trainval support for segformer

1. Add related train and val configs;

2. Add AlignedResize;

* Set arg: find_unused_parameters = True

* parameters init refactor

* 1. Refactor segformer backbone parameters init;

2. Remove rebundant functions and unit tests;

* Remove rebundant codes

* Replace Linear Layer to 1X1 Conv

* Use nn.ModuleList to refactor segformer head.

* Remove local to_xtuple

* 1. Remove rebundant codes;

2. Modify module name;

* Refactor the backbone of segformer using mmcv.cnn.bricks.transformer.py

* Fix some code logic bugs.

* Add mit_convert.py to match pretrain keys of segformer.

* Resolve some comments.

* 1. Add some assert to ensure right params;

2. Support flexible peconv position;

* Add pe_index assert and fix unit test.

* 1. Add doc string for MixVisionTransformer;

2. Add some unit tests for MixVisionTransformer;

* Use hw_shape to pass shape of feature map.

* 1. Fix doc string of MixVisionTransformer;

2. Simplify MixFFN;

3. Modify H, W to hw_shape;

* Add more unit tests.

* Add doc string for shape convertion functions.

* Add some unit tests to improve code coverage.

* Fix Segformer backbone pretrain weights match bug.

* Modify configs of segformer.

* resolve the shape convertion functions doc string.

* Add pad_to_patch_size arg.

* Support progressive test with fewer memory cost.

* Modify default value of pad_to_patch_size arg.

* Temp code

* Using processor to refactor evaluation workflow.

* refactor eval hook.

* Fix process bar.

* Fix middle save argument.

* Modify some variable name of dataset evaluate api.

* Modify some viriable name of eval hook.

* Fix some priority bugs of eval hook.

* Fix some bugs about model loading and eval hook.

* Add ade20k 640x640 dataset.

* Fix related segformer configs.

* Depreciated efficient_test.

* Fix training progress blocked by eval hook.

* Depreciated old test api.

* Modify error patch size.

* Fix pretrain of mit_b0

* Fix the test api error.

* Modify dataset base config.

* Fix test api error.

* Modify outer api.

* Build a sampler test api.

* TODO: Refactor format_results.

* Modify variable names.

* Fix num_classes bug.

* Fix sampler index bug.

* Fix grammaly bug.

* Add part of benchmark results.

* Support batch sampler.

* More readable test api.

* Remove some command arg and fix eval hook bug.

* Support format-only arg.

* Modify format_results of datasets.

* Modify tool which use test apis.

* Update readme.

* Update readme of segformer.

* Updata readme of segformer.

* Update segformer readme and fix segformer mit_b4.

* Update readme of segformer.

* Clean AlignedResize related config.

* Clean code from pr open-mmlab#709

* Clean code from pr open-mmlab#709

* Add 512x512 segformer_mit-b5.

* Fix lint.

* Fix some segformer head bugs.

* Add segformer unit tests.

* Replace AlignedResize to ResizeToMultiple.

* Modify readme of segformer.

* Fix bug of ResizeToMultiple.

* Add ResizeToMultiple unit tests.

* Resolve conflict.

* Simplify the implementation of ResizeToMultiple.

* Update test results.

* Fix multi-scale test error when resize_ratio=1.75 and input size=640x640.

* Update segformer results.

* Update Segformer results.

* Fix some url bugs and pipelines bug.

* Move ckpt convertion to tools.

* Add segformer official pretrain weights usage.

* Clean redundant codes.

* Remove redundant codes.

* Unfied format.

* Add description for segformer converter.

* Update workers.
clownrat6 authored Aug 13, 2021
1 parent 0abacd8 commit b4fd32d
Showing 18 changed files with 494 additions and 62 deletions.
34 changes: 34 additions & 0 deletions configs/_base_/models/segformer_mit-b0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='MixVisionTransformer',
in_channels=3,
embed_dims=32,
num_stages=4,
num_layers=[2, 2, 2, 2],
num_heads=[1, 2, 5, 8],
patch_sizes=[7, 3, 3, 3],
sr_ratios=[8, 4, 2, 1],
out_indices=(0, 1, 2, 3),
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1),
decode_head=dict(
type='SegformerHead',
in_channels=[32, 64, 160, 256],
in_index=[0, 1, 2, 3],
channels=256,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
73 changes: 73 additions & 0 deletions configs/segformer/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers

## Introduction

<!-- [ALGORITHM] -->

```latex
@article{xie2021segformer,
title={SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers},
author={Xie, Enze and Wang, Wenhai and Yu, Zhiding and Anandkumar, Anima and Alvarez, Jose M and Luo, Ping},
journal={arXiv preprint arXiv:2105.15203},
year={2021}
}
```

## Results and models

### ADE20k

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ------: | -------: | -------------- | ---: | ------------- | ------ | -------- |
|Segformer | MIT-B0 | 512x512 | 160000 | 2.1 | 51.32 | 37.41 | 38.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b0_512x512_160k_ade20k/segformer_mit-b0_512x512_160k_ade20k_20210726_101530-8ffa8fda.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b0_512x512_160k_ade20k/segformer_mit-b0_512x512_160k_ade20k_20210726_101530.log.json) |
|Segformer | MIT-B1 | 512x512 | 160000 | 2.6 | 47.66 | 40.97 | 42.54 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b1_512x512_160k_ade20k/segformer_mit-b1_512x512_160k_ade20k_20210726_112106-d70e859d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b1_512x512_160k_ade20k/segformer_mit-b1_512x512_160k_ade20k_20210726_112106.log.json) |
|Segformer | MIT-B2 | 512x512 | 160000 | 3.6 | 30.88 | 45.58 | 47.03 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b2_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b2_512x512_160k_ade20k/segformer_mit-b2_512x512_160k_ade20k_20210726_112103-cbd414ac.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b2_512x512_160k_ade20k/segformer_mit-b2_512x512_160k_ade20k_20210726_112103.log.json) |
|Segformer | MIT-B3 | 512x512 | 160000 | 4.8 | 22.11 | 47.82 | 48.81 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b3_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b3_512x512_160k_ade20k/segformer_mit-b3_512x512_160k_ade20k_20210726_081410-962b98d2.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b3_512x512_160k_ade20k/segformer_mit-b3_512x512_160k_ade20k_20210726_081410.log.json) |
|Segformer | MIT-B4 | 512x512 | 160000 | 6.1 | 15.45 | 48.46 | 49.76 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b4_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b4_512x512_160k_ade20k/segformer_mit-b4_512x512_160k_ade20k_20210728_183055-7f509d7d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b4_512x512_160k_ade20k/segformer_mit-b4_512x512_160k_ade20k_20210728_183055.log.json) |
|Segformer | MIT-B5 | 512x512 | 160000 | 7.2 | 11.89 | 49.13 | 50.22 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b5_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_512x512_160k_ade20k/segformer_mit-b5_512x512_160k_ade20k_20210726_145235-94cedf59.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_512x512_160k_ade20k/segformer_mit-b5_512x512_160k_ade20k_20210726_145235.log.json) |
|Segformer | MIT-B5 | 640x640 | 160000 | 11.5 | 11.30 | 49.62 | 50.36 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b5_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_640x640_160k_ade20k/segformer_mit-b5_640x640_160k_ade20k_20210801_121243-41d2845b.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_640x640_160k_ade20k/segformer_mit-b5_640x640_160k_ade20k_20210801_121243.log.json) |

Evaluation with AlignedResize:

| Method | Backbone | Crop Size | Lr schd | mIoU | mIoU(ms+flip) |
| ------ | -------- | --------- | ------: | ---: | ------------- |
|Segformer | MIT-B0 | 512x512 | 160000 | 38.1 | 38.57 |
|Segformer | MIT-B1 | 512x512 | 160000 | 41.64 | 42.76 |
|Segformer | MIT-B2 | 512x512 | 160000 | 46.53 | 47.49 |
|Segformer | MIT-B3 | 512x512 | 160000 | 48.46 | 49.14 |
|Segformer | MIT-B4 | 512x512 | 160000 | 49.34 | 50.29 |
|Segformer | MIT-B5 | 512x512 | 160000 | 50.08 | 50.72 |
|Segformer | MIT-B5 | 640x640 | 160000 | 50.58 | 50.8 |

We replace `AlignedResize` in original implementatiuon to `Resize + ResizeToMultiple`. If you want to test by
using `AlignedResize`, you can change the dataset pipeline like this:

```python
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
# resize image to multiple of 32, improve SegFormer by 0.5-1.0 mIoU.
dict(type='ResizeToMultiple', size_divisor=32),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
```

## How to use segformer official pretrain weights

We convert the backbone weights from the official repo (https://github.com/NVlabs/SegFormer) with `tools/model_converters/mit_convert.py`.

You may follow below steps to start segformer training preparation:

1. Download segformer pretrain weights (Suggest put in `pretrain/`);
2. Run convert script to convert official pretrain weights: `python tools/model_converters/mit_convert.py pretrain/mit_b0.pth pretrain/mit_b0.pth`;
3. Modify `pretrained` of segformer model config, for example, `pretrained` of `segformer_mit-b0_512x512_160k_ade20k.py` is set to `pretrain/mit_b0.pth`;
33 changes: 33 additions & 0 deletions configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
_base_ = [
'../_base_/models/segformer_mit-b0.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]

model = dict(
pretrained='pretrain/mit_b0.pth', decode_head=dict(num_classes=150))

# optimizer
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=10.)
}))

lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)

data = dict(samples_per_gpu=2, workers_per_gpu=2)
8 changes: 8 additions & 0 deletions configs/segformer/segformer_mit-b1_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']

# model settings
model = dict(
pretrained='pretrain/mit_b1.pth',
backbone=dict(
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[2, 2, 2, 2]),
decode_head=dict(in_channels=[64, 128, 320, 512]))
8 changes: 8 additions & 0 deletions configs/segformer/segformer_mit-b2_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']

# model settings
model = dict(
pretrained='pretrain/mit_b2.pth',
backbone=dict(
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 4, 6, 3]),
decode_head=dict(in_channels=[64, 128, 320, 512]))
8 changes: 8 additions & 0 deletions configs/segformer/segformer_mit-b3_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']

# model settings
model = dict(
pretrained='pretrain/mit_b3.pth',
backbone=dict(
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 4, 18, 3]),
decode_head=dict(in_channels=[64, 128, 320, 512]))
8 changes: 8 additions & 0 deletions configs/segformer/segformer_mit-b4_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']

# model settings
model = dict(
pretrained='pretrain/mit_b4.pth',
backbone=dict(
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 8, 27, 3]),
decode_head=dict(in_channels=[64, 128, 320, 512]))
8 changes: 8 additions & 0 deletions configs/segformer/segformer_mit-b5_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']

# model settings
model = dict(
pretrained='pretrain/mit_b5.pth',
backbone=dict(
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 6, 40, 3]),
decode_head=dict(in_channels=[64, 128, 320, 512]))
44 changes: 44 additions & 0 deletions configs/segformer/segformer_mit-b5_640x640_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py']

# dataset settings
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (640, 640)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 640),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

# model settings
model = dict(
pretrained='pretrain/mit_b5.pth',
backbone=dict(
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 6, 40, 3]),
decode_head=dict(in_channels=[64, 128, 320, 512]))
57 changes: 57 additions & 0 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,63 @@
from ..builder import PIPELINES


@PIPELINES.register_module()
class ResizeToMultiple(object):
"""Resize images & seg to multiple of divisor.
Args:
size_divisor (int): images and gt seg maps need to resize to multiple
of size_divisor. Default: 32.
interpolation (str, optional): The interpolation mode of image resize.
Default: None
"""

def __init__(self, size_divisor=32, interpolation=None):
self.size_divisor = size_divisor
self.interpolation = interpolation

def __call__(self, results):
"""Call function to resize images, semantic segmentation map to
multiple of size divisor.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img_shape', 'pad_shape' keys are updated.
"""
# Align image to multiple of size divisor.
img = results['img']
img = mmcv.imresize_to_multiple(
img,
self.size_divisor,
scale_factor=1,
interpolation=self.interpolation
if self.interpolation else 'bilinear')

results['img'] = img
results['img_shape'] = img.shape
results['pad_shape'] = img.shape

# Align segmentation map to multiple of size divisor.
for key in results.get('seg_fields', []):
gt_seg = results[key]
gt_seg = mmcv.imresize_to_multiple(
gt_seg,
self.size_divisor,
scale_factor=1,
interpolation='nearest')
results[key] = gt_seg

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(size_divisor={self.size_divisor}, '
f'interpolation={self.interpolation})')
return repr_str


@PIPELINES.register_module()
class Resize(object):
"""Resize images & seg.
18 changes: 8 additions & 10 deletions mmseg/models/backbones/mit.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@

from ...utils import get_root_logger
from ..builder import BACKBONES
from ..utils import PatchEmbed, mit_convert, nchw_to_nlc, nlc_to_nchw
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw


class MixFFN(BaseModule):
@@ -159,7 +159,13 @@ def forward(self, x, hw_shape, identity=None):
if identity is None:
identity = x_q

out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
# `need_weights=True` will let nn.MultiHeadAttention
# `return attn_output, attn_output_weights.sum(dim=1) / num_heads`
# The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set
# `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`.
# This issue - `https://github.com/pytorch/pytorch/issues/37583` report
# the error that large scale tensor sum operation may cause cuda error.
out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0]

return identity + self.dropout_layer(self.proj_drop(out))

@@ -387,17 +393,9 @@ def init_weights(self):
self.pretrained, logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint

if self.pretrain_style == 'official':
# Because segformer backbone is not support by mmcls,
# so we need to convert pretrain weights to match this
# implementation.
state_dict = mit_convert(state_dict)

self.load_state_dict(state_dict, False)

def forward(self, x):
4 changes: 3 additions & 1 deletion mmseg/models/decode_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
from .point_head import PointHead
from .psa_head import PSAHead
from .psp_head import PSPHead
from .segformer_head import SegformerHead
from .sep_aspp_head import DepthwiseSeparableASPPHead
from .sep_fcn_head import DepthwiseSeparableFCNHead
from .setr_mla_head import SETRMLAHead
@@ -26,5 +27,6 @@
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 'SETRMLAHead'
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'SegformerHead'
]
65 changes: 65 additions & 0 deletions mmseg/models/decode_heads/segformer_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule

from mmseg.models.builder import HEADS
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.ops import resize


@HEADS.register_module()
class SegformerHead(BaseDecodeHead):
"""The all mlp Head of segformer.
This head is the implementation of
`Segformer <https://arxiv.org/abs/2105.15203>` _.
Args:
interpolate_mode: The interpolate mode of MLP head upsample operation.
Default: 'bilinear'.
"""

def __init__(self, interpolate_mode='bilinear', **kwargs):
super().__init__(input_transform='multiple_select', **kwargs)

self.interpolate_mode = interpolate_mode
num_inputs = len(self.in_channels)

assert num_inputs == len(self.in_index)

self.convs = nn.ModuleList()
for i in range(num_inputs):
self.convs.append(
ConvModule(
in_channels=self.in_channels[i],
out_channels=self.channels,
kernel_size=1,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))

self.fusion_conv = ConvModule(
in_channels=self.channels * num_inputs,
out_channels=self.channels,
kernel_size=1,
norm_cfg=self.norm_cfg)

def forward(self, inputs):
# Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
inputs = self._transform_inputs(inputs)
outs = []
for idx in range(len(inputs)):
x = inputs[idx]
conv = self.convs[idx]
outs.append(
resize(
input=conv(x),
size=inputs[0].shape[2:],
mode=self.interpolate_mode,
align_corners=self.align_corners))

out = self.fusion_conv(torch.cat(outs, dim=1))

out = self.cls_seg(out)

return out
4 changes: 2 additions & 2 deletions mmseg/models/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .ckpt_convert import mit_convert, swin_convert, vit_convert
from .ckpt_convert import swin_convert, vit_convert
from .embed import PatchEmbed
from .inverted_residual import InvertedResidual, InvertedResidualV3
from .make_divisible import make_divisible
@@ -11,5 +11,5 @@
__all__ = [
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert',
'mit_convert', 'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw'
'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw'
]
49 changes: 0 additions & 49 deletions mmseg/models/utils/ckpt_convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from collections import OrderedDict

import torch


def swin_convert(ckpt):
new_ckpt = OrderedDict()
@@ -90,50 +88,3 @@ def vit_convert(ckpt):
new_ckpt[new_k] = v

return new_ckpt


def mit_convert(ckpt):
new_ckpt = OrderedDict()
# Process the concat between q linear weights and kv linear weights
for k, v in ckpt.items():
if k.startswith('head'):
continue
elif k.startswith('patch_embed'):
stage_i = int(k.split('.')[0].replace('patch_embed', ''))
new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0')
new_v = v
if 'proj.' in new_k:
new_k = new_k.replace('proj.', 'projection.')
elif k.startswith('block'):
stage_i = int(k.split('.')[0].replace('block', ''))
new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1')
new_v = v
if 'attn.q.' in new_k:
sub_item_k = k.replace('q.', 'kv.')
new_k = new_k.replace('q.', 'attn.in_proj_')
new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
elif 'attn.kv.' in new_k:
continue
elif 'attn.proj.' in new_k:
new_k = new_k.replace('proj.', 'attn.out_proj.')
elif 'attn.sr.' in new_k:
new_k = new_k.replace('sr.', 'sr.')
elif 'mlp.' in new_k:
string = f'{new_k}-'
new_k = new_k.replace('mlp.', 'ffn.layers.')
if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
new_v = v.reshape((*v.shape, 1, 1))
new_k = new_k.replace('fc1.', '0.')
new_k = new_k.replace('dwconv.dwconv.', '1.')
new_k = new_k.replace('fc2.', '4.')
string += f'{new_k} {v.shape}-{new_v.shape}'
# print(string)
elif k.startswith('norm'):
stage_i = int(k.split('.')[0].replace('norm', ''))
new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2')
new_v = v
else:
new_k = k
new_v = v
new_ckpt[new_k] = new_v
return new_ckpt
20 changes: 20 additions & 0 deletions tests/test_data/test_transform.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,26 @@
from mmseg.datasets.builder import PIPELINES


def test_resize_to_multiple():
transform = dict(type='ResizeToMultiple', size_divisor=32)
transform = build_from_cfg(transform, PIPELINES)

img = np.random.randn(213, 232, 3)
seg = np.random.randint(0, 19, (213, 232))
results = dict()
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['pad_shape'] = img.shape

results = transform(results)
assert results['img'].shape == (224, 256, 3)
assert results['gt_semantic_seg'].shape == (224, 256)
assert results['img_shape'] == (224, 256, 3)
assert results['pad_shape'] == (224, 256, 3)


def test_resize():
# test assertion if img_scale is a list
with pytest.raises(AssertionError):
39 changes: 39 additions & 0 deletions tests/test_models/test_heads/test_segformer_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import torch

from mmseg.models.decode_heads import SegformerHead


def test_segformer_head():
with pytest.raises(AssertionError):
# `in_channels` must have same length as `in_index`
SegformerHead(
in_channels=(1, 2, 3), in_index=(0, 1), channels=5, num_classes=2)

H, W = (64, 64)
in_channels = (32, 64, 160, 256)
shapes = [(H // 2**(i + 2), W // 2**(i + 2))
for i in range(len(in_channels))]
model = SegformerHead(
in_channels=in_channels,
in_index=[0, 1, 2, 3],
channels=256,
num_classes=19)

with pytest.raises(IndexError):
# in_index must match the input feature maps.
inputs = [
torch.randn((1, in_channel, *shape))
for in_channel, shape in zip(in_channels, shapes)
][:3]
temp = model(inputs)

# Normal Input
# ((1, 32, 16, 16), (1, 64, 8, 8), (1, 160, 4, 4), (1, 256, 2, 2)
inputs = [
torch.randn((1, in_channel, *shape))
for in_channel, shape in zip(in_channels, shapes)
]
temp = model(inputs)

assert temp.shape == (1, 19, H // 4, W // 4)
76 changes: 76 additions & 0 deletions tools/model_converters/mit_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import argparse
from collections import OrderedDict

import torch


def mit_convert(ckpt):
new_ckpt = OrderedDict()
# Process the concat between q linear weights and kv linear weights
for k, v in ckpt.items():
if k.startswith('head'):
continue
# patch embedding convertion
elif k.startswith('patch_embed'):
stage_i = int(k.split('.')[0].replace('patch_embed', ''))
new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0')
new_v = v
if 'proj.' in new_k:
new_k = new_k.replace('proj.', 'projection.')
# transformer encoder layer convertion
elif k.startswith('block'):
stage_i = int(k.split('.')[0].replace('block', ''))
new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1')
new_v = v
if 'attn.q.' in new_k:
sub_item_k = k.replace('q.', 'kv.')
new_k = new_k.replace('q.', 'attn.in_proj_')
new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
elif 'attn.kv.' in new_k:
continue
elif 'attn.proj.' in new_k:
new_k = new_k.replace('proj.', 'attn.out_proj.')
elif 'attn.sr.' in new_k:
new_k = new_k.replace('sr.', 'sr.')
elif 'mlp.' in new_k:
string = f'{new_k}-'
new_k = new_k.replace('mlp.', 'ffn.layers.')
if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
new_v = v.reshape((*v.shape, 1, 1))
new_k = new_k.replace('fc1.', '0.')
new_k = new_k.replace('dwconv.dwconv.', '1.')
new_k = new_k.replace('fc2.', '4.')
string += f'{new_k} {v.shape}-{new_v.shape}'
# norm layer convertion
elif k.startswith('norm'):
stage_i = int(k.split('.')[0].replace('norm', ''))
new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2')
new_v = v
else:
new_k = k
new_v = v
new_ckpt[new_k] = new_v
return new_ckpt


def parse_args():
parser = argparse.ArgumentParser(
'Convert official segformer backbone weights to mmseg style.')
parser.add_argument(
'src', help='Source path of official segformer backbone weights.')
parser.add_argument(
'dst',
help='Destination path of converted segformer backbone weights.')

return parser.parse_args()


if __name__ == '__main__':
args = parse_args()
src_path = args.src
dst_path = args.dst

ckpt = torch.load(src_path, map_location='cpu')

ckpt = mit_convert(ckpt)
torch.save(ckpt, dst_path)

0 comments on commit b4fd32d

Please sign in to comment.