diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5de608fa2..f9d67c92c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: hooks: - id: black - - repo: https://gitlab.com/pycqa/flake8 + - repo: https://github.com/pycqa/flake8 rev: 5.0.4 hooks: - id: flake8 diff --git a/docs/config_file_detail.md b/docs/config_file_detail.md index 65895dadd..79e7a6718 100644 --- a/docs/config_file_detail.md +++ b/docs/config_file_detail.md @@ -67,6 +67,7 @@ head: scales_per_octave: 1 strides: [8, 16, 32] reg_max: 7 + ignore_iof_thr: -1 norm_cfg: type: BN loss: @@ -92,6 +93,8 @@ head: `reg_max`: max value of per-level l-r-t-b distance +`ignore_iof_thr`: thresh of iof for ignore box, default value -1 + `norm_cfg`: normalization layer setting `loss`: adjust loss functions and weights diff --git a/nanodet/data/dataset/coco.py b/nanodet/data/dataset/coco.py index 3c46b14df..63147d9ed 100644 --- a/nanodet/data/dataset/coco.py +++ b/nanodet/data/dataset/coco.py @@ -77,15 +77,13 @@ def get_img_annotation(self, idx): if self.use_keypoint: gt_keypoints = [] for ann in anns: - if ann.get("ignore", False): - continue x1, y1, w, h = ann["bbox"] if ann["area"] <= 0 or w < 1 or h < 1: continue if ann["category_id"] not in self.cat_ids: continue bbox = [x1, y1, x1 + w, y1 + h] - if ann.get("iscrowd", False): + if ann.get("iscrowd", False) or ann.get("ignore", False): gt_bboxes_ignore.append(bbox) else: gt_bboxes.append(bbox) @@ -131,7 +129,11 @@ def get_train_data(self, idx): raise FileNotFoundError("Cant load image! Please check image path!") ann = self.get_img_annotation(idx) meta = dict( - img=img, img_info=img_info, gt_bboxes=ann["bboxes"], gt_labels=ann["labels"] + img=img, + img_info=img_info, + gt_bboxes=ann["bboxes"], + gt_labels=ann["labels"], + gt_bboxes_ignore=ann["bboxes_ignore"], ) if self.use_instance_mask: meta["gt_masks"] = ann["masks"] diff --git a/nanodet/data/transform/warp.py b/nanodet/data/transform/warp.py index a102348f8..815c0c28d 100644 --- a/nanodet/data/transform/warp.py +++ b/nanodet/data/transform/warp.py @@ -185,6 +185,11 @@ def warp_and_resize( if "gt_bboxes" in meta: boxes = meta["gt_bboxes"] meta["gt_bboxes"] = warp_boxes(boxes, M, dst_shape[0], dst_shape[1]) + if "gt_bboxes_ignore" in meta: + bboxes_ignore = meta["gt_bboxes_ignore"] + meta["gt_bboxes_ignore"] = warp_boxes( + bboxes_ignore, M, dst_shape[0], dst_shape[1] + ) if "gt_masks" in meta: for i, mask in enumerate(meta["gt_masks"]): meta["gt_masks"][i] = cv2.warpPerspective(mask, M, dsize=tuple(dst_shape)) @@ -343,6 +348,11 @@ def __call__(self, meta_data, dst_shape): if "gt_bboxes" in meta_data: boxes = meta_data["gt_bboxes"] meta_data["gt_bboxes"] = warp_boxes(boxes, M, dst_shape[0], dst_shape[1]) + if "gt_bboxes_ignore" in meta_data: + bboxes_ignore = meta_data["gt_bboxes_ignore"] + meta_data["gt_bboxes_ignore"] = warp_boxes( + bboxes_ignore, M, dst_shape[0], dst_shape[1] + ) if "gt_masks" in meta_data: for i, mask in enumerate(meta_data["gt_masks"]): meta_data["gt_masks"][i] = cv2.warpPerspective( diff --git a/nanodet/model/head/assigner/atss_assigner.py b/nanodet/model/head/assigner/atss_assigner.py index c182bff44..2c3cf0944 100644 --- a/nanodet/model/head/assigner/atss_assigner.py +++ b/nanodet/model/head/assigner/atss_assigner.py @@ -23,18 +23,21 @@ class ATSSAssigner(BaseAssigner): """Assign a corresponding gt bbox or background to each bbox. - Each proposals will be assigned with `0` or a positive integer + Each proposals will be assigned with `-1`, `0` or a positive integer indicating the ground truth index. - + - -1: ignore sample, will be masked in loss calculation - 0: negative sample, no assigned gt - positive integer: positive sample, index (1-based) of assigned gt Args: topk (float): number of bbox selected in each level + ignore_iof_thr (float): whether ignore max overlaps or not. + Default -1 ([0,1] or -1). """ - def __init__(self, topk): + def __init__(self, topk, ignore_iof_thr=-1): self.topk = topk + self.ignore_iof_thr = ignore_iof_thr # https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py @@ -105,6 +108,18 @@ def assign( (bboxes_points[:, None, :] - gt_points[None, :, :]).pow(2).sum(-1).sqrt() ) + if ( + self.ignore_iof_thr > 0 + and gt_bboxes_ignore is not None + and gt_bboxes_ignore.numel() > 0 + and bboxes.numel() > 0 + ): + ignore_overlaps = bbox_overlaps(bboxes, gt_bboxes_ignore, mode="iof") + ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) + ignore_idxs = ignore_max_overlaps > self.ignore_iof_thr + distances[ignore_idxs, :] = INF + assigned_gt_inds[ignore_idxs] = -1 + # Selecting candidates based on the center distance candidate_idxs = [] start_idx = 0 diff --git a/nanodet/model/head/assigner/dsl_assigner.py b/nanodet/model/head/assigner/dsl_assigner.py index e74dc0854..17df54d26 100644 --- a/nanodet/model/head/assigner/dsl_assigner.py +++ b/nanodet/model/head/assigner/dsl_assigner.py @@ -14,11 +14,14 @@ class DynamicSoftLabelAssigner(BaseAssigner): topk (int): Select top-k predictions to calculate dynamic k best matchs for each gt. Default 13. iou_factor (float): The scale factor of iou cost. Default 3.0. + ignore_iof_thr (int): whether ignore max overlaps or not. + Default -1 (1 or -1). """ - def __init__(self, topk=13, iou_factor=3.0): + def __init__(self, topk=13, iou_factor=3.0, ignore_iof_thr=-1): self.topk = topk self.iou_factor = iou_factor + self.ignore_iof_thr = ignore_iof_thr def assign( self, @@ -27,6 +30,7 @@ def assign( decoded_bboxes, gt_bboxes, gt_labels, + gt_bboxes_ignore=None, ): """Assign gt to priors with dynamic soft label assignment. Args: @@ -38,6 +42,8 @@ def assign( [num_priors, 4] in [tl_x, tl_y, br_x, br_y] format. gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format. + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`, e.g., crowd boxes in COCO. gt_labels (Tensor): Ground truth labels of one image, a Tensor with shape [num_gts]. @@ -113,6 +119,20 @@ def assign( (num_bboxes,), -INF, dtype=torch.float32 ) max_overlaps[valid_mask] = matched_pred_ious + + if ( + self.ignore_iof_thr > 0 + and gt_bboxes_ignore is not None + and gt_bboxes_ignore.numel() > 0 + and num_bboxes > 0 + ): + ignore_overlaps = bbox_overlaps( + valid_decoded_bbox, gt_bboxes_ignore, mode="iof" + ) + ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) + ignore_idxs = ignore_max_overlaps > self.ignore_iof_thr + assigned_gt_inds[ignore_idxs] = -1 + return AssignResult( num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels ) diff --git a/nanodet/model/head/gfl_head.py b/nanodet/model/head/gfl_head.py index ee5409c7e..65c457fde 100644 --- a/nanodet/model/head/gfl_head.py +++ b/nanodet/model/head/gfl_head.py @@ -105,6 +105,7 @@ def __init__( conv_cfg=None, norm_cfg=dict(type="GN", num_groups=32, requires_grad=True), reg_max=16, + ignore_iof_thr=-1, **kwargs ): super(GFLHead, self).__init__() @@ -120,12 +121,13 @@ def __init__( self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.use_sigmoid = self.loss_cfg.loss_qfl.use_sigmoid + self.ignore_iof_thr = ignore_iof_thr if self.use_sigmoid: self.cls_out_channels = num_classes else: self.cls_out_channels = num_classes + 1 - self.assigner = ATSSAssigner(topk=9) + self.assigner = ATSSAssigner(topk=9, ignore_iof_thr=ignore_iof_thr) self.distribution_project = Integral(self.reg_max) self.loss_qfl = QualityFocalLoss( @@ -209,9 +211,9 @@ def loss(self, preds, gt_meta): ) device = cls_scores.device gt_bboxes = gt_meta["gt_bboxes"] + gt_bboxes_ignore = gt_meta["gt_bboxes_ignore"] gt_labels = gt_meta["gt_labels"] input_height, input_width = gt_meta["img"].shape[2:] - gt_bboxes_ignore = None featmap_sizes = [ (math.ceil(input_height / stride), math.ceil(input_width) / stride) @@ -465,6 +467,9 @@ def target_assign_single_img( gt_bboxes = torch.from_numpy(gt_bboxes).to(device) gt_labels = torch.from_numpy(gt_labels).to(device) + if gt_bboxes_ignore is not None: + gt_bboxes_ignore = torch.from_numpy(gt_bboxes_ignore).to(device) + assign_result = self.assigner.assign( grid_cells, num_level_cells, gt_bboxes, gt_bboxes_ignore, gt_labels ) diff --git a/nanodet/model/head/nanodet_plus_head.py b/nanodet/model/head/nanodet_plus_head.py index 0a6769199..e7e71b0f1 100644 --- a/nanodet/model/head/nanodet_plus_head.py +++ b/nanodet/model/head/nanodet_plus_head.py @@ -158,10 +158,15 @@ def loss(self, preds, gt_meta, aux_preds=None): loss (Tensor): Loss tensor. loss_states (dict): State dict of each loss. """ - gt_bboxes = gt_meta["gt_bboxes"] - gt_labels = gt_meta["gt_labels"] device = preds.device batch_size = preds.shape[0] + gt_bboxes = gt_meta["gt_bboxes"] + gt_labels = gt_meta["gt_labels"] + + gt_bboxes_ignore = gt_meta["gt_bboxes_ignore"] + if gt_bboxes_ignore is None: + gt_bboxes_ignore = [None for _ in range(batch_size)] + input_height, input_width = gt_meta["img"].shape[2:] featmap_sizes = [ (math.ceil(input_height / stride), math.ceil(input_width) / stride) @@ -202,6 +207,7 @@ def loss(self, preds, gt_meta, aux_preds=None): aux_decoded_bboxes.detach(), gt_bboxes, gt_labels, + gt_bboxes_ignore, ) else: # use self prediction to assign @@ -212,6 +218,7 @@ def loss(self, preds, gt_meta, aux_preds=None): decoded_bboxes.detach(), gt_bboxes, gt_labels, + gt_bboxes_ignore, ) loss, loss_states = self._get_loss_from_assign( @@ -229,19 +236,30 @@ def loss(self, preds, gt_meta, aux_preds=None): def _get_loss_from_assign(self, cls_preds, reg_preds, decoded_bboxes, assign): device = cls_preds.device - labels, label_scores, bbox_targets, dist_targets, num_pos = assign + ( + labels, + label_scores, + label_weights, + bbox_targets, + dist_targets, + num_pos, + ) = assign num_total_samples = max( reduce_mean(torch.tensor(sum(num_pos)).to(device)).item(), 1.0 ) labels = torch.cat(labels, dim=0) label_scores = torch.cat(label_scores, dim=0) + label_weights = torch.cat(label_weights, dim=0) bbox_targets = torch.cat(bbox_targets, dim=0) cls_preds = cls_preds.reshape(-1, self.num_classes) reg_preds = reg_preds.reshape(-1, 4 * (self.reg_max + 1)) decoded_bboxes = decoded_bboxes.reshape(-1, 4) loss_qfl = self.loss_qfl( - cls_preds, (labels, label_scores), avg_factor=num_total_samples + cls_preds, + (labels, label_scores), + weight=label_weights, + avg_factor=num_total_samples, ) pos_inds = torch.nonzero( @@ -276,7 +294,13 @@ def _get_loss_from_assign(self, cls_preds, reg_preds, decoded_bboxes, assign): @torch.no_grad() def target_assign_single_img( - self, cls_preds, center_priors, decoded_bboxes, gt_bboxes, gt_labels + self, + cls_preds, + center_priors, + decoded_bboxes, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None, ): """Compute classification, regression, and objectness targets for priors in a single image. @@ -292,31 +316,40 @@ def target_assign_single_img( with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format. gt_labels (Tensor): Ground truth labels of one image, a Tensor with shape [num_gts]. + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`, e.g., crowd boxes in COCO. """ - num_priors = center_priors.size(0) device = center_priors.device gt_bboxes = torch.from_numpy(gt_bboxes).to(device) gt_labels = torch.from_numpy(gt_labels).to(device) - num_gts = gt_labels.size(0) gt_bboxes = gt_bboxes.to(decoded_bboxes.dtype) + if gt_bboxes_ignore is not None: + gt_bboxes_ignore = torch.from_numpy(gt_bboxes_ignore).to(device) + gt_bboxes_ignore = gt_bboxes_ignore.to(decoded_bboxes.dtype) + + assign_result = self.assigner.assign( + cls_preds.sigmoid(), + center_priors, + decoded_bboxes, + gt_bboxes, + gt_labels, + gt_bboxes_ignore, + ) + pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.sample( + assign_result, gt_bboxes + ) + + num_priors = center_priors.size(0) bbox_targets = torch.zeros_like(center_priors) dist_targets = torch.zeros_like(center_priors) labels = center_priors.new_full( (num_priors,), self.num_classes, dtype=torch.long ) + label_weights = center_priors.new_zeros(num_priors, dtype=torch.float) label_scores = center_priors.new_zeros(labels.shape, dtype=torch.float) - # No target - if num_gts == 0: - return labels, label_scores, bbox_targets, dist_targets, 0 - assign_result = self.assigner.assign( - cls_preds.sigmoid(), center_priors, decoded_bboxes, gt_bboxes, gt_labels - ) - pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.sample( - assign_result, gt_bboxes - ) num_pos_per_img = pos_inds.size(0) pos_ious = assign_result.max_overlaps[pos_inds] @@ -329,9 +362,13 @@ def target_assign_single_img( dist_targets = dist_targets.clamp(min=0, max=self.reg_max - 0.1) labels[pos_inds] = gt_labels[pos_assigned_gt_inds] label_scores[pos_inds] = pos_ious + label_weights[pos_inds] = 1.0 + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 return ( labels, label_scores, + label_weights, bbox_targets, dist_targets, num_pos_per_img, diff --git a/tests/test_models/test_head/test_gfl_head.py b/tests/test_models/test_head/test_gfl_head.py index aa1988627..eb55ddf2a 100644 --- a/tests/test_models/test_head/test_gfl_head.py +++ b/tests/test_models/test_head/test_gfl_head.py @@ -31,6 +31,7 @@ def test_gfl_head_loss(): meta = dict( img=torch.rand((2, 3, 64, 64)), gt_bboxes=[np.random.random((0, 4))], + gt_bboxes_ignore=[np.random.random((0, 4))], gt_labels=[np.array([])], ) loss, empty_gt_losses = head.loss(preds, meta) @@ -52,9 +53,15 @@ def test_gfl_head_loss(): gt_bboxes = [ np.array([[23.6667, 23.8757, 238.6326, 151.8874]], dtype=np.float32), ] + gt_bboxes_ignore = [ + np.array([[29.6667, 29.8757, 244.6326, 160.8874]], dtype=np.float32), + ] gt_labels = [np.array([2])] meta = dict( - img=torch.rand((2, 3, 64, 64)), gt_bboxes=gt_bboxes, gt_labels=gt_labels + img=torch.rand((2, 3, 64, 64)), + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + gt_bboxes_ignore=gt_bboxes_ignore, ) loss, one_gt_losses = head.loss(preds, meta) onegt_qfl_loss = one_gt_losses["loss_qfl"] diff --git a/tests/test_models/test_head/test_nanodet_plus_head.py b/tests/test_models/test_head/test_nanodet_plus_head.py index eb529f575..84a9578b5 100644 --- a/tests/test_models/test_head/test_nanodet_plus_head.py +++ b/tests/test_models/test_head/test_nanodet_plus_head.py @@ -61,6 +61,7 @@ def test_nanodet_plus_head_loss(): meta = dict( img=torch.rand((1, 3, 320, 320)), gt_bboxes=[np.random.random((0, 4))], + gt_bboxes_ignore=[np.random.random((0, 4))], gt_labels=[np.array([])], ) loss, empty_gt_losses = head.loss(preds, meta) @@ -82,9 +83,15 @@ def test_nanodet_plus_head_loss(): gt_bboxes = [ np.array([[23.6667, 23.8757, 238.6326, 151.8874]], dtype=np.float32), ] + gt_bboxes_ignore = [ + np.array([[29.6667, 29.8757, 244.6326, 160.8874]], dtype=np.float32), + ] gt_labels = [np.array([2])] meta = dict( - img=torch.rand((1, 3, 320, 320)), gt_bboxes=gt_bboxes, gt_labels=gt_labels + img=torch.rand((1, 3, 320, 320)), + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + gt_bboxes_ignore=gt_bboxes_ignore, ) loss, one_gt_losses = head.loss(preds, meta) onegt_qfl_loss = one_gt_losses["loss_qfl"] @@ -98,9 +105,15 @@ def test_nanodet_plus_head_loss(): gt_bboxes = [ np.array([[23.6667, 23.8757, 238.6326, 151.8874]], dtype=np.float32), ] + gt_bboxes_ignore = [ + np.array([[29.6667, 29.8757, 244.6326, 160.8874]], dtype=np.float32), + ] gt_labels = [np.array([2])] meta = dict( - img=torch.rand((1, 3, 320, 320)), gt_bboxes=gt_bboxes, gt_labels=gt_labels + img=torch.rand((1, 3, 320, 320)), + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + gt_bboxes_ignore=gt_bboxes_ignore, ) loss, one_gt_losses = head.loss(preds, meta, aux_preds=preds) onegt_qfl_loss = one_gt_losses["loss_qfl"] diff --git a/tests/test_trainer/test_lightning_task.py b/tests/test_trainer/test_lightning_task.py index f21908e0f..58618d2b4 100644 --- a/tests/test_trainer/test_lightning_task.py +++ b/tests/test_trainer/test_lightning_task.py @@ -42,6 +42,12 @@ def test(self): [[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], dtype=np.float32 ), ], + "gt_bboxes_ignore": [ + np.array([[3.0, 4.0, 5.0, 6.0]], dtype=np.float32), + np.array( + [[7.0, 8.0, 9.0, 10.0], [7.0, 8.0, 9.0, 10.0]], dtype=np.float32 + ), + ], "gt_labels": [np.array([1]), np.array([1, 2])], "warp_matrix": [np.eye(3), np.eye(3)], }