Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Community] Add IDOL, VITA to project #835

Open
wants to merge 1 commit into
base: dev-1.x
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
add
pixeli99 committed Feb 9, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 96f0538533612f2649b0e33bc8a45142eb78958f
120 changes: 120 additions & 0 deletions projects/VIS_SOTA/IDOL/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# IDOL: In Defense of Online Models for Video Instance Segmentation

## Description

This is an implementation of [IDOL](https://github.com/wjf5203/VNext.git) based on [MMTracking](https://github.com/open-mmlab/mmtracking/tree/1.x), [MMDetection](https://github.com/open-mmlab/mmdetection/tree/3.x), [MMCV](https://github.com/open-mmlab/mmcv), and [MMEngine](https://github.com/open-mmlab/mmengine).

In recent years, video instance segmentation (VIS) has been largely advanced by offline models, while online models are usually inferior to the contemporaneous offline models by over 10 AP, which is a huge drawback. By dissecting current online models and offline models, we demonstrate that the main cause of the performance gap is the error-prone association and propose IDOL, which outperforms all online and offline methods on three benchmarks. IDOL won first place in the video instance segmentation track of the 4th Large-scale Video Object Segmentation Challenge (CVPR2022).

<center>
<img src="https://github.com/wjf5203/VNext/blob/main/assets/IDOL/arch.png">
</center>

## Usage

<!-- For a typical model, this section should contain the commands for training and testing. You are also suggested to dump your environment specification to env.yml by `conda env export > env.yml`. -->

### Training commands

In MMTracking's root directory, run the following command to train the model:

```bash
python tools/train.py projects/VIS_SOTA/IDOL/configs/idol_r50_8xb4-12k_youtubevis2019.py
```

For multi-gpu training, run:

```bash
python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${NUM_GPUS} --master_port=29506 --master_addr="127.0.0.1" tools/train.py projects/VIS_SOTA/IDOL/configs/idol_r50_8xb4-12k_youtubevis2019.py
```

### Testing commands

In MMTracking's root directory, run the following command to test the model:

```bash
python tools/test.py projects/VIS_SOTA/IDOL/configs/idol_r50_8xb4-12k_youtubevis2019.py ${CHECKPOINT_PATH}
```

## Results

#### YouTubeVIS-2019

| Method | Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | AP | Config | Download |
| :----: | :------: | :-----: | :-----: | :------: | :------------: | :--: | :--------------------------------------------------------------------------: | :----------------------: |
| IDOL | R-50 | pytorch | 12k | 27.0 | - | 49.3 | [config](projects/VIS_SOTA/IDOL/configs/idol_r50_8xb4-12k_youtubevis2019.py) | [model](<>) \| [log](<>) |

#### OVIS

| Method | Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | AP | Config | Download |
| :----: | :------: | :-----: | :-----: | :------: | :------------: | :--: | :----------: | :----------------------: |
| IDOL | R-50 | pytorch | 12k | 27.0 | - | 29.7 | [config](<>) | [model](<>) \| [log](<>) |

## Citation

If you find IDOL is useful in your research or applications, please consider giving a star 🌟 to the [official repository](https://github.com/wjf5203/VNext) and citing IDOL by the following BibTeX entry.

```BibTeX
@inproceedings{IDOL,
title={In Defense of Online Models for Video Instance Segmentation},
author={Wu, Junfeng and Liu, Qihao and Jiang, Yi and Bai, Song and Yuille, Alan and Bai, Xiang},
booktitle={ECCV},
year={2022},
}
```

## Checklist

<!-- Here is a checklist illustrating a usual development workflow of a successful project, and also serves as an overview of this project's progress. The PIC (person in charge) or contributors of this project should check all the items that they believe have been finished, which will further be verified by codebase maintainers via a PR.
OpenMMLab's maintainer will review the code to ensure the project's quality. Reaching the first milestone means that this project suffices the minimum requirement of being merged into 'projects/'. But this project is only eligible to become a part of the core package upon attaining the last milestone.
Note that keeping this section up-to-date is crucial not only for this project's developers but the entire community, since there might be some other contributors joining this project and deciding their starting point from this list. It also helps maintainers accurately estimate time and effort on further code polishing, if needed.
A project does not necessarily have to be finished in a single PR, but it's essential for the project to at least reach the first milestone in its very first PR. -->

- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.

- [x] Finish the code

<!-- The code's design shall follow existing interfaces and convention. For example, each model component should be registered into `mmdet.registry.MODELS` and configurable via a config file. -->

- [x] Basic docstrings & proper citation

<!-- Each major object should contain a docstring, describing its functionality and arguments. If you have adapted the code from other open-source projects, don't forget to cite the source project in docstring and make sure your behavior is not against its license. Typically, we do not accept any code snippet under GPL license. [A Short Guide to Open Source Licenses](https://medium.com/nationwide-technology/a-short-guide-to-open-source-licenses-cf5b1c329edd) -->

- [x] Test-time correctness

<!-- If you are reproducing the result from a paper, make sure your model's inference-time performance matches that in the original paper. The weights usually could be obtained by simply renaming the keys in the official pre-trained weights. This test could be skipped though, if you are able to prove the training-time correctness and check the second milestone. -->

- [x] A full README

<!-- As this template does. -->

- [x] Milestone 2: Indicates a successful model implementation.

- [x] Training-time correctness

<!-- If you are reproducing the result from a paper, checking this item means that you should have trained your model from scratch based on the original paper's specification and verified that the final result matches the report within a minor error range. -->

- [ ] Milestone 3: Good to be a part of our core package!

- [ ] Type hints and docstrings

<!-- Ideally *all* the methods should have [type hints](https://www.pythontutorial.net/python-basics/python-type-hints/) and [docstrings](https://google.github.io/styleguide/pyguide.html#381-docstrings). [Example](https://github.com/open-mmlab/mmdetection/blob/5b0d5b40d5c6cfda906db7464ca22cbd4396728a/mmdet/datasets/transforms/transforms.py#L41-L169) -->

- [ ] Unit tests

<!-- Unit tests for each module are required. [Example](https://github.com/open-mmlab/mmdetection/blob/5b0d5b40d5c6cfda906db7464ca22cbd4396728a/tests/test_datasets/test_transforms/test_transforms.py#L35-L88) -->

- [ ] Code polishing

<!-- Refactor your code according to reviewer's comment. -->

- [ ] Metafile.yml

<!-- It will be parsed by MIM and Inferencer. [Example](https://github.com/open-mmlab/mmdetection/blob/3.x/configs/faster_rcnn/metafile.yml) -->

- [ ] Move your modules into the core package following the codebase's file hierarchy structure.

<!-- In particular, you may have to refactor this README into a standard one. [Example](https://github.com/open-mmlab/mmdetection/blob/3.x/configs/faster_rcnn/README.md) -->

- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
64 changes: 64 additions & 0 deletions projects/VIS_SOTA/IDOL/configs/coco_instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# dataset settings
dataset_type = 'mmdet.CocoDataset'
data_root = 'data/coco/'

# file_client_args = dict(
# backend='petrel',
# path_mapping=dict({
# './data/': 's3://openmmlab/datasets/detection/',
# 'data/': 's3://openmmlab/datasets/detection/'
# }))
file_client_args = dict(backend='disk')

train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='LoadTrackAnnotations', with_bbox=True, with_mask=True),
dict(type='mmdet.Resize', scale=(1333, 800), keep_ratio=True),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(type='PackTrackInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='mmdet.Resize', scale=(2000, 640), keep_ratio=True),
dict(
type='LoadTrackAnnotations',
with_instance_id=False,
with_bbox=True,
with_mask=True),
dict(type='PackTrackInputs', pack_single_img=True)
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
batch_sampler=dict(type='mmdet.AspectRatioBatchSampler'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_val2017.json',
data_prefix=dict(img='val2017/'),
test_mode=True,
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(
_scope_='mmdet',
type='CocoMetric',
ann_file=data_root + 'annotations/instances_val2017.json',
metric=['bbox', 'segm'],
format_only=False)
test_evaluator = val_evaluator
136 changes: 136 additions & 0 deletions projects/VIS_SOTA/IDOL/configs/idol_r50_8xb2-16e_coco-seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
_base_ = [
'./coco_instance.py', # noqa: E501
'../../../../configs/_base_/default_runtime.py'
]

custom_imports = dict(imports=['projects.VIS_SOTA.IDOL.idol_src'], )

model = dict(
type='IDOL',
data_preprocessor=dict(
type='TrackDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_mask=True,
pad_size_divisor=32),
backbone=dict(
type='mmdet.ResNet',
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='mmdet.ChannelMapper',
in_channels=[512, 1024, 2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
track_head=dict(
_scope_='mmdet',
type='mmtrack.IDOLTrackHead',
num_query=300,
num_classes=80,
in_channels=2048,
with_box_refine=True,
sync_cls_avg_factor=True,
as_two_stage=False,
transformer=dict(
type='mmtrack.DeformableDetrTransformer',
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention', embed_dims=256),
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
decoder=dict(
type='DeformableDetrTransformerDecoder',
num_layers=6,
return_intermediate=True,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
dict(
type='MultiScaleDeformableAttention',
embed_dims=256)
],
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')))),
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128,
normalize=True,
offset=-0.5),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
# training and testing settings
# can't del 'mmtrack'
train_cfg=dict(
assigner=dict(type='mmtrack.SimOTAAssigner', center_radius=2.5),
cur_train_mode='COCO_Video'),
)

# optimizer
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=0.0001,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999)),
paramwise_cfg=dict(
custom_keys={
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi,
},
norm_decay_mult=0.0),
clip_grad=dict(max_norm=0.01, norm_type=2))

# learning policy
max_iters = 6000
param_scheduler = dict(
type='MultiStepLR',
begin=0,
end=max_iters,
by_epoch=False,
milestones=[
4000,
],
gamma=0.1)
# runtime settings
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=max_iters, val_interval=6001)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook', by_epoch=False, save_last=True, interval=2000))
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=False)
227 changes: 227 additions & 0 deletions projects/VIS_SOTA/IDOL/configs/idol_r50_8xb4-12k_youtubevis2019.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
_base_ = [
'../../../../configs/_base_/datasets/youtube_vis.py', # noqa: E501
'../../../../configs/_base_/default_runtime.py'
]

custom_imports = dict(imports=['projects.VIS_SOTA.IDOL.idol_src'], )

model = dict(
type='IDOL',
data_preprocessor=dict(
type='TrackDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_mask=True,
pad_size_divisor=32),
backbone=dict(
type='mmdet.ResNet',
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='mmdet.ChannelMapper',
in_channels=[512, 1024, 2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
track_head=dict(
_scope_='mmdet',
type='mmtrack.IDOLTrackHead',
num_query=300,
num_classes=40,
in_channels=2048,
with_box_refine=True,
sync_cls_avg_factor=True,
as_two_stage=False,
transformer=dict(
type='mmtrack.DeformableDetrTransformer',
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention', embed_dims=256),
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
decoder=dict(
type='DeformableDetrTransformerDecoder',
num_layers=6,
return_intermediate=True,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
dict(
type='MultiScaleDeformableAttention',
embed_dims=256)
],
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')))),
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128,
normalize=True,
offset=-0.5),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0),
loss_mask=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=20.0),
loss_dice=dict(
type='DiceLoss',
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=1.0),
loss_track=dict(
type='mmtrack.MultiPosCrossEntropyLoss', loss_weight=0.25),
loss_track_aux=dict(
type='mmtrack.L2Loss',
neg_pos_ub=3,
pos_margin=0,
neg_margin=0.1,
hard_mining=True,
loss_weight=1.0)),
tracker=dict(
type='IDOLTracker',
init_score_thr=0.2,
obj_score_thr=0.1,
nms_thr_pre=0.5,
nms_thr_post=0.05,
addnew_score_thr=0.2,
memo_tracklet_frames=10,
memo_momentum=0.8,
long_match=True,
frame_weight=True,
temporal_weight=True,
memory_len=3,
match_metric='bisoftmax'),
# training and testing settings
# can't del 'mmtrack'
train_cfg=dict(
assigner=dict(
type='mmtrack.SimOTAAssigner',
center_radius=2.5,
match_costs=[
dict(type='FocalLossCost', weight=1.0),
dict(type='IoUCost', iou_mode='giou', weight=3.0)
]),
cur_train_mode='VIS'),
)

# optimizer
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=0.0001,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999)),
paramwise_cfg=dict(
custom_keys={
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi,
},
norm_decay_mult=0.0),
clip_grad=dict(max_norm=0.01, norm_type=2))

# learning policy
max_iters = 6000
param_scheduler = dict(
type='MultiStepLR',
begin=0,
end=max_iters,
by_epoch=False,
milestones=[
4000,
],
gamma=0.1)
# runtime settings
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=max_iters, val_interval=6001)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook', by_epoch=False, save_last=True, interval=2000))
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=False)

train_pipeline = [
dict(
type='TransformBroadcaster',
share_random_params=True,
transforms=[
dict(type='LoadImageFromFile'),
dict(
type='LoadTrackAnnotations',
with_instance_id=True,
with_mask=True,
with_bbox=True),
dict(type='mmdet.Resize', scale=(640, 360), keep_ratio=True),
dict(type='mmdet.RandomFlip', prob=0.5),
]),
dict(type='PackTrackInputs', num_key_frames=2)
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='LoadTrackAnnotations',
with_instance_id=True,
with_mask=True,
with_bbox=True),
# dict(type='mmdet.Resize', scale=(480, 1000), keep_ratio=True),
dict(type='PackTrackInputs', pack_single_img=True)
]
# dataloader
train_dataloader = dict(
batch_size=2,
num_workers=2,
dataset=dict(
pipeline=train_pipeline,
ref_img_sampler=dict(
num_ref_imgs=1,
frame_range=5,
filter_key_img=True,
method='uniform')))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
# evaluator
val_evaluator = dict(
type='YouTubeVISMetric',
metric='youtube_vis_ap',
outfile_prefix='./youtube_vis_results',
format_only=True)
test_evaluator = val_evaluator
5 changes: 5 additions & 0 deletions projects/VIS_SOTA/IDOL/idol_src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .idol import IDOL
from .models import IDOLTrackHead

__all__ = ['IDOL', 'IDOLTrackHead']
96 changes: 96 additions & 0 deletions projects/VIS_SOTA/IDOL/idol_src/idol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple, Union

from torch import Tensor

from mmtrack.models.mot import BaseMultiObjectTracker
from mmtrack.registry import MODELS
from mmtrack.utils import OptConfigType, OptMultiConfig, SampleList


@MODELS.register_module()
class IDOL(BaseMultiObjectTracker):

def __init__(self,
backbone: Optional[dict] = None,
neck: Optional[dict] = None,
track_head: Optional[dict] = None,
tracker: Optional[dict] = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super().__init__(data_preprocessor, init_cfg)

if backbone is not None:
self.backbone = MODELS.build(backbone)

if neck is not None:
self.neck = MODELS.build(neck)

if track_head is not None:
track_head.update(train_cfg=train_cfg)
track_head.update(test_cfg=test_cfg)
self.track_head = MODELS.build(track_head)

if tracker is not None:
self.tracker = MODELS.build(tracker)

self.train_cfg = train_cfg
self.test_cfg = test_cfg

self.cur_train_mode = train_cfg.cur_train_mode

def loss(self, inputs: Dict[str, Tensor], data_samples: SampleList,
**kwargs) -> Union[dict, tuple]:

img = inputs['img']
assert img.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).'
# shape (N * T, C, H, W)
img = img.flatten(0, 1)

x = self.extract_feat(img)
losses = self.track_head.loss(x, data_samples)

return losses

def predict(self,
inputs: dict,
data_samples: SampleList,
rescale: bool = True) -> SampleList:

img = inputs['img']
assert img.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).'
# the "T" is 1
img = img.squeeze(1)
feats = self.extract_feat(img)
pred_det_ins_list = self.track_head.predict(feats, data_samples,
rescale)
track_data_sample = data_samples[0]
pred_det_ins = pred_det_ins_list[0]
track_data_sample.pred_det_instances = \
pred_det_ins.clone()

if self.cur_train_mode == 'VIS':
pred_track_instances = self.tracker.track(
data_sample=track_data_sample, rescale=rescale)
track_data_sample.pred_track_instances = pred_track_instances
else:
pred_det_ins.masks = pred_det_ins.masks.squeeze(1) > 0.5
track_data_sample.pred_instances = pred_det_ins

return [track_data_sample]

def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
"""Extract features.
Args:
batch_inputs (Tensor): Image tensor with shape (N * T, C, H ,W).
Returns:
tuple[Tensor]: Multi-level features that may have
different resolutions.
"""
x = self.backbone(batch_inputs)
x = self.neck(x)
return x
10 changes: 10 additions & 0 deletions projects/VIS_SOTA/IDOL/idol_src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .idol_query_track_head import IDOLTrackHead
from .idol_tracker import IDOLTracker
from .sim_ota_assigner import SimOTAAssigner
from .transformer import DeformableDetrTransformer

__all__ = [
'IDOLTrackHead', 'DeformableDetrTransformer', 'SimOTAAssigner',
'IDOLTracker'
]
929 changes: 929 additions & 0 deletions projects/VIS_SOTA/IDOL/idol_src/models/idol_query_track_head.py

Large diffs are not rendered by default.

328 changes: 328 additions & 0 deletions projects/VIS_SOTA/IDOL/idol_src/models/idol_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch
import torch.nn.functional as F
from mmengine.structures import InstanceData
from torch import Tensor

from mmtrack.models.trackers import BaseTracker
from mmtrack.registry import MODELS
from mmtrack.structures import TrackDataSample
from .utils import mask_iou, mask_nms


@MODELS.register_module()
class IDOLTracker(BaseTracker):

def __init__(self,
nms_thr_pre=0.7,
nms_thr_post=0.3,
init_score_thr=0.2,
addnew_score_thr=0.5,
obj_score_thr=0.1,
match_score_thr=0.5,
memo_tracklet_frames=10,
memo_backdrop_frames=1,
memo_momentum=0.5,
nms_conf_thr=0.5,
nms_backdrop_iou_thr=0.5,
nms_class_iou_thr=0.7,
with_cats=True,
match_metric='bisoftmax',
long_match=False,
frame_weight=False,
temporal_weight=False,
memory_len=10,
**kwargs):
super().__init__(**kwargs)
assert 0 <= memo_momentum <= 1.0
assert memo_tracklet_frames >= 0
assert memo_backdrop_frames >= 0
self.memory_len = memory_len
self.temporal_weight = temporal_weight
self.long_match = long_match
self.frame_weight = frame_weight
self.nms_thr_pre = nms_thr_pre
self.nms_thr_post = nms_thr_post
self.init_score_thr = init_score_thr
self.addnew_score_thr = addnew_score_thr
self.obj_score_thr = obj_score_thr
self.match_score_thr = match_score_thr
self.memo_tracklet_frames = memo_tracklet_frames
self.memo_backdrop_frames = memo_backdrop_frames
self.memo_momentum = memo_momentum
self.nms_conf_thr = nms_conf_thr
self.nms_backdrop_iou_thr = nms_backdrop_iou_thr
self.nms_class_iou_thr = nms_class_iou_thr
self.with_cats = with_cats
assert match_metric in ['bisoftmax', 'softmax', 'cosine']
self.match_metric = match_metric

self.reset()

def reset(self):
"""Reset the buffer of the tracker."""
self.num_tracks = 0
self.tracks = dict()
self.backdrops = []

def update(self, ids: Tensor, bboxes: Tensor, embeds: Tensor,
labels: Tensor, scores: Tensor, frame_id: int) -> None:
"""Tracking forward function.
Args:
ids (Tensor): of shape(N, ).
bboxes (Tensor): of shape (N, 5).
embeds (Tensor): of shape (N, 256).
labels (Tensor): of shape (N, ).
scores (Tensor): of shape (N, ).
frame_id (int): The id of current frame, 0-index.
"""
tracklet_inds = ids > -1

for id, bbox, embed, label, score in zip(ids[tracklet_inds],
bboxes[tracklet_inds],
embeds[tracklet_inds],
labels[tracklet_inds],
scores[tracklet_inds]):
id = int(id)
# update the tracked ones and initialize new tracks
if id in self.tracks.keys():
velocity = (bbox - self.tracks[id]['bbox']) / (
frame_id - self.tracks[id]['last_frame'])
self.tracks[id]['bbox'] = bbox
self.tracks[id]['long_score'].append(score)
self.tracks[id]['embed'] = (
1 - self.memo_momentum
) * self.tracks[id]['embed'] + self.memo_momentum * embed
self.tracks[id]['long_embed'].append(embed)
self.tracks[id]['last_frame'] = frame_id
self.tracks[id]['label'] = label
self.tracks[id]['score'] = score
self.tracks[id]['velocity'] = (
self.tracks[id]['velocity'] * self.tracks[id]['acc_frame']
+ velocity) / (
self.tracks[id]['acc_frame'] + 1)
self.tracks[id]['acc_frame'] += 1
self.tracks[id]['exist_frame'] += 1
else:
self.tracks[id] = dict(
bbox=bbox,
embed=embed,
label=label,
long_embed=[embed],
long_score=[score],
score=score,
last_frame=frame_id,
velocity=torch.zeros_like(bbox),
acc_frame=0,
exist_frame=1)
# backdrop update according to IoU
backdrop_inds = torch.nonzero(ids == -1, as_tuple=False).squeeze(1)
self.backdrops.insert(
0,
dict(
bboxes=bboxes[backdrop_inds],
embeds=embeds[backdrop_inds],
labels=labels[backdrop_inds]))

# pop memo
invalid_ids = []
for k, v in self.tracks.items():
if frame_id - v['last_frame'] >= self.memo_tracklet_frames:
invalid_ids.append(k)
if len(v['long_embed']) > self.memory_len:
v['long_embed'].pop(0)
if len(v['long_score']) > self.memory_len:
v['long_score'].pop(0)
for invalid_id in invalid_ids:
self.tracks.pop(invalid_id)

if len(self.backdrops) > self.memo_backdrop_frames:
self.backdrops.pop()

@property
def memo(self) -> Tuple[Tensor, ...]:
"""Get tracks memory."""
memo_embeds = []
memo_ids = []
memo_bboxes = []
memo_labels = []
# velocity of tracks
memo_vs = []
# for long term tracking
memo_long_embeds = []
memo_long_score = []
memo_exist_frame = []
# get tracks
for k, v in self.tracks.items():
memo_bboxes.append(v['bbox'][None, :])
if self.long_match:
weights = torch.stack(v['long_score'])
if self.temporal_weight:
length = len(weights)
temporal_weight = torch.range(0.0, 1,
1 / length)[1:].to(weights)
weights = weights + temporal_weight
sum_embed = (torch.stack(v['long_embed']) *
weights.unsqueeze(1)).sum(0) / weights.sum()
memo_embeds.append(sum_embed[None, :])
else:
memo_embeds.append(v['embed'][None, :])

memo_long_embeds.append(torch.stack(v['long_embed']))
memo_long_score.append(torch.stack(v['long_score']))
memo_exist_frame.append(v['exist_frame'])

memo_ids.append(k)
memo_labels.append(v['label'].view(1, 1))
memo_vs.append(v['velocity'][None, :])
memo_ids = torch.tensor(memo_ids, dtype=torch.long).view(1, -1)
memo_exist_frame = torch.tensor(memo_exist_frame, dtype=torch.long)

memo_bboxes = torch.cat(memo_bboxes, dim=0)
memo_embeds = torch.cat(memo_embeds, dim=0)
memo_labels = torch.cat(memo_labels, dim=0).squeeze(1)
memo_vs = torch.cat(memo_vs, dim=0)
return memo_bboxes, memo_labels, memo_embeds, memo_ids.squeeze(
0), memo_vs, memo_long_embeds, memo_long_score, memo_exist_frame

def track(self,
data_sample: TrackDataSample,
rescale=True,
**kwargs) -> InstanceData:

metainfo = data_sample.metainfo
bboxes = data_sample.pred_det_instances.bboxes
masks = data_sample.pred_det_instances.masks
labels = data_sample.pred_det_instances.labels
scores = data_sample.pred_det_instances.scores
embeds = data_sample.pred_det_instances.track_feats

valids = mask_nms(masks, scores, self.nms_thr_pre)
frame_id = metainfo.get('frame_id', -1)

bboxes = bboxes[valids, :]
labels = labels[valids]
masks = masks[valids]
scores = scores[valids]
embeds = embeds[valids, :]
# create pred_track_instances
pred_track_instances = InstanceData(metainfo=metainfo)

# return zero bboxes if there is no track targets
if bboxes.shape[0] == 0:
ids = torch.zeros_like(labels)
pred_track_instances = data_sample.pred_det_instances.clone()
pred_track_instances.instances_id = ids
return pred_track_instances

# init ids container
ids = torch.full((bboxes.size(0), ), -2, dtype=torch.long)

# match if buffer is not empty
if bboxes.size(0) > 0 and not self.empty:
(memo_bboxes, memo_labels, memo_embeds, memo_ids, memo_vs,
memo_long_embeds, memo_long_score, memo_exist_frame) = self.memo

memo_exist_frame = memo_exist_frame.to(memo_embeds)
memo_ids = memo_ids.to(memo_embeds)

if self.match_metric == 'bisoftmax':
feats = torch.mm(embeds, memo_embeds.t())
d2t_scores = feats.softmax(dim=1)
t2d_scores = feats.softmax(dim=0)
match_scores = (d2t_scores + t2d_scores) / 2
elif self.match_metric == 'softmax':
feats = torch.mm(embeds, memo_embeds.t())
match_scores = feats.softmax(dim=1)
elif self.match_metric == 'cosine':
match_scores = torch.mm(
F.normalize(embeds, p=2, dim=1),
F.normalize(memo_embeds, p=2, dim=1).t())
else:
raise NotImplementedError
# track according to match_scores
for i in range(bboxes.size(0)):
if self.frame_weight:
non_backs = (memo_ids > -1) & (match_scores[i, :] > 0.5)
if (match_scores[i, non_backs] > 0.5).sum() > 1:
wighted_scores = match_scores.clone()
frame_weight = memo_exist_frame[
match_scores[i, :][memo_ids > -1] > 0.5]
wighted_scores[i, non_backs] = wighted_scores[
i, non_backs] * frame_weight
wighted_scores[i, ~non_backs] = wighted_scores[
i, ~non_backs] * frame_weight.mean()
conf, memo_ind = torch.max(wighted_scores[i, :], dim=0)
else:
conf, memo_ind = torch.max(match_scores[i, :], dim=0)
else:
conf, memo_ind = torch.max(match_scores[i, :], dim=0)
id = memo_ids[memo_ind]
if conf > self.match_score_thr:
if id > -1:
ids[i] = id
match_scores[:i, memo_ind] = 0
match_scores[i + 1:, memo_ind] = 0
# initialize new tracks
new_inds = (ids == -2) & (scores > self.addnew_score_thr).cpu()
num_news = new_inds.sum()
ids[new_inds] = torch.arange(
self.num_tracks, self.num_tracks + num_news, dtype=torch.long)
self.num_tracks += num_news

# get backdrops
unselected_inds = torch.nonzero(
ids == -2, as_tuple=False).squeeze(1)
mask_ious = mask_iou(masks[unselected_inds].sigmoid() > 0.5,
masks.permute(1, 0, 2, 3).sigmoid() > 0.5)
for i, ind in enumerate(unselected_inds):
if (mask_ious[i, :ind] < self.nms_thr_post).all():
ids[ind] = -1

self.update(ids, bboxes, embeds, labels, scores, frame_id)

elif self.empty:
init_inds = (ids == -2) & (scores > self.init_score_thr).cpu()
num_news = init_inds.sum()
ids[init_inds] = torch.arange(
self.num_tracks, self.num_tracks + num_news, dtype=torch.long)
self.num_tracks += num_news
unselected_inds = torch.nonzero(
ids == -2, as_tuple=False).squeeze(1)
mask_ious = mask_iou(masks[unselected_inds].sigmoid() > 0.5,
masks.permute(1, 0, 2, 3).sigmoid() > 0.5)
for i, ind in enumerate(unselected_inds):
if (mask_ious[i, :ind] < self.nms_thr_post).all():
ids[ind] = -1
self.update(ids, bboxes, embeds, labels, scores, frame_id)

tracklet_inds = ids > -1

if rescale:
# return result in original resolution
# rz_*: the resize shape
pad_height, pad_width = metainfo['pad_shape']
rz_height, rz_width = metainfo['img_shape']
masks = F.interpolate(
masks,
size=(pad_height, pad_width),
mode='bilinear',
align_corners=False).sigmoid()
# crop the padding area
masks = masks[:, :, :rz_height, :rz_width]
ori_height, ori_width = metainfo['ori_shape']
masks = (
F.interpolate(
masks, size=(ori_height, ori_width), mode='nearest'))

# update pred_track_instances
pred_track_instances.bboxes = bboxes[tracklet_inds]
pred_track_instances.masks = masks[tracklet_inds].squeeze(1) > 0.5
pred_track_instances.labels = labels[tracklet_inds]
pred_track_instances.scores = scores[tracklet_inds]
pred_track_instances.instances_id = ids[tracklet_inds]

return pred_track_instances
232 changes: 232 additions & 0 deletions projects/VIS_SOTA/IDOL/idol_src/models/pos_neg_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import List

import torch
import torch.nn as nn
import torchvision.ops as ops
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh

from .utils import generalized_box_iou


def get_contrast_items(key_embeds, ref_embeds, key_gt_instances,
ref_gt_instances, ref_bbox, ref_cls, query_inds,
batch_img_metas) -> List:

one = torch.tensor(1).to(ref_embeds)
zero = torch.tensor(0).to(ref_embeds)
contrast_items = []

for bid, (key_gt, ref_gt, indices) in enumerate(
zip(key_gt_instances, ref_gt_instances, query_inds)):

key_ins_ids = key_gt.instances_id
ref_ins_ids = ref_gt.instances_id
valid = torch.tensor([(ref_ins_id in key_ins_ids)
for ref_ins_id in ref_ins_ids],
dtype=torch.bool)

img_h, img_w, = batch_img_metas[bid]['img_shape']
gt_bboxes = bbox_xyxy_to_cxcywh(ref_gt['bboxes'])
factor = gt_bboxes.new_tensor([img_w, img_h, img_w,
img_h]).unsqueeze(0).repeat(
gt_bboxes.size(0), 1)
gt_bboxes /= factor
gt_labels = ref_gt['labels']
contrastive_pos, contrastive_neg = get_pos_idx(ref_bbox[bid],
ref_cls[bid], gt_bboxes,
gt_labels, valid)

for inst_i, (is_valid, matched_query_id) in enumerate(
zip(valid, query_inds[bid])):

if not is_valid:
continue
# key_embeds: (bs, num_queries, c)
key_embed_i = key_embeds[bid, matched_query_id].unsqueeze(0)

pos_embed = ref_embeds[bid][contrastive_pos[inst_i]]
neg_embed = ref_embeds[bid][~contrastive_neg[inst_i]]
contrastive_embed = torch.cat([pos_embed, neg_embed], dim=0)
contrastive_label = torch.cat(
[one.repeat(len(pos_embed)),
zero.repeat(len(neg_embed))],
dim=0).unsqueeze(1)

contrast = torch.einsum('nc,kc->nk',
[contrastive_embed, key_embed_i])

if len(pos_embed) == 0:
num_sample_neg = 10
elif len(pos_embed) * 10 >= len(neg_embed):
num_sample_neg = len(neg_embed)
else:
num_sample_neg = len(pos_embed) * 10

# for aux loss
sample_ids = random.sample(
list(range(0, len(neg_embed))), num_sample_neg)
aux_contrastive_embed = torch.cat(
[pos_embed, neg_embed[sample_ids]], dim=0)
aux_contrastive_label = torch.cat(
[one.repeat(len(pos_embed)),
zero.repeat(num_sample_neg)],
dim=0).unsqueeze(1)
aux_contrastive_embed = nn.functional.normalize(
aux_contrastive_embed.float(), dim=1)
key_embed_i = nn.functional.normalize(key_embed_i.float(), dim=1)
cosine = torch.einsum('nc,kc->nk',
[aux_contrastive_embed, key_embed_i])

contrast_items.append({
'contrast': contrast,
'label': contrastive_label,
'aux_consin': cosine,
'aux_label': aux_contrastive_label
})

return contrast_items


def get_pos_idx(ref_bbox, ref_cls, gt_bbox, gt_cls, valid):

with torch.no_grad():
if False in valid:
gt_bbox = gt_bbox[valid]
gt_cls = gt_cls[valid]

fg_mask, is_in_boxes_and_center = \
get_in_gt_and_in_center_info(
ref_bbox,
gt_bbox,
expanded_strides=32)
pair_wise_ious = ops.box_iou(
bbox_cxcywh_to_xyxy(ref_bbox), bbox_cxcywh_to_xyxy(gt_bbox))

# Compute the classification cost.
alpha = 0.25
gamma = 2.0
neg_cost_class = (1 - alpha) * (ref_cls ** gamma) * \
(-(1 - ref_cls + 1e-8).log())
pos_cost_class = alpha * (
(1 - ref_cls)**gamma) * (-(ref_cls + 1e-8).log())
cost_class = pos_cost_class[:, gt_cls] - neg_cost_class[:, gt_cls]
cost_giou = - \
generalized_box_iou(bbox_cxcywh_to_xyxy(
ref_bbox), bbox_cxcywh_to_xyxy(gt_bbox))

cost = (
cost_class + 3.0 * cost_giou + 100.0 * (~is_in_boxes_and_center))

cost[~fg_mask] = cost[~fg_mask] + 10000.0

if False in valid:
pos_indices = []
neg_indices = []
if valid.sum() > 0:
# Select positive sample on reference frame, k = 10
pos_matched = dynamic_k_matching(cost, pair_wise_ious,
int(valid.sum()), 10)
# Select positive sample on reference frame, k = 100
neg_matched = dynamic_k_matching(cost, pair_wise_ious,
int(valid.sum()), 100)
valid_idx = 0
valid_list = valid.tolist()
for istrue in valid_list:
if istrue:
pos_indices.append(pos_matched[valid_idx])
neg_indices.append(neg_matched[valid_idx])
valid_idx += 1
else:
pos_indices.append(None)
neg_indices.append(None)

else:
if valid.sum() > 0:
pos_indices = dynamic_k_matching(cost, pair_wise_ious,
gt_bbox.shape[0], 10)
neg_indices = dynamic_k_matching(cost, pair_wise_ious,
gt_bbox.shape[0], 100)
else:
pos_indices = [None]
neg_indices = [None]

return (pos_indices, neg_indices)


def get_in_gt_and_in_center_info(bboxes, target_gts, expanded_strides):

xy_target_gts = bbox_cxcywh_to_xyxy(target_gts)

anchor_center_x = bboxes[:, 0].unsqueeze(1)
anchor_center_y = bboxes[:, 1].unsqueeze(1)

b_l = anchor_center_x > xy_target_gts[:, 0].unsqueeze(0)
b_r = anchor_center_x < xy_target_gts[:, 2].unsqueeze(0)
b_t = anchor_center_y > xy_target_gts[:, 1].unsqueeze(0)
b_b = anchor_center_y < xy_target_gts[:, 3].unsqueeze(0)
is_in_boxes = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4)
is_in_boxes_all = is_in_boxes.sum(1) > 0

# in fixed center
center_radius = 2.5
b_l = anchor_center_x > (
target_gts[:, 0] - (1 * center_radius / expanded_strides)).unsqueeze(0)
b_r = anchor_center_x < (
target_gts[:, 0] + (1 * center_radius / expanded_strides)).unsqueeze(0)
b_t = anchor_center_y > (
target_gts[:, 1] - (1 * center_radius / expanded_strides)).unsqueeze(0)
b_b = anchor_center_y < (
target_gts[:, 1] + (1 * center_radius / expanded_strides)).unsqueeze(0)
is_in_centers = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4)
is_in_centers_all = is_in_centers.sum(1) > 0

# in boxes or in centers, shape: [num_priors]
is_in_boxes_or_centers = is_in_boxes_all | is_in_centers_all
# both in boxes and centers, shape: [num_fg, num_gt]
is_in_boxes_and_center = (is_in_boxes & is_in_centers)

return is_in_boxes_or_centers, is_in_boxes_and_center


def dynamic_k_matching(cost, pair_wise_ious, num_gt, n_candidate_k):
matching_matrix = torch.zeros_like(cost)
ious_in_boxes_matrix = pair_wise_ious

topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=0)
dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(
cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
matching_matrix[:, gt_idx][pos_idx] = 1.0

del topk_ious, dynamic_ks, pos_idx

anchor_matching_gt = matching_matrix.sum(1)

if (anchor_matching_gt > 1).sum() > 0:
_, cost_argmin = torch.min(cost[anchor_matching_gt > 1], dim=1)
matching_matrix[anchor_matching_gt > 1] *= 0
matching_matrix[anchor_matching_gt > 1, cost_argmin, ] = 1

while (matching_matrix.sum(0) == 0).any():
matched_query_id = matching_matrix.sum(1) > 0
cost[matched_query_id] += 100000.0
unmatch_id = torch.nonzero(
matching_matrix.sum(0) == 0, as_tuple=False).squeeze(1)
for gt_idx in unmatch_id:
pos_idx = torch.argmin(cost[:, gt_idx])
matching_matrix[:, gt_idx][pos_idx] = 1.0
if (matching_matrix.sum(1) > 1).sum() > 0:
_, cost_argmin = torch.min(cost[anchor_matching_gt > 1], dim=1)
matching_matrix[anchor_matching_gt > 1] *= 0
matching_matrix[anchor_matching_gt > 1, cost_argmin, ] = 1

assert not (matching_matrix.sum(0) == 0).any()

matched_pos = []
for gt_idx in range(num_gt):
matched_pos.append(matching_matrix[:, gt_idx] > 0)

return matched_pos
211 changes: 211 additions & 0 deletions projects/VIS_SOTA/IDOL/idol_src/models/sim_ota_assigner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union

import torch
import torchvision.ops as ops
from mmdet.models.task_modules import AssignResult
from mmdet.models.task_modules import SimOTAAssigner as MMDET_SimOTAAssigner
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy
from mmengine import ConfigDict
from mmengine.structures import InstanceData
from torch import Tensor

from mmtrack.registry import TASK_UTILS

INF = 10000.0
EPS = 1.0e-7


@TASK_UTILS.register_module()
class SimOTAAssigner(MMDET_SimOTAAssigner):
"""Computes matching between predictions and ground truth.
Args:
center_radius (float): Ground truth center size
to judge whether a prior is in center. Defaults to 2.5.
candidate_topk (int): The candidate top-k which used to
get top-k ious to calculate dynamic-k. Defaults to 10.
iou_weight (float): The scale factor for regression
iou cost. Defaults to 3.0.
cls_weight (float): The scale factor for classification
cost. Defaults to 1.0.
iou_calculator (ConfigType): Config of overlaps Calculator.
Defaults to dict(type='BboxOverlaps2D').
"""

def __init__(self,
match_costs: Union[List[Union[dict, ConfigDict]], dict,
ConfigDict],
center_radius: float = 2.5,
candidate_topk: int = 10):

if isinstance(match_costs, dict):
match_costs = [match_costs]
elif isinstance(match_costs, list):
assert len(match_costs) > 0, \
'match_costs must not be a empty list.'

self.center_radius = center_radius
self.candidate_topk = candidate_topk
self.match_costs = [
TASK_UTILS.build(match_cost) for match_cost in match_costs
]

def assign(self,
pred_instances: InstanceData,
gt_instances: InstanceData,
img_meta: Optional[dict] = None,
**kwargs) -> AssignResult:

gt_bboxes = gt_instances.bboxes
gt_labels = gt_instances.labels
num_gt = gt_bboxes.size(0)

pred_bboxes = pred_instances.bboxes
priors = pred_instances.bboxes
num_bboxes = pred_bboxes.size(0)

# assign 0 by default
assigned_gt_inds = pred_bboxes.new_full((num_bboxes, ),
0,
dtype=torch.long)
assigned_labels = pred_bboxes.new_full((num_bboxes, ),
0,
dtype=torch.long)
if num_gt == 0 or num_bboxes == 0:
# No ground truth or boxes, return empty assignment
if num_gt == 0:
# No ground truth, assign all to background
assigned_gt_inds[:] = 0
return AssignResult(
num_gts=num_gt,
gt_inds=assigned_gt_inds,
max_overlaps=None,
labels=assigned_labels)

valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info(
priors, gt_bboxes)
pairwise_ious = ops.box_iou(
bbox_cxcywh_to_xyxy(pred_bboxes), bbox_cxcywh_to_xyxy(gt_bboxes))

# compute weighted cost
cost_list = []
for match_cost in self.match_costs:
cost = match_cost(
pred_instances=pred_instances,
gt_instances=gt_instances,
img_meta=img_meta)
cost_list.append(cost)
cost = torch.stack(cost_list).sum(dim=0)
cost += 100 * (~is_in_boxes_and_center)
cost[~valid_mask] = cost[~valid_mask] + INF

fg_mask_inboxes, matched_gt_inds, min_cost_qid = \
self.dynamic_k_matching(
cost, pairwise_ious, num_gt)

# convert to AssignResult format
assigned_gt_inds[fg_mask_inboxes] = matched_gt_inds + 1
assigned_labels[fg_mask_inboxes] = gt_labels[matched_gt_inds].long()

assign_res = AssignResult(
num_gt, assigned_gt_inds, None, labels=assigned_labels)
assign_res.set_extra_property('query_inds', min_cost_qid)
return assign_res

def get_in_gt_and_in_center_info(
self, priors: Tensor, gt_bboxes: Tensor) -> Tuple[Tensor, Tensor]:
"""Get the information of which prior is in gt bboxes and gt center
priors."""
num_gt = gt_bboxes.size(0)

repeated_x = priors[:, 0].unsqueeze(1).repeat(1, num_gt)
repeated_y = priors[:, 1].unsqueeze(1).repeat(1, num_gt)
repeated_stride_x = priors[:, 2].unsqueeze(1).repeat(1, num_gt)
repeated_stride_y = priors[:, 3].unsqueeze(1).repeat(1, num_gt)

# is prior centers in gt bboxes, shape: [n_prior, n_gt]
l_ = repeated_x - gt_bboxes[:, 0]
t_ = repeated_y - gt_bboxes[:, 1]
r_ = gt_bboxes[:, 2] - repeated_x
b_ = gt_bboxes[:, 3] - repeated_y

deltas = torch.stack([l_, t_, r_, b_], dim=1)
is_in_gts = deltas.min(dim=1).values > 0
is_in_gts_all = is_in_gts.sum(dim=1) > 0

# is prior centers in gt centers
gt_cxs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
gt_cys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
ct_box_l = gt_cxs - self.center_radius * repeated_stride_x
ct_box_t = gt_cys - self.center_radius * repeated_stride_y
ct_box_r = gt_cxs + self.center_radius * repeated_stride_x
ct_box_b = gt_cys + self.center_radius * repeated_stride_y

cl_ = repeated_x - ct_box_l
ct_ = repeated_y - ct_box_t
cr_ = ct_box_r - repeated_x
cb_ = ct_box_b - repeated_y

ct_deltas = torch.stack([cl_, ct_, cr_, cb_], dim=1)
is_in_cts = ct_deltas.min(dim=1).values > 0
is_in_cts_all = is_in_cts.sum(dim=1) > 0

# in boxes or in centers, shape: [num_priors]
is_in_gts_or_centers = is_in_gts_all | is_in_cts_all

# both in boxes and centers, shape: [num_fg, num_gt]
is_in_boxes_and_centers = (is_in_gts & is_in_cts)
return is_in_gts_or_centers, is_in_boxes_and_centers

def dynamic_k_matching(self, cost: Tensor, pairwise_ious: Tensor,
num_gt: int) -> Tuple[Tensor, Tensor]:
"""Use IoU and matching cost to calculate the dynamic top-k positive
targets."""
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
# select candidate topk ious for dynamic-k calculation
candidate_topk = min(self.candidate_topk, pairwise_ious.size(0))
topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
# calculate dynamic k for each gt
dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(
cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False)
matching_matrix[:, gt_idx][pos_idx] = 1

del topk_ious, dynamic_ks, pos_idx

prior_match_gt_mask = matching_matrix.sum(1) > 1
if prior_match_gt_mask.sum() > 0:
cost_min, cost_argmin = torch.min(
cost[prior_match_gt_mask, :], dim=1)
matching_matrix[prior_match_gt_mask, :] *= 0
matching_matrix[prior_match_gt_mask, cost_argmin] = 1

while (matching_matrix.sum(0) == 0).any():
matched_query_id = matching_matrix.sum(1) > 0
cost[matched_query_id] += 100000.0
unmatch_id = torch.nonzero(
matching_matrix.sum(0) == 0, as_tuple=False).squeeze(1)
for gt_idx in unmatch_id:
pos_idx = torch.argmin(cost[:, gt_idx])
matching_matrix[:, gt_idx][pos_idx] = 1.0
if (matching_matrix.sum(1) >
1).sum() > 0: # If a query matches more than one gt
# find gt for these queries with minimal cost
_, cost_argmin = torch.min(
cost[prior_match_gt_mask > 1], dim=1)
# reset mapping relationship
matching_matrix[prior_match_gt_mask > 1] *= 0
# keep gt with minimal cost
matching_matrix[prior_match_gt_mask > 1, cost_argmin, ] = 1

assert not (matching_matrix.sum(0) == 0).any()
# get foreground mask inside box and center prior
fg_mask_inboxes = matching_matrix.sum(1) > 0
matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)

cost[matching_matrix == 0] = cost[matching_matrix == 0] + float('inf')
min_cost_query_id = torch.min(cost, dim=0)[1]

return fg_mask_inboxes, matched_gt_inds, min_cost_query_id
184 changes: 184 additions & 0 deletions projects/VIS_SOTA/IDOL/idol_src/models/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.models.layers.transformer import \
DeformableDetrTransformer as MMDET_DeformableDetrTransformer

from mmtrack.registry import MODELS


@MODELS.register_module()
class DeformableDetrTransformer(MMDET_DeformableDetrTransformer):
"""Implements the DeformableDETR transformer.
Rewritten the forward function to return `memory`.
"""

def forward(self,
mlvl_feats,
mlvl_masks,
query_embed,
mlvl_pos_embeds,
reg_branches=None,
cls_branches=None,
**kwargs):
"""Forward function for `Transformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, embed_dims, h, w].
mlvl_masks (list(Tensor)): The key_padding_mask from
different level used for encoder and decoder,
each element has shape [bs, h, w].
query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
mlvl_pos_embeds (list(Tensor)): The positional encoding
of feats from different level, has the shape
[bs, embed_dims, h, w].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when
`with_box_refine` is True. Default to None.
cls_branches (obj:`nn.ModuleList`): Classification heads
for feature maps from each decoder layer. Only would
be passed when `as_two_stage`
is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims).
- memory: Output results from encoder, with shape \
(h*w, bs, embed_dims).
- init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \
proposals generated from \
encoder's feature maps, has shape \
(batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \
otherwise None.
- enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \
(batch, h*w, 4). Only would \
be returned when `as_two_stage` is True, \
otherwise None.
"""
assert self.as_two_stage or query_embed is not None

feat_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (feat, mask, pos_embed) in enumerate(
zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
bs, c, h, w = feat.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
feat = feat.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
feat_flatten.append(feat)
mask_flatten.append(mask)
feat_flatten = torch.cat(feat_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=feat_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack(
[self.get_valid_ratio(m) for m in mlvl_masks], 1)

reference_points = \
self.get_reference_points(spatial_shapes,
valid_ratios,
device=feat.device)

feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(
1, 0, 2) # (H*W, bs, embed_dims)
memory = self.encoder(
query=feat_flatten,
key=None,
value=None,
query_pos=lvl_pos_embed_flatten,
query_key_padding_mask=mask_flatten,
spatial_shapes=spatial_shapes,
reference_points=reference_points,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
**kwargs)

memory = memory.permute(1, 0, 2)
bs, _, c = memory.shape
if self.as_two_stage:
output_memory, output_proposals = \
self.gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes)
enc_outputs_class = cls_branches[self.decoder.num_layers](
output_memory)
enc_outputs_coord_unact = \
reg_branches[
self.decoder.num_layers](output_memory) + output_proposals

topk = self.two_stage_num_proposals
# We only use the first channel in enc_outputs_class as foreground,
# the other (num_classes - 1) channels are actually not used.
# Its targets are set to be 0s, which indicates the first
# class (foreground) because we use [0, num_classes - 1] to
# indicate class labels, background class is indicated by
# num_classes (similar convention in RPN).
# See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa
# This follows the official implementation of Deformable DETR.
topk_proposals = torch.topk(
enc_outputs_class[..., 0], topk, dim=1)[1]
topk_coords_unact = torch.gather(
enc_outputs_coord_unact, 1,
topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
topk_coords_unact = topk_coords_unact.detach()
reference_points = topk_coords_unact.sigmoid()
init_reference_out = reference_points
pos_trans_out = self.pos_trans_norm(
self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
query_pos, query = torch.split(pos_trans_out, c, dim=2)
else:
query_pos, query = torch.split(query_embed, c, dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
query = query.unsqueeze(0).expand(bs, -1, -1)
reference_points = self.reference_points(query_pos).sigmoid()
init_reference_out = reference_points

# decoder
query = query.permute(1, 0, 2)
memory = memory.permute(1, 0, 2)
query_pos = query_pos.permute(1, 0, 2)
inter_states, inter_references = self.decoder(
query=query,
key=None,
value=memory,
query_pos=query_pos,
key_padding_mask=mask_flatten,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
reg_branches=reg_branches,
**kwargs)

inter_references_out = inter_references
if self.as_two_stage:
return inter_states, memory, init_reference_out,\
inter_references_out, enc_outputs_class,\
enc_outputs_coord_unact
return inter_states, memory, init_reference_out, \
inter_references_out, None, None
246 changes: 246 additions & 0 deletions projects/VIS_SOTA/IDOL/idol_src/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Linear
from torchvision.ops.boxes import box_area


class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""

def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x


def _expand(tensor, length: int):
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)


class MaskHeadSmallConv(nn.Module):
"""Simple convolutional head, using group norm.
Upsampling is done using a FPN approach
"""

def __init__(self, dim, fpn_dims, context_dim):
super().__init__()

inter_dims = [
dim, context_dim, context_dim, context_dim, context_dim,
context_dim
]

# used after upsampling to reduce dimension of fused features!
self.lay1 = torch.nn.Conv2d(dim, dim // 4, 3, padding=1)
self.lay2 = torch.nn.Conv2d(dim // 4, dim // 32, 3, padding=1)
self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
self.dcn = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
self.dim = dim

if fpn_dims is not None:
self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1)

for name, m in self.named_modules():
if name == 'conv_offset':
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)
else:
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, a=1)
nn.init.constant_(m.bias, 0)

def forward(self, x, fpns):

if fpns is not None:
cur_fpn = self.adapter1(fpns[0])
if cur_fpn.size(0) != x[-1].size(0):
cur_fpn = _expand(cur_fpn, x[-1].size(0) // cur_fpn.size(0))
fused_x = (cur_fpn + x[-1]) / 2
else:
fused_x = x[-1]
fused_x = self.lay3(fused_x)
fused_x = F.relu(fused_x)

if fpns is not None:
cur_fpn = self.adapter2(fpns[1])
if cur_fpn.size(0) != x[-2].size(0):
cur_fpn = _expand(cur_fpn, x[-2].size(0) // cur_fpn.size(0))
fused_x = (cur_fpn + x[-2]) / 2 + F.interpolate(
fused_x, size=cur_fpn.shape[-2:], mode='nearest')
else:
fused_x = x[-2] + F.interpolate(
fused_x, size=x[-2].shape[-2:], mode='nearest')
fused_x = self.lay4(fused_x)
fused_x = F.relu(fused_x)

if fpns is not None:
cur_fpn = self.adapter3(fpns[2])
if cur_fpn.size(0) != x[-3].size(0):
cur_fpn = _expand(cur_fpn, x[-3].size(0) // cur_fpn.size(0))
fused_x = (cur_fpn + x[-3]) / 2 + F.interpolate(
fused_x, size=cur_fpn.shape[-2:], mode='nearest')
else:
fused_x = x[-3] + F.interpolate(
fused_x, size=x[-3].shape[-2:], mode='nearest')
fused_x = self.dcn(fused_x)
fused_x = F.relu(fused_x)
fused_x = self.lay1(fused_x)
fused_x = F.relu(fused_x)
fused_x = self.lay2(fused_x)
fused_x = F.relu(fused_x)

return fused_x


def compute_locations(h, w, device, stride=1):

shifts_x = torch.arange(
0, w * stride, step=stride, dtype=torch.float32, device=device)

shifts_y = torch.arange(
0, h * stride, step=stride, dtype=torch.float32, device=device)

shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2

return locations


def parse_dynamic_params(params, channels, weight_nums, bias_nums):

assert params.dim() == 2
assert len(weight_nums) == len(bias_nums)
assert params.size(1) == sum(weight_nums) + sum(bias_nums)

num_insts = params.size(0)
num_layers = len(weight_nums)

params_splits = list(
torch.split_with_sizes(params, weight_nums + bias_nums, dim=1))

weight_splits = params_splits[:num_layers]
bias_splits = params_splits[num_layers:]

for layer in range(num_layers):
if layer < num_layers - 1:
# out_channels x in_channels x 1 x 1
weight_splits[layer] = weight_splits[layer].reshape(
num_insts * channels, -1, 1, 1)
bias_splits[layer] = bias_splits[layer].reshape(num_insts *
channels)
else:
# out_channels x in_channels x 1 x 1
weight_splits[layer] = weight_splits[layer].reshape(
num_insts * 1, -1, 1, 1)
bias_splits[layer] = bias_splits[layer].reshape(num_insts)

return weight_splits, bias_splits


def aligned_bilinear(tensor, factor):

assert tensor.dim() == 4
assert factor >= 1
assert int(factor) == factor

if factor == 1:
return tensor

h, w = tensor.size()[2:]
tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode='replicate')
oh = factor * h + 1
ow = factor * w + 1
tensor = F.interpolate(
tensor, size=(oh, ow), mode='bilinear', align_corners=True)
tensor = F.pad(
tensor, pad=(factor // 2, 0, factor // 2, 0), mode='replicate')

return tensor[:, :, :oh - 1, :ow - 1]


def mask_iou(mask1, mask2):
mask1 = mask1.char()
mask2 = mask2.char()

intersection = (mask1[:, :, :] * mask2[:, :, :]).sum(-1).sum(-1)
union = (mask1[:, :, :] + mask2[:, :, :] -
mask1[:, :, :] * mask2[:, :, :]).sum(-1).sum(-1)

return (intersection + 1e-6) / (union + 1e-6)


def mask_nms(seg_masks, scores, nms_thr=0.5):
n_samples = len(scores)
if n_samples == 0:
return []
keep = [True for i in range(n_samples)]
seg_masks = seg_masks.sigmoid() > 0.5

for i in range(n_samples - 1):
if not keep[i]:
continue
mask_i = seg_masks[i]
for j in range(i + 1, n_samples, 1):
if not keep[j]:
continue
mask_j = seg_masks[j]

iou = mask_iou(mask_i, mask_j)[0]
if iou > nms_thr:
keep[j] = False
return keep


def box_iou(boxes1, boxes2):
# modified from torchvision to also return the union
area1 = box_area(boxes1)
area2 = box_area(boxes2)

lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]

wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]

union = area1[:, None] + area2 - inter

iou = inter / union
return iou, union


def generalized_box_iou(boxes1, boxes2):
"""Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = box_iou(boxes1, boxes2)

lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]

return iou - (area - union) / (area + 1e-7)
42 changes: 42 additions & 0 deletions projects/VIS_SOTA/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# VIS (Video Instance Segmentation)

## Introduction

The goal of VIS task is simultaneous detection, segmentation and tracking of instances in videos. Here we implements `IDOL` based on contrastive learning and `VITA` series algorithms based on `Mask2Former`. Currently it provides advanced online and offline video instance segmentation algorithms. With a commitment to advancing the field of video instance segmentation, we will continually refine and enhance our framework to ensure it is both unified and efficient, providing the necessary nourishment for growth and development in this area.

In recent years, the online methods for video instance segmentation have witnessed significant advancements, largely attributed to the improvements in image-level object detection algorithms. Meanwhile, semi-online and offline paradigms are tapping into the vast potential offered by the temporal context in multiple frames, offering a more comprehensive approach to video analysis.

## Requirements

At the outset of this project, the dependencies used were as follows. Of course, this does not mean that you must strictly use the libraries and algorithms dependencies with the following versions. This is just a recommendation to make your use easier.

```
mmcv==2.0.0rc4
mmdet==3.0.0rc4
mmengine==0.4.0
```

## Citation

```BibTeX
@inproceedings{IDOL,
title={In Defense of Online Models for Video Instance Segmentation},
author={Wu, Junfeng and Liu, Qihao and Jiang, Yi and Bai, Song and Yuille, Alan and Bai, Xiang},
booktitle={ECCV},
year={2022},
}
@inproceedings{GenVIS,
title={A Generalized Framework for Video Instance Segmentation},
author={Heo, Miran and Hwang, Sukjun and Hyun, Jeongseok and Kim, Hanjung and Oh, Seoung Wug and Lee, Joon-Young and Kim, Seon Joo},
booktitle={arXiv preprint arXiv:2211.08834},
year={2022}
}
@inproceedings{VITA,
title={VITA: Video Instance Segmentation via Object Token Association},
author={Heo, Miran and Hwang, Sukjun and Oh, Seoung Wug and Lee, Joon-Young and Kim, Seon Joo},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
```
120 changes: 120 additions & 0 deletions projects/VIS_SOTA/VITA/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# VITA: Video Instance Segmentation via Object Token Association

## Description

This is an implementation of [VITA](https://github.com/sukjunhwang/VITA) based on [MMTracking](https://github.com/open-mmlab/mmtracking/tree/1.x), [MMDetection](https://github.com/open-mmlab/mmdetection/tree/3.x), [MMCV](https://github.com/open-mmlab/mmcv), and [MMEngine](https://github.com/open-mmlab/mmengine).

We introduce a novel paradigm for offline Video Instance Segmentation (VIS), based on the hypothesis that explicit object-oriented information can be a strong clue for understanding the context of the entire sequence. To this end, we proposeVITA, a simple structure built on top of an off-the-shelf Transformer-based image instance segmentation model. Specifically, we use an image object detector as a means of distilling object-specific contexts into object tokens. VITA accomplishes video-level understanding by associating frame-level object tokens without using spatio-temporal backbone features. By effectively building relationships between objects using the condensed information, VITA achieves the state-of-the-art on VIS benchmarks with a ResNet-50 backbone: 49.8 AP, 45.7 AP on YouTube-VIS 2019 & 2021 and 19.6 AP on OVIS. Moreover, thanks to its object token-based structure that is disjoint from the backbone features, VITA shows several practical advantages that previous offline VIS methods have not explored - handling long and high-resolution videos with a common GPU and freezing a frame-level detector trained on image domain.

<center>
<img src="https://github.com/sukjunhwang/VITA/blob/main/vita_teaser.png">
</center>

## Usage

<!-- For a typical model, this section should contain the commands for training and testing. You are also suggested to dump your environment specification to env.yml by `conda env export > env.yml`. -->

### Training commands

In MMTracking's root directory, run the following command to train the model:

```bash
python tools/train.py projects/VIS_SOTA/IDOL/configs/vita_r50_8xb2-8e_youtubevis2019.py
```

For multi-gpu training, run:

```bash
python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${NUM_GPUS} --master_port=29506 --master_addr="127.0.0.1" tools/train.py projects/VIS_SOTA/IDOL/configs/vita_r50_8xb2-8e_youtubevis2019.py
```

### Testing commands

In MMTracking's root directory, run the following command to test the model:

```bash
python tools/test.py projects/VIS_SOTA/IDOL/configs/vita_r50_8xb2-8e_youtubevis2019.py ${CHECKPOINT_PATH}
```

## Results

#### YouTubeVIS-2019

| Method | Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | AP | Config | Download |
| :----: | :------: | :-----: | :-----: | :------: | :------------: | :--: | :-------------------------------------------------------------------------: | :----------------------: |
| VITA | R-50 | pytorch | 140k | 10.0 | - | 50.3 | [config](projects/VIS_SOTA/VITA/configs/vita_r50_8xb2-8e_youtubevis2019.py) | [model](<>) \| [log](<>) |

#### OVIS

| Method | Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | AP | Config | Download |
| :----: | :------: | :-----: | :-----: | :------: | :------------: | :--: | :----------: | :----------------------: |
| VITA | R-50 | pytorch | 150k | 10.0 | - | 19.2 | [config](<>) | [model](<>) \| [log](<>) |

## Citation

If you find VITA is useful in your research or applications, please consider giving a star 🌟 to the [official repository](https://github.com/sukjunhwang/VITA) and citing VITA by the following BibTeX entry.

```BibTeX
@inproceedings{VITA,
title={VITA: Video Instance Segmentation via Object Token Association},
author={Heo, Miran and Hwang, Sukjun and Oh, Seoung Wug and Lee, Joon-Young and Kim, Seon Joo},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
```

## Checklist

<!-- Here is a checklist illustrating a usual development workflow of a successful project, and also serves as an overview of this project's progress. The PIC (person in charge) or contributors of this project should check all the items that they believe have been finished, which will further be verified by codebase maintainers via a PR.
OpenMMLab's maintainer will review the code to ensure the project's quality. Reaching the first milestone means that this project suffices the minimum requirement of being merged into 'projects/'. But this project is only eligible to become a part of the core package upon attaining the last milestone.
Note that keeping this section up-to-date is crucial not only for this project's developers but the entire community, since there might be some other contributors joining this project and deciding their starting point from this list. It also helps maintainers accurately estimate time and effort on further code polishing, if needed.
A project does not necessarily have to be finished in a single PR, but it's essential for the project to at least reach the first milestone in its very first PR. -->

- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.

- [x] Finish the code

<!-- The code's design shall follow existing interfaces and convention. For example, each model component should be registered into `mmdet.registry.MODELS` and configurable via a config file. -->

- [x] Basic docstrings & proper citation

<!-- Each major object should contain a docstring, describing its functionality and arguments. If you have adapted the code from other open-source projects, don't forget to cite the source project in docstring and make sure your behavior is not against its license. Typically, we do not accept any code snippet under GPL license. [A Short Guide to Open Source Licenses](https://medium.com/nationwide-technology/a-short-guide-to-open-source-licenses-cf5b1c329edd) -->

- [x] Test-time correctness

<!-- If you are reproducing the result from a paper, make sure your model's inference-time performance matches that in the original paper. The weights usually could be obtained by simply renaming the keys in the official pre-trained weights. This test could be skipped though, if you are able to prove the training-time correctness and check the second milestone. -->

- [x] A full README

<!-- As this template does. -->

- [x] Milestone 2: Indicates a successful model implementation.

- [x] Training-time correctness

<!-- If you are reproducing the result from a paper, checking this item means that you should have trained your model from scratch based on the original paper's specification and verified that the final result matches the report within a minor error range. -->

- [ ] Milestone 3: Good to be a part of our core package!

- [ ] Type hints and docstrings

<!-- Ideally *all* the methods should have [type hints](https://www.pythontutorial.net/python-basics/python-type-hints/) and [docstrings](https://google.github.io/styleguide/pyguide.html#381-docstrings). [Example](https://github.com/open-mmlab/mmdetection/blob/5b0d5b40d5c6cfda906db7464ca22cbd4396728a/mmdet/datasets/transforms/transforms.py#L41-L169) -->

- [ ] Unit tests

<!-- Unit tests for each module are required. [Example](https://github.com/open-mmlab/mmdetection/blob/5b0d5b40d5c6cfda906db7464ca22cbd4396728a/tests/test_datasets/test_transforms/test_transforms.py#L35-L88) -->

- [ ] Code polishing

<!-- Refactor your code according to reviewer's comment. -->

- [ ] Metafile.yml

<!-- It will be parsed by MIM and Inferencer. [Example](https://github.com/open-mmlab/mmdetection/blob/3.x/configs/faster_rcnn/metafile.yml) -->

- [ ] Move your modules into the core package following the codebase's file hierarchy structure.

<!-- In particular, you may have to refactor this README into a standard one. [Example](https://github.com/open-mmlab/mmdetection/blob/3.x/configs/faster_rcnn/README.md) -->

- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
279 changes: 279 additions & 0 deletions projects/VIS_SOTA/VITA/configs/vita_r50_8xb2-8e_youtubevis2019.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
_base_ = [
'../../../../configs/_base_/datasets/youtube_vis.py',
'../../../../configs/_base_/default_runtime.py',
]

custom_imports = dict(imports=['projects.VIS_SOTA.VITA.vita_src'], )

num_classes = 40
num_frames = 2
model = dict(
type='VITA',
data_preprocessor=dict(
type='TrackDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_mask=True,
pad_size_divisor=32),
backbone=dict(
_scope_='mmdet',
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
seg_head=dict(
type='mmtrack.VITASegHead',
in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside
strides=[4, 8, 16, 32],
feat_channels=256,
out_channels=256,
num_things_classes=num_classes,
num_stuff_classes=0,
num_queries=100,
num_frames=num_frames,
num_transformer_feat_level=3,
pixel_decoder=dict(
type='mmtrack.VITAPixelDecoder',
num_outs=3,
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=dict(
type='mmdet.DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention',
embed_dims=256,
num_heads=8,
num_levels=3,
num_points=4,
im2col_step=128,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None),
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.0,
act_cfg=dict(type='ReLU', inplace=True)),
operation_order=('self_attn', 'norm', 'ffn', 'norm')),
init_cfg=None),
positional_encoding=dict(
type='mmdet.SinePositionalEncoding',
num_feats=128,
normalize=True),
init_cfg=None),
enforce_decoder_input_project=False,
positional_encoding=dict(
type='mmdet.SinePositionalEncoding', num_feats=128,
normalize=True),
transformer_decoder=dict(
type='mmdet.DetrTransformerDecoder',
return_intermediate=True,
num_layers=9,
transformerlayers=dict(
type='mmdet.DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
dropout_layer=None,
batch_first=False),
ffn_cfgs=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.0,
dropout_layer=None,
add_identity=True),
feedforward_channels=2048,
operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
'ffn', 'norm')),
init_cfg=None),
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
loss_weight=2.0,
reduction='mean',
class_weight=[1.0] * num_classes + [0.1]),
loss_mask=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=5.0),
loss_dice=dict(
type='mmdet.DiceLoss',
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=5.0)),
track_head=dict(
type='VITATrackHead',
num_classes=num_classes,
frame_query_encoder=dict(
type='mmdet.DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.0,
batch_first=False),
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.0,
act_cfg=dict(type='ReLU', inplace=True)),
operation_order=('self_attn', 'norm', 'ffn', 'norm')),
init_cfg=None),
frame_query_decoder=dict(
type='mmdet.DetrTransformerDecoder',
return_intermediate=True,
num_layers=3,
transformerlayers=dict(
type='mmdet.DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
dropout_layer=None,
batch_first=False),
ffn_cfgs=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.0,
dropout_layer=None,
add_identity=True),
feedforward_channels=2048,
operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
'ffn', 'norm')),
init_cfg=None),
),
train_cfg=dict(
num_points=12544,
oversample_ratio=3.0,
importance_sample_ratio=0.75,
assigner=dict(
type='mmdet.HungarianAssigner',
match_costs=[
dict(type='mmdet.ClassificationCost', weight=2.0),
dict(
type='mmdet.CrossEntropyLossCost',
weight=5.0,
use_sigmoid=True),
dict(
type='mmdet.DiceCost', weight=5.0, pred_act=True, eps=1.0)
]),
sampler=dict(type='mmdet.MaskPseudoSampler')),
test_cfg=dict(
test_run_chunk_size=18,
test_interpolate_chunk_size=5,
max_per_image=100,
))

# optimizer
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=0.0001,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999)),
paramwise_cfg=dict(
custom_keys={
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi,
},
norm_decay_mult=0.0),
clip_grad=dict(max_norm=0.01, norm_type=2))

# learning policy
max_iters = 6000
param_scheduler = dict(
type='MultiStepLR',
begin=0,
end=max_iters,
by_epoch=False,
milestones=[
4000,
],
gamma=0.1)
# runtime settings
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=max_iters, val_interval=6001)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook', by_epoch=False, save_last=True, interval=2000))
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=False)

train_pipeline = [
dict(
type='TransformBroadcaster',
share_random_params=True,
transforms=[
dict(type='LoadImageFromFile'),
dict(
type='LoadTrackAnnotations',
with_instance_id=True,
with_mask=True,
with_bbox=True),
dict(type='mmdet.Resize', scale=(640, 360), keep_ratio=True),
dict(type='mmdet.RandomFlip', prob=0.5),
]),
dict(type='PackTrackInputs', num_key_frames=num_frames)
]
# dataloader
train_dataloader = dict(
batch_size=2,
num_workers=2,
dataset=dict(
pipeline=train_pipeline,
ref_img_sampler=dict(
num_ref_imgs=1,
frame_range=5,
filter_key_img=True,
method='uniform')))
val_dataloader = dict(
num_workers=2,
sampler=dict(type='VideoSampler'),
batch_sampler=dict(type='EntireVideoBatchSampler'),
)
test_dataloader = val_dataloader

# evaluator
val_evaluator = dict(
type='YouTubeVISMetric',
metric='youtube_vis_ap',
outfile_prefix='./youtube_vis_results',
format_only=True)
test_evaluator = val_evaluator
5 changes: 5 additions & 0 deletions projects/VIS_SOTA/VITA/vita_src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .models import VITATrackHead
from .vita import VITA

__all__ = ['VITA', 'VITATrackHead']
6 changes: 6 additions & 0 deletions projects/VIS_SOTA/VITA/vita_src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .vita_pixel_decoder import VITAPixelDecoder
from .vita_query_track_head import VITATrackHead
from .vita_seg_head import VITASegHead

__all__ = ['VITATrackHead', 'VITASegHead', 'VITAPixelDecoder']
125 changes: 125 additions & 0 deletions projects/VIS_SOTA/VITA/vita_src/models/vita_pixel_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple

import torch
import torch.nn.functional as F
from mmdet.models.layers import \
MSDeformAttnPixelDecoder as MMDET_MSDeformAttnPixelDecoder
from torch import Tensor

from mmtrack.registry import MODELS


@MODELS.register_module()
class VITAPixelDecoder(MMDET_MSDeformAttnPixelDecoder):

def forward(self, feats: List[Tensor]) -> Tuple[Tensor, Tensor]:
"""
Args:
feats (list[Tensor]): Feature maps of each level. Each has
shape of (batch_size, c, h, w).
Returns:
tuple: A tuple containing the following:
- mask_feature (Tensor): shape (batch_size, c, h, w).
- multi_scale_features (list[Tensor]): Multi scale \
features, each in shape (batch_size, c, h, w).
"""
# generate padding mask for each level, for each image
batch_size = feats[0].shape[0]
encoder_input_list = []
padding_mask_list = []
level_positional_encoding_list = []
spatial_shapes = []
reference_points_list = []
for i in range(self.num_encoder_levels):
level_idx = self.num_input_levels - i - 1
feat = feats[level_idx]
feat_projected = self.input_convs[i](feat)
h, w = feat.shape[-2:]

# no padding
padding_mask_resized = feat.new_zeros(
(batch_size, ) + feat.shape[-2:], dtype=torch.bool)
pos_embed = self.postional_encoding(padding_mask_resized)
level_embed = self.level_encoding.weight[i]
level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed
# (h_i * w_i, 2)
reference_points = self.point_generator.single_level_grid_priors(
feat.shape[-2:], level_idx, device=feat.device)
# normalize
factor = feat.new_tensor([[w, h]]) * self.strides[level_idx]
reference_points = reference_points / factor

# shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c)
feat_projected = feat_projected.flatten(2).permute(2, 0, 1)
level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1)
padding_mask_resized = padding_mask_resized.flatten(1)

encoder_input_list.append(feat_projected)
padding_mask_list.append(padding_mask_resized)
level_positional_encoding_list.append(level_pos_embed)
spatial_shapes.append(feat.shape[-2:])
reference_points_list.append(reference_points)
# shape (batch_size, total_num_query),
# total_num_query=sum([., h_i * w_i,.])
padding_masks = torch.cat(padding_mask_list, dim=1)
# shape (total_num_query, batch_size, c)
encoder_inputs = torch.cat(encoder_input_list, dim=0)
level_positional_encodings = torch.cat(
level_positional_encoding_list, dim=0)
device = encoder_inputs.device
# shape (num_encoder_levels, 2), from low
# resolution to high resolution
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=device)
# shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
reference_points = torch.cat(reference_points_list, dim=0)
reference_points = reference_points[None, :, None].repeat(
batch_size, 1, self.num_encoder_levels, 1)
valid_radios = reference_points.new_ones(
(batch_size, self.num_encoder_levels, 2))
# shape (num_total_query, batch_size, c)
memory = self.encoder(
query=encoder_inputs,
key=None,
value=None,
query_pos=level_positional_encodings,
key_pos=None,
attn_masks=None,
key_padding_mask=None,
query_key_padding_mask=padding_masks,
spatial_shapes=spatial_shapes,
reference_points=reference_points,
level_start_index=level_start_index,
valid_radios=valid_radios)
# (num_total_query, batch_size, c) -> (batch_size, c, num_total_query)
memory = memory.permute(1, 2, 0)

# from low resolution to high resolution
num_query_per_level = [e[0] * e[1] for e in spatial_shapes]
outs = torch.split(memory, num_query_per_level, dim=-1)
outs = [
x.reshape(batch_size, -1, spatial_shapes[i][0],
spatial_shapes[i][1]) for i, x in enumerate(outs)
]

for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
-1):
x = feats[i]
cur_feat = self.lateral_convs[i](x)
y = cur_feat + F.interpolate(
outs[-1],
size=cur_feat.shape[-2:],
mode='bilinear',
align_corners=False)
y = self.output_convs[i](y)
outs.append(y)
multi_scale_features = outs[:self.num_outs]

clip_mask_feature = outs[-1]
mask_feature = self.mask_feature(outs[-1])
return mask_feature, clip_mask_feature, multi_scale_features
Loading