diff --git a/configs/boxinst/README.md b/configs/boxinst/README.md new file mode 100644 index 00000000000..6f015a1d16b --- /dev/null +++ b/configs/boxinst/README.md @@ -0,0 +1,31 @@ +# BoxInst + +> [BoxInst: High-Performance Instance Segmentation with Box Annotations](https://arxiv.org/pdf/2012.02310.pdf) + + + +## Abstract + +We present a high-performance method that can achieve mask-level instance segmentation with only bounding-box annotations for training. While this setting has been studied in the literature, here we show significantly stronger performance with a simple design (e.g., dramatically improving previous best reported mask AP of 21.1% to 31.6% on the COCO dataset). Our core idea is to redesign the loss +of learning masks in instance segmentation, with no modification to the segmentation network itself. The new loss functions can supervise the mask training without relying on mask annotations. This is made possible with two loss terms, namely, 1) a surrogate term that minimizes the discrepancy between the projections of the ground-truth box and the predicted mask; 2) a pairwise loss that can exploit the prior that proximal pixels with similar colors are very likely to have the same category label. Experiments demonstrate that the redesigned mask loss can yield surprisingly high-quality instance masks with only box annotations. For example, without using any mask annotations, with a ResNet-101 backbone and 3× training schedule, we achieve 33.2% mask AP on COCO test-dev split (vs. 39.1% of the fully supervised counterpart). Our excellent experiment results on COCO and Pascal VOC indicate that our method dramatically narrows the performance gap between weakly and fully supervised instance segmentation. + +
+ +
+ +## Results and Models + +| Backbone | Style | MS train | Lr schd | bbox AP | mask AP | Config | Download | +| :------: | :-----: | :------: | :-----: | :-----: | :-----: | :----------------------------------------: | :----------------------: | +| R-50 | pytorch | Y | 1x | 39.4 | 30.8 | [config](./boxinst_r50_fpn_ms-90k_coco.py) | [model](<>) \| [log](<>) | + +## Citation + +```latex +@inproceedings{tian2020boxinst, + title = {{BoxInst}: High-Performance Instance Segmentation with Box Annotations}, + author = {Tian, Zhi and Shen, Chunhua and Wang, Xinlong and Chen, Hao}, + booktitle = {Proc. IEEE Conf. Computer Vision and Pattern Recognition (CVPR)}, + year = {2021} +} +``` diff --git a/configs/boxinst/boxinst_r50_fpn_ms-90k_coco.py b/configs/boxinst/boxinst_r50_fpn_ms-90k_coco.py new file mode 100644 index 00000000000..371f252a153 --- /dev/null +++ b/configs/boxinst/boxinst_r50_fpn_ms-90k_coco.py @@ -0,0 +1,93 @@ +_base_ = '../common/ms-90k_coco.py' + +# model settings +model = dict( + type='BoxInst', + data_preprocessor=dict( + type='BoxInstDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + mask_stride=4, + pairwise_size=3, + pairwise_dilation=2, + pairwise_color_thresh=0.3, + bottom_pixels_removed=10), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', # use P5 + num_outs=5, + relu_before_extra_convs=True), + bbox_head=dict( + type='BoxInstBboxHead', + num_params=593, + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + strides=[8, 16, 32, 64, 128], + norm_on_bbox=True, + centerness_on_reg=True, + dcn_on_last_conv=False, + center_sampling=True, + conv_bias=True, + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), + mask_head=dict( + type='BoxInstMaskHead', + num_layers=3, + feat_channels=16, + size_of_interest=8, + mask_out_stride=4, + topk_masks_per_img=64, + mask_feature_head=dict( + in_channels=256, + feat_channels=128, + start_level=0, + end_level=2, + out_channels=16, + mask_stride=8, + num_stacked_convs=4, + norm_cfg=dict(type='BN', requires_grad=True)), + loss_mask=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + eps=5e-6, + loss_weight=1.0)), + # model training and testing settings + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100, + mask_thr=0.5)) + +# optimizer +optim_wrapper = dict(optimizer=dict(lr=0.01)) + +# evaluator +val_evaluator = dict(metric=['bbox', 'segm']) +test_evaluator = val_evaluator diff --git a/mmdet/models/data_preprocessors/__init__.py b/mmdet/models/data_preprocessors/__init__.py index 58c28f25b36..a5077e03c96 100644 --- a/mmdet/models/data_preprocessors/__init__.py +++ b/mmdet/models/data_preprocessors/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .data_preprocessor import (BatchFixedSizePad, BatchResize, - BatchSyncRandomResize, DetDataPreprocessor, + BatchSyncRandomResize, BoxInstDataPreprocessor, + DetDataPreprocessor, MultiBranchDataPreprocessor) __all__ = [ 'DetDataPreprocessor', 'BatchSyncRandomResize', 'BatchFixedSizePad', - 'MultiBranchDataPreprocessor', 'BatchResize' + 'MultiBranchDataPreprocessor', 'BatchResize', 'BoxInstDataPreprocessor' ] diff --git a/mmdet/models/data_preprocessors/data_preprocessor.py b/mmdet/models/data_preprocessors/data_preprocessor.py index 848174170eb..10d97e408f7 100644 --- a/mmdet/models/data_preprocessors/data_preprocessor.py +++ b/mmdet/models/data_preprocessors/data_preprocessor.py @@ -14,11 +14,18 @@ from mmengine.utils import is_list_of from torch import Tensor +from mmdet.models.utils import unfold_wo_center from mmdet.models.utils.misc import samplelist_boxtype2tensor from mmdet.registry import MODELS from mmdet.structures import DetDataSample +from mmdet.structures.mask import BitmapMasks from mmdet.utils import ConfigType +try: + import skimage +except ImportError: + skimage = None + @MODELS.register_module() class DetDataPreprocessor(ImgDataPreprocessor): @@ -645,3 +652,138 @@ def get_padded_tensor(self, tensor: Tensor, pad_value: int) -> Tensor: padded_tensor = padded_tensor.type_as(tensor) padded_tensor[:, :, :target_height, :target_width] = tensor return padded_tensor + + +@MODELS.register_module() +class BoxInstDataPreprocessor(DetDataPreprocessor): + """Pseudo mask pre-processor for BoxInst. + + Comparing with the :class:`mmdet.DetDataPreprocessor`, + + 1. It generates masks using box annotations. + 2. It computes the images color similarity in LAB color space. + + Args: + mask_stride (int): The mask output stride in boxinst. Defaults to 4. + pairwise_size (int): The size of neighborhood for each pixel. + Defaults to 3. + pairwise_dilation (int): The dilation of neighborhood for each pixel. + Defaults to 2. + pairwise_color_thresh (float): The thresh of image color similarity. + Defaults to 0.3. + bottom_pixels_removed (int): The length of removed pixels in bottom. + It is caused by the annotation error in coco dataset. + Defaults to 10. + """ + + def __init__(self, + *arg, + mask_stride: int = 4, + pairwise_size: int = 3, + pairwise_dilation: int = 2, + pairwise_color_thresh: float = 0.3, + bottom_pixels_removed: int = 10, + **kwargs) -> None: + super().__init__(*arg, **kwargs) + self.mask_stride = mask_stride + self.pairwise_size = pairwise_size + self.pairwise_dilation = pairwise_dilation + self.pairwise_color_thresh = pairwise_color_thresh + self.bottom_pixels_removed = bottom_pixels_removed + + if skimage is None: + raise RuntimeError('skimage is not installed,\ + please install it by: pip install scikit-image') + + def get_images_color_similarity(self, inputs: Tensor, + image_masks: Tensor) -> Tensor: + """Compute the image color similarity in LAB color space.""" + assert inputs.dim() == 4 + assert inputs.size(0) == 1 + + unfolded_images = unfold_wo_center( + inputs, + kernel_size=self.pairwise_size, + dilation=self.pairwise_dilation) + diff = inputs[:, :, None] - unfolded_images + similarity = torch.exp(-torch.norm(diff, dim=1) * 0.5) + + unfolded_weights = unfold_wo_center( + image_masks[None, None], + kernel_size=self.pairwise_size, + dilation=self.pairwise_dilation) + unfolded_weights = torch.max(unfolded_weights, dim=1)[0] + + return similarity * unfolded_weights + + def forward(self, data: dict, training: bool = False) -> dict: + """Get pseudo mask labels using color similarity.""" + det_data = super().forward(data, training) + inputs, data_samples = det_data['inputs'], det_data['data_samples'] + + if training: + # get image masks and remove bottom pixels + b_img_h, b_img_w = data_samples[0].batch_input_shape + img_masks = [] + for i in range(inputs.shape[0]): + img_h, img_w = data_samples[i].img_shape + img_mask = inputs.new_ones((img_h, img_w)) + pixels_removed = int(self.bottom_pixels_removed * + float(img_h) / float(b_img_h)) + if pixels_removed > 0: + img_mask[-pixels_removed:, :] = 0 + pad_w = b_img_w - img_w + pad_h = b_img_h - img_h + img_mask = F.pad(img_mask, (0, pad_w, 0, pad_h), 'constant', + 0.) + img_masks.append(img_mask) + img_masks = torch.stack(img_masks, dim=0) + start = int(self.mask_stride // 2) + img_masks = img_masks[:, start::self.mask_stride, + start::self.mask_stride] + + # Get origin rgb image for color similarity + ori_imgs = inputs * self.std + self.mean + downsampled_imgs = F.avg_pool2d( + ori_imgs.float(), + kernel_size=self.mask_stride, + stride=self.mask_stride, + padding=0) + + # Compute color similarity for pseudo mask generation + for im_i, data_sample in enumerate(data_samples): + # TODO: Support rgb2lab in mmengine? + images_lab = skimage.color.rgb2lab( + downsampled_imgs[im_i].byte().permute(1, 2, + 0).cpu().numpy()) + images_lab = torch.as_tensor( + images_lab, device=ori_imgs.device, dtype=torch.float32) + images_lab = images_lab.permute(2, 0, 1)[None] + images_color_similarity = self.get_images_color_similarity( + images_lab, img_masks[im_i]) + pairwise_mask = (images_color_similarity >= + self.pairwise_color_thresh).float() + + per_im_bboxes = data_sample.gt_instances.bboxes + if per_im_bboxes.shape[0] > 0: + per_im_masks = [] + for per_box in per_im_bboxes: + mask_full = torch.zeros((b_img_h, b_img_w), + device=self.device).float() + mask_full[int(per_box[1]):int(per_box[3] + 1), + int(per_box[0]):int(per_box[2] + 1)] = 1.0 + per_im_masks.append(mask_full) + per_im_masks = torch.stack(per_im_masks, dim=0) + pairwise_masks = torch.cat( + [pairwise_mask for _ in range(per_im_bboxes.shape[0])], + dim=0) + else: + per_im_masks = torch.zeros((0, b_img_h, b_img_w)) + pairwise_masks = torch.zeros( + (0, self.pairwise_size**2 - 1, b_img_h, b_img_w)) + + # TODO: Support BitmapMasks with tensor? + data_sample.gt_instances.masks = BitmapMasks( + per_im_masks.cpu().numpy(), b_img_h, b_img_w) + data_sample.gt_instances.pairwise_masks = pairwise_masks + return {'inputs': inputs, 'data_samples': data_samples} diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index 469f5cc69d8..a4fcd4be62b 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -3,6 +3,7 @@ from .anchor_head import AnchorHead from .atss_head import ATSSHead from .autoassign_head import AutoAssignHead +from .boxinst_head import BoxInstBboxHead, BoxInstMaskHead from .cascade_rpn_head import CascadeRPNHead, StageCascadeRPNHead from .centernet_head import CenterNetHead from .centernet_update_head import CenterNetUpdateHead @@ -59,5 +60,6 @@ 'DecoupledSOLOHead', 'DecoupledSOLOLightHead', 'SOLOV2Head', 'LADHead', 'TOODHead', 'MaskFormerHead', 'Mask2FormerHead', 'DDODHead', 'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead', 'CondInstBboxHead', - 'CondInstMaskHead', 'RTMDetInsHead', 'RTMDetInsSepBNHead' + 'CondInstMaskHead', 'RTMDetInsHead', 'RTMDetInsSepBNHead', + 'BoxInstBboxHead', 'BoxInstMaskHead' ] diff --git a/mmdet/models/dense_heads/boxinst_head.py b/mmdet/models/dense_heads/boxinst_head.py new file mode 100644 index 00000000000..7d6e8f7777a --- /dev/null +++ b/mmdet/models/dense_heads/boxinst_head.py @@ -0,0 +1,252 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn.functional as F +from mmengine import MessageHub +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.utils import InstanceList +from ..utils.misc import unfold_wo_center +from .condinst_head import CondInstBboxHead, CondInstMaskHead + + +@MODELS.register_module() +class BoxInstBboxHead(CondInstBboxHead): + """BoxInst box head used in https://arxiv.org/abs/2012.02310.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + +@MODELS.register_module() +class BoxInstMaskHead(CondInstMaskHead): + """BoxInst mask head used in https://arxiv.org/abs/2012.02310. + + This head outputs the mask for BoxInst. + + Args: + pairwise_size (dict): The size of neighborhood for each pixel. + Defaults to 3. + pairwise_dilation (int): The dilation of neighborhood for each pixel. + Defaults to 2. + warmup_iters (int): Warmup iterations for pair-wise loss. + Defaults to 10000. + """ + + def __init__(self, + *arg, + pairwise_size: int = 3, + pairwise_dilation: int = 2, + warmup_iters: int = 10000, + **kwargs) -> None: + self.pairwise_size = pairwise_size + self.pairwise_dilation = pairwise_dilation + self.warmup_iters = warmup_iters + super().__init__(*arg, **kwargs) + + def get_pairwise_affinity(self, mask_logits: Tensor) -> Tensor: + """Compute the pairwise affinity for each pixel.""" + log_fg_prob = F.logsigmoid(mask_logits).unsqueeze(1) + log_bg_prob = F.logsigmoid(-mask_logits).unsqueeze(1) + + log_fg_prob_unfold = unfold_wo_center( + log_fg_prob, + kernel_size=self.pairwise_size, + dilation=self.pairwise_dilation) + log_bg_prob_unfold = unfold_wo_center( + log_bg_prob, + kernel_size=self.pairwise_size, + dilation=self.pairwise_dilation) + + # the probability of making the same prediction: + # p_i * p_j + (1 - p_i) * (1 - p_j) + # we compute the the probability in log space + # to avoid numerical instability + log_same_fg_prob = log_fg_prob[:, :, None] + log_fg_prob_unfold + log_same_bg_prob = log_bg_prob[:, :, None] + log_bg_prob_unfold + + # TODO: Figure out the difference between it and directly sum + max_ = torch.max(log_same_fg_prob, log_same_bg_prob) + log_same_prob = torch.log( + torch.exp(log_same_fg_prob - max_) + + torch.exp(log_same_bg_prob - max_)) + max_ + + return -log_same_prob[:, 0] + + def loss_by_feat(self, mask_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], positive_infos: InstanceList, + **kwargs) -> dict: + """Calculate the loss based on the features extracted by the mask head. + + Args: + mask_preds (list[Tensor]): List of predicted masks, each has + shape (num_classes, H, W). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``masks``, + and ``labels`` attributes. + batch_img_metas (list[dict]): Meta information of multiple images. + positive_infos (List[:obj:``InstanceData``]): Information of + positive samples of each image that are assigned in detection + head. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert positive_infos is not None, \ + 'positive_infos should not be None in `BoxInstMaskHead`' + losses = dict() + + loss_mask_project = 0. + loss_mask_pairwise = 0. + num_imgs = len(mask_preds) + total_pos = 0. + avg_fatcor = 0. + + for idx in range(num_imgs): + (mask_pred, pos_mask_targets, pos_pairwise_masks, num_pos) = \ + self._get_targets_single( + mask_preds[idx], batch_gt_instances[idx], + positive_infos[idx]) + # mask loss + total_pos += num_pos + if num_pos == 0 or pos_mask_targets is None: + loss_project = mask_pred.new_zeros(1).mean() + loss_pairwise = mask_pred.new_zeros(1).mean() + avg_fatcor += 0. + else: + # compute the project term + loss_project_x = self.loss_mask( + mask_pred.max(dim=1, keepdim=True)[0], + pos_mask_targets.max(dim=1, keepdim=True)[0], + reduction_override='none').sum() + loss_project_y = self.loss_mask( + mask_pred.max(dim=2, keepdim=True)[0], + pos_mask_targets.max(dim=2, keepdim=True)[0], + reduction_override='none').sum() + loss_project = loss_project_x + loss_project_y + # compute the pairwise term + pairwise_affinity = self.get_pairwise_affinity(mask_pred) + avg_fatcor += pos_pairwise_masks.sum().clamp(min=1.0) + loss_pairwise = (pairwise_affinity * pos_pairwise_masks).sum() + + loss_mask_project += loss_project + loss_mask_pairwise += loss_pairwise + + if total_pos == 0: + total_pos += 1 # avoid nan + if avg_fatcor == 0: + avg_fatcor += 1 # avoid nan + loss_mask_project = loss_mask_project / total_pos + loss_mask_pairwise = loss_mask_pairwise / avg_fatcor + message_hub = MessageHub.get_current_instance() + iter = message_hub.get_info('iter') + warmup_factor = min(iter / float(self.warmup_iters), 1.0) + loss_mask_pairwise *= warmup_factor + + losses.update( + loss_mask_project=loss_mask_project, + loss_mask_pairwise=loss_mask_pairwise) + return losses + + def _get_targets_single(self, mask_preds: Tensor, + gt_instances: InstanceData, + positive_info: InstanceData): + """Compute targets for predictions of single image. + + Args: + mask_preds (Tensor): Predicted prototypes with shape + (num_classes, H, W). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes``, ``labels``, + and ``masks`` attributes. + positive_info (:obj:`InstanceData`): Information of positive + samples that are assigned in detection head. It usually + contains following keys. + + - pos_assigned_gt_inds (Tensor): Assigner GT indexes of + positive proposals, has shape (num_pos, ) + - pos_inds (Tensor): Positive index of image, has + shape (num_pos, ). + - param_pred (Tensor): Positive param preditions + with shape (num_pos, num_params). + + Returns: + tuple: Usually returns a tuple containing learning targets. + + - mask_preds (Tensor): Positive predicted mask with shape + (num_pos, mask_h, mask_w). + - pos_mask_targets (Tensor): Positive mask targets with shape + (num_pos, mask_h, mask_w). + - pos_pairwise_masks (Tensor): Positive pairwise masks with + shape: (num_pos, num_neighborhood, mask_h, mask_w). + - num_pos (int): Positive numbers. + """ + gt_bboxes = gt_instances.bboxes + device = gt_bboxes.device + # Note that gt_masks are generated by full box + # from BoxInstDataPreprocessor + gt_masks = gt_instances.masks.to_tensor( + dtype=torch.bool, device=device).float() + # Note that pairwise_masks are generated by image color similarity + # from BoxInstDataPreprocessor + pairwise_masks = gt_instances.pairwise_masks + pairwise_masks = pairwise_masks.to(device=device) + + # process with mask targets + pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds') + scores = positive_info.get('scores') + centernesses = positive_info.get('centernesses') + num_pos = pos_assigned_gt_inds.size(0) + + if gt_masks.size(0) == 0 or num_pos == 0: + return mask_preds, None, None, 0 + # Since we're producing (near) full image masks, + # it'd take too much vram to backprop on every single mask. + # Thus we select only a subset. + if (self.max_masks_to_train != -1) and \ + (num_pos > self.max_masks_to_train): + perm = torch.randperm(num_pos) + select = perm[:self.max_masks_to_train] + mask_preds = mask_preds[select] + pos_assigned_gt_inds = pos_assigned_gt_inds[select] + num_pos = self.max_masks_to_train + elif self.topk_masks_per_img != -1: + unique_gt_inds = pos_assigned_gt_inds.unique() + num_inst_per_gt = max( + int(self.topk_masks_per_img / len(unique_gt_inds)), 1) + + keep_mask_preds = [] + keep_pos_assigned_gt_inds = [] + for gt_ind in unique_gt_inds: + per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind) + mask_preds_per_inst = mask_preds[per_inst_pos_inds] + gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds] + if sum(per_inst_pos_inds) > num_inst_per_gt: + per_inst_scores = scores[per_inst_pos_inds].sigmoid().max( + dim=1)[0] + per_inst_centerness = centernesses[ + per_inst_pos_inds].sigmoid().reshape(-1, ) + select = (per_inst_scores * per_inst_centerness).topk( + k=num_inst_per_gt, dim=0)[1] + mask_preds_per_inst = mask_preds_per_inst[select] + gt_inds_per_inst = gt_inds_per_inst[select] + keep_mask_preds.append(mask_preds_per_inst) + keep_pos_assigned_gt_inds.append(gt_inds_per_inst) + mask_preds = torch.cat(keep_mask_preds) + pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds) + num_pos = pos_assigned_gt_inds.size(0) + + # Follow the origin implement + start = int(self.mask_out_stride // 2) + gt_masks = gt_masks[:, start::self.mask_out_stride, + start::self.mask_out_stride] + gt_masks = gt_masks.gt(0.5).float() + pos_mask_targets = gt_masks[pos_assigned_gt_inds] + pos_pairwise_masks = pairwise_masks[pos_assigned_gt_inds] + pos_pairwise_masks = pos_pairwise_masks * pos_mask_targets.unsqueeze(1) + + return (mask_preds, pos_mask_targets, pos_pairwise_masks, num_pos) diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index f6f7e7f78d0..6df092a025a 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -2,6 +2,7 @@ from .atss import ATSS from .autoassign import AutoAssign from .base import BaseDetector +from .boxinst import BoxInst from .cascade_rcnn import CascadeRCNN from .centernet import CenterNet from .condinst import CondInst @@ -61,5 +62,5 @@ 'SOLOv2', 'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX', 'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD', 'MaskFormer', 'DDOD', 'Mask2Former', 'SemiBaseDetector', 'SoftTeacher', - 'RTMDet', 'Detectron2Wrapper', 'RTMDet', 'CrowdDet', 'CondInst' + 'RTMDet', 'Detectron2Wrapper', 'RTMDet', 'CrowdDet', 'CondInst', 'BoxInst' ] diff --git a/mmdet/models/detectors/boxinst.py b/mmdet/models/detectors/boxinst.py new file mode 100644 index 00000000000..ca6b0bdd90a --- /dev/null +++ b/mmdet/models/detectors/boxinst.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from .single_stage_instance_seg import SingleStageInstanceSegmentor + + +@MODELS.register_module() +class BoxInst(SingleStageInstanceSegmentor): + """Implementation of `BoxInst `_""" + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + mask_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + mask_head=mask_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/utils/__init__.py b/mmdet/models/utils/__init__.py index b84889a494d..af3b2448dbe 100644 --- a/mmdet/models/utils/__init__.py +++ b/mmdet/models/utils/__init__.py @@ -9,8 +9,8 @@ levels_to_images, mask2ndarray, multi_apply, relative_coordinate_maps, rename_loss_dict, reweight_loss_dict, samplelist_boxtype2tensor, - select_single_mlvl, sigmoid_geometric_mean, unmap, - unpack_gt_instances) + select_single_mlvl, sigmoid_geometric_mean, + unfold_wo_center, unmap, unpack_gt_instances) from .panoptic_gt_processing import preprocess_panoptic_gt from .point_sample import (get_uncertain_point_coords_with_randomness, get_uncertainty) @@ -25,5 +25,6 @@ 'generate_coordinate', 'levels_to_images', 'mask2ndarray', 'multi_apply', 'select_single_mlvl', 'unmap', 'images_to_levels', 'samplelist_boxtype2tensor', 'filter_gt_instances', 'rename_loss_dict', - 'reweight_loss_dict', 'relative_coordinate_maps', 'aligned_bilinear' + 'reweight_loss_dict', 'relative_coordinate_maps', 'aligned_bilinear', + 'unfold_wo_center' ] diff --git a/mmdet/models/utils/misc.py b/mmdet/models/utils/misc.py index 93a885d3d01..823d73c0ac3 100644 --- a/mmdet/models/utils/misc.py +++ b/mmdet/models/utils/misc.py @@ -625,3 +625,28 @@ def aligned_bilinear(tensor: Tensor, factor: int) -> Tensor: tensor, pad=(factor // 2, 0, factor // 2, 0), mode='replicate') return tensor[:, :, :oh - 1, :ow - 1] + + +def unfold_wo_center(x, kernel_size: int, dilation: int) -> Tensor: + """unfold_wo_center, used in original implement in BoxInst: + + https://github.com/aim-uofa/AdelaiDet/blob/\ + 4a3a1f7372c35b48ebf5f6adc59f135a0fa28d60/\ + adet/modeling/condinst/condinst.py#L53 + """ + assert x.dim() == 4 + assert kernel_size % 2 == 1 + + # using SAME padding + padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2 + unfolded_x = F.unfold( + x, kernel_size=kernel_size, padding=padding, dilation=dilation) + unfolded_x = unfolded_x.reshape( + x.size(0), x.size(1), -1, x.size(2), x.size(3)) + # remove the center pixels + size = kernel_size**2 + unfolded_x = torch.cat( + (unfolded_x[:, :, :size // 2], unfolded_x[:, :, size // 2 + 1:]), + dim=2) + + return unfolded_x diff --git a/tests/test_models/test_data_preprocessors/test_boxinst_preprocessor.py b/tests/test_models/test_data_preprocessors/test_boxinst_preprocessor.py new file mode 100644 index 00000000000..57038fe80dd --- /dev/null +++ b/tests/test_models/test_data_preprocessors/test_boxinst_preprocessor.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmdet.models.data_preprocessors import BoxInstDataPreprocessor +from mmdet.structures import DetDataSample +from mmdet.testing import demo_mm_inputs + + +class TestBoxInstDataPreprocessor(TestCase): + + def test_forward(self): + processor = BoxInstDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1]) + + data = { + 'inputs': [torch.randint(0, 256, (3, 256, 256))], + 'data_samples': [DetDataSample()] + } + + # Test evaluation mode + out_data = processor(data) + batch_inputs, batch_data_samples = out_data['inputs'], out_data[ + 'data_samples'] + + self.assertEqual(batch_inputs.shape, (1, 3, 256, 256)) + self.assertEqual(len(batch_data_samples), 1) + + # Test traning mode without gt bboxes + packed_inputs = demo_mm_inputs( + 2, [[3, 256, 256], [3, 128, 128]], num_items=[0, 0]) + out_data = processor(packed_inputs, training=True) + batch_inputs, batch_data_samples = out_data['inputs'], out_data[ + 'data_samples'] + + self.assertEqual(batch_inputs.shape, (2, 3, 256, 256)) + self.assertEqual(len(batch_data_samples), 2) + self.assertEqual(len(batch_data_samples[0].gt_instances.masks), 0) + self.assertEqual( + len(batch_data_samples[0].gt_instances.pairwise_masks), 0) + self.assertEqual(len(batch_data_samples[1].gt_instances.masks), 0) + self.assertEqual( + len(batch_data_samples[1].gt_instances.pairwise_masks), 0) + + # Test traning mode with gt bboxes + packed_inputs = demo_mm_inputs( + 2, [[3, 256, 256], [3, 128, 128]], num_items=[2, 1]) + out_data = processor(packed_inputs, training=True) + batch_inputs, batch_data_samples = out_data['inputs'], out_data[ + 'data_samples'] + + self.assertEqual(batch_inputs.shape, (2, 3, 256, 256)) + self.assertEqual(len(batch_data_samples), 2) + self.assertEqual(len(batch_data_samples[0].gt_instances.masks), 2) + self.assertEqual( + len(batch_data_samples[0].gt_instances.pairwise_masks), 2) + self.assertEqual(len(batch_data_samples[1].gt_instances.masks), 1) + self.assertEqual( + len(batch_data_samples[1].gt_instances.pairwise_masks), 1) diff --git a/tests/test_models/test_dense_heads/test_boxinst_head.py b/tests/test_models/test_dense_heads/test_boxinst_head.py new file mode 100644 index 00000000000..b5fe30695ac --- /dev/null +++ b/tests/test_models/test_dense_heads/test_boxinst_head.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np +import torch +from mmengine import MessageHub +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData + +from mmdet.models.dense_heads import BoxInstBboxHead, BoxInstMaskHead +from mmdet.structures.mask import BitmapMasks + + +def _rand_masks(num_items, bboxes, img_w, img_h): + rng = np.random.RandomState(0) + masks = np.zeros((num_items, img_h, img_w), dtype=np.float32) + for i, bbox in enumerate(bboxes): + bbox = bbox.astype(np.int32) + mask = (rng.rand(1, bbox[3] - bbox[1], bbox[2] - bbox[0]) > + 0.3).astype(np.int) + masks[i:i + 1, bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask + return BitmapMasks(masks, height=img_h, width=img_w) + + +def _fake_mask_feature_head(): + mask_feature_head = ConfigDict( + in_channels=1, + feat_channels=1, + start_level=0, + end_level=2, + out_channels=8, + mask_stride=8, + num_stacked_convs=4, + norm_cfg=dict(type='BN', requires_grad=True)) + return mask_feature_head + + +class TestBoxInstHead(TestCase): + + def test_boxinst_maskhead_loss(self): + """Tests boxinst maskhead loss when truth is empty and non-empty.""" + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'pad_shape': (s, s, 3), + 'scale_factor': 1, + }] + boxinst_bboxhead = BoxInstBboxHead( + num_classes=4, + in_channels=1, + feat_channels=1, + stacked_convs=1, + norm_cfg=None) + + mask_feature_head = _fake_mask_feature_head() + boxinst_maskhead = BoxInstMaskHead( + mask_feature_head=mask_feature_head, + loss_mask=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + eps=5e-6, + loss_weight=1.0)) + + # Fcos head expects a multiple levels of features per image + feats = [] + for i in range(len(boxinst_bboxhead.strides)): + feats.append( + torch.rand(1, 1, s // (2**(i + 3)), s // (2**(i + 3)))) + feats = tuple(feats) + cls_scores, bbox_preds, centernesses, param_preds =\ + boxinst_bboxhead.forward(feats) + + # Test that empty ground truth encourages the network to + # predict background + gt_instances = InstanceData() + gt_instances.bboxes = torch.empty((0, 4)) + gt_instances.labels = torch.LongTensor([]) + gt_instances.masks = _rand_masks(0, gt_instances.bboxes.numpy(), s, s) + gt_instances.pairwise_masks = _rand_masks( + 0, gt_instances.bboxes.numpy(), s // 4, s // 4).to_tensor( + dtype=torch.float32, + device='cpu').unsqueeze(1).repeat(1, 8, 1, 1) + message_hub = MessageHub.get_instance('runtime_info') + message_hub.update_info('iter', 1) + _ = boxinst_bboxhead.loss_by_feat(cls_scores, bbox_preds, centernesses, + param_preds, [gt_instances], + img_metas) + # When truth is empty then all mask loss + # should be zero for random inputs + positive_infos = boxinst_bboxhead.get_positive_infos() + mask_outs = boxinst_maskhead.forward(feats, positive_infos) + empty_gt_mask_losses = boxinst_maskhead.loss_by_feat( + *mask_outs, [gt_instances], img_metas, positive_infos) + loss_mask_project = empty_gt_mask_losses['loss_mask_project'] + loss_mask_pairwise = empty_gt_mask_losses['loss_mask_pairwise'] + self.assertEqual(loss_mask_project, 0, + 'mask project loss should be zero') + self.assertEqual(loss_mask_pairwise, 0, + 'mask pairwise loss should be zero') + + # When truth is non-empty then all cls, box loss and centerness loss + # should be nonzero for random inputs + gt_instances = InstanceData() + gt_instances.bboxes = torch.Tensor([[0.111, 0.222, 25.6667, 29.8757]]) + gt_instances.labels = torch.LongTensor([2]) + gt_instances.masks = _rand_masks(1, gt_instances.bboxes.numpy(), s, s) + gt_instances.pairwise_masks = _rand_masks( + 1, gt_instances.bboxes.numpy(), s // 4, s // 4).to_tensor( + dtype=torch.float32, + device='cpu').unsqueeze(1).repeat(1, 8, 1, 1) + + _ = boxinst_bboxhead.loss_by_feat(cls_scores, bbox_preds, centernesses, + param_preds, [gt_instances], + img_metas) + positive_infos = boxinst_bboxhead.get_positive_infos() + mask_outs = boxinst_maskhead.forward(feats, positive_infos) + one_gt_mask_losses = boxinst_maskhead.loss_by_feat( + *mask_outs, [gt_instances], img_metas, positive_infos) + loss_mask_project = one_gt_mask_losses['loss_mask_project'] + loss_mask_pairwise = one_gt_mask_losses['loss_mask_pairwise'] + self.assertGreater(loss_mask_project, 0, + 'mask project loss should be nonzero') + self.assertGreater(loss_mask_pairwise, 0, + 'mask pairwise loss should be nonzero')