From 96f0538533612f2649b0e33bc8a45142eb78958f Mon Sep 17 00:00:00 2001 From: pixeli Date: Thu, 9 Feb 2023 15:33:46 +0800 Subject: [PATCH] add --- projects/VIS_SOTA/IDOL/README.md | 120 +++ .../VIS_SOTA/IDOL/configs/coco_instance.py | 64 ++ .../configs/idol_r50_8xb2-16e_coco-seq.py | 136 +++ .../idol_r50_8xb4-12k_youtubevis2019.py | 227 +++++ projects/VIS_SOTA/IDOL/idol_src/__init__.py | 5 + projects/VIS_SOTA/IDOL/idol_src/idol.py | 96 ++ .../VIS_SOTA/IDOL/idol_src/models/__init__.py | 10 + .../idol_src/models/idol_query_track_head.py | 929 ++++++++++++++++++ .../IDOL/idol_src/models/idol_tracker.py | 328 +++++++ .../IDOL/idol_src/models/pos_neg_select.py | 232 +++++ .../IDOL/idol_src/models/sim_ota_assigner.py | 211 ++++ .../IDOL/idol_src/models/transformer.py | 184 ++++ .../VIS_SOTA/IDOL/idol_src/models/utils.py | 246 +++++ projects/VIS_SOTA/README.md | 42 + projects/VIS_SOTA/VITA/README.md | 120 +++ .../vita_r50_8xb2-8e_youtubevis2019.py | 279 ++++++ projects/VIS_SOTA/VITA/vita_src/__init__.py | 5 + .../VIS_SOTA/VITA/vita_src/models/__init__.py | 6 + .../vita_src/models/vita_pixel_decoder.py | 125 +++ .../vita_src/models/vita_query_track_head.py | 390 ++++++++ .../VITA/vita_src/models/vita_seg_head.py | 221 +++++ projects/VIS_SOTA/VITA/vita_src/vita.py | 125 +++ 22 files changed, 4101 insertions(+) create mode 100644 projects/VIS_SOTA/IDOL/README.md create mode 100644 projects/VIS_SOTA/IDOL/configs/coco_instance.py create mode 100644 projects/VIS_SOTA/IDOL/configs/idol_r50_8xb2-16e_coco-seq.py create mode 100644 projects/VIS_SOTA/IDOL/configs/idol_r50_8xb4-12k_youtubevis2019.py create mode 100644 projects/VIS_SOTA/IDOL/idol_src/__init__.py create mode 100644 projects/VIS_SOTA/IDOL/idol_src/idol.py create mode 100644 projects/VIS_SOTA/IDOL/idol_src/models/__init__.py create mode 100644 projects/VIS_SOTA/IDOL/idol_src/models/idol_query_track_head.py create mode 100644 projects/VIS_SOTA/IDOL/idol_src/models/idol_tracker.py create mode 100644 projects/VIS_SOTA/IDOL/idol_src/models/pos_neg_select.py create mode 100644 projects/VIS_SOTA/IDOL/idol_src/models/sim_ota_assigner.py create mode 100644 projects/VIS_SOTA/IDOL/idol_src/models/transformer.py create mode 100644 projects/VIS_SOTA/IDOL/idol_src/models/utils.py create mode 100644 projects/VIS_SOTA/README.md create mode 100644 projects/VIS_SOTA/VITA/README.md create mode 100644 projects/VIS_SOTA/VITA/configs/vita_r50_8xb2-8e_youtubevis2019.py create mode 100644 projects/VIS_SOTA/VITA/vita_src/__init__.py create mode 100644 projects/VIS_SOTA/VITA/vita_src/models/__init__.py create mode 100644 projects/VIS_SOTA/VITA/vita_src/models/vita_pixel_decoder.py create mode 100644 projects/VIS_SOTA/VITA/vita_src/models/vita_query_track_head.py create mode 100644 projects/VIS_SOTA/VITA/vita_src/models/vita_seg_head.py create mode 100644 projects/VIS_SOTA/VITA/vita_src/vita.py diff --git a/projects/VIS_SOTA/IDOL/README.md b/projects/VIS_SOTA/IDOL/README.md new file mode 100644 index 000000000..1bf06abe1 --- /dev/null +++ b/projects/VIS_SOTA/IDOL/README.md @@ -0,0 +1,120 @@ +# IDOL: In Defense of Online Models for Video Instance Segmentation + +## Description + +This is an implementation of [IDOL](https://github.com/wjf5203/VNext.git) based on [MMTracking](https://github.com/open-mmlab/mmtracking/tree/1.x), [MMDetection](https://github.com/open-mmlab/mmdetection/tree/3.x), [MMCV](https://github.com/open-mmlab/mmcv), and [MMEngine](https://github.com/open-mmlab/mmengine). + +In recent years, video instance segmentation (VIS) has been largely advanced by offline models, while online models are usually inferior to the contemporaneous offline models by over 10 AP, which is a huge drawback. By dissecting current online models and offline models, we demonstrate that the main cause of the performance gap is the error-prone association and propose IDOL, which outperforms all online and offline methods on three benchmarks. IDOL won first place in the video instance segmentation track of the 4th Large-scale Video Object Segmentation Challenge (CVPR2022). + +
+ +
+ +## Usage + + + +### Training commands + +In MMTracking's root directory, run the following command to train the model: + +```bash +python tools/train.py projects/VIS_SOTA/IDOL/configs/idol_r50_8xb4-12k_youtubevis2019.py +``` + +For multi-gpu training, run: + +```bash +python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${NUM_GPUS} --master_port=29506 --master_addr="127.0.0.1" tools/train.py projects/VIS_SOTA/IDOL/configs/idol_r50_8xb4-12k_youtubevis2019.py +``` + +### Testing commands + +In MMTracking's root directory, run the following command to test the model: + +```bash +python tools/test.py projects/VIS_SOTA/IDOL/configs/idol_r50_8xb4-12k_youtubevis2019.py ${CHECKPOINT_PATH} +``` + +## Results + +#### YouTubeVIS-2019 + +| Method | Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | AP | Config | Download | +| :----: | :------: | :-----: | :-----: | :------: | :------------: | :--: | :--------------------------------------------------------------------------: | :----------------------: | +| IDOL | R-50 | pytorch | 12k | 27.0 | - | 49.3 | [config](projects/VIS_SOTA/IDOL/configs/idol_r50_8xb4-12k_youtubevis2019.py) | [model](<>) \| [log](<>) | + +#### OVIS + +| Method | Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | AP | Config | Download | +| :----: | :------: | :-----: | :-----: | :------: | :------------: | :--: | :----------: | :----------------------: | +| IDOL | R-50 | pytorch | 12k | 27.0 | - | 29.7 | [config](<>) | [model](<>) \| [log](<>) | + +## Citation + +If you find IDOL is useful in your research or applications, please consider giving a star 🌟 to the [official repository](https://github.com/wjf5203/VNext) and citing IDOL by the following BibTeX entry. + +```BibTeX +@inproceedings{IDOL, + title={In Defense of Online Models for Video Instance Segmentation}, + author={Wu, Junfeng and Liu, Qihao and Jiang, Yi and Bai, Song and Yuille, Alan and Bai, Xiang}, + booktitle={ECCV}, + year={2022}, +} + +``` + +## Checklist + + + +- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [x] Finish the code + + + + - [x] Basic docstrings & proper citation + + + + - [x] Test-time correctness + + + + - [x] A full README + + + +- [x] Milestone 2: Indicates a successful model implementation. + + - [x] Training-time correctness + + + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Type hints and docstrings + + + + - [ ] Unit tests + + + + - [ ] Code polishing + + + + - [ ] Metafile.yml + + + +- [ ] Move your modules into the core package following the codebase's file hierarchy structure. + + + +- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/VIS_SOTA/IDOL/configs/coco_instance.py b/projects/VIS_SOTA/IDOL/configs/coco_instance.py new file mode 100644 index 000000000..29992ab0f --- /dev/null +++ b/projects/VIS_SOTA/IDOL/configs/coco_instance.py @@ -0,0 +1,64 @@ +# dataset settings +dataset_type = 'mmdet.CocoDataset' +data_root = 'data/coco/' + +# file_client_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +file_client_args = dict(backend='disk') + +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='LoadTrackAnnotations', with_bbox=True, with_mask=True), + dict(type='mmdet.Resize', scale=(1333, 800), keep_ratio=True), + dict(type='mmdet.RandomFlip', prob=0.5), + dict(type='PackTrackInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='mmdet.Resize', scale=(2000, 640), keep_ratio=True), + dict( + type='LoadTrackAnnotations', + with_instance_id=False, + with_bbox=True, + with_mask=True), + dict(type='PackTrackInputs', pack_single_img=True) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=dict(type='mmdet.AspectRatioBatchSampler'), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict( + _scope_='mmdet', + type='CocoMetric', + ann_file=data_root + 'annotations/instances_val2017.json', + metric=['bbox', 'segm'], + format_only=False) +test_evaluator = val_evaluator diff --git a/projects/VIS_SOTA/IDOL/configs/idol_r50_8xb2-16e_coco-seq.py b/projects/VIS_SOTA/IDOL/configs/idol_r50_8xb2-16e_coco-seq.py new file mode 100644 index 000000000..bda90914b --- /dev/null +++ b/projects/VIS_SOTA/IDOL/configs/idol_r50_8xb2-16e_coco-seq.py @@ -0,0 +1,136 @@ +_base_ = [ + './coco_instance.py', # noqa: E501 + '../../../../configs/_base_/default_runtime.py' +] + +custom_imports = dict(imports=['projects.VIS_SOTA.IDOL.idol_src'], ) + +model = dict( + type='IDOL', + data_preprocessor=dict( + type='TrackDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_mask=True, + pad_size_divisor=32), + backbone=dict( + type='mmdet.ResNet', + depth=50, + num_stages=4, + out_indices=(1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='mmdet.ChannelMapper', + in_channels=[512, 1024, 2048], + kernel_size=1, + out_channels=256, + act_cfg=None, + norm_cfg=dict(type='GN', num_groups=32), + num_outs=4), + track_head=dict( + _scope_='mmdet', + type='mmtrack.IDOLTrackHead', + num_query=300, + num_classes=80, + in_channels=2048, + with_box_refine=True, + sync_cls_avg_factor=True, + as_two_stage=False, + transformer=dict( + type='mmtrack.DeformableDetrTransformer', + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', embed_dims=256), + feedforward_channels=1024, + ffn_dropout=0.1, + operation_order=('self_attn', 'norm', 'ffn', 'norm'))), + decoder=dict( + type='DeformableDetrTransformerDecoder', + num_layers=6, + return_intermediate=True, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=[ + dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + dropout=0.1), + dict( + type='MultiScaleDeformableAttention', + embed_dims=256) + ], + feedforward_channels=1024, + ffn_dropout=0.1, + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', + 'ffn', 'norm')))), + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=128, + normalize=True, + offset=-0.5), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0), + loss_bbox=dict(type='L1Loss', loss_weight=5.0), + loss_iou=dict(type='GIoULoss', loss_weight=2.0)), + # training and testing settings + # can't del 'mmtrack' + train_cfg=dict( + assigner=dict(type='mmtrack.SimOTAAssigner', center_radius=2.5), + cur_train_mode='COCO_Video'), +) + +# optimizer +embed_multi = dict(lr_mult=1.0, decay_mult=0.0) +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', + lr=0.0001, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999)), + paramwise_cfg=dict( + custom_keys={ + 'backbone': dict(lr_mult=0.1, decay_mult=1.0), + 'query_embed': embed_multi, + 'query_feat': embed_multi, + 'level_embed': embed_multi, + }, + norm_decay_mult=0.0), + clip_grad=dict(max_norm=0.01, norm_type=2)) + +# learning policy +max_iters = 6000 +param_scheduler = dict( + type='MultiStepLR', + begin=0, + end=max_iters, + by_epoch=False, + milestones=[ + 4000, + ], + gamma=0.1) +# runtime settings +train_cfg = dict( + type='IterBasedTrainLoop', max_iters=max_iters, val_interval=6001) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', by_epoch=False, save_last=True, interval=2000)) +log_processor = dict(type='LogProcessor', window_size=50, by_epoch=False) diff --git a/projects/VIS_SOTA/IDOL/configs/idol_r50_8xb4-12k_youtubevis2019.py b/projects/VIS_SOTA/IDOL/configs/idol_r50_8xb4-12k_youtubevis2019.py new file mode 100644 index 000000000..1f1947865 --- /dev/null +++ b/projects/VIS_SOTA/IDOL/configs/idol_r50_8xb4-12k_youtubevis2019.py @@ -0,0 +1,227 @@ +_base_ = [ + '../../../../configs/_base_/datasets/youtube_vis.py', # noqa: E501 + '../../../../configs/_base_/default_runtime.py' +] + +custom_imports = dict(imports=['projects.VIS_SOTA.IDOL.idol_src'], ) + +model = dict( + type='IDOL', + data_preprocessor=dict( + type='TrackDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_mask=True, + pad_size_divisor=32), + backbone=dict( + type='mmdet.ResNet', + depth=50, + num_stages=4, + out_indices=(1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='mmdet.ChannelMapper', + in_channels=[512, 1024, 2048], + kernel_size=1, + out_channels=256, + act_cfg=None, + norm_cfg=dict(type='GN', num_groups=32), + num_outs=4), + track_head=dict( + _scope_='mmdet', + type='mmtrack.IDOLTrackHead', + num_query=300, + num_classes=40, + in_channels=2048, + with_box_refine=True, + sync_cls_avg_factor=True, + as_two_stage=False, + transformer=dict( + type='mmtrack.DeformableDetrTransformer', + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', embed_dims=256), + feedforward_channels=1024, + ffn_dropout=0.1, + operation_order=('self_attn', 'norm', 'ffn', 'norm'))), + decoder=dict( + type='DeformableDetrTransformerDecoder', + num_layers=6, + return_intermediate=True, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=[ + dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + dropout=0.1), + dict( + type='MultiScaleDeformableAttention', + embed_dims=256) + ], + feedforward_channels=1024, + ffn_dropout=0.1, + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', + 'ffn', 'norm')))), + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=128, + normalize=True, + offset=-0.5), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0), + loss_bbox=dict(type='L1Loss', loss_weight=5.0), + loss_iou=dict(type='GIoULoss', loss_weight=2.0), + loss_mask=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=20.0), + loss_dice=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=1.0), + loss_track=dict( + type='mmtrack.MultiPosCrossEntropyLoss', loss_weight=0.25), + loss_track_aux=dict( + type='mmtrack.L2Loss', + neg_pos_ub=3, + pos_margin=0, + neg_margin=0.1, + hard_mining=True, + loss_weight=1.0)), + tracker=dict( + type='IDOLTracker', + init_score_thr=0.2, + obj_score_thr=0.1, + nms_thr_pre=0.5, + nms_thr_post=0.05, + addnew_score_thr=0.2, + memo_tracklet_frames=10, + memo_momentum=0.8, + long_match=True, + frame_weight=True, + temporal_weight=True, + memory_len=3, + match_metric='bisoftmax'), + # training and testing settings + # can't del 'mmtrack' + train_cfg=dict( + assigner=dict( + type='mmtrack.SimOTAAssigner', + center_radius=2.5, + match_costs=[ + dict(type='FocalLossCost', weight=1.0), + dict(type='IoUCost', iou_mode='giou', weight=3.0) + ]), + cur_train_mode='VIS'), +) + +# optimizer +embed_multi = dict(lr_mult=1.0, decay_mult=0.0) +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', + lr=0.0001, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999)), + paramwise_cfg=dict( + custom_keys={ + 'backbone': dict(lr_mult=0.1, decay_mult=1.0), + 'query_embed': embed_multi, + 'query_feat': embed_multi, + 'level_embed': embed_multi, + }, + norm_decay_mult=0.0), + clip_grad=dict(max_norm=0.01, norm_type=2)) + +# learning policy +max_iters = 6000 +param_scheduler = dict( + type='MultiStepLR', + begin=0, + end=max_iters, + by_epoch=False, + milestones=[ + 4000, + ], + gamma=0.1) +# runtime settings +train_cfg = dict( + type='IterBasedTrainLoop', max_iters=max_iters, val_interval=6001) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', by_epoch=False, save_last=True, interval=2000)) +log_processor = dict(type='LogProcessor', window_size=50, by_epoch=False) + +train_pipeline = [ + dict( + type='TransformBroadcaster', + share_random_params=True, + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='LoadTrackAnnotations', + with_instance_id=True, + with_mask=True, + with_bbox=True), + dict(type='mmdet.Resize', scale=(640, 360), keep_ratio=True), + dict(type='mmdet.RandomFlip', prob=0.5), + ]), + dict(type='PackTrackInputs', num_key_frames=2) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadTrackAnnotations', + with_instance_id=True, + with_mask=True, + with_bbox=True), + # dict(type='mmdet.Resize', scale=(480, 1000), keep_ratio=True), + dict(type='PackTrackInputs', pack_single_img=True) +] +# dataloader +train_dataloader = dict( + batch_size=2, + num_workers=2, + dataset=dict( + pipeline=train_pipeline, + ref_img_sampler=dict( + num_ref_imgs=1, + frame_range=5, + filter_key_img=True, + method='uniform'))) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader +# evaluator +val_evaluator = dict( + type='YouTubeVISMetric', + metric='youtube_vis_ap', + outfile_prefix='./youtube_vis_results', + format_only=True) +test_evaluator = val_evaluator diff --git a/projects/VIS_SOTA/IDOL/idol_src/__init__.py b/projects/VIS_SOTA/IDOL/idol_src/__init__.py new file mode 100644 index 000000000..f0d10aeb4 --- /dev/null +++ b/projects/VIS_SOTA/IDOL/idol_src/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .idol import IDOL +from .models import IDOLTrackHead + +__all__ = ['IDOL', 'IDOLTrackHead'] diff --git a/projects/VIS_SOTA/IDOL/idol_src/idol.py b/projects/VIS_SOTA/IDOL/idol_src/idol.py new file mode 100644 index 000000000..12b40ea3f --- /dev/null +++ b/projects/VIS_SOTA/IDOL/idol_src/idol.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Tuple, Union + +from torch import Tensor + +from mmtrack.models.mot import BaseMultiObjectTracker +from mmtrack.registry import MODELS +from mmtrack.utils import OptConfigType, OptMultiConfig, SampleList + + +@MODELS.register_module() +class IDOL(BaseMultiObjectTracker): + + def __init__(self, + backbone: Optional[dict] = None, + neck: Optional[dict] = None, + track_head: Optional[dict] = None, + tracker: Optional[dict] = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__(data_preprocessor, init_cfg) + + if backbone is not None: + self.backbone = MODELS.build(backbone) + + if neck is not None: + self.neck = MODELS.build(neck) + + if track_head is not None: + track_head.update(train_cfg=train_cfg) + track_head.update(test_cfg=test_cfg) + self.track_head = MODELS.build(track_head) + + if tracker is not None: + self.tracker = MODELS.build(tracker) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + self.cur_train_mode = train_cfg.cur_train_mode + + def loss(self, inputs: Dict[str, Tensor], data_samples: SampleList, + **kwargs) -> Union[dict, tuple]: + + img = inputs['img'] + assert img.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' + # shape (N * T, C, H, W) + img = img.flatten(0, 1) + + x = self.extract_feat(img) + losses = self.track_head.loss(x, data_samples) + + return losses + + def predict(self, + inputs: dict, + data_samples: SampleList, + rescale: bool = True) -> SampleList: + + img = inputs['img'] + assert img.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' + # the "T" is 1 + img = img.squeeze(1) + feats = self.extract_feat(img) + pred_det_ins_list = self.track_head.predict(feats, data_samples, + rescale) + track_data_sample = data_samples[0] + pred_det_ins = pred_det_ins_list[0] + track_data_sample.pred_det_instances = \ + pred_det_ins.clone() + + if self.cur_train_mode == 'VIS': + pred_track_instances = self.tracker.track( + data_sample=track_data_sample, rescale=rescale) + track_data_sample.pred_track_instances = pred_track_instances + else: + pred_det_ins.masks = pred_det_ins.masks.squeeze(1) > 0.5 + track_data_sample.pred_instances = pred_det_ins + + return [track_data_sample] + + def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: + """Extract features. + + Args: + batch_inputs (Tensor): Image tensor with shape (N * T, C, H ,W). + + Returns: + tuple[Tensor]: Multi-level features that may have + different resolutions. + """ + x = self.backbone(batch_inputs) + x = self.neck(x) + return x diff --git a/projects/VIS_SOTA/IDOL/idol_src/models/__init__.py b/projects/VIS_SOTA/IDOL/idol_src/models/__init__.py new file mode 100644 index 000000000..c347a365e --- /dev/null +++ b/projects/VIS_SOTA/IDOL/idol_src/models/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .idol_query_track_head import IDOLTrackHead +from .idol_tracker import IDOLTracker +from .sim_ota_assigner import SimOTAAssigner +from .transformer import DeformableDetrTransformer + +__all__ = [ + 'IDOLTrackHead', 'DeformableDetrTransformer', 'SimOTAAssigner', + 'IDOLTracker' +] diff --git a/projects/VIS_SOTA/IDOL/idol_src/models/idol_query_track_head.py b/projects/VIS_SOTA/IDOL/idol_src/models/idol_query_track_head.py new file mode 100644 index 000000000..1408d60b4 --- /dev/null +++ b/projects/VIS_SOTA/IDOL/idol_src/models/idol_query_track_head.py @@ -0,0 +1,929 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +import torch +import torch.nn.functional as F +import torchvision.ops as ops +from mmdet.models.dense_heads import DeformableDETRHead +from mmdet.models.layers import inverse_sigmoid +from mmdet.models.utils import multi_apply +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh +from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean +from mmengine.structures import InstanceData +from torch import Tensor + +from mmtrack.registry import MODELS +from .pos_neg_select import get_contrast_items +from .utils import (MLP, MaskHeadSmallConv, aligned_bilinear, + compute_locations, parse_dynamic_params) + + +@MODELS.register_module() +class IDOLTrackHead(DeformableDETRHead): + + def __init__(self, + *args, + loss_mask: ConfigType = dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=20.0), + loss_dice: ConfigType = dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + naive_dice=True, + loss_weight=1.0), + loss_track: ConfigType = dict( + type='MultiPosCrossEntropyLoss', loss_weight=0.25), + loss_track_aux: ConfigType = dict( + type='L2Loss', + sample_ratio=3, + margin=0.3, + loss_weight=1.0, + hard_mining=True), + enable_reid: bool = True, + rel_coord: bool = True, + inference_select_thres: float = 0.1, + **kwargs) -> None: + + super().__init__(*args, **kwargs) + + self.enable_reid = enable_reid + self.rel_coord = rel_coord + self.inference_select_thres = inference_select_thres + embed_dims = self.transformer.embed_dims + + self.in_channels = embed_dims // 32 + self.dynamic_mask_channels = 8 + self.controller_layers = 3 + self.max_insts_num = 100 + self.mask_out_stride = 4 + self.up_rate = 8 // self.mask_out_stride + + # dynamic_mask_head params + weight_nums, bias_nums = [], [] + for layer in range(self.controller_layers): + if layer == 0: + if self.rel_coord: + weight_nums.append( + (self.in_channels + 2) * self.dynamic_mask_channels) + else: + weight_nums.append(self.in_channels * + self.dynamic_mask_channels) + bias_nums.append(self.dynamic_mask_channels) + elif layer == self.controller_layers - 1: + weight_nums.append(self.dynamic_mask_channels * 1) + bias_nums.append(1) + else: + weight_nums.append(self.dynamic_mask_channels * + self.dynamic_mask_channels) + bias_nums.append(self.dynamic_mask_channels) + + self.weight_nums = weight_nums + self.bias_nums = bias_nums + self.num_gen_params = sum(weight_nums) + sum(bias_nums) + + self.controller = MLP(embed_dims, embed_dims, self.num_gen_params, 3) + self.simple_conv = MaskHeadSmallConv(embed_dims, None, embed_dims) + + if self.enable_reid: + self.reid_embed_head = MLP(embed_dims, embed_dims, embed_dims, 3) + + self.loss_mask = MODELS.build(loss_mask) + self.loss_dice = MODELS.build(loss_dice) + self.loss_track = MODELS.build(loss_track) + self.loss_track_aux = MODELS.build(loss_track_aux) + + def forward(self, x: Tuple[Tensor], + batch_img_metas: List[dict]) -> Tuple[Tensor, ...]: + """Forward function. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + """ + if self.training: + return self.training_forward(x, batch_img_metas) + else: + return self.inference_forward(x, batch_img_metas) + + def training_forward(self, x: Tuple[Tensor], + batch_img_metas: List[dict]) -> Tuple[Tensor, ...]: + + batch_size = x[0].size(0) + num_frames = len(batch_img_metas[0]['frame_id']) + # Since ref_img exists, batch_img_metas[0]['batch_input_shape'][0] + # means the key img. + input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape'][0] + img_masks = x[0].new_ones((batch_size, input_img_h, input_img_w)) + for batch_id in range(batch_size // num_frames): + img_h, img_w = batch_img_metas[batch_id]['img_shape'][0] + img_masks[batch_id * num_frames:(batch_id + 1) * + num_frames, :img_h, :img_w] = 0 + + mlvl_masks = [] + mlvl_positional_encodings = [] + spatial_shapes = [] + + for feat in x: + mlvl_masks.append( + F.interpolate(img_masks[None], + size=feat.shape[-2:]).to(torch.bool).squeeze(0)) + mlvl_positional_encodings.append( + self.positional_encoding(mlvl_masks[-1])) + spatial_shapes.append((feat.shape[2], feat.shape[3])) + + query_embeds = None + if not self.as_two_stage: + query_embeds = self.query_embedding.weight + + # record key img info + x_key = [] + mlvl_masks_key = [] + mlvl_positional_encodings_key = [] + # record ref img info + x_ref = [] + mlvl_masks_ref = [] + mlvl_positional_encodings_ref = [] + + key_ids = list(range(0, batch_size - 1, 2)) + ref_ids = list(range(1, batch_size, 2)) + + # get key frame and ref frame infos + for n_l in range(self.transformer.num_feature_levels): + x_key.append(x[n_l][key_ids]) + x_ref.append(x[n_l][ref_ids]) + + mlvl_masks_key.append(mlvl_masks[n_l][key_ids]) + mlvl_masks_ref.append(mlvl_masks[n_l][ref_ids]) + + mlvl_positional_encodings_key.append( + mlvl_positional_encodings[n_l][key_ids]) + mlvl_positional_encodings_ref.append( + mlvl_positional_encodings[n_l][ref_ids]) + + hs, memory, init_reference, inter_references, \ + enc_outputs_class, enc_outputs_coord = self.transformer( + x_key, + mlvl_masks_key, + query_embeds, + mlvl_positional_encodings_key, + reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501 + cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501 + ) + + hs_ref, memory_ref, init_reference_ref, inter_references_ref, \ + enc_outputs_class_ref, enc_outputs_coord_ref = \ + self.transformer( + x_ref, + mlvl_masks_ref, + query_embeds, + mlvl_positional_encodings_ref, + reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501 + cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501 + ) + + hs = hs.permute(0, 2, 1, 3) + hs_ref = hs_ref.permute(0, 2, 1, 3) + + outputs_classes = [] + outputs_coords = [] + outputs_masks = [] + + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.cls_branches[lvl](hs[lvl]) + tmp = self.reg_branches[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = tmp.sigmoid() + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + + # [bs, num_quries, num_params] + dynamic_mask_head_params = self.controller(hs[lvl]) + mask_head_params = dynamic_mask_head_params.reshape( + 1, -1, dynamic_mask_head_params.shape[-1]) + + reference_points = [] + for batch_id in range(batch_size // num_frames): + img_h, img_w = batch_img_metas[batch_id]['img_shape'][0] + img_h = torch.as_tensor(img_h).to(reference[batch_id]) + img_w = torch.as_tensor(img_w).to(reference[batch_id]) + scale_f = torch.stack([img_w, img_h], dim=0) + ref_cur_f = reference[batch_id].sigmoid()[..., :2] * scale_f[ + None, :] + reference_points.append(ref_cur_f.unsqueeze(0)) + + reference_points = torch.cat(reference_points, dim=1) + num_insts = [ + self.num_query for i in range(batch_size // num_frames) + ] + outputs_mask = self.mask_head(memory, spatial_shapes, + reference_points, mask_head_params, + num_insts) + outputs_masks.append(outputs_mask) + + outputs_classes = torch.stack(outputs_classes) + outputs_coords = torch.stack(outputs_coords) + outputs_masks = torch.stack(outputs_masks) + outputs_embeds = InstanceData( + ref_embeds=self.reid_embed_head(hs_ref[-1]), + key_embeds=self.reid_embed_head(hs[-1]), + ref_cls=self.cls_branches[-1](hs_ref[-1]), + ref_bbox=inter_references_ref[-1]) + # not support two_stage yet + return outputs_classes, outputs_coords, \ + outputs_masks, outputs_embeds, None, None + + def inference_forward(self, x: Tuple[Tensor], + batch_img_metas: List[dict]) -> Tuple[Tensor, ...]: + + batch_size = x[0].size(0) + + input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape'] + img_masks = x[0].new_ones((batch_size, input_img_h, input_img_w)) + for batch_id in range(batch_size): + img_h, img_w = batch_img_metas[batch_id]['img_shape'] + img_masks[batch_id, :img_h, :img_w] = 0 + + mlvl_masks = [] + mlvl_positional_encodings = [] + spatial_shapes = [] + + for feat in x: + # feat: (1, c, h, w) + mlvl_masks.append( + F.interpolate(img_masks[None], + size=feat.shape[-2:]).to(torch.bool).squeeze(0)) + mlvl_positional_encodings.append( + self.positional_encoding(mlvl_masks[-1])) + spatial_shapes.append((feat.shape[2], feat.shape[3])) + + query_embeds = None + if not self.as_two_stage: + query_embeds = self.query_embedding.weight + + hs, memory, init_reference, inter_references, \ + enc_outputs_class, enc_outputs_coord = self.transformer( + x, + mlvl_masks, + query_embeds, + mlvl_positional_encodings, + reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501 + cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501 + ) + + hs = hs.permute(0, 2, 1, 3) + + reference = inter_references[-1 - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.cls_branches[-1](hs[-1]) + tmp = self.reg_branches[-1](hs[-1]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = tmp.sigmoid() + + outputs_embeds = self.reid_embed_head(hs[-1]) + + dynamic_mask_head_params = self.controller( + hs[-1]) # [bs, num_queries, num_params] + + norm_reference_points = inter_references[-2, :, :, :2] + + reference_points = [] + for batch_id in range(batch_size): + img_h, img_w = batch_img_metas[batch_id]['img_shape'] + img_h = torch.as_tensor(img_h).to(norm_reference_points[batch_id]) + img_w = torch.as_tensor(img_w).to(norm_reference_points[batch_id]) + scale_f = torch.stack([img_w, img_h], dim=0) + ref_cur_f = norm_reference_points[batch_id] * scale_f[None, :] + reference_points.append(ref_cur_f.unsqueeze(0)) + + reference_points = torch.cat(reference_points, dim=1) + mask_head_params = dynamic_mask_head_params.reshape( + 1, -1, dynamic_mask_head_params.shape[-1]) + + num_insts = [self.num_query for i in range(batch_size)] + outputs_masks = self.mask_head(memory, spatial_shapes, + reference_points, mask_head_params, + num_insts) + # not support two_stage yet + return outputs_class, outputs_coord, \ + outputs_masks, outputs_embeds, None, None + + def mask_head(self, feats, spatial_shapes, reference_points, + mask_head_params, num_insts): + + feats = feats.transpose(0, 1) + bs, _, c = feats.shape + + # encod_feat_l: num_layers x [bs, C, hi, wi] + encod_feat_l = [] + spatial_indx = 0 + for lvl in range(self.transformer.num_feature_levels - 1): + h, w = spatial_shapes[lvl] + mem_l = feats[:, spatial_indx:spatial_indx + 1 * h * w, :].reshape( + bs, h, w, c).permute(0, 3, 1, 2) + encod_feat_l.append(mem_l) + spatial_indx += 1 * h * w + + decod_feat_f = self.simple_conv(encod_feat_l, fpns=None) + + mask_logits = self.dynamic_mask_with_coords( + decod_feat_f, + reference_points, + mask_head_params, + num_insts=num_insts, + mask_feat_stride=8, + rel_coord=self.rel_coord) + # mask_logits: [1, num_queries_all, H/4, W/4] + mask_f = [] + inst_st = 0 + for num_inst in num_insts: + # [1, selected_queries, 1, H/4, W/4] + mask_f.append(mask_logits[:, inst_st:inst_st + + num_inst, :, :].unsqueeze(2)) + inst_st += num_inst + + output_pred_masks = torch.cat(mask_f, dim=0) + + return output_pred_masks + + def dynamic_mask_with_coords(self, + mask_feats, + reference_points, + mask_head_params, + num_insts, + mask_feat_stride, + rel_coord=True): + + device = mask_feats.device + + bs, in_channels, H, W = mask_feats.size() + num_insts_all = reference_points.shape[1] + + locations = compute_locations( + mask_feats.size(2), + mask_feats.size(3), + device=device, + stride=mask_feat_stride) + # locations: [H*W, 2] + + if rel_coord: + instance_locations = reference_points + relative_coords = instance_locations.reshape( + 1, num_insts_all, 1, 1, 2) - locations.reshape(1, 1, H, W, 2) + relative_coords = relative_coords.float() + relative_coords = relative_coords.permute(0, 1, 4, 2, + 3).flatten(-2, -1) + mask_head_inputs = [] + inst_st = 0 + for i, num_inst in enumerate(num_insts): + # [1, num_queries * (C/32+2), H/8 * W/8] + relative_coords_b = relative_coords[:, inst_st:inst_st + + num_inst, :, :] + mask_feats_b = mask_feats[i].reshape( + 1, in_channels, + H * W).unsqueeze(1).repeat(1, num_inst, 1, 1) + mask_head_b = torch.cat([relative_coords_b, mask_feats_b], + dim=2) + + mask_head_inputs.append(mask_head_b) + inst_st += num_inst + + else: + mask_head_inputs = [] + inst_st = 0 + for i, num_inst in enumerate(num_insts): + mask_head_b = mask_feats[i].reshape(1, in_channels, + H * W).unsqueeze(1).repeat( + 1, num_inst, 1, 1) + mask_head_b = mask_head_b.reshape(1, -1, H, W) + mask_head_inputs.append(mask_head_b) + + # mask_head_inputs: [1, \sum{num_queries * (C/32+2)}, H/8, W/8] + mask_head_inputs = torch.cat(mask_head_inputs, dim=1) + mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) + + # mask_head_params: [num_insts_all, num_params] + mask_head_params = torch.flatten(mask_head_params, 0, 1) + + if num_insts_all != 0: + weights, biases = parse_dynamic_params(mask_head_params, + self.dynamic_mask_channels, + self.weight_nums, + self.bias_nums) + + mask_logits = self.dynamic_conv_forward(mask_head_inputs, weights, + biases, + mask_head_params.shape[0]) + else: + mask_logits = mask_head_inputs + return mask_logits + # mask_logits: [1, num_insts_all, H/8, W/8] + mask_logits = mask_logits.reshape(-1, 1, H, W) + + # upsample predicted masks + assert mask_feat_stride >= self.mask_out_stride + assert mask_feat_stride % self.mask_out_stride == 0 + + mask_logits = aligned_bilinear( + mask_logits, int(mask_feat_stride / self.mask_out_stride)) + + mask_logits = mask_logits.reshape(1, -1, mask_logits.shape[-2], + mask_logits.shape[-1]) + # mask_logits: [1, num_insts_all, H/4, W/4] + + return mask_logits + + def dynamic_conv_forward(self, features: Tensor, weights: List[Tensor], + biases: List[Tensor], num_insts: int) -> Tensor: + """dynamic forward, each layer follow a relu.""" + n_layers = len(weights) + x = features + for i, (w, b) in enumerate(zip(weights, biases)): + x = F.conv2d(x, w, bias=b, stride=1, padding=0, groups=num_insts) + if i < n_layers - 1: + x = F.relu(x) + return x + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + head on the features of the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + batch_gt_instances = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + outs = self(x, batch_img_metas) + loss_inputs = outs + (batch_gt_instances, batch_img_metas) + losses = self.loss_by_feat(*loss_inputs) + return losses + + def loss_by_feat( + self, + all_cls_scores: Tensor, + all_bbox_preds: Tensor, + all_masks_preds: Tensor, + embeds_preds: InstanceData, + enc_cls_scores: Tensor, + enc_bbox_preds: Tensor, + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None + ) -> Dict[str, Tensor]: + + assert batch_gt_instances_ignore is None, \ + f'{self.__class__.__name__} only supports ' \ + f'for gt_bboxes_ignore setting to None.' + # fix batch_img_metas + key_list = [ + 'batch_input_shape', 'pad_shape', 'img_shape', 'scale_factor' + ] + for key in key_list: + for batch_id in range(len(batch_img_metas)): + batch_img_metas[batch_id][key] = batch_img_metas[batch_id][ + key][0] + + num_dec_layers = len(all_cls_scores) + batch_key_gt_instances = [] + batch_ref_gt_instances = [] + for gt_instances in batch_gt_instances: + batch_key_gt_instances.append( + gt_instances[gt_instances.map_instances_to_img_idx == 0]) + batch_ref_gt_instances.append( + gt_instances[gt_instances.map_instances_to_img_idx == 1]) + + batch_key_gt_instances_list = [ + batch_key_gt_instances for _ in range(num_dec_layers) + ] + batch_ref_gt_instances_list = [ + batch_ref_gt_instances for _ in range(num_dec_layers) + ] + batch_img_metas_list = [batch_img_metas for _ in range(num_dec_layers)] + + losses_cls, losses_bbox, losses_iou, \ + losses_mask, losses_dice, min_cost_qids = multi_apply( + self.loss_by_feat_single, + all_cls_scores, + all_bbox_preds, + all_masks_preds, + batch_key_gt_instances_list, + batch_ref_gt_instances_list, + batch_img_metas_list) + + # get track loss + contrast_items = get_contrast_items( + key_embeds=embeds_preds.key_embeds, + ref_embeds=embeds_preds.ref_embeds, + key_gt_instances=batch_key_gt_instances, + ref_gt_instances=batch_ref_gt_instances, + ref_bbox=embeds_preds.ref_bbox, # rescale bboxes + ref_cls=embeds_preds.ref_cls.sigmoid(), + query_inds=min_cost_qids[-1], + batch_img_metas=batch_img_metas) + + loss_track = 0. + loss_track_aux = 0. + for contrast_item in contrast_items: + loss_track += self.loss_track(contrast_item['contrast'], + contrast_item['label']) + if self.loss_track_aux is not None: + loss_track_aux += self.loss_track_aux( + contrast_item['aux_consin'], contrast_item['aux_label']) + loss_track = loss_track / len(contrast_items) + if self.loss_track_aux is not None: + loss_track_aux = loss_track_aux / len(contrast_items) + + loss_dict = dict() + # loss of proposal generated from encode feature map. + if enc_cls_scores is not None: + for i in range(len(batch_img_metas)): + batch_gt_instances[i].labels = torch.zeros_like( + batch_gt_instances[i].labels) + enc_loss_cls, enc_losses_bbox, enc_losses_iou = \ + self.loss_single(enc_cls_scores, enc_bbox_preds, + batch_gt_instances, batch_img_metas) + loss_dict['enc_loss_cls'] = enc_loss_cls + loss_dict['enc_loss_bbox'] = enc_losses_bbox + loss_dict['enc_loss_iou'] = enc_losses_iou + + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_bbox'] = losses_bbox[-1] + loss_dict['loss_iou'] = losses_iou[-1] + loss_dict['loss_mask'] = losses_mask[-1] + loss_dict['loss_dice'] = losses_dice[-1] + loss_dict['loss_track'] = loss_track + loss_dict['loss_track_aux'] = loss_track_aux + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_bbox_i, loss_iou_i, loss_mask_i, loss_dice_i in \ + zip( + losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1], + losses_mask[:-1], losses_dice[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i + loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i + loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i + loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i + num_dec_layer += 1 + return loss_dict + + def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor, + mask_preds: Tensor, + batch_key_gt_instances: InstanceList, + batch_ref_gt_instances: InstanceList, + batch_img_metas: List[dict]) -> Tuple[Tensor]: + + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] + mask_preds_list = [mask_preds[i] for i in range(num_imgs)] + + cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, + mask_preds_list, + batch_key_gt_instances, + batch_img_metas) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + mask_targets_list, mask_weights_list, num_total_pos, num_total_neg, + min_cost_qid) = cls_reg_targets + + labels = torch.cat(labels_list, 0) + label_weights = torch.cat(label_weights_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + bbox_weights = torch.cat(bbox_weights_list, 0) + mask_targets = torch.cat(mask_targets_list, 0) + mask_weights = torch.cat(mask_weights_list, 0) + + # classification loss + cls_scores = cls_scores.reshape(-1, self.cls_out_channels) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + \ + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + cls_scores.new_tensor([cls_avg_factor])) + cls_avg_factor = max(cls_avg_factor, 1) + + loss_cls = self.loss_cls( + cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + + # Compute the average number of gt boxes across all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + + # construct factors used for rescale bboxes + factors = [] + for img_meta, bbox_pred in zip(batch_img_metas, bbox_preds): + img_h, img_w, = img_meta['img_shape'] + factor = bbox_pred.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0).repeat( + bbox_pred.size(0), 1) + factors.append(factor) + factors = torch.cat(factors, 0) + + # DETR regress the relative position of boxes (cxcywh) in the image, + # thus the learning target is normalized by the image size. So here + # we need to re-scale them for calculating IoU loss + bbox_preds = bbox_preds.reshape(-1, 4) + bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors + bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors + + # regression IoU loss, defaultly GIoU loss + loss_iou = self.loss_iou( + bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) + + # regression L1 loss + loss_bbox = self.loss_bbox( + bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds.flatten(0, 1).squeeze(1) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + else: + # dice loss + loss_dice = self.loss_dice( + mask_preds, mask_targets, avg_factor=num_total_pos) + + # mask loss + # FocalLoss support input of shape (n, num_class) + h, w = mask_preds.shape[-2:] + # shape (num_total_gts, h, w) -> (num_total_gts * h * w, 1) + mask_preds = mask_preds.reshape(-1, 1) + # shape (num_total_gts, h, w) -> (num_total_gts * h * w) + mask_targets = mask_targets.reshape(-1) + loss_mask = self.loss_mask( + mask_preds, 1 - mask_targets, avg_factor=num_total_pos * h * w) + + return loss_cls, loss_bbox, loss_iou, \ + loss_mask, loss_dice, min_cost_qid + + def get_targets(self, cls_scores_list: List[Tensor], + bbox_preds_list: List[Tensor], + mask_preds_list: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict]) -> tuple: + """Compute regression and classification targets for a batch image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_scores_list (list[Tensor]): Box score logits from a single + decoder layer for each image with shape [num_query, + cls_out_channels]. + bbox_preds_list (list[Tensor]): Sigmoid outputs from a single + decoder layer for each image, with normalized coordinate + (cx, cy, w, h) and shape [num_query, 4]. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + tuple: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels for all images. + - label_weights_list (list[Tensor]): Label weights for all images. + - bbox_targets_list (list[Tensor]): BBox targets for all images. + - bbox_weights_list (list[Tensor]): BBox weights for all images. + - num_total_pos (int): Number of positive samples in all images. + - num_total_neg (int): Number of negative samples in all images. + """ + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + mask_targets_list, mask_weights_list, pos_inds_list, neg_inds_list, + min_cost_qid) = multi_apply(self._get_targets_single, cls_scores_list, + bbox_preds_list, mask_preds_list, + batch_gt_instances, batch_img_metas) + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, mask_targets_list, mask_weights_list, + num_total_pos, num_total_neg, min_cost_qid) + + def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor, + mask_pred: Tensor, gt_instances: InstanceData, + img_meta: dict) -> tuple: + """Compute regression and classification targets for one image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_score (Tensor): Box score logits from a single decoder layer + for one image. Shape [num_query, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from a single decoder layer + for one image, with normalized coordinate (cx, cy, w, h) and + shape [num_query, 4]. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for one image. + + Returns: + tuple[Tensor]: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + img_h, img_w = img_meta['img_shape'] + factor = bbox_pred.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0) + num_bboxes = bbox_pred.size(0) + # convert bbox_pred from xywh, normalized to xyxy, unnormalized + bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred) + bbox_pred = bbox_pred * factor + + pred_instances = InstanceData(scores=cls_score, bboxes=bbox_pred) + # assigner and sampler + assign_result = self.assigner.assign( + pred_instances=pred_instances, + gt_instances=gt_instances, + img_meta=img_meta) + + gt_bboxes = gt_instances.bboxes + target_shape = mask_pred.shape[-2:] + gt_masks = gt_instances.masks.to_tensor( + dtype=torch.bool, device=gt_bboxes.device) + if gt_masks.shape[0] > 0: + gt_masks_downsampled = F.interpolate( + gt_masks.unsqueeze(1).float(), target_shape, + mode='nearest').squeeze(1).long() + else: + gt_masks_downsampled = gt_masks + + gt_labels = gt_instances.labels + pos_inds = torch.nonzero( + assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero( + assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds.long(), :] + pos_gt_masks = gt_masks_downsampled[pos_assigned_gt_inds.long(), :] + min_cost_qid = assign_result.get_extra_property('query_inds') + + # label targets + labels = gt_bboxes.new_full((num_bboxes, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[pos_assigned_gt_inds] + label_weights = gt_bboxes.new_ones(num_bboxes) + + # bbox targets + bbox_targets = torch.zeros_like(bbox_pred) + bbox_weights = torch.zeros_like(bbox_pred) + bbox_weights[pos_inds] = 1.0 + + # mask targets + mask_targets = torch.zeros_like(mask_pred.squeeze(1), dtype=torch.long) + mask_targets[pos_inds] = pos_gt_masks + mask_weights = gt_bboxes.new_zeros(num_bboxes) + mask_weights[pos_inds] = 1.0 + + # DETR regress the relative position of boxes (cxcywh) in the image. + # Thus the learning target should be normalized by the image size, also + # the box format should be converted from defaultly x1y1x2y2 to cxcywh. + pos_gt_bboxes_normalized = pos_gt_bboxes / factor + pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized) + bbox_targets[pos_inds] = pos_gt_bboxes_targets + return (labels, label_weights, bbox_targets, bbox_weights, + mask_targets, mask_weights, pos_inds, neg_inds, min_cost_qid) + + def predict(self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + rescale: bool = True) -> InstanceList: + """Perform forward propagation of the detection head and predict + detection results on the features of the upstream network. Over-write + because img_metas are needed as inputs for bbox_head. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool, optional): Whether to rescale the results. + Defaults to True. + + Returns: + list[obj:`InstanceData`]: Detection results of each image + after the post process. + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + outs = self.inference_forward(x, batch_img_metas) + + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_img_metas, rescale=rescale) + return predictions + + def predict_by_feat(self, + all_cls_scores: Tensor, + all_bbox_preds: Tensor, + all_mask_preds: Tensor, + all_embeds_preds: Tensor, + enc_cls_scores: Tensor, + enc_bbox_preds: Tensor, + batch_img_metas: List[Dict], + rescale: bool = False) -> InstanceList: + + result_list = [] + for img_id in range(len(batch_img_metas)): + cls_score = all_cls_scores[img_id] + bbox_pred = all_bbox_preds[img_id] + mask_pred = all_mask_preds[img_id] + embed_pred = all_embeds_preds[img_id] + img_meta = batch_img_metas[img_id] + results = self._predict_by_feat_single(cls_score, bbox_pred, + mask_pred, embed_pred, + img_meta, rescale) + result_list.append(results) + return result_list + + def _predict_by_feat_single(self, + cls_score: Tensor, + bbox_pred: Tensor, + mask_pred: Tensor, + embed_pred: Tensor, + img_meta: dict, + rescale: bool = True) -> InstanceData: + + assert len(cls_score) == len(bbox_pred) + img_shape = img_meta['img_shape'] + + cls_score = cls_score.sigmoid() + max_score, _ = torch.max(cls_score, 1) + bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred) + indices = torch.nonzero( + max_score > self.inference_select_thres, as_tuple=False).squeeze(1) + if len(indices) == 0: + topkv, indices_top1 = torch.topk(cls_score.max(1)[0], k=1) + indices_top1 = indices_top1[torch.argmax(topkv)] + indices = [indices_top1.tolist()] + else: + nms_scores, idxs = torch.max(cls_score[indices], 1) + boxes_before_nms = bbox_pred[indices] + keep_indices = ops.batched_nms(boxes_before_nms, nms_scores, idxs, + 0.9) + indices = indices[keep_indices] + + scores = torch.max(cls_score[indices], 1)[0] + det_bboxes = bbox_pred[indices] + det_labels = torch.argmax(cls_score[indices], dim=1) + track_feats = embed_pred[indices] + det_masks = mask_pred[indices] + + det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1] + det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0] + det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1]) + det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0]) + + results = InstanceData() + results.bboxes = det_bboxes + results.scores = scores + results.labels = det_labels + results.masks = det_masks + results.track_feats = track_feats + return results diff --git a/projects/VIS_SOTA/IDOL/idol_src/models/idol_tracker.py b/projects/VIS_SOTA/IDOL/idol_src/models/idol_tracker.py new file mode 100644 index 000000000..5dfd54eb6 --- /dev/null +++ b/projects/VIS_SOTA/IDOL/idol_src/models/idol_tracker.py @@ -0,0 +1,328 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +import torch.nn.functional as F +from mmengine.structures import InstanceData +from torch import Tensor + +from mmtrack.models.trackers import BaseTracker +from mmtrack.registry import MODELS +from mmtrack.structures import TrackDataSample +from .utils import mask_iou, mask_nms + + +@MODELS.register_module() +class IDOLTracker(BaseTracker): + + def __init__(self, + nms_thr_pre=0.7, + nms_thr_post=0.3, + init_score_thr=0.2, + addnew_score_thr=0.5, + obj_score_thr=0.1, + match_score_thr=0.5, + memo_tracklet_frames=10, + memo_backdrop_frames=1, + memo_momentum=0.5, + nms_conf_thr=0.5, + nms_backdrop_iou_thr=0.5, + nms_class_iou_thr=0.7, + with_cats=True, + match_metric='bisoftmax', + long_match=False, + frame_weight=False, + temporal_weight=False, + memory_len=10, + **kwargs): + super().__init__(**kwargs) + assert 0 <= memo_momentum <= 1.0 + assert memo_tracklet_frames >= 0 + assert memo_backdrop_frames >= 0 + self.memory_len = memory_len + self.temporal_weight = temporal_weight + self.long_match = long_match + self.frame_weight = frame_weight + self.nms_thr_pre = nms_thr_pre + self.nms_thr_post = nms_thr_post + self.init_score_thr = init_score_thr + self.addnew_score_thr = addnew_score_thr + self.obj_score_thr = obj_score_thr + self.match_score_thr = match_score_thr + self.memo_tracklet_frames = memo_tracklet_frames + self.memo_backdrop_frames = memo_backdrop_frames + self.memo_momentum = memo_momentum + self.nms_conf_thr = nms_conf_thr + self.nms_backdrop_iou_thr = nms_backdrop_iou_thr + self.nms_class_iou_thr = nms_class_iou_thr + self.with_cats = with_cats + assert match_metric in ['bisoftmax', 'softmax', 'cosine'] + self.match_metric = match_metric + + self.reset() + + def reset(self): + """Reset the buffer of the tracker.""" + self.num_tracks = 0 + self.tracks = dict() + self.backdrops = [] + + def update(self, ids: Tensor, bboxes: Tensor, embeds: Tensor, + labels: Tensor, scores: Tensor, frame_id: int) -> None: + """Tracking forward function. + + Args: + ids (Tensor): of shape(N, ). + bboxes (Tensor): of shape (N, 5). + embeds (Tensor): of shape (N, 256). + labels (Tensor): of shape (N, ). + scores (Tensor): of shape (N, ). + frame_id (int): The id of current frame, 0-index. + """ + tracklet_inds = ids > -1 + + for id, bbox, embed, label, score in zip(ids[tracklet_inds], + bboxes[tracklet_inds], + embeds[tracklet_inds], + labels[tracklet_inds], + scores[tracklet_inds]): + id = int(id) + # update the tracked ones and initialize new tracks + if id in self.tracks.keys(): + velocity = (bbox - self.tracks[id]['bbox']) / ( + frame_id - self.tracks[id]['last_frame']) + self.tracks[id]['bbox'] = bbox + self.tracks[id]['long_score'].append(score) + self.tracks[id]['embed'] = ( + 1 - self.memo_momentum + ) * self.tracks[id]['embed'] + self.memo_momentum * embed + self.tracks[id]['long_embed'].append(embed) + self.tracks[id]['last_frame'] = frame_id + self.tracks[id]['label'] = label + self.tracks[id]['score'] = score + self.tracks[id]['velocity'] = ( + self.tracks[id]['velocity'] * self.tracks[id]['acc_frame'] + + velocity) / ( + self.tracks[id]['acc_frame'] + 1) + self.tracks[id]['acc_frame'] += 1 + self.tracks[id]['exist_frame'] += 1 + else: + self.tracks[id] = dict( + bbox=bbox, + embed=embed, + label=label, + long_embed=[embed], + long_score=[score], + score=score, + last_frame=frame_id, + velocity=torch.zeros_like(bbox), + acc_frame=0, + exist_frame=1) + # backdrop update according to IoU + backdrop_inds = torch.nonzero(ids == -1, as_tuple=False).squeeze(1) + self.backdrops.insert( + 0, + dict( + bboxes=bboxes[backdrop_inds], + embeds=embeds[backdrop_inds], + labels=labels[backdrop_inds])) + + # pop memo + invalid_ids = [] + for k, v in self.tracks.items(): + if frame_id - v['last_frame'] >= self.memo_tracklet_frames: + invalid_ids.append(k) + if len(v['long_embed']) > self.memory_len: + v['long_embed'].pop(0) + if len(v['long_score']) > self.memory_len: + v['long_score'].pop(0) + for invalid_id in invalid_ids: + self.tracks.pop(invalid_id) + + if len(self.backdrops) > self.memo_backdrop_frames: + self.backdrops.pop() + + @property + def memo(self) -> Tuple[Tensor, ...]: + """Get tracks memory.""" + memo_embeds = [] + memo_ids = [] + memo_bboxes = [] + memo_labels = [] + # velocity of tracks + memo_vs = [] + # for long term tracking + memo_long_embeds = [] + memo_long_score = [] + memo_exist_frame = [] + # get tracks + for k, v in self.tracks.items(): + memo_bboxes.append(v['bbox'][None, :]) + if self.long_match: + weights = torch.stack(v['long_score']) + if self.temporal_weight: + length = len(weights) + temporal_weight = torch.range(0.0, 1, + 1 / length)[1:].to(weights) + weights = weights + temporal_weight + sum_embed = (torch.stack(v['long_embed']) * + weights.unsqueeze(1)).sum(0) / weights.sum() + memo_embeds.append(sum_embed[None, :]) + else: + memo_embeds.append(v['embed'][None, :]) + + memo_long_embeds.append(torch.stack(v['long_embed'])) + memo_long_score.append(torch.stack(v['long_score'])) + memo_exist_frame.append(v['exist_frame']) + + memo_ids.append(k) + memo_labels.append(v['label'].view(1, 1)) + memo_vs.append(v['velocity'][None, :]) + memo_ids = torch.tensor(memo_ids, dtype=torch.long).view(1, -1) + memo_exist_frame = torch.tensor(memo_exist_frame, dtype=torch.long) + + memo_bboxes = torch.cat(memo_bboxes, dim=0) + memo_embeds = torch.cat(memo_embeds, dim=0) + memo_labels = torch.cat(memo_labels, dim=0).squeeze(1) + memo_vs = torch.cat(memo_vs, dim=0) + return memo_bboxes, memo_labels, memo_embeds, memo_ids.squeeze( + 0), memo_vs, memo_long_embeds, memo_long_score, memo_exist_frame + + def track(self, + data_sample: TrackDataSample, + rescale=True, + **kwargs) -> InstanceData: + + metainfo = data_sample.metainfo + bboxes = data_sample.pred_det_instances.bboxes + masks = data_sample.pred_det_instances.masks + labels = data_sample.pred_det_instances.labels + scores = data_sample.pred_det_instances.scores + embeds = data_sample.pred_det_instances.track_feats + + valids = mask_nms(masks, scores, self.nms_thr_pre) + frame_id = metainfo.get('frame_id', -1) + + bboxes = bboxes[valids, :] + labels = labels[valids] + masks = masks[valids] + scores = scores[valids] + embeds = embeds[valids, :] + # create pred_track_instances + pred_track_instances = InstanceData(metainfo=metainfo) + + # return zero bboxes if there is no track targets + if bboxes.shape[0] == 0: + ids = torch.zeros_like(labels) + pred_track_instances = data_sample.pred_det_instances.clone() + pred_track_instances.instances_id = ids + return pred_track_instances + + # init ids container + ids = torch.full((bboxes.size(0), ), -2, dtype=torch.long) + + # match if buffer is not empty + if bboxes.size(0) > 0 and not self.empty: + (memo_bboxes, memo_labels, memo_embeds, memo_ids, memo_vs, + memo_long_embeds, memo_long_score, memo_exist_frame) = self.memo + + memo_exist_frame = memo_exist_frame.to(memo_embeds) + memo_ids = memo_ids.to(memo_embeds) + + if self.match_metric == 'bisoftmax': + feats = torch.mm(embeds, memo_embeds.t()) + d2t_scores = feats.softmax(dim=1) + t2d_scores = feats.softmax(dim=0) + match_scores = (d2t_scores + t2d_scores) / 2 + elif self.match_metric == 'softmax': + feats = torch.mm(embeds, memo_embeds.t()) + match_scores = feats.softmax(dim=1) + elif self.match_metric == 'cosine': + match_scores = torch.mm( + F.normalize(embeds, p=2, dim=1), + F.normalize(memo_embeds, p=2, dim=1).t()) + else: + raise NotImplementedError + # track according to match_scores + for i in range(bboxes.size(0)): + if self.frame_weight: + non_backs = (memo_ids > -1) & (match_scores[i, :] > 0.5) + if (match_scores[i, non_backs] > 0.5).sum() > 1: + wighted_scores = match_scores.clone() + frame_weight = memo_exist_frame[ + match_scores[i, :][memo_ids > -1] > 0.5] + wighted_scores[i, non_backs] = wighted_scores[ + i, non_backs] * frame_weight + wighted_scores[i, ~non_backs] = wighted_scores[ + i, ~non_backs] * frame_weight.mean() + conf, memo_ind = torch.max(wighted_scores[i, :], dim=0) + else: + conf, memo_ind = torch.max(match_scores[i, :], dim=0) + else: + conf, memo_ind = torch.max(match_scores[i, :], dim=0) + id = memo_ids[memo_ind] + if conf > self.match_score_thr: + if id > -1: + ids[i] = id + match_scores[:i, memo_ind] = 0 + match_scores[i + 1:, memo_ind] = 0 + # initialize new tracks + new_inds = (ids == -2) & (scores > self.addnew_score_thr).cpu() + num_news = new_inds.sum() + ids[new_inds] = torch.arange( + self.num_tracks, self.num_tracks + num_news, dtype=torch.long) + self.num_tracks += num_news + + # get backdrops + unselected_inds = torch.nonzero( + ids == -2, as_tuple=False).squeeze(1) + mask_ious = mask_iou(masks[unselected_inds].sigmoid() > 0.5, + masks.permute(1, 0, 2, 3).sigmoid() > 0.5) + for i, ind in enumerate(unselected_inds): + if (mask_ious[i, :ind] < self.nms_thr_post).all(): + ids[ind] = -1 + + self.update(ids, bboxes, embeds, labels, scores, frame_id) + + elif self.empty: + init_inds = (ids == -2) & (scores > self.init_score_thr).cpu() + num_news = init_inds.sum() + ids[init_inds] = torch.arange( + self.num_tracks, self.num_tracks + num_news, dtype=torch.long) + self.num_tracks += num_news + unselected_inds = torch.nonzero( + ids == -2, as_tuple=False).squeeze(1) + mask_ious = mask_iou(masks[unselected_inds].sigmoid() > 0.5, + masks.permute(1, 0, 2, 3).sigmoid() > 0.5) + for i, ind in enumerate(unselected_inds): + if (mask_ious[i, :ind] < self.nms_thr_post).all(): + ids[ind] = -1 + self.update(ids, bboxes, embeds, labels, scores, frame_id) + + tracklet_inds = ids > -1 + + if rescale: + # return result in original resolution + # rz_*: the resize shape + pad_height, pad_width = metainfo['pad_shape'] + rz_height, rz_width = metainfo['img_shape'] + masks = F.interpolate( + masks, + size=(pad_height, pad_width), + mode='bilinear', + align_corners=False).sigmoid() + # crop the padding area + masks = masks[:, :, :rz_height, :rz_width] + ori_height, ori_width = metainfo['ori_shape'] + masks = ( + F.interpolate( + masks, size=(ori_height, ori_width), mode='nearest')) + + # update pred_track_instances + pred_track_instances.bboxes = bboxes[tracklet_inds] + pred_track_instances.masks = masks[tracklet_inds].squeeze(1) > 0.5 + pred_track_instances.labels = labels[tracklet_inds] + pred_track_instances.scores = scores[tracklet_inds] + pred_track_instances.instances_id = ids[tracklet_inds] + + return pred_track_instances diff --git a/projects/VIS_SOTA/IDOL/idol_src/models/pos_neg_select.py b/projects/VIS_SOTA/IDOL/idol_src/models/pos_neg_select.py new file mode 100644 index 000000000..239336a91 --- /dev/null +++ b/projects/VIS_SOTA/IDOL/idol_src/models/pos_neg_select.py @@ -0,0 +1,232 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import List + +import torch +import torch.nn as nn +import torchvision.ops as ops +from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh + +from .utils import generalized_box_iou + + +def get_contrast_items(key_embeds, ref_embeds, key_gt_instances, + ref_gt_instances, ref_bbox, ref_cls, query_inds, + batch_img_metas) -> List: + + one = torch.tensor(1).to(ref_embeds) + zero = torch.tensor(0).to(ref_embeds) + contrast_items = [] + + for bid, (key_gt, ref_gt, indices) in enumerate( + zip(key_gt_instances, ref_gt_instances, query_inds)): + + key_ins_ids = key_gt.instances_id + ref_ins_ids = ref_gt.instances_id + valid = torch.tensor([(ref_ins_id in key_ins_ids) + for ref_ins_id in ref_ins_ids], + dtype=torch.bool) + + img_h, img_w, = batch_img_metas[bid]['img_shape'] + gt_bboxes = bbox_xyxy_to_cxcywh(ref_gt['bboxes']) + factor = gt_bboxes.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0).repeat( + gt_bboxes.size(0), 1) + gt_bboxes /= factor + gt_labels = ref_gt['labels'] + contrastive_pos, contrastive_neg = get_pos_idx(ref_bbox[bid], + ref_cls[bid], gt_bboxes, + gt_labels, valid) + + for inst_i, (is_valid, matched_query_id) in enumerate( + zip(valid, query_inds[bid])): + + if not is_valid: + continue + # key_embeds: (bs, num_queries, c) + key_embed_i = key_embeds[bid, matched_query_id].unsqueeze(0) + + pos_embed = ref_embeds[bid][contrastive_pos[inst_i]] + neg_embed = ref_embeds[bid][~contrastive_neg[inst_i]] + contrastive_embed = torch.cat([pos_embed, neg_embed], dim=0) + contrastive_label = torch.cat( + [one.repeat(len(pos_embed)), + zero.repeat(len(neg_embed))], + dim=0).unsqueeze(1) + + contrast = torch.einsum('nc,kc->nk', + [contrastive_embed, key_embed_i]) + + if len(pos_embed) == 0: + num_sample_neg = 10 + elif len(pos_embed) * 10 >= len(neg_embed): + num_sample_neg = len(neg_embed) + else: + num_sample_neg = len(pos_embed) * 10 + + # for aux loss + sample_ids = random.sample( + list(range(0, len(neg_embed))), num_sample_neg) + aux_contrastive_embed = torch.cat( + [pos_embed, neg_embed[sample_ids]], dim=0) + aux_contrastive_label = torch.cat( + [one.repeat(len(pos_embed)), + zero.repeat(num_sample_neg)], + dim=0).unsqueeze(1) + aux_contrastive_embed = nn.functional.normalize( + aux_contrastive_embed.float(), dim=1) + key_embed_i = nn.functional.normalize(key_embed_i.float(), dim=1) + cosine = torch.einsum('nc,kc->nk', + [aux_contrastive_embed, key_embed_i]) + + contrast_items.append({ + 'contrast': contrast, + 'label': contrastive_label, + 'aux_consin': cosine, + 'aux_label': aux_contrastive_label + }) + + return contrast_items + + +def get_pos_idx(ref_bbox, ref_cls, gt_bbox, gt_cls, valid): + + with torch.no_grad(): + if False in valid: + gt_bbox = gt_bbox[valid] + gt_cls = gt_cls[valid] + + fg_mask, is_in_boxes_and_center = \ + get_in_gt_and_in_center_info( + ref_bbox, + gt_bbox, + expanded_strides=32) + pair_wise_ious = ops.box_iou( + bbox_cxcywh_to_xyxy(ref_bbox), bbox_cxcywh_to_xyxy(gt_bbox)) + + # Compute the classification cost. + alpha = 0.25 + gamma = 2.0 + neg_cost_class = (1 - alpha) * (ref_cls ** gamma) * \ + (-(1 - ref_cls + 1e-8).log()) + pos_cost_class = alpha * ( + (1 - ref_cls)**gamma) * (-(ref_cls + 1e-8).log()) + cost_class = pos_cost_class[:, gt_cls] - neg_cost_class[:, gt_cls] + cost_giou = - \ + generalized_box_iou(bbox_cxcywh_to_xyxy( + ref_bbox), bbox_cxcywh_to_xyxy(gt_bbox)) + + cost = ( + cost_class + 3.0 * cost_giou + 100.0 * (~is_in_boxes_and_center)) + + cost[~fg_mask] = cost[~fg_mask] + 10000.0 + + if False in valid: + pos_indices = [] + neg_indices = [] + if valid.sum() > 0: + # Select positive sample on reference frame, k = 10 + pos_matched = dynamic_k_matching(cost, pair_wise_ious, + int(valid.sum()), 10) + # Select positive sample on reference frame, k = 100 + neg_matched = dynamic_k_matching(cost, pair_wise_ious, + int(valid.sum()), 100) + valid_idx = 0 + valid_list = valid.tolist() + for istrue in valid_list: + if istrue: + pos_indices.append(pos_matched[valid_idx]) + neg_indices.append(neg_matched[valid_idx]) + valid_idx += 1 + else: + pos_indices.append(None) + neg_indices.append(None) + + else: + if valid.sum() > 0: + pos_indices = dynamic_k_matching(cost, pair_wise_ious, + gt_bbox.shape[0], 10) + neg_indices = dynamic_k_matching(cost, pair_wise_ious, + gt_bbox.shape[0], 100) + else: + pos_indices = [None] + neg_indices = [None] + + return (pos_indices, neg_indices) + + +def get_in_gt_and_in_center_info(bboxes, target_gts, expanded_strides): + + xy_target_gts = bbox_cxcywh_to_xyxy(target_gts) + + anchor_center_x = bboxes[:, 0].unsqueeze(1) + anchor_center_y = bboxes[:, 1].unsqueeze(1) + + b_l = anchor_center_x > xy_target_gts[:, 0].unsqueeze(0) + b_r = anchor_center_x < xy_target_gts[:, 2].unsqueeze(0) + b_t = anchor_center_y > xy_target_gts[:, 1].unsqueeze(0) + b_b = anchor_center_y < xy_target_gts[:, 3].unsqueeze(0) + is_in_boxes = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4) + is_in_boxes_all = is_in_boxes.sum(1) > 0 + + # in fixed center + center_radius = 2.5 + b_l = anchor_center_x > ( + target_gts[:, 0] - (1 * center_radius / expanded_strides)).unsqueeze(0) + b_r = anchor_center_x < ( + target_gts[:, 0] + (1 * center_radius / expanded_strides)).unsqueeze(0) + b_t = anchor_center_y > ( + target_gts[:, 1] - (1 * center_radius / expanded_strides)).unsqueeze(0) + b_b = anchor_center_y < ( + target_gts[:, 1] + (1 * center_radius / expanded_strides)).unsqueeze(0) + is_in_centers = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4) + is_in_centers_all = is_in_centers.sum(1) > 0 + + # in boxes or in centers, shape: [num_priors] + is_in_boxes_or_centers = is_in_boxes_all | is_in_centers_all + # both in boxes and centers, shape: [num_fg, num_gt] + is_in_boxes_and_center = (is_in_boxes & is_in_centers) + + return is_in_boxes_or_centers, is_in_boxes_and_center + + +def dynamic_k_matching(cost, pair_wise_ious, num_gt, n_candidate_k): + matching_matrix = torch.zeros_like(cost) + ious_in_boxes_matrix = pair_wise_ious + + topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=0) + dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1) + for gt_idx in range(num_gt): + _, pos_idx = torch.topk( + cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False) + matching_matrix[:, gt_idx][pos_idx] = 1.0 + + del topk_ious, dynamic_ks, pos_idx + + anchor_matching_gt = matching_matrix.sum(1) + + if (anchor_matching_gt > 1).sum() > 0: + _, cost_argmin = torch.min(cost[anchor_matching_gt > 1], dim=1) + matching_matrix[anchor_matching_gt > 1] *= 0 + matching_matrix[anchor_matching_gt > 1, cost_argmin, ] = 1 + + while (matching_matrix.sum(0) == 0).any(): + matched_query_id = matching_matrix.sum(1) > 0 + cost[matched_query_id] += 100000.0 + unmatch_id = torch.nonzero( + matching_matrix.sum(0) == 0, as_tuple=False).squeeze(1) + for gt_idx in unmatch_id: + pos_idx = torch.argmin(cost[:, gt_idx]) + matching_matrix[:, gt_idx][pos_idx] = 1.0 + if (matching_matrix.sum(1) > 1).sum() > 0: + _, cost_argmin = torch.min(cost[anchor_matching_gt > 1], dim=1) + matching_matrix[anchor_matching_gt > 1] *= 0 + matching_matrix[anchor_matching_gt > 1, cost_argmin, ] = 1 + + assert not (matching_matrix.sum(0) == 0).any() + + matched_pos = [] + for gt_idx in range(num_gt): + matched_pos.append(matching_matrix[:, gt_idx] > 0) + + return matched_pos diff --git a/projects/VIS_SOTA/IDOL/idol_src/models/sim_ota_assigner.py b/projects/VIS_SOTA/IDOL/idol_src/models/sim_ota_assigner.py new file mode 100644 index 000000000..e88b06743 --- /dev/null +++ b/projects/VIS_SOTA/IDOL/idol_src/models/sim_ota_assigner.py @@ -0,0 +1,211 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torchvision.ops as ops +from mmdet.models.task_modules import AssignResult +from mmdet.models.task_modules import SimOTAAssigner as MMDET_SimOTAAssigner +from mmdet.structures.bbox import bbox_cxcywh_to_xyxy +from mmengine import ConfigDict +from mmengine.structures import InstanceData +from torch import Tensor + +from mmtrack.registry import TASK_UTILS + +INF = 10000.0 +EPS = 1.0e-7 + + +@TASK_UTILS.register_module() +class SimOTAAssigner(MMDET_SimOTAAssigner): + """Computes matching between predictions and ground truth. + + Args: + center_radius (float): Ground truth center size + to judge whether a prior is in center. Defaults to 2.5. + candidate_topk (int): The candidate top-k which used to + get top-k ious to calculate dynamic-k. Defaults to 10. + iou_weight (float): The scale factor for regression + iou cost. Defaults to 3.0. + cls_weight (float): The scale factor for classification + cost. Defaults to 1.0. + iou_calculator (ConfigType): Config of overlaps Calculator. + Defaults to dict(type='BboxOverlaps2D'). + """ + + def __init__(self, + match_costs: Union[List[Union[dict, ConfigDict]], dict, + ConfigDict], + center_radius: float = 2.5, + candidate_topk: int = 10): + + if isinstance(match_costs, dict): + match_costs = [match_costs] + elif isinstance(match_costs, list): + assert len(match_costs) > 0, \ + 'match_costs must not be a empty list.' + + self.center_radius = center_radius + self.candidate_topk = candidate_topk + self.match_costs = [ + TASK_UTILS.build(match_cost) for match_cost in match_costs + ] + + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + img_meta: Optional[dict] = None, + **kwargs) -> AssignResult: + + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + num_gt = gt_bboxes.size(0) + + pred_bboxes = pred_instances.bboxes + priors = pred_instances.bboxes + num_bboxes = pred_bboxes.size(0) + + # assign 0 by default + assigned_gt_inds = pred_bboxes.new_full((num_bboxes, ), + 0, + dtype=torch.long) + assigned_labels = pred_bboxes.new_full((num_bboxes, ), + 0, + dtype=torch.long) + if num_gt == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + if num_gt == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult( + num_gts=num_gt, + gt_inds=assigned_gt_inds, + max_overlaps=None, + labels=assigned_labels) + + valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info( + priors, gt_bboxes) + pairwise_ious = ops.box_iou( + bbox_cxcywh_to_xyxy(pred_bboxes), bbox_cxcywh_to_xyxy(gt_bboxes)) + + # compute weighted cost + cost_list = [] + for match_cost in self.match_costs: + cost = match_cost( + pred_instances=pred_instances, + gt_instances=gt_instances, + img_meta=img_meta) + cost_list.append(cost) + cost = torch.stack(cost_list).sum(dim=0) + cost += 100 * (~is_in_boxes_and_center) + cost[~valid_mask] = cost[~valid_mask] + INF + + fg_mask_inboxes, matched_gt_inds, min_cost_qid = \ + self.dynamic_k_matching( + cost, pairwise_ious, num_gt) + + # convert to AssignResult format + assigned_gt_inds[fg_mask_inboxes] = matched_gt_inds + 1 + assigned_labels[fg_mask_inboxes] = gt_labels[matched_gt_inds].long() + + assign_res = AssignResult( + num_gt, assigned_gt_inds, None, labels=assigned_labels) + assign_res.set_extra_property('query_inds', min_cost_qid) + return assign_res + + def get_in_gt_and_in_center_info( + self, priors: Tensor, gt_bboxes: Tensor) -> Tuple[Tensor, Tensor]: + """Get the information of which prior is in gt bboxes and gt center + priors.""" + num_gt = gt_bboxes.size(0) + + repeated_x = priors[:, 0].unsqueeze(1).repeat(1, num_gt) + repeated_y = priors[:, 1].unsqueeze(1).repeat(1, num_gt) + repeated_stride_x = priors[:, 2].unsqueeze(1).repeat(1, num_gt) + repeated_stride_y = priors[:, 3].unsqueeze(1).repeat(1, num_gt) + + # is prior centers in gt bboxes, shape: [n_prior, n_gt] + l_ = repeated_x - gt_bboxes[:, 0] + t_ = repeated_y - gt_bboxes[:, 1] + r_ = gt_bboxes[:, 2] - repeated_x + b_ = gt_bboxes[:, 3] - repeated_y + + deltas = torch.stack([l_, t_, r_, b_], dim=1) + is_in_gts = deltas.min(dim=1).values > 0 + is_in_gts_all = is_in_gts.sum(dim=1) > 0 + + # is prior centers in gt centers + gt_cxs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0 + gt_cys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0 + ct_box_l = gt_cxs - self.center_radius * repeated_stride_x + ct_box_t = gt_cys - self.center_radius * repeated_stride_y + ct_box_r = gt_cxs + self.center_radius * repeated_stride_x + ct_box_b = gt_cys + self.center_radius * repeated_stride_y + + cl_ = repeated_x - ct_box_l + ct_ = repeated_y - ct_box_t + cr_ = ct_box_r - repeated_x + cb_ = ct_box_b - repeated_y + + ct_deltas = torch.stack([cl_, ct_, cr_, cb_], dim=1) + is_in_cts = ct_deltas.min(dim=1).values > 0 + is_in_cts_all = is_in_cts.sum(dim=1) > 0 + + # in boxes or in centers, shape: [num_priors] + is_in_gts_or_centers = is_in_gts_all | is_in_cts_all + + # both in boxes and centers, shape: [num_fg, num_gt] + is_in_boxes_and_centers = (is_in_gts & is_in_cts) + return is_in_gts_or_centers, is_in_boxes_and_centers + + def dynamic_k_matching(self, cost: Tensor, pairwise_ious: Tensor, + num_gt: int) -> Tuple[Tensor, Tensor]: + """Use IoU and matching cost to calculate the dynamic top-k positive + targets.""" + matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) + # select candidate topk ious for dynamic-k calculation + candidate_topk = min(self.candidate_topk, pairwise_ious.size(0)) + topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0) + # calculate dynamic k for each gt + dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1) + for gt_idx in range(num_gt): + _, pos_idx = torch.topk( + cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False) + matching_matrix[:, gt_idx][pos_idx] = 1 + + del topk_ious, dynamic_ks, pos_idx + + prior_match_gt_mask = matching_matrix.sum(1) > 1 + if prior_match_gt_mask.sum() > 0: + cost_min, cost_argmin = torch.min( + cost[prior_match_gt_mask, :], dim=1) + matching_matrix[prior_match_gt_mask, :] *= 0 + matching_matrix[prior_match_gt_mask, cost_argmin] = 1 + + while (matching_matrix.sum(0) == 0).any(): + matched_query_id = matching_matrix.sum(1) > 0 + cost[matched_query_id] += 100000.0 + unmatch_id = torch.nonzero( + matching_matrix.sum(0) == 0, as_tuple=False).squeeze(1) + for gt_idx in unmatch_id: + pos_idx = torch.argmin(cost[:, gt_idx]) + matching_matrix[:, gt_idx][pos_idx] = 1.0 + if (matching_matrix.sum(1) > + 1).sum() > 0: # If a query matches more than one gt + # find gt for these queries with minimal cost + _, cost_argmin = torch.min( + cost[prior_match_gt_mask > 1], dim=1) + # reset mapping relationship + matching_matrix[prior_match_gt_mask > 1] *= 0 + # keep gt with minimal cost + matching_matrix[prior_match_gt_mask > 1, cost_argmin, ] = 1 + + assert not (matching_matrix.sum(0) == 0).any() + # get foreground mask inside box and center prior + fg_mask_inboxes = matching_matrix.sum(1) > 0 + matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1) + + cost[matching_matrix == 0] = cost[matching_matrix == 0] + float('inf') + min_cost_query_id = torch.min(cost, dim=0)[1] + + return fg_mask_inboxes, matched_gt_inds, min_cost_query_id diff --git a/projects/VIS_SOTA/IDOL/idol_src/models/transformer.py b/projects/VIS_SOTA/IDOL/idol_src/models/transformer.py new file mode 100644 index 000000000..037f7a827 --- /dev/null +++ b/projects/VIS_SOTA/IDOL/idol_src/models/transformer.py @@ -0,0 +1,184 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmdet.models.layers.transformer import \ + DeformableDetrTransformer as MMDET_DeformableDetrTransformer + +from mmtrack.registry import MODELS + + +@MODELS.register_module() +class DeformableDetrTransformer(MMDET_DeformableDetrTransformer): + """Implements the DeformableDETR transformer. + + Rewritten the forward function to return `memory`. + """ + + def forward(self, + mlvl_feats, + mlvl_masks, + query_embed, + mlvl_pos_embeds, + reg_branches=None, + cls_branches=None, + **kwargs): + """Forward function for `Transformer`. + + Args: + mlvl_feats (list(Tensor)): Input queries from + different level. Each element has shape + [bs, embed_dims, h, w]. + mlvl_masks (list(Tensor)): The key_padding_mask from + different level used for encoder and decoder, + each element has shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, + with shape [num_query, c]. + mlvl_pos_embeds (list(Tensor)): The positional encoding + of feats from different level, has the shape + [bs, embed_dims, h, w]. + reg_branches (obj:`nn.ModuleList`): Regression heads for + feature maps from each decoder layer. Only would + be passed when + `with_box_refine` is True. Default to None. + cls_branches (obj:`nn.ModuleList`): Classification heads + for feature maps from each decoder layer. Only would + be passed when `as_two_stage` + is True. Default to None. + + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - inter_states: Outputs from decoder. If + return_intermediate_dec is True output has shape \ + (num_dec_layers, bs, num_query, embed_dims), else has \ + shape (1, bs, num_query, embed_dims). + - memory: Output results from encoder, with shape \ + (h*w, bs, embed_dims). + - init_reference_out: The initial value of reference \ + points, has shape (bs, num_queries, 4). + - inter_references_out: The internal value of reference \ + points in decoder, has shape \ + (num_dec_layers, bs,num_query, embed_dims) + - enc_outputs_class: The classification score of \ + proposals generated from \ + encoder's feature maps, has shape \ + (batch, h*w, num_classes). \ + Only would be returned when `as_two_stage` is True, \ + otherwise None. + - enc_outputs_coord_unact: The regression results \ + generated from encoder's feature maps., has shape \ + (batch, h*w, 4). Only would \ + be returned when `as_two_stage` is True, \ + otherwise None. + """ + assert self.as_two_stage or query_embed is not None + + feat_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + bs, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + feat_flatten.append(feat) + mask_flatten.append(mask) + feat_flatten = torch.cat(feat_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=feat_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack( + [self.get_valid_ratio(m) for m in mlvl_masks], 1) + + reference_points = \ + self.get_reference_points(spatial_shapes, + valid_ratios, + device=feat.device) + + feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute( + 1, 0, 2) # (H*W, bs, embed_dims) + memory = self.encoder( + query=feat_flatten, + key=None, + value=None, + query_pos=lvl_pos_embed_flatten, + query_key_padding_mask=mask_flatten, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + **kwargs) + + memory = memory.permute(1, 0, 2) + bs, _, c = memory.shape + if self.as_two_stage: + output_memory, output_proposals = \ + self.gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes) + enc_outputs_class = cls_branches[self.decoder.num_layers]( + output_memory) + enc_outputs_coord_unact = \ + reg_branches[ + self.decoder.num_layers](output_memory) + output_proposals + + topk = self.two_stage_num_proposals + # We only use the first channel in enc_outputs_class as foreground, + # the other (num_classes - 1) channels are actually not used. + # Its targets are set to be 0s, which indicates the first + # class (foreground) because we use [0, num_classes - 1] to + # indicate class labels, background class is indicated by + # num_classes (similar convention in RPN). + # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa + # This follows the official implementation of Deformable DETR. + topk_proposals = torch.topk( + enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather( + enc_outputs_coord_unact, 1, + topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + init_reference_out = reference_points + pos_trans_out = self.pos_trans_norm( + self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) + query_pos, query = torch.split(pos_trans_out, c, dim=2) + else: + query_pos, query = torch.split(query_embed, c, dim=1) + query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) + query = query.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_pos).sigmoid() + init_reference_out = reference_points + + # decoder + query = query.permute(1, 0, 2) + memory = memory.permute(1, 0, 2) + query_pos = query_pos.permute(1, 0, 2) + inter_states, inter_references = self.decoder( + query=query, + key=None, + value=memory, + query_pos=query_pos, + key_padding_mask=mask_flatten, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=reg_branches, + **kwargs) + + inter_references_out = inter_references + if self.as_two_stage: + return inter_states, memory, init_reference_out,\ + inter_references_out, enc_outputs_class,\ + enc_outputs_coord_unact + return inter_states, memory, init_reference_out, \ + inter_references_out, None, None diff --git a/projects/VIS_SOTA/IDOL/idol_src/models/utils.py b/projects/VIS_SOTA/IDOL/idol_src/models/utils.py new file mode 100644 index 000000000..a9d5ee7cf --- /dev/null +++ b/projects/VIS_SOTA/IDOL/idol_src/models/utils.py @@ -0,0 +1,246 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Linear +from torchvision.ops.boxes import box_area + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def _expand(tensor, length: int): + return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) + + +class MaskHeadSmallConv(nn.Module): + """Simple convolutional head, using group norm. + + Upsampling is done using a FPN approach + """ + + def __init__(self, dim, fpn_dims, context_dim): + super().__init__() + + inter_dims = [ + dim, context_dim, context_dim, context_dim, context_dim, + context_dim + ] + + # used after upsampling to reduce dimension of fused features! + self.lay1 = torch.nn.Conv2d(dim, dim // 4, 3, padding=1) + self.lay2 = torch.nn.Conv2d(dim // 4, dim // 32, 3, padding=1) + self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) + self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) + self.dcn = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) + self.dim = dim + + if fpn_dims is not None: + self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) + self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) + self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) + + for name, m in self.named_modules(): + if name == 'conv_offset': + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + else: + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, a=1) + nn.init.constant_(m.bias, 0) + + def forward(self, x, fpns): + + if fpns is not None: + cur_fpn = self.adapter1(fpns[0]) + if cur_fpn.size(0) != x[-1].size(0): + cur_fpn = _expand(cur_fpn, x[-1].size(0) // cur_fpn.size(0)) + fused_x = (cur_fpn + x[-1]) / 2 + else: + fused_x = x[-1] + fused_x = self.lay3(fused_x) + fused_x = F.relu(fused_x) + + if fpns is not None: + cur_fpn = self.adapter2(fpns[1]) + if cur_fpn.size(0) != x[-2].size(0): + cur_fpn = _expand(cur_fpn, x[-2].size(0) // cur_fpn.size(0)) + fused_x = (cur_fpn + x[-2]) / 2 + F.interpolate( + fused_x, size=cur_fpn.shape[-2:], mode='nearest') + else: + fused_x = x[-2] + F.interpolate( + fused_x, size=x[-2].shape[-2:], mode='nearest') + fused_x = self.lay4(fused_x) + fused_x = F.relu(fused_x) + + if fpns is not None: + cur_fpn = self.adapter3(fpns[2]) + if cur_fpn.size(0) != x[-3].size(0): + cur_fpn = _expand(cur_fpn, x[-3].size(0) // cur_fpn.size(0)) + fused_x = (cur_fpn + x[-3]) / 2 + F.interpolate( + fused_x, size=cur_fpn.shape[-2:], mode='nearest') + else: + fused_x = x[-3] + F.interpolate( + fused_x, size=x[-3].shape[-2:], mode='nearest') + fused_x = self.dcn(fused_x) + fused_x = F.relu(fused_x) + fused_x = self.lay1(fused_x) + fused_x = F.relu(fused_x) + fused_x = self.lay2(fused_x) + fused_x = F.relu(fused_x) + + return fused_x + + +def compute_locations(h, w, device, stride=1): + + shifts_x = torch.arange( + 0, w * stride, step=stride, dtype=torch.float32, device=device) + + shifts_y = torch.arange( + 0, h * stride, step=stride, dtype=torch.float32, device=device) + + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 + + return locations + + +def parse_dynamic_params(params, channels, weight_nums, bias_nums): + + assert params.dim() == 2 + assert len(weight_nums) == len(bias_nums) + assert params.size(1) == sum(weight_nums) + sum(bias_nums) + + num_insts = params.size(0) + num_layers = len(weight_nums) + + params_splits = list( + torch.split_with_sizes(params, weight_nums + bias_nums, dim=1)) + + weight_splits = params_splits[:num_layers] + bias_splits = params_splits[num_layers:] + + for layer in range(num_layers): + if layer < num_layers - 1: + # out_channels x in_channels x 1 x 1 + weight_splits[layer] = weight_splits[layer].reshape( + num_insts * channels, -1, 1, 1) + bias_splits[layer] = bias_splits[layer].reshape(num_insts * + channels) + else: + # out_channels x in_channels x 1 x 1 + weight_splits[layer] = weight_splits[layer].reshape( + num_insts * 1, -1, 1, 1) + bias_splits[layer] = bias_splits[layer].reshape(num_insts) + + return weight_splits, bias_splits + + +def aligned_bilinear(tensor, factor): + + assert tensor.dim() == 4 + assert factor >= 1 + assert int(factor) == factor + + if factor == 1: + return tensor + + h, w = tensor.size()[2:] + tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode='replicate') + oh = factor * h + 1 + ow = factor * w + 1 + tensor = F.interpolate( + tensor, size=(oh, ow), mode='bilinear', align_corners=True) + tensor = F.pad( + tensor, pad=(factor // 2, 0, factor // 2, 0), mode='replicate') + + return tensor[:, :, :oh - 1, :ow - 1] + + +def mask_iou(mask1, mask2): + mask1 = mask1.char() + mask2 = mask2.char() + + intersection = (mask1[:, :, :] * mask2[:, :, :]).sum(-1).sum(-1) + union = (mask1[:, :, :] + mask2[:, :, :] - + mask1[:, :, :] * mask2[:, :, :]).sum(-1).sum(-1) + + return (intersection + 1e-6) / (union + 1e-6) + + +def mask_nms(seg_masks, scores, nms_thr=0.5): + n_samples = len(scores) + if n_samples == 0: + return [] + keep = [True for i in range(n_samples)] + seg_masks = seg_masks.sigmoid() > 0.5 + + for i in range(n_samples - 1): + if not keep[i]: + continue + mask_i = seg_masks[i] + for j in range(i + 1, n_samples, 1): + if not keep[j]: + continue + mask_j = seg_masks[j] + + iou = mask_iou(mask_i, mask_j)[0] + if iou > nms_thr: + keep[j] = False + return keep + + +def box_iou(boxes1, boxes2): + # modified from torchvision to also return the union + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / (area + 1e-7) diff --git a/projects/VIS_SOTA/README.md b/projects/VIS_SOTA/README.md new file mode 100644 index 000000000..7fcc1e53a --- /dev/null +++ b/projects/VIS_SOTA/README.md @@ -0,0 +1,42 @@ +# VIS (Video Instance Segmentation) + +## Introduction + +The goal of VIS task is simultaneous detection, segmentation and tracking of instances in videos. Here we implements `IDOL` based on contrastive learning and `VITA` series algorithms based on `Mask2Former`. Currently it provides advanced online and offline video instance segmentation algorithms. With a commitment to advancing the field of video instance segmentation, we will continually refine and enhance our framework to ensure it is both unified and efficient, providing the necessary nourishment for growth and development in this area. + +In recent years, the online methods for video instance segmentation have witnessed significant advancements, largely attributed to the improvements in image-level object detection algorithms. Meanwhile, semi-online and offline paradigms are tapping into the vast potential offered by the temporal context in multiple frames, offering a more comprehensive approach to video analysis. + +## Requirements + +At the outset of this project, the dependencies used were as follows. Of course, this does not mean that you must strictly use the libraries and algorithms dependencies with the following versions. This is just a recommendation to make your use easier. + +``` +mmcv==2.0.0rc4 +mmdet==3.0.0rc4 +mmengine==0.4.0 +``` + +## Citation + +```BibTeX +@inproceedings{IDOL, + title={In Defense of Online Models for Video Instance Segmentation}, + author={Wu, Junfeng and Liu, Qihao and Jiang, Yi and Bai, Song and Yuille, Alan and Bai, Xiang}, + booktitle={ECCV}, + year={2022}, +} + +@inproceedings{GenVIS, + title={A Generalized Framework for Video Instance Segmentation}, + author={Heo, Miran and Hwang, Sukjun and Hyun, Jeongseok and Kim, Hanjung and Oh, Seoung Wug and Lee, Joon-Young and Kim, Seon Joo}, + booktitle={arXiv preprint arXiv:2211.08834}, + year={2022} +} + +@inproceedings{VITA, + title={VITA: Video Instance Segmentation via Object Token Association}, + author={Heo, Miran and Hwang, Sukjun and Oh, Seoung Wug and Lee, Joon-Young and Kim, Seon Joo}, + booktitle={Advances in Neural Information Processing Systems}, + year={2022} +} +``` diff --git a/projects/VIS_SOTA/VITA/README.md b/projects/VIS_SOTA/VITA/README.md new file mode 100644 index 000000000..5ddae19c1 --- /dev/null +++ b/projects/VIS_SOTA/VITA/README.md @@ -0,0 +1,120 @@ +# VITA: Video Instance Segmentation via Object Token Association + +## Description + +This is an implementation of [VITA](https://github.com/sukjunhwang/VITA) based on [MMTracking](https://github.com/open-mmlab/mmtracking/tree/1.x), [MMDetection](https://github.com/open-mmlab/mmdetection/tree/3.x), [MMCV](https://github.com/open-mmlab/mmcv), and [MMEngine](https://github.com/open-mmlab/mmengine). + +We introduce a novel paradigm for offline Video Instance Segmentation (VIS), based on the hypothesis that explicit object-oriented information can be a strong clue for understanding the context of the entire sequence. To this end, we proposeVITA, a simple structure built on top of an off-the-shelf Transformer-based image instance segmentation model. Specifically, we use an image object detector as a means of distilling object-specific contexts into object tokens. VITA accomplishes video-level understanding by associating frame-level object tokens without using spatio-temporal backbone features. By effectively building relationships between objects using the condensed information, VITA achieves the state-of-the-art on VIS benchmarks with a ResNet-50 backbone: 49.8 AP, 45.7 AP on YouTube-VIS 2019 & 2021 and 19.6 AP on OVIS. Moreover, thanks to its object token-based structure that is disjoint from the backbone features, VITA shows several practical advantages that previous offline VIS methods have not explored - handling long and high-resolution videos with a common GPU and freezing a frame-level detector trained on image domain. + +
+ +
+ +## Usage + + + +### Training commands + +In MMTracking's root directory, run the following command to train the model: + +```bash +python tools/train.py projects/VIS_SOTA/IDOL/configs/vita_r50_8xb2-8e_youtubevis2019.py +``` + +For multi-gpu training, run: + +```bash +python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${NUM_GPUS} --master_port=29506 --master_addr="127.0.0.1" tools/train.py projects/VIS_SOTA/IDOL/configs/vita_r50_8xb2-8e_youtubevis2019.py +``` + +### Testing commands + +In MMTracking's root directory, run the following command to test the model: + +```bash +python tools/test.py projects/VIS_SOTA/IDOL/configs/vita_r50_8xb2-8e_youtubevis2019.py ${CHECKPOINT_PATH} +``` + +## Results + +#### YouTubeVIS-2019 + +| Method | Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | AP | Config | Download | +| :----: | :------: | :-----: | :-----: | :------: | :------------: | :--: | :-------------------------------------------------------------------------: | :----------------------: | +| VITA | R-50 | pytorch | 140k | 10.0 | - | 50.3 | [config](projects/VIS_SOTA/VITA/configs/vita_r50_8xb2-8e_youtubevis2019.py) | [model](<>) \| [log](<>) | + +#### OVIS + +| Method | Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | AP | Config | Download | +| :----: | :------: | :-----: | :-----: | :------: | :------------: | :--: | :----------: | :----------------------: | +| VITA | R-50 | pytorch | 150k | 10.0 | - | 19.2 | [config](<>) | [model](<>) \| [log](<>) | + +## Citation + +If you find VITA is useful in your research or applications, please consider giving a star 🌟 to the [official repository](https://github.com/sukjunhwang/VITA) and citing VITA by the following BibTeX entry. + +```BibTeX +@inproceedings{VITA, + title={VITA: Video Instance Segmentation via Object Token Association}, + author={Heo, Miran and Hwang, Sukjun and Oh, Seoung Wug and Lee, Joon-Young and Kim, Seon Joo}, + booktitle={Advances in Neural Information Processing Systems}, + year={2022} +} + +``` + +## Checklist + + + +- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [x] Finish the code + + + + - [x] Basic docstrings & proper citation + + + + - [x] Test-time correctness + + + + - [x] A full README + + + +- [x] Milestone 2: Indicates a successful model implementation. + + - [x] Training-time correctness + + + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Type hints and docstrings + + + + - [ ] Unit tests + + + + - [ ] Code polishing + + + + - [ ] Metafile.yml + + + +- [ ] Move your modules into the core package following the codebase's file hierarchy structure. + + + +- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/VIS_SOTA/VITA/configs/vita_r50_8xb2-8e_youtubevis2019.py b/projects/VIS_SOTA/VITA/configs/vita_r50_8xb2-8e_youtubevis2019.py new file mode 100644 index 000000000..d8434468b --- /dev/null +++ b/projects/VIS_SOTA/VITA/configs/vita_r50_8xb2-8e_youtubevis2019.py @@ -0,0 +1,279 @@ +_base_ = [ + '../../../../configs/_base_/datasets/youtube_vis.py', + '../../../../configs/_base_/default_runtime.py', +] + +custom_imports = dict(imports=['projects.VIS_SOTA.VITA.vita_src'], ) + +num_classes = 40 +num_frames = 2 +model = dict( + type='VITA', + data_preprocessor=dict( + type='TrackDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_mask=True, + pad_size_divisor=32), + backbone=dict( + _scope_='mmdet', + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + seg_head=dict( + type='mmtrack.VITASegHead', + in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_classes, + num_stuff_classes=0, + num_queries=100, + num_frames=num_frames, + num_transformer_feat_level=3, + pixel_decoder=dict( + type='mmtrack.VITAPixelDecoder', + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='mmdet.DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=128, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='mmdet.SinePositionalEncoding', + num_feats=128, + normalize=True), + init_cfg=None), + enforce_decoder_input_project=False, + positional_encoding=dict( + type='mmdet.SinePositionalEncoding', num_feats=128, + normalize=True), + transformer_decoder=dict( + type='mmdet.DetrTransformerDecoder', + return_intermediate=True, + num_layers=9, + transformerlayers=dict( + type='mmdet.DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + feedforward_channels=2048, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None), + loss_cls=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * num_classes + [0.1]), + loss_mask=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type='mmdet.DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0)), + track_head=dict( + type='VITATrackHead', + num_classes=num_classes, + frame_query_encoder=dict( + type='mmdet.DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=False), + ffn_cfgs=dict( + type='FFN', + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + frame_query_decoder=dict( + type='mmdet.DetrTransformerDecoder', + return_intermediate=True, + num_layers=3, + transformerlayers=dict( + type='mmdet.DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + feedforward_channels=2048, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None), + ), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type='mmdet.HungarianAssigner', + match_costs=[ + dict(type='mmdet.ClassificationCost', weight=2.0), + dict( + type='mmdet.CrossEntropyLossCost', + weight=5.0, + use_sigmoid=True), + dict( + type='mmdet.DiceCost', weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type='mmdet.MaskPseudoSampler')), + test_cfg=dict( + test_run_chunk_size=18, + test_interpolate_chunk_size=5, + max_per_image=100, + )) + +# optimizer +embed_multi = dict(lr_mult=1.0, decay_mult=0.0) +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', + lr=0.0001, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999)), + paramwise_cfg=dict( + custom_keys={ + 'backbone': dict(lr_mult=0.1, decay_mult=1.0), + 'query_embed': embed_multi, + 'query_feat': embed_multi, + 'level_embed': embed_multi, + }, + norm_decay_mult=0.0), + clip_grad=dict(max_norm=0.01, norm_type=2)) + +# learning policy +max_iters = 6000 +param_scheduler = dict( + type='MultiStepLR', + begin=0, + end=max_iters, + by_epoch=False, + milestones=[ + 4000, + ], + gamma=0.1) +# runtime settings +train_cfg = dict( + type='IterBasedTrainLoop', max_iters=max_iters, val_interval=6001) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', by_epoch=False, save_last=True, interval=2000)) +log_processor = dict(type='LogProcessor', window_size=50, by_epoch=False) + +train_pipeline = [ + dict( + type='TransformBroadcaster', + share_random_params=True, + transforms=[ + dict(type='LoadImageFromFile'), + dict( + type='LoadTrackAnnotations', + with_instance_id=True, + with_mask=True, + with_bbox=True), + dict(type='mmdet.Resize', scale=(640, 360), keep_ratio=True), + dict(type='mmdet.RandomFlip', prob=0.5), + ]), + dict(type='PackTrackInputs', num_key_frames=num_frames) +] +# dataloader +train_dataloader = dict( + batch_size=2, + num_workers=2, + dataset=dict( + pipeline=train_pipeline, + ref_img_sampler=dict( + num_ref_imgs=1, + frame_range=5, + filter_key_img=True, + method='uniform'))) +val_dataloader = dict( + num_workers=2, + sampler=dict(type='VideoSampler'), + batch_sampler=dict(type='EntireVideoBatchSampler'), +) +test_dataloader = val_dataloader + +# evaluator +val_evaluator = dict( + type='YouTubeVISMetric', + metric='youtube_vis_ap', + outfile_prefix='./youtube_vis_results', + format_only=True) +test_evaluator = val_evaluator diff --git a/projects/VIS_SOTA/VITA/vita_src/__init__.py b/projects/VIS_SOTA/VITA/vita_src/__init__.py new file mode 100644 index 000000000..3fa61d766 --- /dev/null +++ b/projects/VIS_SOTA/VITA/vita_src/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .models import VITATrackHead +from .vita import VITA + +__all__ = ['VITA', 'VITATrackHead'] diff --git a/projects/VIS_SOTA/VITA/vita_src/models/__init__.py b/projects/VIS_SOTA/VITA/vita_src/models/__init__.py new file mode 100644 index 000000000..1a3f9cdb9 --- /dev/null +++ b/projects/VIS_SOTA/VITA/vita_src/models/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .vita_pixel_decoder import VITAPixelDecoder +from .vita_query_track_head import VITATrackHead +from .vita_seg_head import VITASegHead + +__all__ = ['VITATrackHead', 'VITASegHead', 'VITAPixelDecoder'] diff --git a/projects/VIS_SOTA/VITA/vita_src/models/vita_pixel_decoder.py b/projects/VIS_SOTA/VITA/vita_src/models/vita_pixel_decoder.py new file mode 100644 index 000000000..dd10c5c29 --- /dev/null +++ b/projects/VIS_SOTA/VITA/vita_src/models/vita_pixel_decoder.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from mmdet.models.layers import \ + MSDeformAttnPixelDecoder as MMDET_MSDeformAttnPixelDecoder +from torch import Tensor + +from mmtrack.registry import MODELS + + +@MODELS.register_module() +class VITAPixelDecoder(MMDET_MSDeformAttnPixelDecoder): + + def forward(self, feats: List[Tensor]) -> Tuple[Tensor, Tensor]: + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of (batch_size, c, h, w). + + Returns: + tuple: A tuple containing the following: + + - mask_feature (Tensor): shape (batch_size, c, h, w). + - multi_scale_features (list[Tensor]): Multi scale \ + features, each in shape (batch_size, c, h, w). + """ + # generate padding mask for each level, for each image + batch_size = feats[0].shape[0] + encoder_input_list = [] + padding_mask_list = [] + level_positional_encoding_list = [] + spatial_shapes = [] + reference_points_list = [] + for i in range(self.num_encoder_levels): + level_idx = self.num_input_levels - i - 1 + feat = feats[level_idx] + feat_projected = self.input_convs[i](feat) + h, w = feat.shape[-2:] + + # no padding + padding_mask_resized = feat.new_zeros( + (batch_size, ) + feat.shape[-2:], dtype=torch.bool) + pos_embed = self.postional_encoding(padding_mask_resized) + level_embed = self.level_encoding.weight[i] + level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed + # (h_i * w_i, 2) + reference_points = self.point_generator.single_level_grid_priors( + feat.shape[-2:], level_idx, device=feat.device) + # normalize + factor = feat.new_tensor([[w, h]]) * self.strides[level_idx] + reference_points = reference_points / factor + + # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c) + feat_projected = feat_projected.flatten(2).permute(2, 0, 1) + level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1) + padding_mask_resized = padding_mask_resized.flatten(1) + + encoder_input_list.append(feat_projected) + padding_mask_list.append(padding_mask_resized) + level_positional_encoding_list.append(level_pos_embed) + spatial_shapes.append(feat.shape[-2:]) + reference_points_list.append(reference_points) + # shape (batch_size, total_num_query), + # total_num_query=sum([., h_i * w_i,.]) + padding_masks = torch.cat(padding_mask_list, dim=1) + # shape (total_num_query, batch_size, c) + encoder_inputs = torch.cat(encoder_input_list, dim=0) + level_positional_encodings = torch.cat( + level_positional_encoding_list, dim=0) + device = encoder_inputs.device + # shape (num_encoder_levels, 2), from low + # resolution to high resolution + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=device) + # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = torch.cat(reference_points_list, dim=0) + reference_points = reference_points[None, :, None].repeat( + batch_size, 1, self.num_encoder_levels, 1) + valid_radios = reference_points.new_ones( + (batch_size, self.num_encoder_levels, 2)) + # shape (num_total_query, batch_size, c) + memory = self.encoder( + query=encoder_inputs, + key=None, + value=None, + query_pos=level_positional_encodings, + key_pos=None, + attn_masks=None, + key_padding_mask=None, + query_key_padding_mask=padding_masks, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_radios=valid_radios) + # (num_total_query, batch_size, c) -> (batch_size, c, num_total_query) + memory = memory.permute(1, 2, 0) + + # from low resolution to high resolution + num_query_per_level = [e[0] * e[1] for e in spatial_shapes] + outs = torch.split(memory, num_query_per_level, dim=-1) + outs = [ + x.reshape(batch_size, -1, spatial_shapes[i][0], + spatial_shapes[i][1]) for i, x in enumerate(outs) + ] + + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, + -1): + x = feats[i] + cur_feat = self.lateral_convs[i](x) + y = cur_feat + F.interpolate( + outs[-1], + size=cur_feat.shape[-2:], + mode='bilinear', + align_corners=False) + y = self.output_convs[i](y) + outs.append(y) + multi_scale_features = outs[:self.num_outs] + + clip_mask_feature = outs[-1] + mask_feature = self.mask_feature(outs[-1]) + return mask_feature, clip_mask_feature, multi_scale_features diff --git a/projects/VIS_SOTA/VITA/vita_src/models/vita_query_track_head.py b/projects/VIS_SOTA/VITA/vita_src/models/vita_query_track_head.py new file mode 100644 index 000000000..49a2aa8cd --- /dev/null +++ b/projects/VIS_SOTA/VITA/vita_src/models/vita_query_track_head.py @@ -0,0 +1,390 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from math import ceil +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d +from mmdet.structures import SampleList +from mmdet.structures.mask import mask2bbox +from mmdet.utils import ConfigType, InstanceList, OptConfigType +from mmengine.model import BaseModule +from mmengine.structures import InstanceData +from torch import Tensor + +from mmtrack.registry import MODELS + + +@MODELS.register_module() +class VITATrackHead(BaseModule): + + def __init__(self, + mask_dim: int = 256, + enc_window_size: int = 6, + use_sim: bool = True, + enforce_input_project: bool = True, + sim_use_clip: bool = True, + num_heads: int = 8, + hidden_dim: int = 256, + num_queries: int = 100, + num_classes: int = 40, + num_frame_queries: int = 100, + frame_query_encoder: ConfigType = ..., + frame_query_decoder: ConfigType = ..., + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + *args) -> None: + super(VITATrackHead, self).__init__() + self.window_size = enc_window_size + self.vita_mask_features = Conv2d( + in_channels=mask_dim, + out_channels=mask_dim, + kernel_size=1, + stride=1, + bias=True) + self.frame_query_encoder = MODELS.build(frame_query_encoder) + self.frame_query_decoder = MODELS.build(frame_query_decoder) + self.num_heads = num_heads + self.sim_use_clip = sim_use_clip + self.use_sim = use_sim + self.num_queries = num_queries + self.num_classes = num_classes + self.num_transformer_decoder_layers = frame_query_decoder.num_layers + + # learnable query features + self.query_feat = nn.Embedding(num_queries, hidden_dim) + # learnable query p.e. + self.query_embed = nn.Embedding(num_queries, hidden_dim) + + self.fq_pos = nn.Embedding(num_frame_queries, hidden_dim) + + self.cls_embed = nn.Linear(hidden_dim, num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim)) + + if self.use_sim: + self.sim_embed_frame = nn.Linear(hidden_dim, hidden_dim) + if self.sim_use_clip: + self.sim_embed_clip = nn.Linear(hidden_dim, hidden_dim) + + if enforce_input_project: + self.input_proj_dec = nn.Linear(hidden_dim, hidden_dim) + else: + self.input_proj_dec = nn.Sequential() + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def loss( + self, + x: Tuple[Tensor], + data_samples: SampleList, + ) -> Dict[str, Tensor]: + """Perform forward propagation and loss calculation of the track head + on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + data_samples (List[:obj:`TrackDataSample`]): The Data + Samples. It usually includes information such as `gt_instance`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + batch_img_metas = [] + batch_gt_instances = [] + + for data_sample in data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + # forward + all_cls_scores, all_mask_preds = self(x, data_samples) + + # preprocess ground truth + batch_gt_instances = self.preprocess_gt(batch_gt_instances) + # loss + losses = self.loss_by_feat(all_cls_scores, all_mask_preds, + batch_gt_instances, batch_img_metas) + + return losses + + def forward(self, frame_queries: Tensor) -> Tuple[Tensor, ...]: + """Forward function. + + L: Number of Layers. + B: Batch size. + T: Temporal window size. Number of frames per video. + C: Channel size. + fQ: Number of frame-wise queries from IFC. + cQ: Number of clip-wise queries to decode Q. + """ + L, BT, fQ, C = frame_queries.shape + B = BT // self.num_frames if self.training else 1 + T = self.num_frames if self.training else BT // B + + frame_queries = frame_queries.reshape(L * B, T, fQ, C) + frame_queries = frame_queries.permute(1, 2, 0, 3).contiguous() + frame_queries = self.input_proj_dec(frame_queries) + + # for window attention + if self.window_size > 0: + pad = int(ceil(T / self.window_size)) * self.window_size - T + _T = pad + T + frame_queries = F.pad(frame_queries, (0, 0, 0, 0, 0, 0, 0, pad)) + enc_mask = frame_queries.new_ones(L * B, _T).bool() + enc_mask[:, :T] = False + else: + enc_mask = None + + frame_queries = self.encode_frame_query(frame_queries, enc_mask) + # (LB, T*fQ, C) + frame_queries = frame_queries[:T].flatten(0, 1) + + if self.use_sim: + fq_embed = self.sim_embed_frame(frame_queries) + fq_embed = fq_embed.transpose(0, 1).reshape(L, B, T, fQ, C) + else: + fq_embed = None + + dec_pos = self.fq_pos.weight[None, :, None, :].repeat(T, 1, L * B, + 1).flatten(0, 1) + + query_feat = self.query_feat.weight.unsqueeze(1).repeat((1, L * B, 1)) + query_embed = self.query_embed.weight.unsqueeze(1).repeat( + (1, L * B, 1)) + + decoder_outputs = [] + for i in range(self.num_transformer_decoder_layers): + # cross_attn + self_attn + layer = self.frame_query_decoder.layers[i] + query_feat = layer( + query=query_feat, + key=frame_queries, + value=frame_queries, + query_pos=query_embed, + key_pos=dec_pos, + attn_masks=None, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None) + if self.training or (i == self.num_transformer_decoder_layers - 1): + decoder_out = self.frame_query_decoder.post_norm(query_feat) + decoder_out = decoder_out.transpose(0, 1) + decoder_outputs.append( + decoder_out.view(L, B, self.num_queries, C)) + + decoder_outputs = torch.stack(decoder_outputs, dim=0) + + all_cls_pred = self.cls_embed(decoder_outputs) + all_mask_embed = self.mask_embed(decoder_outputs) + if self.use_sim and self.sim_use_clip: + all_cq_embed = self.sim_embed_clip(decoder_outputs) + else: + all_cq_embed = [None] * self.num_transformer_decoder_layers + + return all_cls_pred, all_mask_embed, all_cq_embed, fq_embed + + def encode_frame_query(self, frame_queries, attn_mask): + # Not using window-based attention if self.window_size == 0. + if self.window_size == 0: + return_shape = frame_queries.shape + # (T, fQ, LB, C) -> (T*fQ, LB, C) + frame_queries = frame_queries.flatten(0, 1) + + # TODO: add + frame_queries = frame_queries.view(return_shape) + return frame_queries + # Using window-based attention if self.window_size > 0. + else: + T, fQ, LB, C = frame_queries.shape + win_s = self.window_size + num_win = T // win_s + half_win_s = int(ceil(win_s / 2)) + + window_mask = attn_mask.view(LB * num_win, + win_s)[..., + None].repeat(1, 1, + fQ).flatten(1) + + _attn_mask = torch.roll(attn_mask, half_win_s, 1) + _attn_mask = _attn_mask.view(LB, num_win, + win_s)[..., + None].repeat(1, 1, 1, win_s) + _attn_mask[:, 0] = _attn_mask[:, 0] | _attn_mask[:, 0].transpose( + -2, -1) + _attn_mask[:, + -1] = _attn_mask[:, -1] | _attn_mask[:, -1].transpose( + -2, -1) + _attn_mask[:, 0, :half_win_s, half_win_s:] = True + _attn_mask[:, 0, half_win_s:, :half_win_s] = True + _attn_mask = _attn_mask.view( + LB * num_win, 1, win_s, 1, win_s, + 1).repeat(1, self.num_heads, 1, fQ, 1, + fQ).view(LB * num_win * self.num_heads, win_s * fQ, + win_s * fQ) + shift_window_mask = _attn_mask.float() * -1000 + + for layer_idx in range(self.frame_query_encoder.num_layers): + if self.training or layer_idx % 2 == 0: + frame_queries = self._window_attn(frame_queries, + window_mask, layer_idx) + else: + frame_queries = self._shift_window_attn( + frame_queries, shift_window_mask, layer_idx) + return frame_queries + + def _window_attn(self, frame_queries, attn_mask, layer_idx): + T, fQ, LB, C = frame_queries.shape + + win_s = self.window_size + num_win = T // win_s + + frame_queries = frame_queries.view(num_win, win_s, fQ, LB, C) + frame_queries = frame_queries.permute(1, 2, 3, 0, 4).reshape( + win_s * fQ, LB * num_win, C) + + frame_queries = self.frame_query_encoder.layers[layer_idx]( + frame_queries, query_key_padding_mask=attn_mask) + + frame_queries = frame_queries.reshape(win_s, fQ, LB, num_win, + C).permute(3, 0, 1, 2, + 4).reshape( + T, fQ, LB, C) + return frame_queries + + def _shift_window_attn(self, frame_queries, attn_mask, layer_idx): + T, fQ, LB, C = frame_queries.shape + + win_s = self.window_size + num_win = T // win_s + half_win_s = int(ceil(win_s / 2)) + + frame_queries = torch.roll(frame_queries, half_win_s, 0) + frame_queries = frame_queries.view(num_win, win_s, fQ, LB, C) + frame_queries = frame_queries.permute(1, 2, 3, 0, 4).reshape( + win_s * fQ, LB * num_win, C) + + frame_queries = self.frame_query_encoder.layers[layer_idx]( + frame_queries, attn_masks=attn_mask) + frame_queries = frame_queries.reshape(win_s, fQ, LB, num_win, + C).permute(3, 0, 1, 2, + 4).reshape( + T, fQ, LB, C) + + frame_queries = torch.roll(frame_queries, -half_win_s, 0) + + return frame_queries + + def predict(self, + mask_features: Tensor, + frame_queries: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> InstanceList: + + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + all_cls_pred, all_mask_embed, _, _ = self(frame_queries) + + results = self.predict_by_feat( + all_cls_pred, + all_mask_embed, + mask_features, + batch_img_metas=batch_img_metas, + rescale=rescale) + return results + + def predict_by_feat(self, + all_cls_pred: Tensor, + all_mask_embed: Tensor, + mask_features: Tensor, + batch_img_metas: List[Dict], + rescale: bool = True) -> InstanceList: + + cls_pred = all_cls_pred[-1, -1, 0] + mask_embed = all_mask_embed[-1, -1, 0] + # The input is a video, and a batch is a video, + # so the img shape of each image is the same. + # Here is the first image. + img_meta = batch_img_metas[0] + + scores = F.softmax(cls_pred, dim=-1)[:, :-1] + + max_per_image = self.test_cfg.get('max_per_video', 10) + test_interpolate_chunk_size = self.test_cfg.get( + 'test_interpolate_chunk_size', 5) + scores_per_video, topk_indices = scores.flatten(0, 1).topk( + max_per_image, sorted=False) + labels = torch.arange( + self.num_classes, + device=cls_pred.device).unsqueeze(0).repeat(self.num_queries, + 1).flatten(0, 1) + labels_per_video = labels[topk_indices] + + query_indices = topk_indices // self.num_classes + mask_embed = mask_embed[query_indices] + + masks_per_video = [] + numerator = torch.zeros( + len(mask_embed), dtype=torch.float, device=cls_pred.device) + denominator = torch.zeros( + len(mask_embed), dtype=torch.float, device=cls_pred.device) + for i in range(ceil(len(mask_features) / test_interpolate_chunk_size)): + mask_feat = mask_features[i * test_interpolate_chunk_size:(i + 1) * + test_interpolate_chunk_size].to( + cls_pred.device) + + mask_pred = torch.einsum('qc,tchw->qthw', mask_embed, mask_feat) + + pad_height, pad_width = img_meta['pad_shape'] + rz_height, rz_width = img_meta['img_shape'] + # upsample masks + mask_pred = F.interpolate( + mask_pred, + size=(pad_height, pad_width), + mode='bilinear', + align_corners=False) + + # crop the padding area + mask_pred = mask_pred[:, :, :rz_height, :rz_width] + ori_height, ori_width = img_meta['ori_shape'] + + interim_mask_soft = mask_pred.sigmoid() + interim_mask_hard = interim_mask_soft > 0.5 + + numerator += (interim_mask_soft.flatten(1) * + interim_mask_hard.flatten(1)).sum(1) + denominator += interim_mask_hard.flatten(1).sum(1) + + mask_pred = F.interpolate( + mask_pred, + size=(ori_height, ori_width), + mode='bilinear', + align_corners=False) > 0. + + masks_per_video.append(mask_pred) + + masks_per_video = torch.cat(masks_per_video, dim=1) + scores_per_video *= (numerator / (denominator + 1e-6)) + + # format top-10 predictions + results = [] + for img_idx in range(len(batch_img_metas)): + pred_track_instances = InstanceData() + + pred_track_instances.masks = masks_per_video[:, img_idx] + pred_track_instances.bboxes = mask2bbox(masks_per_video[:, + img_idx]) + pred_track_instances.labels = labels_per_video + pred_track_instances.scores = scores_per_video + pred_track_instances.instances_id = torch.arange(10) + + results.append(pred_track_instances) + + return results diff --git a/projects/VIS_SOTA/VITA/vita_src/models/vita_seg_head.py b/projects/VIS_SOTA/VITA/vita_src/models/vita_seg_head.py new file mode 100644 index 000000000..75f6c59c6 --- /dev/null +++ b/projects/VIS_SOTA/VITA/vita_src/models/vita_seg_head.py @@ -0,0 +1,221 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +import torch +from mmdet.models.dense_heads import Mask2FormerHead as MMDET_Mask2FormerHead +from mmdet.structures import SampleList +from mmengine.structures import InstanceData +from torch import Tensor +from torch.nn import functional as F + +from mmtrack.registry import MODELS +from mmtrack.utils import InstanceList + + +@MODELS.register_module() +class VITASegHead(MMDET_Mask2FormerHead): + + def __init__(self, num_frames: int = 2, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.num_frames = num_frames + + def preprocess_gt(self, batch_gt_instances: InstanceList) -> InstanceList: + """Preprocess the ground truth for all images.""" + final_batch_gt_instances = [] + for gt_instances in batch_gt_instances: + _device = gt_instances.labels.device + gt_instances.masks = gt_instances.masks.to_tensor( + dtype=torch.bool, device=_device) + # a list used to record which image each instance belongs to + map_info = gt_instances.map_instances_to_img_idx + for frame_id in range(self.num_frames): + ins_index = (map_info == frame_id) + per_frame_gt = gt_instances[ins_index] + tmp_instances = InstanceData( + labels=per_frame_gt.labels, + masks=per_frame_gt.masks.long()) + final_batch_gt_instances.append(tmp_instances) + + return final_batch_gt_instances + + def forward(self, x: List[Tensor]) -> Tuple[List[Tensor]]: + """Forward function. + + Overwriting here is mainly for VITA. + """ + batch_size = x[0].size(0) + mask_features, clip_mask_features, multi_scale_memorys = \ + self.pixel_decoder(x) + # multi_scale_memorys (from low resolution to high resolution) + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + decoder_input = decoder_input.flatten(2).permute(2, 0, 1) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + mask = decoder_input.new_zeros( + (batch_size, ) + multi_scale_memorys[i].shape[-2:], + dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding( + mask) + decoder_positional_encoding = decoder_positional_encoding.flatten( + 2).permute(2, 0, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + # shape (num_queries, c) -> (num_queries, batch_size, c) + query_feat = self.query_feat.weight.unsqueeze(1).repeat( + (1, batch_size, 1)) + query_embed = self.query_embed.weight.unsqueeze(1).repeat( + (1, batch_size, 1)) + + frame_query_list = [] + cls_pred_list = [] + mask_pred_list = [] + cls_pred, mask_pred, attn_mask, frame_query = self._forward_head( + query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where( + attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + layer = self.transformer_decoder.layers[i] + attn_masks = [attn_mask, None] + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + attn_masks=attn_masks, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None) + cls_pred, mask_pred, attn_mask, frame_query = self._forward_head( + query_feat, mask_features, multi_scale_memorys[ + (i + 1) % self.num_transformer_feat_level].shape[-2:]) + + frame_query_list.append(frame_query) + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + # we only need frame query for VITA + return cls_pred_list, mask_pred_list, \ + frame_query_list, clip_mask_features + + def _forward_head(self, decoder_out: Tensor, mask_feature: Tensor, + attn_mask_target_size: Tuple[int, int]) -> Tuple[Tensor]: + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (num_queries, batch_size, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + decoder_out = decoder_out.transpose(0, 1) + # shape (num_queries, batch_size, c) + cls_pred = self.cls_embed(decoder_out) + # shape (num_queries, batch_size, c) + mask_embed = self.mask_embed(decoder_out) + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) + attn_mask = F.interpolate( + mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( + (1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return cls_pred, mask_pred, attn_mask, decoder_out + + def loss( + self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + ) -> Dict[str, Tensor]: + """Perform forward propagation and loss calculation of the panoptic + head on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + batch_img_metas = [] + batch_gt_instances = [] + for data_sample in batch_data_samples: + for _ in range(self.num_frames): + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + # fix batch_img_metas, these keys will be used in self.loss_by_feat + key_list = [ + 'batch_input_shape', 'pad_shape', 'img_shape', 'scale_factor' + ] + for key in key_list: + for batch_id in range(len(batch_img_metas)): + batch_img_metas[batch_id][key] = batch_img_metas[batch_id][ + key][0] + + # forward + all_cls_scores, all_mask_preds, all_frame_queries, mask_features = \ + self(x) + + # preprocess ground truth + batch_gt_instances = self.preprocess_gt(batch_gt_instances) + + # loss + losses = self.loss_by_feat(all_cls_scores, all_mask_preds, + batch_gt_instances, batch_img_metas) + + return losses + + def predict(self, x: Tuple[Tensor]) -> Tuple[Tensor]: + """Test without augmentaton. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + + Returns: + tuple[Tensor]: A tuple contains two tensors. + + - mask_cls_results (Tensor): Mask classification logits,\ + shape (batch_size, num_queries, cls_out_channels). + Note `cls_out_channels` should includes background. + - mask_pred_results (Tensor): Mask logits, shape \ + (batch_size, num_queries, h, w). + """ + all_frame_queries, mask_features = self(x) + + frame_queries = all_frame_queries[-1] + + return frame_queries, mask_features diff --git a/projects/VIS_SOTA/VITA/vita_src/vita.py b/projects/VIS_SOTA/VITA/vita_src/vita.py new file mode 100644 index 000000000..5f56e165a --- /dev/null +++ b/projects/VIS_SOTA/VITA/vita_src/vita.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, Optional, Union + +import torch +from torch import Tensor + +from mmtrack.models.mot import BaseMultiObjectTracker +from mmtrack.registry import MODELS +from mmtrack.utils import OptConfigType, OptMultiConfig, SampleList + + +@MODELS.register_module() +class VITA(BaseMultiObjectTracker): + + def __init__(self, + backbone: Optional[dict] = None, + seg_head: Optional[dict] = None, + track_head: Optional[dict] = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__(data_preprocessor, init_cfg) + + seg_head_ = seg_head.deepcopy() + track_head_ = track_head.deepcopy() + seg_head_.update(train_cfg=train_cfg) + track_head_.update(test_cfg=test_cfg) + + if backbone is not None: + self.backbone = MODELS.build(backbone) + + if seg_head is not None: + self.seg_head = MODELS.build(seg_head_) + + if track_head is not None: + self.track_head = MODELS.build(track_head_) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def loss(self, inputs: Dict[str, Tensor], data_samples: SampleList, + **kwargs) -> Union[dict, tuple]: + """ + Args: + inputs (Tensor): Input images of shape (N, T, C, H, W). + These should usually be mean centered and std scaled. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + img = inputs['img'] + assert img.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' + # shape (N * T, C, H, W) + img = img.flatten(0, 1) + + feats = self.backbone(img) + losses = self.seg_head.loss(feats, data_samples) + + return losses + + def predict(self, + inputs: dict, + data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with + postprocessing. + + Args: + inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) + encoding input images. Typically, these should be mean centered + and std scaled. The N denotes batch size. The T denotes the + number of key/reference frames. + - img (Tensor) : The key images. + - ref_img (Tensor): The reference images. + In test mode, T = 1 and there is only ``img`` and no + ``ref_img``. + data_samples (list[:obj:`TrackDataSample`]): The batch + data samples. It usually includes information such + as ``gt_instances`` and 'metainfo'. + rescale (bool, Optional): If False, then returned bboxes and masks + will fit the scale of img, otherwise, returned bboxes and masks + will fit the scale of original image shape. Defaults to True. + + Returns: + SampleList: Tracking results of the input images. + Each TrackDataSample usually contains ``pred_track_instances``. + """ + img = inputs['img'] + assert img.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' + # the "T" is 1 + img = img.squeeze(1) + num_frames = img.size(0) + frame_queries, mask_features = [], [] + + test_run_chunk_size = self.test_cfg.get('test_run_chunk_size', 18) + for i in range(math.ceil(num_frames / test_run_chunk_size)): + clip_imgs = img[i * test_run_chunk_size:(i + 1) * + test_run_chunk_size] + + feats = self.backbone(clip_imgs) + _frame_queries, _mask_features = self.seg_head.predict(feats) + # just a conv2d + _mask_features = self.track_head.vita_mask_features(_mask_features) + + frame_queries.append(_frame_queries) + mask_features.append(_mask_features) + + frame_queries = torch.cat(frame_queries)[None] + mask_features = torch.cat(mask_features) + pred_track_ins_list = self.track_head.predict(mask_features, + frame_queries, + data_samples, rescale) + + results = [] + for idx, pred_track_ins in enumerate(pred_track_ins_list): + track_data_sample = data_samples[idx] + track_data_sample.pred_track_instances = pred_track_ins + results.append(track_data_sample) + + return results