Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement] Add metafile, readme and converted models for Mlp-Mixer #539

Merged
merged 6 commits into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions configs/_base_/datasets/imagenet_bs64_mixer_224.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# dataset settings
dataset_type = 'ImageNet'

# change according to https://github.com/rwightman/pytorch-image-models/blob
# /master/timm/models/mlp_mixer.py
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)

# training is not supported for now
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224, backend='cv2'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize', size=(256, -1), backend='cv2', interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=64,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_prefix='data/imagenet/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline))

evaluation = dict(interval=10, metric='accuracy')
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
type='ImageClassifier',
backbone=dict(
type='MlpMixer',
arch='b',
arch='l',
img_size=224,
patch_size=32,
patch_size=16,
drop_rate=0.1,
init_cfg=[
dict(
Expand All @@ -18,7 +18,7 @@
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
in_channels=1024,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
),
Expand Down
37 changes: 37 additions & 0 deletions configs/mlp_mixer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# MLP-Mixer: An all-MLP Architecture for Vision
<!-- {Mlp-Mixer} -->
<!-- [ALGORITHM] -->

## Abstract
<!-- [ABSTRACT] -->
Convolutional Neural Networks (CNNs) are the go-to model for computer vision. Recently, attention-based networks, such as the Vision Transformer, have also become popular. In this paper we show that while convolutions and attention are both sufficient for good performance, neither of them are necessary. We present MLP-Mixer, an architecture based exclusively on multi-layer perceptrons (MLPs). MLP-Mixer contains two types of layers: one with MLPs applied independently to image patches (i.e. "mixing" the per-location features), and one with MLPs applied across patches (i.e. "mixing" spatial information). When trained on large datasets, or with modern regularization schemes, MLP-Mixer attains competitive scores on image classification benchmarks, with pre-training and inference cost comparable to state-of-the-art models. We hope that these results spark further research beyond the realms of well established CNNs and Transformers.

<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/26739999/143178327-7118b48a-5f5f-4844-a614-a571917384ca.png" width="90%"/>
</div>

## Citation
```latex
@misc{tolstikhin2021mlpmixer,
title={MLP-Mixer: An all-MLP Architecture for Vision},
author={Ilya Tolstikhin and Neil Houlsby and Alexander Kolesnikov and Lucas Beyer and Xiaohua Zhai and Thomas Unterthiner and Jessica Yung and Andreas Steiner and Daniel Keysers and Jakob Uszkoreit and Mario Lucic and Alexey Dosovitskiy},
year={2021},
eprint={2105.01601},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

## Pretrain model

The pre-trained modles are converted from [timm](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/mlp_mixer.py).

### ImageNet-1k

| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:--------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth)|
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth)|

*Models with \* are converted from other repos.*
50 changes: 50 additions & 0 deletions configs/mlp_mixer/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
Collections:
- Name: MLP-Mixer
Metadata:
Training Data: ImageNet-1k
Architecture:
- MLP
- Layer Normalization
- Dropout
Paper:
URL: https://arxiv.org/abs/2105.01601
Title: "MLP-Mixer: An all-MLP Architecture for Vision"
README: configs/mlp_mixer/README.md
# Code:
# URL: # todo
# Version: # todo

Models:
- Name: mlp-mixer-base-p16_3rdparty_64xb64_in1k
In Collection: MLP-Mixer
Config: configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py
Metadata:
FLOPs: 12610000000 # 12.61 G
Parameters: 59880000 # 59.88 M
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 76.68
Top 5 Accuracy: 92.25
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth
Converted From:
Weights: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth
Code: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/mlp_mixer.py#L70

- Name: mlp-mixer-large-p16_3rdparty_64xb64_in1k
In Collection: MLP-Mixer
Config: configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py
Metadata:
FLOPs: 44570000000 # 44.57 G
Parameters: 208200000 # 208.2 M
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 72.34
Top 5 Accuracy: 88.02
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth
Converted From:
Weights: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth
Code: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/mlp_mixer.py#L73
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_base_ = [
'../_base_/models/mlp_mixer_base_patch16.py',
'../_base_/datasets/imagenet_bs64_pil_resize.py',
'../_base_/datasets/imagenet_bs64_mixer_224.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py',
]
6 changes: 6 additions & 0 deletions configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/mlp_mixer_large_patch16.py',
'../_base_/datasets/imagenet_bs64_mixer_224.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py',
]
6 changes: 0 additions & 6 deletions configs/mlp_mixer/mlp_mixer-base_p32_64xb64_in1k.py

This file was deleted.

3 changes: 3 additions & 0 deletions docs/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| T2T-ViT_t-14\* | 21.47 | 4.34 | 81.69 | 95.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-420df0f6.pth) &#124; [log]()|
| T2T-ViT_t-19\* | 39.08 | 7.80 | 82.43 | 96.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-e479c2a6.pth) &#124; [log]()|
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-b5bf2526.pth) &#124; [log]()|
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) &#124; [log]()|
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) &#124; [log]()|


Models with * are converted from other repos, others are trained by ourselves.

Expand Down
14 changes: 9 additions & 5 deletions mmcls/models/backbones/mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@


class MixerBlock(BaseModule):
"""Implements mixer block in MLP Mixer.
"""Mlp-Mixer basic block.

Basic module of `MLP-Mixer: An all-MLP Architecture for Vision
<https://arxiv.org/pdf/2105.01601.pdf>`_

Args:
num_tokens (int): The number of patched tokens
Expand Down Expand Up @@ -96,13 +99,14 @@ def forward(self, x):

@BACKBONES.register_module()
class MlpMixer(BaseBackbone):
"""Mlp Mixer.
"""Mlp-Mixer backbone.

Pytorch implementation of `MLP-Mixer: An all-MLP Architecture for Vision
<https://arxiv.org/pdf/2105.01601.pdf>`_

A PyTorch implement of : `MLP-Mixer: An all-MLP Architecture for Vision` -
https://arxiv.org/abs/2105.01601
Args:
arch (str | dict): MLP Mixer architecture
Default: 'b'.
Defaults to 'b'.
img_size (int | tuple): Input image size.
patch_size (int | tuple): The patch size.
out_indices (Sequence | int): Output from which layer.
Expand Down
1 change: 1 addition & 0 deletions model-index.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ Import:
- configs/tnt/metafile.yml
- configs/vision_transformer/metafile.yml
- configs/t2t_vit/metafile.yml
- configs/mlp_mixer/metafile.yml
57 changes: 57 additions & 0 deletions tools/convert_models/mlpmixer_to_mmcls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import argparse
from pathlib import Path

import torch


def convert_weights(weight):
"""Weight Converter.

Converts the weights from timm to mmcls

Args:
weight (dict): weight dict from timm

Returns: converted weight dict for mmcls
"""
result = dict()
result['meta'] = dict()
temp = dict()
mapping = {
'stem': 'patch_embed',
'proj': 'projection',
'mlp_tokens.fc1': 'token_mix.layers.0.0',
'mlp_tokens.fc2': 'token_mix.layers.1',
'mlp_channels.fc1': 'channel_mix.layers.0.0',
'mlp_channels.fc2': 'channel_mix.layers.1',
'norm1': 'ln1',
'norm2': 'ln2',
'norm.': 'ln1.',
'blocks': 'layers'
}
for k, v in weight.items():
for mk, mv in mapping.items():
if mk in k:
k = k.replace(mk, mv)
if k.startswith('head.'):
temp['head.fc.' + k[5:]] = v
else:
temp['backbone.' + k] = v
result['state_dict'] = temp
return result


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
dst = Path(args.dst)
if dst.suffix != '.pth':
print('The path should contain the name of the pth format file.')
exit()
dst.parent.mkdir(parents=True, exist_ok=True)

original_model = torch.load(args.src, map_location='cpu')
converted_model = convert_weights(original_model)
torch.save(converted_model, args.dst)