From fffa30dd4827395274d1ebd35752f9f3482ea73f Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Fri, 29 Oct 2021 10:37:16 +0800 Subject: [PATCH] [Feature] Add Tokens-to-Token ViT backbone and converted checkpoints. (#467) * add t2t backbone * register t2t_vit * add t2t_vit config * [Temp] Align posterize transform with timm. * Fix lint * Refactor t2t-vit * Add config for t2t-vit * Add metafile and README for t2t-vit * Add unit tests * configs * Update metafile and README * Improve docstring * Fix batch size which should be 8x64 instead of 8x128 * Fix typo * Update model zoo * Update training augments config. * Move some arguments of T2TModule to T2TViT * Update docs. * Update unit test Co-authored-by: HIT-cwh <2892770585@qq.com> --- .../_base_/datasets/imagenet_bs64_t2t_224.py | 71 ++++ configs/_base_/models/t2t-vit-t-14.py | 41 ++ configs/_base_/models/t2t-vit-t-19.py | 41 ++ configs/_base_/models/t2t-vit-t-24.py | 41 ++ configs/t2t_vit/README.md | 33 ++ configs/t2t_vit/metafile.yml | 64 +++ configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py | 31 ++ configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py | 31 ++ configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py | 31 ++ docs/model_zoo.md | 3 + mmcls/datasets/pipelines/auto_augment.py | 4 +- mmcls/models/backbones/__init__.py | 3 +- mmcls/models/backbones/t2t_vit.py | 367 ++++++++++++++++++ model-index.yml | 1 + .../test_backbones/test_t2t_vit.py | 84 ++++ 15 files changed, 844 insertions(+), 2 deletions(-) create mode 100644 configs/_base_/datasets/imagenet_bs64_t2t_224.py create mode 100644 configs/_base_/models/t2t-vit-t-14.py create mode 100644 configs/_base_/models/t2t-vit-t-19.py create mode 100644 configs/_base_/models/t2t-vit-t-24.py create mode 100644 configs/t2t_vit/README.md create mode 100644 configs/t2t_vit/metafile.yml create mode 100644 configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py create mode 100644 configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py create mode 100644 configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py create mode 100644 mmcls/models/backbones/t2t_vit.py create mode 100644 tests/test_models/test_backbones/test_t2t_vit.py diff --git a/configs/_base_/datasets/imagenet_bs64_t2t_224.py b/configs/_base_/datasets/imagenet_bs64_t2t_224.py new file mode 100644 index 00000000000..375775debdc --- /dev/null +++ b/configs/_base_/datasets/imagenet_bs64_t2t_224.py @@ -0,0 +1,71 @@ +_base_ = ['./pipelines/rand_aug.py'] + +# dataset settings +dataset_type = 'ImageNet' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + size=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies={{_base_.rand_increasing_policies}}, + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]], + interpolation='bicubic')), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=img_norm_cfg['mean'][::-1], + fill_std=img_norm_cfg['std'][::-1]), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=['gt_label']), + dict(type='Collect', keys=['img', 'gt_label']) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + size=(248, -1), + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) +] +data = dict( + samples_per_gpu=64, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_prefix='data/imagenet/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_prefix='data/imagenet/val', + ann_file='data/imagenet/meta/val.txt', + pipeline=test_pipeline), + test=dict( + # replace `data/val` with `data/test` for standard test + type=dataset_type, + data_prefix='data/imagenet/val', + ann_file='data/imagenet/meta/val.txt', + pipeline=test_pipeline)) + +evaluation = dict(interval=10, metric='accuracy') diff --git a/configs/_base_/models/t2t-vit-t-14.py b/configs/_base_/models/t2t-vit-t-14.py new file mode 100644 index 00000000000..91dbb67621b --- /dev/null +++ b/configs/_base_/models/t2t-vit-t-14.py @@ -0,0 +1,41 @@ +# model settings +embed_dims = 384 +num_classes = 1000 + +model = dict( + type='ImageClassifier', + backbone=dict( + type='T2T_ViT', + img_size=224, + in_channels=3, + embed_dims=embed_dims, + t2t_cfg=dict( + token_dims=64, + use_performer=False, + ), + num_layers=14, + layer_cfgs=dict( + num_heads=6, + feedforward_channels=3 * embed_dims, # mlp_ratio = 3 + ), + drop_path_rate=0.1, + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ]), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=num_classes, + in_channels=embed_dims, + loss=dict( + type='LabelSmoothLoss', + label_smooth_val=0.1, + mode='original', + ), + topk=(1, 5), + init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)), + train_cfg=dict(augments=[ + dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes), + dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes), + ])) diff --git a/configs/_base_/models/t2t-vit-t-19.py b/configs/_base_/models/t2t-vit-t-19.py new file mode 100644 index 00000000000..8ab139d679d --- /dev/null +++ b/configs/_base_/models/t2t-vit-t-19.py @@ -0,0 +1,41 @@ +# model settings +embed_dims = 448 +num_classes = 1000 + +model = dict( + type='ImageClassifier', + backbone=dict( + type='T2T_ViT', + img_size=224, + in_channels=3, + embed_dims=embed_dims, + t2t_cfg=dict( + token_dims=64, + use_performer=False, + ), + num_layers=19, + layer_cfgs=dict( + num_heads=7, + feedforward_channels=3 * embed_dims, # mlp_ratio = 3 + ), + drop_path_rate=0.1, + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ]), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=num_classes, + in_channels=embed_dims, + loss=dict( + type='LabelSmoothLoss', + label_smooth_val=0.1, + mode='original', + ), + topk=(1, 5), + init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)), + train_cfg=dict(augments=[ + dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes), + dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes), + ])) diff --git a/configs/_base_/models/t2t-vit-t-24.py b/configs/_base_/models/t2t-vit-t-24.py new file mode 100644 index 00000000000..5990960ab4e --- /dev/null +++ b/configs/_base_/models/t2t-vit-t-24.py @@ -0,0 +1,41 @@ +# model settings +embed_dims = 512 +num_classes = 1000 + +model = dict( + type='ImageClassifier', + backbone=dict( + type='T2T_ViT', + img_size=224, + in_channels=3, + embed_dims=embed_dims, + t2t_cfg=dict( + token_dims=64, + use_performer=False, + ), + num_layers=24, + layer_cfgs=dict( + num_heads=8, + feedforward_channels=3 * embed_dims, # mlp_ratio = 3 + ), + drop_path_rate=0.1, + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ]), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=num_classes, + in_channels=embed_dims, + loss=dict( + type='LabelSmoothLoss', + label_smooth_val=0.1, + mode='original', + ), + topk=(1, 5), + init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)), + train_cfg=dict(augments=[ + dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes), + dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes), + ])) diff --git a/configs/t2t_vit/README.md b/configs/t2t_vit/README.md new file mode 100644 index 00000000000..501ef24b975 --- /dev/null +++ b/configs/t2t_vit/README.md @@ -0,0 +1,33 @@ +# Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet + + +## Introduction + + + +```latex +@article{yuan2021tokens, + title={Tokens-to-token vit: Training vision transformers from scratch on imagenet}, + author={Yuan, Li and Chen, Yunpeng and Wang, Tao and Yu, Weihao and Shi, Yujun and Tay, Francis EH and Feng, Jiashi and Yan, Shuicheng}, + journal={arXiv preprint arXiv:2101.11986}, + year={2021} +} +``` + +## Pretrain model + +The pre-trained modles are converted from [official repo](https://github.com/yitu-opensource/T2T-ViT/tree/main#2-t2t-vit-models). + +### ImageNet-1k + +| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download | +|:--------------:|:---------:|:--------:|:---------:|:---------:|:--------:| +| T2T-ViT_t-14\* | 21.47 | 4.34 | 81.69 | 95.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth) | [log]()| +| T2T-ViT_t-19\* | 39.08 | 7.80 | 82.43 | 96.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth) | [log]()| +| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) | [log]()| + +*Models with \* are converted from other repos.* + +## Results and models + +Waiting for adding. diff --git a/configs/t2t_vit/metafile.yml b/configs/t2t_vit/metafile.yml new file mode 100644 index 00000000000..0abcfe0617d --- /dev/null +++ b/configs/t2t_vit/metafile.yml @@ -0,0 +1,64 @@ +Collections: + - Name: Tokens-to-Token ViT + Metadata: + Training Data: ImageNet-1k + Architecture: + - Layer Normalization + - Scaled Dot-Product Attention + - Attention Dropout + - Dropout + - Tokens to Token + Paper: + URL: https://arxiv.org/abs/2101.11986 + Title: "Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet" + README: configs/t2t_vit/README.md + +Models: + - Name: t2t-vit-t-14_3rdparty_8xb64_in1k + Metadata: + FLOPs: 4340000000 + Parameters: 21470000 + In Collection: Tokens-to-Token ViT + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 81.69 + Top 5 Accuracy: 95.85 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth + Converted From: + Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/81.7_T2T_ViTt_14.pth.tar + Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L243 + Config: configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py + - Name: t2t-vit-t-19_3rdparty_8xb64_in1k + Metadata: + FLOPs: 7800000000 + Parameters: 39080000 + In Collection: Tokens-to-Token ViT + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 82.43 + Top 5 Accuracy: 96.08 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth + Converted From: + Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/82.4_T2T_ViTt_19.pth.tar + Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L254 + Config: configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py + - Name: t2t-vit-t-24_3rdparty_8xb64_in1k + Metadata: + FLOPs: 12690000000 + Parameters: 64000000 + In Collection: Tokens-to-Token ViT + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 82.55 + Top 5 Accuracy: 96.06 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth + Converted From: + Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/82.6_T2T_ViTt_24.pth.tar + Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L265 + Config: configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py diff --git a/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py b/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py new file mode 100644 index 00000000000..126d564ed27 --- /dev/null +++ b/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py @@ -0,0 +1,31 @@ +_base_ = [ + '../_base_/models/t2t-vit-t-14.py', + '../_base_/datasets/imagenet_bs64_t2t_224.py', + '../_base_/default_runtime.py', +] + +# optimizer +paramwise_cfg = dict( + bias_decay_mult=0.0, + custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)}, +) +optimizer = dict( + type='AdamW', + lr=5e-4, + weight_decay=0.05, + paramwise_cfg=paramwise_cfg, +) +optimizer_config = dict(grad_clip=None) + +# learning policy +# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and +# the lr in the last 10 epoch equals to min_lr +lr_config = dict( + policy='CosineAnnealing', + min_lr=1e-5, + by_epoch=True, + warmup_by_epoch=True, + warmup='linear', + warmup_iters=10, + warmup_ratio=1e-6) +runner = dict(type='EpochBasedRunner', max_epochs=310) diff --git a/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py b/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py new file mode 100644 index 00000000000..afd05a76a47 --- /dev/null +++ b/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py @@ -0,0 +1,31 @@ +_base_ = [ + '../_base_/models/t2t-vit-t-19.py', + '../_base_/datasets/imagenet_bs64_t2t_224.py', + '../_base_/default_runtime.py', +] + +# optimizer +paramwise_cfg = dict( + bias_decay_mult=0.0, + custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)}, +) +optimizer = dict( + type='AdamW', + lr=5e-4, + weight_decay=0.065, + paramwise_cfg=paramwise_cfg, +) +optimizer_config = dict(grad_clip=None) + +# learning policy +# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and +# the lr in the last 10 epoch equals to min_lr +lr_config = dict( + policy='CosineAnnealing', + min_lr=1e-5, + by_epoch=True, + warmup_by_epoch=True, + warmup='linear', + warmup_iters=10, + warmup_ratio=1e-6) +runner = dict(type='EpochBasedRunner', max_epochs=310) diff --git a/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py b/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py new file mode 100644 index 00000000000..9f856f3e592 --- /dev/null +++ b/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py @@ -0,0 +1,31 @@ +_base_ = [ + '../_base_/models/t2t-vit-t-24.py', + '../_base_/datasets/imagenet_bs64_t2t_224.py', + '../_base_/default_runtime.py', +] + +# optimizer +paramwise_cfg = dict( + bias_decay_mult=0.0, + custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)}, +) +optimizer = dict( + type='AdamW', + lr=5e-4, + weight_decay=0.065, + paramwise_cfg=paramwise_cfg, +) +optimizer_config = dict(grad_clip=None) + +# learning policy +# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and +# the lr in the last 10 epoch equals to min_lr +lr_config = dict( + policy='CosineAnnealing', + min_lr=1e-5, + by_epoch=True, + warmup_by_epoch=True, + warmup='linear', + warmup_iters=10, + warmup_ratio=1e-6) +runner = dict(type='EpochBasedRunner', max_epochs=310) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index e1942fe7c55..dc4163569d9 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -58,6 +58,9 @@ The ResNet family models below are trained by standard data augmentations, i.e., | Swin-Transformer small| 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219.log.json)| | Swin-Transformer base | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742-93230b0d.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742.log.json)| | Transformer in Transformer small\* | 23.76 | 3.36 | 81.52 | 95.73 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/tnt/tnt_s_patch16_224_evalonly_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth) | [log]()| +| T2T-ViT_t-14\* | 21.47 | 4.34 | 81.69 | 95.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-420df0f6.pth) | [log]()| +| T2T-ViT_t-19\* | 39.08 | 7.80 | 82.43 | 96.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-e479c2a6.pth) | [log]()| +| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-b5bf2526.pth) | [log]()| Models with * are converted from other repos, others are trained by ourselves. diff --git a/mmcls/datasets/pipelines/auto_augment.py b/mmcls/datasets/pipelines/auto_augment.py index d9ea6c911c2..973c1bd2506 100644 --- a/mmcls/datasets/pipelines/auto_augment.py +++ b/mmcls/datasets/pipelines/auto_augment.py @@ -2,6 +2,7 @@ import copy import inspect import random +from math import ceil from numbers import Number from typing import Sequence @@ -668,7 +669,8 @@ def __init__(self, bits, prob=0.5): assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \ f'got {prob} instead.' - self.bits = int(bits) + # To align timm version, we need to round up to integer here. + self.bits = ceil(bits) self.prob = prob def __call__(self, results): diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index 526d3450e69..3be2e924268 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -15,6 +15,7 @@ from .shufflenet_v1 import ShuffleNetV1 from .shufflenet_v2 import ShuffleNetV2 from .swin_transformer import SwinTransformer +from .t2t_vit import T2T_ViT from .timm_backbone import TIMMBackbone from .tnt import TNT from .vgg import VGG @@ -24,5 +25,5 @@ 'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer', - 'SwinTransformer', 'TNT', 'TIMMBackbone', 'Res2Net', 'RepVGG' + 'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG' ] diff --git a/mmcls/models/backbones/t2t_vit.py b/mmcls/models/backbones/t2t_vit.py new file mode 100644 index 00000000000..2e9cb527ff1 --- /dev/null +++ b/mmcls/models/backbones/t2t_vit.py @@ -0,0 +1,367 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN +from mmcv.cnn.utils.weight_init import trunc_normal_ +from mmcv.runner.base_module import BaseModule, ModuleList + +from ..builder import BACKBONES +from ..utils import MultiheadAttention +from .base_backbone import BaseBackbone + + +class T2TTransformerLayer(BaseModule): + """Transformer Layer for T2T_ViT. + + Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports + different ``input_dims`` and ``embed_dims``. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs + input_dims (int, optional): The input token dimension. + Defaults to None. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``(input_dims // num_heads) ** -0.5`` if set. Defaults to None. + act_cfg (dict): The activation config for FFNs. + Defaluts to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + + Notes: + In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e. + ``(embed_dims // num_heads) ** -0.5``. However, in the official + code, it uses ``(input_dims // num_heads) ** -0.5``, so here we + keep the same with the official implementation. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + input_dims=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=False, + qk_scale=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(T2TTransformerLayer, self).__init__(init_cfg=init_cfg) + + self.v_shortcut = True if input_dims is not None else False + input_dims = input_dims or embed_dims + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, input_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + self.attn = MultiheadAttention( + input_dims=input_dims, + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + qk_scale=qk_scale or (input_dims // num_heads)**-0.5, + v_shortcut=self.v_shortcut) + + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + self.add_module(self.norm2_name, norm2) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def forward(self, x): + if self.v_shortcut: + x = self.attn(self.norm1(x)) + else: + x = x + self.attn(self.norm1(x)) + x = self.ffn(self.norm2(x), identity=x) + return x + + +class T2TModule(BaseModule): + """Tokens-to-Token module. + + "Tokens-to-Token module" (T2T Module) can model the local structure + information of images and reduce the length of tokens progressively. + + Args: + img_size (int): Input image size + in_channels (int): Number of input channels + embed_dims (int): Embedding dimension + token_dims (int): Tokens dimension in T2TModuleAttention. + use_performer (bool): If True, use Performer version self-attention to + adopt regular self-attention. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Default: None. + + Notes: + Usually, ``token_dim`` is set as a small value (32 or 64) to reduce + MACs + """ + + def __init__( + self, + img_size=224, + in_channels=3, + embed_dims=384, + token_dims=64, + use_performer=False, + init_cfg=None, + ): + super(T2TModule, self).__init__(init_cfg) + + self.embed_dims = embed_dims + + self.soft_split0 = nn.Unfold( + kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) + self.soft_split1 = nn.Unfold( + kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.soft_split2 = nn.Unfold( + kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + + if not use_performer: + self.attention1 = T2TTransformerLayer( + input_dims=in_channels * 7 * 7, + embed_dims=token_dims, + num_heads=1, + feedforward_channels=token_dims) + + self.attention2 = T2TTransformerLayer( + input_dims=token_dims * 3 * 3, + embed_dims=token_dims, + num_heads=1, + feedforward_channels=token_dims) + + self.project = nn.Linear(token_dims * 3 * 3, embed_dims) + else: + raise NotImplementedError("Performer hasn't been implemented.") + + # there are 3 soft split, stride are 4,2,2 separately + self.num_patches = (img_size // (4 * 2 * 2))**2 + + def forward(self, x): + # step0: soft split + x = self.soft_split0(x).transpose(1, 2) + + for step in [1, 2]: + # re-structurization/reconstruction + attn = getattr(self, f'attention{step}') + x = attn(x).transpose(1, 2) + B, C, new_HW = x.shape + x = x.reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) + + # soft split + soft_split = getattr(self, f'soft_split{step}') + x = soft_split(x).transpose(1, 2) + + # final tokens + x = self.project(x) + return x + + +def get_sinusoid_encoding(n_position, embed_dims): + """Generate sinusoid encoding table. + + Sinusoid encoding is a kind of relative position encoding method came from + `Attention Is All You Need`_. + + Args: + n_position (int): The length of the input token. + embed_dims (int): The position embedding dimension. + + Returns: + :obj:`torch.FloatTensor`: The sinusoid encoding table. + """ + + def get_position_angle_vec(position): + return [ + position / np.power(10000, 2 * (i // 2) / embed_dims) + for i in range(embed_dims) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos) for pos in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +@BACKBONES.register_module() +class T2T_ViT(BaseBackbone): + """Tokens-to-Token Vision Transformer (T2T-ViT) + + A PyTorch implementation of `Tokens-to-Token ViT: Training Vision + Transformers from Scratch on ImageNet`_ + + Args: + img_size (int): Input image size. + in_channels (int): Number of input channels. + embed_dims (int): Embedding dimension. + t2t_cfg (dict): Extra config of Tokens-to-Token module. + Defaults to an empty dict. + drop_rate (float): Dropout rate after position embedding. + Defaults to 0. + num_layers (int): Num of transformer layers in encoder. + Defaults to 14. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. Defaults to + ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + output_cls_token (bool): Whether output the cls_token. + Defaults to True. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + img_size=224, + in_channels=3, + embed_dims=384, + t2t_cfg=dict(), + drop_rate=0., + num_layers=14, + out_indices=-1, + layer_cfgs=dict(), + drop_path_rate=0., + norm_cfg=dict(type='LN'), + final_norm=True, + output_cls_token=True, + init_cfg=None): + super(T2T_ViT, self).__init__(init_cfg) + + # Token-to-Token Module + self.tokens_to_token = T2TModule( + img_size=img_size, + in_channels=in_channels, + embed_dims=embed_dims, + **t2t_cfg) + num_patches = self.tokens_to_token.num_patches + + # Class token + self.output_cls_token = output_cls_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + + # Position Embedding + sinusoid_table = get_sinusoid_encoding(num_patches + 1, embed_dims) + self.register_buffer('pos_embed', sinusoid_table) + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = num_layers + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)] + self.encoder = ModuleList() + for i in range(num_layers): + if isinstance(layer_cfgs, Sequence): + layer_cfg = layer_cfgs[i] + else: + layer_cfg = deepcopy(layer_cfgs) + layer_cfg = { + 'embed_dims': embed_dims, + 'num_heads': 6, + 'feedforward_channels': 3 * embed_dims, + 'drop_path_rate': dpr[i], + 'qkv_bias': False, + 'norm_cfg': norm_cfg, + **layer_cfg + } + + layer = T2TTransformerLayer(**layer_cfg) + self.encoder.append(layer) + + self.final_norm = final_norm + if final_norm: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = nn.Identity() + + def init_weights(self): + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress custom init if use pretrained model. + return + + trunc_normal_(self.cls_token, std=.02) + + def forward(self, x): + B = x.shape[0] + x = self.tokens_to_token(x) + num_patches = self.tokens_to_token.num_patches + patch_resolution = [int(np.sqrt(num_patches))] * 2 + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.drop_after_pos(x) + + outs = [] + for i, layer in enumerate(self.encoder): + x = layer(x) + + if i == len(self.encoder) - 1 and self.final_norm: + x = self.norm(x) + + if i in self.out_indices: + B, _, C = x.shape + patch_token = x[:, 1:].reshape(B, *patch_resolution, C) + patch_token = patch_token.permute(0, 3, 1, 2) + cls_token = x[:, 0] + if self.output_cls_token: + out = [patch_token, cls_token] + else: + out = patch_token + outs.append(out) + + return tuple(outs) diff --git a/model-index.yml b/model-index.yml index 814a3c854c8..f48e0937a4e 100644 --- a/model-index.yml +++ b/model-index.yml @@ -12,3 +12,4 @@ Import: - configs/repvgg/metafile.yml - configs/tnt/metafile.yml - configs/vision_transformer/metafile.yml + - configs/t2t_vit/metafile.yml diff --git a/tests/test_models/test_backbones/test_t2t_vit.py b/tests/test_models/test_backbones/test_t2t_vit.py new file mode 100644 index 00000000000..e15f92f9ae9 --- /dev/null +++ b/tests/test_models/test_backbones/test_t2t_vit.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy + +import pytest +import torch +from torch.nn.modules import GroupNorm +from torch.nn.modules.batchnorm import _BatchNorm + +from mmcls.models.backbones import T2T_ViT + + +def is_norm(modules): + """Check if is one of the norms.""" + if isinstance(modules, (GroupNorm, _BatchNorm)): + return True + return False + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + + +def test_vit_backbone(): + + cfg_ori = dict( + img_size=224, + in_channels=3, + embed_dims=384, + t2t_cfg=dict( + token_dims=64, + use_performer=False, + ), + num_layers=14, + layer_cfgs=dict( + num_heads=6, + feedforward_channels=3 * 384, # mlp_ratio = 3 + ), + drop_path_rate=0.1, + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ]) + + with pytest.raises(NotImplementedError): + # test if use performer + cfg = deepcopy(cfg_ori) + cfg['t2t_cfg']['use_performer'] = True + T2T_ViT(**cfg) + + # Test T2T-ViT model with input size of 224 + model = T2T_ViT(**cfg_ori) + model.init_weights() + model.train() + + assert check_norm_state(model.modules(), True) + + imgs = torch.randn(3, 3, 224, 224) + patch_token, cls_token = model(imgs)[-1] + assert cls_token.shape == (3, 384) + assert patch_token.shape == (3, 384, 14, 14) + + # Test custom arch T2T-ViT without output cls token + cfg = deepcopy(cfg_ori) + cfg['embed_dims'] = 256 + cfg['num_layers'] = 16 + cfg['layer_cfgs'] = dict(num_heads=8, feedforward_channels=1024) + cfg['output_cls_token'] = False + + model = T2T_ViT(**cfg) + patch_token = model(imgs)[-1] + assert patch_token.shape == (3, 256, 14, 14) + + # Test T2T_ViT with multi out indices + cfg = deepcopy(cfg_ori) + cfg['out_indices'] = [-3, -2, -1] + model = T2T_ViT(**cfg) + for out in model(imgs): + assert out[0].shape == (3, 384, 14, 14) + assert out[1].shape == (3, 384)